diff --git a/tests/cli/utils/model_test.py b/tests/cli/utils/model_test.py index 1cbc4eab..301ecca4 100644 --- a/tests/cli/utils/model_test.py +++ b/tests/cli/utils/model_test.py @@ -62,22 +62,42 @@ testcase_name="gemma3-270m", model_name="gemma3-270m", ), + dict( + testcase_name="gemma3-270m-it", + model_name="gemma3-270m-it", + ), dict( testcase_name="gemma3-1b", model_name="gemma3-1b", ), + dict( + testcase_name="gemma3-1b-it", + model_name="gemma3-1b-it", + ), dict( testcase_name="gemma3-4b", model_name="gemma3-4b", ), + dict( + testcase_name="gemma3-4b-it", + model_name="gemma3-4b-it", + ), dict( testcase_name="gemma3-12b", model_name="gemma3-12b", ), + dict( + testcase_name="gemma3-12b-it", + model_name="gemma3-12b-it", + ), dict( testcase_name="gemma3-27b", model_name="gemma3-27b", ), + dict( + testcase_name="gemma3-27b-it", + model_name="gemma3-27b-it", + ), dict( testcase_name="llama3-70b", model_name="llama3-70b", diff --git a/tunix/models/gemma3/model.py b/tunix/models/gemma3/model.py index 84662507..c1c5d628 100644 --- a/tunix/models/gemma3/model.py +++ b/tunix/models/gemma3/model.py @@ -131,6 +131,10 @@ def gemma3_270m( shd_config=sharding_config, ) + @classmethod + def gemma3_270m_it(cls, **kwargs): + return cls.gemma3_270m(**kwargs) + @classmethod def gemma3_1b( cls, @@ -150,6 +154,10 @@ def gemma3_1b( shd_config=sharding_config, ) + @classmethod + def gemma3_1b_it(cls, **kwargs): + return cls.gemma3_1b(**kwargs) + @classmethod def gemma3_4b( cls, @@ -171,6 +179,10 @@ def gemma3_4b( shd_config=sharding_config, ) + @classmethod + def gemma3_4b_it(cls, **kwargs): + return cls.gemma3_4b(**kwargs) + @classmethod def gemma3_12b( cls, @@ -193,6 +205,10 @@ def gemma3_12b( shd_config=sharding_config, ) + @classmethod + def gemma3_12b_it(cls, **kwargs): + return cls.gemma3_12b(**kwargs) + @classmethod def gemma3_27b( cls, @@ -215,6 +231,10 @@ def gemma3_27b( shd_config=sharding_config, ) + @classmethod + def gemma3_27b_it(cls, **kwargs): + return cls.gemma3_27b(**kwargs) + def shard(x: jnp.ndarray, s: Tuple[str, ...]): mesh = pxla.thread_resources.env.physical_mesh