1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import enum
1516import gc
1617import importlib
1718import os
2425from orbax import checkpoint as ocp
2526import qwix
2627from tunix .generate import tokenizer_adapter as tokenizer_lib
27- from tunix .models .gemma import model as gemma_lib
28- from tunix .models .gemma import params as gemma_params_lib
29- from tunix .models .gemma3 import model as gemma3_lib
30- from tunix .models .gemma3 import params as gemma3_params_lib
31- from tunix .models .llama3 import model as llama3_lib
32- from tunix .models .qwen2 import model as qwen2_lib
33- from tunix .models .qwen3 import model as qwen3_lib
28+ from tunix .models import naming
3429from tunix .oss import utils as oss_utils
3530from tunix .rl import reshard
3631
3732
38- # Map prefixes to the target object containing the methods.
39- CONFIG_MAP = {
40- 'gemma' : gemma_lib .ModelConfig ,
41- 'gemma1.1' : gemma_lib .ModelConfig ,
42- 'gemma2' : gemma_lib .ModelConfig ,
43- 'gemma3' : gemma3_lib .ModelConfig ,
44- 'llama3' : llama3_lib .ModelConfig ,
45- 'llama3.1' : llama3_lib .ModelConfig ,
46- 'llama3.2' : llama3_lib .ModelConfig ,
47- 'qwen2.5' : qwen2_lib .ModelConfig ,
48- 'qwen3' : qwen3_lib .ModelConfig ,
49- }
50-
5133_BASE_MODULE_PATH = 'tunix.models' # pylint: disable=invalid-name
5234
5335
54- def get_model_module (model_name : str ) -> Any :
55- """Dynamically imports the parameter module based on the model name."""
56- # Extract the base model type (e.g., "qwen2", "llama3")
57- match = re .match (r'^[a-zA-Z0-9]+' , model_name )
58- if not match :
59- raise ValueError (f'Invalid model name format: { model_name } ' )
60- model_type = match .group (0 )
61- # Construct the full module path, e.g.,.path.to.your.models.qwen2.params
62- if model_name .startswith ('gemma1' ) or model_name .startswith ('gemma2' ):
63- model_type = 'gemma'
64- module_path = f'{ _BASE_MODULE_PATH } .{ model_type } .params'
36+ class ModelModule (enum .Enum ):
37+ """Specifies the type of model module to import."""
38+
39+ MODEL = 'model'
40+ PARAMS = 'params'
41+
42+
43+ def get_model_module (model_name : str , module_type : ModelModule ) -> Any :
44+ """Dynamically imports a model module (e.g., 'model' or 'params')."""
45+ model_config_category = naming .get_model_config_category (model_name )
46+ module_path = (
47+ f'{ _BASE_MODULE_PATH } .{ model_config_category } .{ module_type .value } '
48+ )
6549 try :
66- print ( f 'Attempting to import: { module_path } ' )
67- model_module = importlib .import_module (module_path )
68- return model_module
69- except ImportError as exc : # Capture the original exeception as 'exc'
50+ logging . info ( 'Attempting to import: %s' , module_path )
51+ model_lib_module = importlib .import_module (module_path )
52+ return model_lib_module
53+ except ImportError as exc :
7054 raise ImportError (
71- f 'Could not import module for model type: { model_type } '
72- f'at path: { module_path } . Please check BASE_MODULE_PATH '
73- 'and ensure the module exists and is a dependency.'
55+ 'Could not import module for model config category: '
56+ f'{ model_config_category } at path: { module_path } . Please check '
57+ 'BASE_MODULE_PATH and ensure the module exists and is a dependency.'
7458 ) from exc
7559
7660
@@ -93,22 +77,92 @@ def create_model_dynamically(
9377 ImportError: If the required model module cannot be found.
9478 AttributeError: If create_model_from_safe_tensors is not in the module.
9579 """
96- model_module = get_model_module (model_name )
80+ params_module = get_model_module (model_name , ModelModule . PARAMS )
9781
9882 try :
99- create_fn = getattr (model_module , 'create_model_from_safe_tensors' )
83+ create_fn = getattr (params_module , 'create_model_from_safe_tensors' )
10084 except AttributeError as exc :
10185 raise AttributeError (
10286 "'create_model_from_safe_tensors' not found in module "
103- f'{ model_module .__name__ } for model { model_name } '
87+ f'{ params_module .__name__ } for model { model_name } '
10488 ) from exc
10589
10690 logging .info (
107- 'Calling %s.create_model_from_safe_tensors' , model_module .__name__
91+ 'Calling %s.create_model_from_safe_tensors' , params_module .__name__
10892 )
10993 return create_fn (file_dir = file_dir , config = model_config , mesh = mesh )
11094
11195
96+ def obtain_model_params (model_name : str ) -> Any :
97+ """Dynamically calls a configuration function based on the model_string.
98+
99+ The routing to the correct module/class instance is based on the longest
100+ matching prefix of model_name found in CONFIG_MAP.
101+ Hyphens and dots in the model_name are converted to underscores
102+ to form the function name.
103+
104+ Args:
105+ model_name: The string indicating which model config function to call
106+ (e.g., "gemma-2b", "llama3.1-8b", "qwen2.5-0.5b").
107+
108+ Returns:
109+ The result from calling the dynamically determined function.
110+
111+ Raises:
112+ ValueError: If the model_string doesn't match any known prefix.
113+ AttributeError: If the derived function name does not exist in the target
114+ object.
115+ TypeError: If the attribute found on the target object is not callable.
116+ """
117+ config_id = naming .get_model_config_id (model_name )
118+ model_lib_module = get_model_module (model_name , ModelModule .MODEL )
119+ target_obj = model_lib_module .ModelConfig
120+ logging .info ('[SHADI] model_lib_module: %s' , type (model_lib_module ))
121+ logging .info ('[SHADI] target_obj: %s' , type (target_obj ))
122+ logging .info ('[SHADI] config_id: %s' , config_id )
123+ logging .info ('[SHADI] model_name: %s' , model_name )
124+ if not hasattr (target_obj , config_id ):
125+ raise AttributeError (
126+ f"Error: Function '{ config_id } ' not found on the target object "
127+ f"for model '{ model_name } '. Target object type: { type (target_obj )} "
128+ )
129+
130+ method_to_call = getattr (target_obj , config_id )
131+
132+ if not callable (method_to_call ):
133+ raise TypeError (
134+ f"Error: Attribute '{ config_id } ' on the target object is not callable."
135+ )
136+
137+ logging .info (
138+ 'Attempting to call: %s() on object of type %s' ,
139+ config_id ,
140+ type (target_obj ),
141+ )
142+ return method_to_call ()
143+
144+
145+ from tunix .models .gemma import model as gemma_lib
146+ from tunix .models .gemma import params as gemma_params_lib
147+ from tunix .models .gemma3 import model as gemma3_lib
148+ from tunix .models .gemma3 import params as gemma3_params_lib
149+ from tunix .models .llama3 import model as llama3_lib
150+ from tunix .models .qwen2 import model as qwen2_lib
151+ from tunix .models .qwen3 import model as qwen3_lib
152+ # Map prefixes to the target object containing the methods.
153+ CONFIG_MAP = {
154+ 'gemma' : gemma_lib .ModelConfig ,
155+ 'gemma1.1' : gemma_lib .ModelConfig ,
156+ 'gemma2' : gemma_lib .ModelConfig ,
157+ 'gemma3' : gemma3_lib .ModelConfig ,
158+ 'llama3' : llama3_lib .ModelConfig ,
159+ 'llama3.1' : llama3_lib .ModelConfig ,
160+ 'llama3.2' : llama3_lib .ModelConfig ,
161+ 'qwen2.5' : qwen2_lib .ModelConfig ,
162+ 'qwen3' : qwen3_lib .ModelConfig ,
163+ }
164+
165+
112166def _get_version (model_name : str , matched_prefix : str ) -> str :
113167 """Extracts the version string from the model name."""
114168 if not model_name .startswith (matched_prefix ):
@@ -126,7 +180,7 @@ def _get_version(model_name: str, matched_prefix: str) -> str:
126180 return suffix .replace ('.' , '_' ).replace ('-' , '_' )
127181
128182
129- def obtain_model_params (model_name : str ) -> Any :
183+ def obtain_model_params_original (model_name : str ) -> Any :
130184 """Dynamically calls a configuration function based on the model_string.
131185
132186 The routing to the correct module/class instance is based on the longest
@@ -173,6 +227,13 @@ def obtain_model_params(model_name: str) -> Any:
173227
174228 function_name = f'{ family_snake } _{ core_version } '
175229
230+ logging .info ('[SHADI] function_name: %s' , function_name )
231+ logging .info ('[SHADI] matched_prefix: %s' , matched_prefix )
232+ logging .info ('[SHADI] family_snake: %s' , family_snake )
233+ logging .info ('[SHADI] core_version: %s' , core_version )
234+ logging .info ('[SHADI] target_obj: %s' , type (target_obj ))
235+ logging .info ('[SHADI] config_id/function_name: %s' , function_name )
236+ logging .info ('[SHADI] model_name: %s' , model_name )
176237 if not hasattr (target_obj , function_name ):
177238 raise AttributeError (
178239 f"Error: Function '{ function_name } ' not found on the target object "
@@ -197,9 +258,12 @@ def obtain_model_params(model_name: str) -> Any:
197258
198259def _get_base_model (model_config : dict [str , Any ], mesh : jax .sharding .Mesh ):
199260 """Get the base model from the intermediate checkpoint."""
200- model_params = obtain_model_params (model_config ['model_name' ])
261+ model_params = obtain_model_params_original (model_config ['model_name' ])
262+ model_lib_module = get_model_module (
263+ model_config ['model_name' ], ModelModule .MODEL
264+ )
201265 abs_model : nnx .Module = nnx .eval_shape (
202- lambda : gemma_lib .Transformer (
266+ lambda : model_lib_module .Transformer (
203267 model_params , rngs = nnx .Rngs (model_config .get ('rng_seed' , 0 ))
204268 )
205269 )
@@ -363,11 +427,15 @@ def create_model(
363427 tokenizer_path : str = tokenizer_config ['tokenizer_path' ]
364428 model_name = model_config ['model_name' ]
365429 model_source = model_config ['model_source' ]
366-
430+ logging .info (
431+ '[SHADI] model_source: %s, model_name: %s' , model_source , model_name
432+ )
433+ logging .info ('[SHADI] model_config: %s' , model_config )
367434 if model_name .startswith ('gemma3' ) and model_source == 'gcs' :
368435
369436 ckpt_path = model_config ['model_id' ]
370- model_params = obtain_model_params (model_name )
437+ model_params = obtain_model_params_original (model_name )
438+ gemma3_params_lib = get_model_module (model_name , ModelModule .PARAMS )
371439 model = gemma3_params_lib .create_model_from_checkpoint (
372440 ckpt_path , model_params , mesh
373441 )
@@ -394,10 +462,14 @@ def nnx_conversion():
394462 suffix = '-' .join (model_name .split ('-' )[1 :])
395463 params_path = os .path .join (ckpt_path , suffix )
396464
465+ gemma_params_lib = get_model_module (model_name , ModelModule .PARAMS )
397466 params = gemma_params_lib .load_and_format_params (params_path )
398- model = gemma_lib .Transformer .from_params (
399- params , version = _get_model_version_suffix (model_name )
400- )
467+ gemma_lib = get_model_module (model_name , ModelModule .MODEL )
468+ model_family , version = naming .split (model_name )
469+ # TODO(b/451662153): have gemma2 version handling done better in naming.py
470+ if model_family == 'gemma2' :
471+ version = f'2-{ version } '
472+ model = gemma_lib .Transformer .from_params (params , version = version )
401473 return _gemma_conversion (model_config , model , params , mesh )
402474
403475 if skip_nnx_conversion :
@@ -419,6 +491,7 @@ def nnx_conversion():
419491 elif model_source == 'huggingface' :
420492 # for all other model
421493 oss_utils .hf_pipeline (model_config )
494+ logging .info ('[SHADI] inside huggingface' )
422495 else :
423496 logging .error (
424497 'Unsupported workflow: from %s to download %s.' ,
@@ -428,9 +501,13 @@ def nnx_conversion():
428501
429502 if not model_params :
430503 # pick corresponding config based on model version
431- model_params = obtain_model_params (model_name )
504+ logging .info ('[SHADI] inside if not model_params' )
505+ model_params = obtain_model_params_original (model_name )
432506
433507 with mesh :
508+ logging .info (
509+ '[SHADI] inside create_model_dynamically with mesh, %s' , mesh
510+ )
434511 model = create_model_dynamically (
435512 model_name ,
436513 model_config ['model_download_path' ],
0 commit comments