From b8716f3ec23c93b310b569ea94bef8a59877bd1d Mon Sep 17 00:00:00 2001 From: Shadi Noghabi Date: Tue, 11 Nov 2025 09:27:09 -0800 Subject: [PATCH] add automodel to Tunix PiperOrigin-RevId: 830948333 --- tests/cli/utils/model_test.py | 83 ++++++++++- tests/models/naming_test.py | 30 ++++ tunix/cli/utils/model.py | 265 +++++++--------------------------- tunix/models/automodel.py | 204 ++++++++++++++++++++++++++ tunix/models/gemma3/model.py | 40 +++++ tunix/models/naming.py | 2 + 6 files changed, 406 insertions(+), 218 deletions(-) create mode 100644 tunix/models/automodel.py diff --git a/tests/cli/utils/model_test.py b/tests/cli/utils/model_test.py index 1cbc4eabe..62cdd2bd3 100644 --- a/tests/cli/utils/model_test.py +++ b/tests/cli/utils/model_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest from absl.testing import parameterized from tunix.cli.utils import model +from tunix.models import automodel @parameterized.named_parameters( @@ -62,22 +63,82 @@ testcase_name="gemma3-270m", model_name="gemma3-270m", ), + dict( + testcase_name="gemma3-270m-it", + model_name="gemma3-270m-it", + ), dict( testcase_name="gemma3-1b", model_name="gemma3-1b", ), + dict( + testcase_name="gemma3-1b-it", + model_name="gemma3-1b-it", + ), dict( testcase_name="gemma3-4b", model_name="gemma3-4b", ), + dict( + testcase_name="gemma3-4b-it", + model_name="gemma3-4b-it", + ), dict( testcase_name="gemma3-12b", model_name="gemma3-12b", ), + dict( + testcase_name="gemma3-12b-it", + model_name="gemma3-12b-it", + ), dict( testcase_name="gemma3-27b", model_name="gemma3-27b", ), + dict( + testcase_name="gemma3-27b-it", + model_name="gemma3-27b-it", + ), + dict( + testcase_name="gemma-3-270m", + model_name="gemma-3-270m", + ), + dict( + testcase_name="gemma-3-270m-it", + model_name="gemma-3-270m-it", + ), + dict( + testcase_name="gemma-3-1b", + model_name="gemma-3-1b", + ), + dict( + testcase_name="gemma-3-1b-it", + model_name="gemma-3-1b-it", + ), + dict( + testcase_name="gemma-3-4b", + model_name="gemma-3-4b", + ), + dict( + testcase_name="gemma-3-4b-it", + model_name="gemma-3-4b-it", + ), + dict( + testcase_name="gemma-3-12b", + model_name="gemma-3-12b", + ), + dict( + testcase_name="gemma-3-12b-it", + model_name="gemma-3-12b-it", + ), + dict( + testcase_name="gemma-3-27b", + model_name="gemma-3-27b", + ), + dict( + testcase_name="gemma-3-27b-it", + model_name="gemma-3-27b-it", + ), dict( testcase_name="llama3-70b", model_name="llama3-70b", @@ -118,11 +179,10 @@ testcase_name="qwen2.5-math-1.5b", model_name="qwen2.5-math-1.5b", ), - # TODO(b/451662153): support deepseek model name parsing - # dict( - # testcase_name="deepseek-r1-distill-qwen-1.5b", - # model_name="deepseek-r1-distill-qwen-1.5b", - # ), + dict( + testcase_name="deepseek-r1-distill-qwen-1.5b", + model_name="deepseek-r1-distill-qwen-1.5b", + ), dict( testcase_name="qwen3-0.6b", model_name="qwen3-0.6b", @@ -148,13 +208,20 @@ class ModelTest(parameterized.TestCase): def test_obtain_model_params_valid(self, model_name: str): - model.obtain_model_params(model_name) + automodel.obtain_model_params(model_name) def test_create_model_dynamically_routing(self, model_name: str): - model_module = model.get_model_module(model_name) + params_module = automodel.get_model_module( + model_name, model.ModelModule.PARAMS + ) if not model_name.startswith("gemma"): # TODO(b/444572467) - getattr(model_module, "create_model_from_safe_tensors") + getattr(params_module, "create_model_from_safe_tensors") + + model_lib_module = automodel.get_model_module( + model_name, automodel.ModelModule.MODEL + ) + getattr(model_lib_module, "ModelConfig") if __name__ == "__main__": diff --git a/tests/models/naming_test.py b/tests/models/naming_test.py index 01d0cae4f..5931861e1 100644 --- a/tests/models/naming_test.py +++ b/tests/models/naming_test.py @@ -99,30 +99,60 @@ class ModelTestInfo: id='gemma3_270m', category='gemma3', ), + 'gemma-3-270m-it': ModelTestInfo( + family='gemma3', + version='270m_it', + id='gemma3_270m_it', + category='gemma3', + ), 'gemma-3-1b': ModelTestInfo( family='gemma3', version='1b', id='gemma3_1b', category='gemma3', ), + 'gemma-3-1b-it': ModelTestInfo( + family='gemma3', + version='1b_it', + id='gemma3_1b_it', + category='gemma3', + ), 'gemma-3-4b': ModelTestInfo( family='gemma3', version='4b', id='gemma3_4b', category='gemma3', ), + 'gemma-3-4b-it': ModelTestInfo( + family='gemma3', + version='4b_it', + id='gemma3_4b_it', + category='gemma3', + ), 'gemma-3-12b': ModelTestInfo( family='gemma3', version='12b', id='gemma3_12b', category='gemma3', ), + 'gemma-3-12b-it': ModelTestInfo( + family='gemma3', + version='12b_it', + id='gemma3_12b_it', + category='gemma3', + ), 'gemma-3-27b': ModelTestInfo( family='gemma3', version='27b', id='gemma3_27b', category='gemma3', ), + 'gemma-3-27b-it': ModelTestInfo( + family='gemma3', + version='27b_it', + id='gemma3_27b_it', + category='gemma3', + ), 'llama3-70b': ModelTestInfo( family='llama3', version='70b', diff --git a/tunix/cli/utils/model.py b/tunix/cli/utils/model.py index f5c32b361..a06280c12 100644 --- a/tunix/cli/utils/model.py +++ b/tunix/cli/utils/model.py @@ -14,192 +14,35 @@ """Utilities for creating and managing models in Tunix CLI.""" import gc -import importlib import os -import re from typing import Any, Tuple from absl import logging +import flax from flax import nnx import jax import jax.numpy as jnp from orbax import checkpoint as ocp import qwix from tunix.generate import tokenizer_adapter as tokenizer_lib -from tunix.models.gemma import model as gemma_lib -from tunix.models.gemma import params as gemma_params_lib -from tunix.models.gemma3 import model as gemma3_lib -from tunix.models.gemma3 import params as gemma3_params_lib -from tunix.models.llama3 import model as llama3_lib -from tunix.models.qwen2 import model as qwen2_lib -from tunix.models.qwen3 import model as qwen3_lib +from tunix.models import automodel +from tunix.models import naming from tunix.rl import reshard +ModelModule = automodel.ModelModule -# Map prefixes to the target object containing the methods. -CONFIG_MAP = { - 'gemma': gemma_lib.ModelConfig, - 'gemma1.1': gemma_lib.ModelConfig, - 'gemma2': gemma_lib.ModelConfig, - 'gemma3': gemma3_lib.ModelConfig, - 'llama3': llama3_lib.ModelConfig, - 'llama3.1': llama3_lib.ModelConfig, - 'llama3.2': llama3_lib.ModelConfig, - 'qwen2.5': qwen2_lib.ModelConfig, - 'qwen3': qwen3_lib.ModelConfig, -} - -_BASE_MODULE_PATH = 'tunix.models' # pylint: disable=invalid-name - - -def get_model_module(model_name: str) -> Any: - """Dynamically imports the parameter module based on the model name.""" - # Extract the base model type (e.g., "qwen2", "llama3") - match = re.match(r'^[a-zA-Z0-9]+', model_name) - if not match: - raise ValueError(f'Invalid model name format: {model_name}') - model_type = match.group(0) - # Construct the full module path, e.g.,.path.to.your.models.qwen2.params - if model_name.startswith('gemma1') or model_name.startswith('gemma2'): - model_type = 'gemma' - module_path = f'{_BASE_MODULE_PATH}.{model_type}.params' - try: - print(f'Attempting to import: {module_path}') - model_module = importlib.import_module(module_path) - return model_module - except ImportError as exc: # Capture the original exeception as 'exc' - raise ImportError( - f'Could not import module for model type: {model_type} ' - f'at path: {module_path}. Please check BASE_MODULE_PATH ' - 'and ensure the module exists and is a dependency.' - ) from exc - - -def create_model_dynamically( - model_name: str, file_dir: str, model_config: Any, mesh: jax.sharding.Mesh -) -> Any: - """Dynamically imports the correct module and calls `create_model_from_safe_tensors` based on the model_name. - - Args: - model_name: The name of the model (e.g., "qwen2.5-0.5b", "llama3.2-3b"). - file_dir: Directory containing the safe tensors. - model_config: Model configuration object. - mesh: Mesh object for device layout. - - Returns: - The result of the create_model_from_safe_tensors call. - - Raises: - ValueError: If the model_name is invalid. - ImportError: If the required model module cannot be found. - AttributeError: If create_model_from_safe_tensors is not in the module. - """ - model_module = get_model_module(model_name) - - try: - create_fn = getattr(model_module, 'create_model_from_safe_tensors') - except AttributeError as exc: - raise AttributeError( - "'create_model_from_safe_tensors' not found in module " - f'{model_module.__name__} for model {model_name}' - ) from exc - - logging.info( - 'Calling %s.create_model_from_safe_tensors', model_module.__name__ - ) - return create_fn(file_dir=file_dir, config=model_config, mesh=mesh) - - -def _get_version(model_name: str, matched_prefix: str) -> str: - """Extracts the version string from the model name.""" - if not model_name.startswith(matched_prefix): - return '' - - suffix = model_name[len(matched_prefix) :] - - # Remove leading separator (- or .) if present - if suffix.startswith('-') or suffix.startswith('.'): - suffix = suffix[1:] - - if not suffix: - return '' - - return suffix.replace('.', '_').replace('-', '_') - - -def obtain_model_params(model_name: str) -> Any: - """Dynamically calls a configuration function based on the model_string. - - The routing to the correct module/class instance is based on the longest - matching prefix of model_name found in CONFIG_MAP. - Hyphens and dots in the model_name are converted to underscores - to form the function name. - - Args: - model_name: The string indicating which model config function to call - (e.g., "gemma-2b", "llama3.1-8b", "qwen2.5-0.5b"). - - Returns: - The result from calling the dynamically determined function. - - Raises: - ValueError: If the model_string doesn't match any known prefix. - AttributeError: If the derived function name does not exist in the target - object. - TypeError: If the attribute found on the target object is not callable. - """ - target_obj = None - matched_prefix = '' - - # Find the longest matching prefix - for prefix, obj in CONFIG_MAP.items(): - if model_name.startswith(prefix): - if len(prefix) > len(matched_prefix): - matched_prefix = prefix - target_obj = obj - - if not target_obj: - raise ValueError(f'Unsupported model string prefix for: {model_name}') - - logging.info('Routing %s using prefix %s', model_name, matched_prefix) - - family_snake = matched_prefix.replace('-', '_').replace('.', '_') - core_version = _get_version(model_name, matched_prefix) - - if not core_version: - raise ValueError( - f"Could not extract core version from '{model_name}' " - f"for prefix '{matched_prefix}'." - ) - - function_name = f'{family_snake}_{core_version}' - - if not hasattr(target_obj, function_name): - raise AttributeError( - f"Error: Function '{function_name}' not found on the target object " - f"for prefix '{matched_prefix}'. Target object type: {type(target_obj)}" - ) - - method_to_call = getattr(target_obj, function_name) - - if not callable(method_to_call): - raise TypeError( - f"Error: Attribute '{function_name}' on the target object is not" - ' callable.' - ) - - logging.info( - 'Attempting to call: %s() on object of type %s', - function_name, - type(target_obj), - ) - return method_to_call() +# TODO(b/462808330): Handle shading overrides better. +if hasattr(flax.config, 'flax_always_shard_variable'): + flax.config.update('flax_always_shard_variable', False) def _get_base_model(model_config: dict[str, Any], mesh: jax.sharding.Mesh): """Get the base model from the intermediate checkpoint.""" - model_params = obtain_model_params(model_config['model_name']) + model_params = automodel.obtain_model_params(model_config['model_name']) + model_lib_module = automodel.get_model_module( + model_config['model_name'], ModelModule.MODEL + ) abs_model: nnx.Module = nnx.eval_shape( - lambda: gemma_lib.Transformer( + lambda: model_lib_module.Transformer( model_params, rngs=nnx.Rngs(model_config.get('rng_seed', 0)) ) ) @@ -279,42 +122,51 @@ def _gemma_conversion( return _get_base_model(model_config, mesh) -def _get_model_version_suffix(model_name: str) -> str: - """Extracts the version/variant suffix from a model name string. +def _create_model_from_checkpoint( + ckpt_path: str, model_name: str, mesh: jax.sharding.Mesh +) -> Tuple[nnx.Module, Any]: + """Creates a model from a checkpoint. - The function is based on the following examples: - - "gemma2-2b-it" -> "2-2b-it" - - "gemma2-2b" -> "2-2b" - - "gemma-2b" -> "2b" + This function is used to create a model from a checkpoint. It is used for + models that have support for `create_model_from_checkpoint` in their params + module, such as Gemma3. Args: - model_name: The full model name string. + ckpt_path: The path to the checkpoint. + model_name: The name of the model (e.g., "qwen2.5-0.5b", "llama3.2-3b"). + mesh: Mesh object for device layout. Returns: - The version/variant suffix string. + A tuple containing: + - model: The loaded and potentially LoRA-applied nnx.Module. + - model_params: The model parameters. + """ + model_params = automodel.obtain_model_params(model_name) + params_lib = automodel.get_model_module(model_name, ModelModule.PARAMS) + model = params_lib.create_model_from_checkpoint(ckpt_path, model_params, mesh) + return model, model_params + - Raises: - ValueError: If the model_name does not match a known pattern or - unsupported model family. +def _create_model_from_params( + params_path: str, model_name: str +) -> Tuple[nnx.Module, Any]: + """Loads Gemma params and creates a model. + + This function is used to create a model from a params path. It is used for + models that have support for `from_params` in their Transformer module, + such as Gemma and Gemma2. """ - if model_name.startswith('gemma'): - # Pattern 1: Matches names like "gemma2-2b-it", "gemma7b", etc. - # Captures the part starting with the first digit after "gemma". - match = re.match(r'^gemma(\d.*)$', model_name) - if match: - return match.group(1) - - # Pattern 2: Matches names like "gemma-2b", "gemma-7b-it", etc. - # Captures the part after "gemma-". - match = re.match(r'^gemma-(.+)$', model_name) - if match: - return match.group(1) - - # If neither pattern matches - raise ValueError(f'Unrecognized gemma model format: {model_name}') - else: - # This part can be extended for other model families like "llama", etc. - raise ValueError(f'Unsupported model family for: {model_name}') + params_lib = automodel.get_model_module(model_name, ModelModule.PARAMS) + model_params = params_lib.load_and_format_params(params_path) + model_module_lib = automodel.get_model_module(model_name, ModelModule.MODEL) + model_family, version = naming.split(model_name) + # TODO(b/451662153): have gemma2 version handling done better in naming.py + if model_family == 'gemma2': + version = f'2-{version}' + model = model_module_lib.Transformer.from_params( + model_params, version=version + ) + return model, model_params def create_tokenizer(tokenizer_config, tokenizer_path: str | None): @@ -367,9 +219,8 @@ def create_model( if model_name.startswith('gemma3') and model_source == 'gcs': ckpt_path = model_config['model_id'] - model_params = obtain_model_params(model_name) - model = gemma3_params_lib.create_model_from_checkpoint( - ckpt_path, model_params, mesh + model, model_params = _create_model_from_checkpoint( + ckpt_path, model_name, mesh ) tokenizer_path = 'gs://gemma-data/tokenizers/tokenizer_gemma3.model' @@ -395,10 +246,7 @@ def nnx_conversion(): suffix = '-'.join(model_name.split('-')[1:]) params_path = os.path.join(ckpt_path, suffix) - params = gemma_params_lib.load_and_format_params(params_path) - model = gemma_lib.Transformer.from_params( - params, version=_get_model_version_suffix(model_name) - ) + model, params = _create_model_from_params(params_path, model_name) return _gemma_conversion(model_config, model, params, mesh) if skip_nnx_conversion: @@ -430,14 +278,11 @@ def nnx_conversion(): if not model_params: # pick corresponding config based on model version - model_params = obtain_model_params(model_name) - with mesh: - model = create_model_dynamically( + model = automodel.AutoModel.from_pretrained( model_name, model_config['model_download_path'], - model_params, - mesh, + mesh=mesh, ) if model_config.get('lora_config'): diff --git a/tunix/models/automodel.py b/tunix/models/automodel.py new file mode 100644 index 000000000..f2035d27b --- /dev/null +++ b/tunix/models/automodel.py @@ -0,0 +1,204 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AutoModel class.""" + +import enum +import importlib +from typing import Any + +from absl import logging +import jax +from tunix.models import naming + + +_BASE_MODULE_PATH = 'tunix.models' + + +class ModelModule(enum.Enum): + """Specifies the type of model module to import.""" + + MODEL = 'model' + PARAMS = 'params' + + +def get_model_module(model_name: str, module_type: ModelModule) -> Any: + """Dynamically imports a model module (e.g., 'model' or 'params').""" + model_config_category = naming.get_model_config_category(model_name) + module_path = ( + f'{_BASE_MODULE_PATH}.{model_config_category}.{module_type.value}' + ) + try: + logging.info('Attempting to import: %s', module_path) + model_lib_module = importlib.import_module(module_path) + return model_lib_module + except ImportError as exc: + raise ImportError( + 'Could not import module for model config category: ' + f'{model_config_category} at path: {module_path}. Please check ' + 'BASE_MODULE_PATH and ensure the module exists and is a dependency.' + ) from exc + + +def obtain_model_params(model_name: str) -> Any: + """Dynamically calls a configuration function based on the model_string. + + The routing to the correct module/class instance is based on the longest + matching prefix of model_name found in CONFIG_MAP. + Hyphens and dots in the model_name are converted to underscores + to form the function name. + + Args: + model_name: The string indicating which model config function to call + (e.g., "gemma-2b", "llama3.1-8b", "qwen2.5-0.5b"). + + Returns: + The result from calling the dynamically determined function. + + Raises: + ValueError: If the model_string doesn't match any known prefix. + AttributeError: If the derived function name does not exist in the target + object. + TypeError: If the attribute found on the target object is not callable. + """ + config_id = naming.get_model_config_id(model_name) + model_lib_module = get_model_module(model_name, ModelModule.MODEL) + target_obj = model_lib_module.ModelConfig + + if not hasattr(target_obj, config_id): + raise AttributeError( + f"Error: Function '{config_id}' not found on the target object " + f"for model '{model_name}'. Target object type: {type(target_obj)}" + ) + + method_to_call = getattr(target_obj, config_id) + + if not callable(method_to_call): + raise TypeError( + f"Error: Attribute '{config_id}' on the target object is not callable." + ) + + logging.info( + 'Attempting to call: %s() on object of type %s', + config_id, + type(target_obj), + ) + return method_to_call() + + +def _create_model_from_safe_tensors_dynamically( + model_name: str, file_dir: str, model_config: Any, mesh: jax.sharding.Mesh +) -> Any: + """Dynamically imports the correct module and calls `create_model_from_safe_tensors` based on the model_name. + + Args: + model_name: The name of the model (e.g., "qwen2.5-0.5b", "llama3.2-3b"). + file_dir: Directory containing the safe tensors. + model_config: Model configuration object. + mesh: Mesh object for device layout. + + Returns: + The result of the create_model_from_safe_tensors call. + + Raises: + ValueError: If the model_name is invalid. + ImportError: If the required model module cannot be found. + AttributeError: If create_model_from_safe_tensors is not in the module. + """ + + model_config_category = naming.get_model_config_category(model_name) + if model_config_category.startswith('gemma'): + # TODO(b/444572467): Remove this check once Gemma safetensors works + raise NotImplementedError( + 'Gemma safetensors loading is not supported in AutoModel. Please use' + ' the original model module.' + ) + params_module = get_model_module(model_name, ModelModule.PARAMS) + + try: + create_fn = getattr(params_module, 'create_model_from_safe_tensors') + except AttributeError as exc: + raise AttributeError( + "'create_model_from_safe_tensors' not found in module " + f'{params_module.__name__} for model {model_name}' + ) from exc + + logging.info( + 'Calling %s.create_model_from_safe_tensors', params_module.__name__ + ) + return create_fn(file_dir=file_dir, config=model_config, mesh=mesh) + + +class AutoModel: + """A generic model class that will be instantiated as one of the model classes of the library. + + This class provides a way to instantiate a model from a configuration, or load + a pretrained model from a file directory. + It relies on dynamic imports based on the model name to load the correct model + class and parameters. + + Example: + To load a pretrained model from safe tensors: + ``` + model = AutoModel.from_pretrained("qwen2.5-0.5b", "/path/to/weights", mesh) + ``` + + To instantiate a model from a config: + ``` + model_params = obtain_model_params("qwen2.5-0.5b") + model = AutoModel.from_config("qwen2.5-0.5b", model_params) + ``` + """ + + def __init__(self): + raise EnvironmentError( + 'AutoModel is designed to be instantiated using site-class methods ' + 'like `from_pretrained()` or `from_config()`.' + ) + + @classmethod + def from_config(cls, model_name: str, *args, **kwargs): + """Instantiates one of the model classes of the library from a configuration. + + Args: + model_name: The name of the model (e.g., "qwen2.5-0.5b", "llama3.2-3b"). + *args: Positional arguments to pass to the model constructor. + **kwargs: Keyword arguments to pass to the model constructor. + + Returns: + An instance of the model class (e.g., Transformer). + """ + model_lib_module = get_model_module(model_name, ModelModule.MODEL) + return model_lib_module.Transformer(*args, **kwargs) + + @classmethod + def from_pretrained( + cls, model_name: str, file_dir: str, mesh: jax.sharding.Mesh + ): + """Instantiates one of the pretrained model classes of the library from a file directory. + + This method loads model weights from safe tensors found in `file_dir`. + + Args: + model_name: The name of the model (e.g., "qwen2.5-0.5b", "llama3.2-3b"). + file_dir: Directory containing the safe tensors. + mesh: Mesh object for device layout. + + Returns: + An instance of the model class with pretrained weights loaded. + """ + model_params = obtain_model_params(model_name) + return _create_model_from_safe_tensors_dynamically( + model_name, file_dir, model_params, mesh + ) diff --git a/tunix/models/gemma3/model.py b/tunix/models/gemma3/model.py index 846625076..59c56ec28 100644 --- a/tunix/models/gemma3/model.py +++ b/tunix/models/gemma3/model.py @@ -131,6 +131,14 @@ def gemma3_270m( shd_config=sharding_config, ) + @classmethod + def gemma3_270m_it( + cls, + sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + ) -> 'ModelConfig': + """Gemma3-270M instruction-tuned text-only config.""" + return cls.gemma3_270m(sharding_config=sharding_config) + @classmethod def gemma3_1b( cls, @@ -150,6 +158,14 @@ def gemma3_1b( shd_config=sharding_config, ) + @classmethod + def gemma3_1b_it( + cls, + sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + ) -> 'ModelConfig': + """Gemma3-1B instruction-tuned text-only config.""" + return cls.gemma3_1b(sharding_config=sharding_config) + @classmethod def gemma3_4b( cls, @@ -171,6 +187,14 @@ def gemma3_4b( shd_config=sharding_config, ) + @classmethod + def gemma3_4b_it( + cls, + sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + ) -> 'ModelConfig': + """Gemma3-4B instruction-tuned text-only config.""" + return cls.gemma3_4b(sharding_config=sharding_config) + @classmethod def gemma3_12b( cls, @@ -193,6 +217,14 @@ def gemma3_12b( shd_config=sharding_config, ) + @classmethod + def gemma3_12b_it( + cls, + sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + ) -> 'ModelConfig': + """Gemma3-12B instruction-tuned text-only config.""" + return cls.gemma3_12b(sharding_config=sharding_config) + @classmethod def gemma3_27b( cls, @@ -215,6 +247,14 @@ def gemma3_27b( shd_config=sharding_config, ) + @classmethod + def gemma3_27b_it( + cls, + sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + ) -> 'ModelConfig': + """Gemma3-27B instruction-tuned text-only config.""" + return cls.gemma3_27b(sharding_config=sharding_config) + def shard(x: jnp.ndarray, s: Tuple[str, ...]): mesh = pxla.thread_resources.env.physical_mesh diff --git a/tunix/models/naming.py b/tunix/models/naming.py index 9b7574919..6ac74e041 100644 --- a/tunix/models/naming.py +++ b/tunix/models/naming.py @@ -74,6 +74,8 @@ class _ModelFamilyInfo: 'gemma': _ModelFamilyInfo(family='gemma', config_category='gemma'), 'gemma1.1': _ModelFamilyInfo(family='gemma1_1', config_category='gemma'), 'gemma2': _ModelFamilyInfo(family='gemma2', config_category='gemma'), + # Support both gemma3 and gemma-3 as model prefixes. + 'gemma3': _ModelFamilyInfo(family='gemma3', config_category='gemma3'), 'gemma-3': _ModelFamilyInfo(family='gemma3', config_category='gemma3'), 'llama3': _ModelFamilyInfo(family='llama3', config_category='llama3'), 'llama3.1': _ModelFamilyInfo(family='llama3_1', config_category='llama3'),