Skip to content

Commit 5b53e44

Browse files
s-noghabiThe tunix Authors
authored andcommitted
[DNS] debug failure
PiperOrigin-RevId: 834954246
1 parent a6716c3 commit 5b53e44

File tree

6 files changed

+211
-58
lines changed

6 files changed

+211
-58
lines changed

tests/cli/utils/model_test.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@
7878
testcase_name="gemma3-27b",
7979
model_name="gemma3-27b",
8080
),
81+
dict(
82+
testcase_name="gemma-3-270m",
83+
model_name="gemma-3-270m",
84+
),
85+
dict(
86+
testcase_name="gemma-3-1b",
87+
model_name="gemma-3-1b",
88+
),
89+
dict(
90+
testcase_name="gemma-3-4b",
91+
model_name="gemma-3-4b",
92+
),
93+
dict(
94+
testcase_name="gemma-3-12b",
95+
model_name="gemma-3-12b",
96+
),
97+
dict(
98+
testcase_name="gemma-3-27b",
99+
model_name="gemma-3-27b",
100+
),
81101
dict(
82102
testcase_name="llama3-70b",
83103
model_name="llama3-70b",
@@ -118,11 +138,10 @@
118138
testcase_name="qwen2.5-math-1.5b",
119139
model_name="qwen2.5-math-1.5b",
120140
),
121-
# TODO(b/451662153): support deepseek model name parsing
122-
# dict(
123-
# testcase_name="deepseek-r1-distill-qwen-1.5b",
124-
# model_name="deepseek-r1-distill-qwen-1.5b",
125-
# ),
141+
dict(
142+
testcase_name="deepseek-r1-distill-qwen-1.5b",
143+
model_name="deepseek-r1-distill-qwen-1.5b",
144+
),
126145
dict(
127146
testcase_name="qwen3-0.6b",
128147
model_name="qwen3-0.6b",
@@ -151,10 +170,15 @@ def test_obtain_model_params_valid(self, model_name: str):
151170
model.obtain_model_params(model_name)
152171

153172
def test_create_model_dynamically_routing(self, model_name: str):
154-
model_module = model.get_model_module(model_name)
173+
params_module = model.get_model_module(model_name, model.ModelModule.PARAMS)
155174
if not model_name.startswith("gemma"):
156175
# TODO(b/444572467)
157-
getattr(model_module, "create_model_from_safe_tensors")
176+
getattr(params_module, "create_model_from_safe_tensors")
177+
178+
model_lib_module = model.get_model_module(
179+
model_name, model.ModelModule.MODEL
180+
)
181+
getattr(model_lib_module, "ModelConfig")
158182

159183

160184
if __name__ == "__main__":

tunix/cli/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def create_optimizer(
345345
" https://optax.readthedocs.io/en/latest/api/optimizers.html#optimizers"
346346
) from e
347347

348+
logging.info("[SHADI] optimizer_config: %s", optimizer_config)
348349
# Handle learning rate, potentially creating a schedule
349350
learning_rate_val = self._create_learning_rate(
350351
optimizer_config, config_path_info

tunix/cli/debug.ipynb

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 12,
6+
"metadata": {
7+
"id": "Jylqh_larGRI"
8+
},
9+
"outputs": [],
10+
"source": [
11+
"from etils import ecolab\n",
12+
"from flax import nnx\n",
13+
"import jax\n",
14+
"import antlr4\n",
15+
"\n",
16+
"with ecolab.adhoc():\n",
17+
" # import antlr4\n",
18+
" import omegaconf\n",
19+
" #from tunix.cli import peft_main"
20+
]
21+
},
22+
{
23+
"metadata": {
24+
"id": "WmP5N2WNstIi"
25+
},
26+
"cell_type": "code",
27+
"source": [
28+
"pipelines = peft_main.PeftPipeline(\n",
29+
" argv=[\n",
30+
" 'third_party/py/tunix/cli/peft_main',\n",
31+
" '--config=third_party/py/tunix/cli/configs/peft_config.py:gemini_v3_s_1m_128',\n",
32+
" '--workdir=/tmp/peft_debug',\n",
33+
" ]\n",
34+
")\n",
35+
"pipelines.run_peft_trainer()\n"
36+
],
37+
"outputs": [],
38+
"execution_count": null
39+
}
40+
],
41+
"metadata": {
42+
"colab": {
43+
"private_outputs": true
44+
}
45+
},
46+
"nbformat": 4,
47+
"nbformat_minor": 0
48+
}

tunix/cli/utils/model.py

Lines changed: 128 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import enum
1516
import gc
1617
import importlib
1718
import os
@@ -24,53 +25,36 @@
2425
from orbax import checkpoint as ocp
2526
import qwix
2627
from tunix.generate import tokenizer_adapter as tokenizer_lib
27-
from tunix.models.gemma import model as gemma_lib
28-
from tunix.models.gemma import params as gemma_params_lib
29-
from tunix.models.gemma3 import model as gemma3_lib
30-
from tunix.models.gemma3 import params as gemma3_params_lib
31-
from tunix.models.llama3 import model as llama3_lib
32-
from tunix.models.qwen2 import model as qwen2_lib
33-
from tunix.models.qwen3 import model as qwen3_lib
28+
from tunix.models import naming
3429
from tunix.oss import utils as oss_utils
3530
from tunix.rl import reshard
3631

3732

38-
# Map prefixes to the target object containing the methods.
39-
CONFIG_MAP = {
40-
'gemma': gemma_lib.ModelConfig,
41-
'gemma1.1': gemma_lib.ModelConfig,
42-
'gemma2': gemma_lib.ModelConfig,
43-
'gemma3': gemma3_lib.ModelConfig,
44-
'llama3': llama3_lib.ModelConfig,
45-
'llama3.1': llama3_lib.ModelConfig,
46-
'llama3.2': llama3_lib.ModelConfig,
47-
'qwen2.5': qwen2_lib.ModelConfig,
48-
'qwen3': qwen3_lib.ModelConfig,
49-
}
50-
5133
_BASE_MODULE_PATH = 'tunix.models' # pylint: disable=invalid-name
5234

5335

54-
def get_model_module(model_name: str) -> Any:
55-
"""Dynamically imports the parameter module based on the model name."""
56-
# Extract the base model type (e.g., "qwen2", "llama3")
57-
match = re.match(r'^[a-zA-Z0-9]+', model_name)
58-
if not match:
59-
raise ValueError(f'Invalid model name format: {model_name}')
60-
model_type = match.group(0)
61-
# Construct the full module path, e.g.,.path.to.your.models.qwen2.params
62-
if model_name.startswith('gemma1') or model_name.startswith('gemma2'):
63-
model_type = 'gemma'
64-
module_path = f'{_BASE_MODULE_PATH}.{model_type}.params'
36+
class ModelModule(enum.Enum):
37+
"""Specifies the type of model module to import."""
38+
39+
MODEL = 'model'
40+
PARAMS = 'params'
41+
42+
43+
def get_model_module(model_name: str, module_type: ModelModule) -> Any:
44+
"""Dynamically imports a model module (e.g., 'model' or 'params')."""
45+
model_config_category = naming.get_model_config_category(model_name)
46+
module_path = (
47+
f'{_BASE_MODULE_PATH}.{model_config_category}.{module_type.value}'
48+
)
6549
try:
66-
print(f'Attempting to import: {module_path}')
67-
model_module = importlib.import_module(module_path)
68-
return model_module
69-
except ImportError as exc: # Capture the original exeception as 'exc'
50+
logging.info('Attempting to import: %s', module_path)
51+
model_lib_module = importlib.import_module(module_path)
52+
return model_lib_module
53+
except ImportError as exc:
7054
raise ImportError(
71-
f'Could not import module for model type: {model_type} '
72-
f'at path: {module_path}. Please check BASE_MODULE_PATH '
73-
'and ensure the module exists and is a dependency.'
55+
'Could not import module for model config category: '
56+
f'{model_config_category} at path: {module_path}. Please check '
57+
'BASE_MODULE_PATH and ensure the module exists and is a dependency.'
7458
) from exc
7559

7660

@@ -93,22 +77,92 @@ def create_model_dynamically(
9377
ImportError: If the required model module cannot be found.
9478
AttributeError: If create_model_from_safe_tensors is not in the module.
9579
"""
96-
model_module = get_model_module(model_name)
80+
params_module = get_model_module(model_name, ModelModule.PARAMS)
9781

9882
try:
99-
create_fn = getattr(model_module, 'create_model_from_safe_tensors')
83+
create_fn = getattr(params_module, 'create_model_from_safe_tensors')
10084
except AttributeError as exc:
10185
raise AttributeError(
10286
"'create_model_from_safe_tensors' not found in module "
103-
f'{model_module.__name__} for model {model_name}'
87+
f'{params_module.__name__} for model {model_name}'
10488
) from exc
10589

10690
logging.info(
107-
'Calling %s.create_model_from_safe_tensors', model_module.__name__
91+
'Calling %s.create_model_from_safe_tensors', params_module.__name__
10892
)
10993
return create_fn(file_dir=file_dir, config=model_config, mesh=mesh)
11094

11195

96+
def obtain_model_params(model_name: str) -> Any:
97+
"""Dynamically calls a configuration function based on the model_string.
98+
99+
The routing to the correct module/class instance is based on the longest
100+
matching prefix of model_name found in CONFIG_MAP.
101+
Hyphens and dots in the model_name are converted to underscores
102+
to form the function name.
103+
104+
Args:
105+
model_name: The string indicating which model config function to call
106+
(e.g., "gemma-2b", "llama3.1-8b", "qwen2.5-0.5b").
107+
108+
Returns:
109+
The result from calling the dynamically determined function.
110+
111+
Raises:
112+
ValueError: If the model_string doesn't match any known prefix.
113+
AttributeError: If the derived function name does not exist in the target
114+
object.
115+
TypeError: If the attribute found on the target object is not callable.
116+
"""
117+
config_id = naming.get_model_config_id(model_name)
118+
model_lib_module = get_model_module(model_name, ModelModule.MODEL)
119+
target_obj = model_lib_module.ModelConfig
120+
logging.info('[SHADI] model_lib_module: %s', type(model_lib_module))
121+
logging.info('[SHADI] target_obj: %s', type(target_obj))
122+
logging.info('[SHADI] config_id: %s', config_id)
123+
logging.info('[SHADI] model_name: %s', model_name)
124+
if not hasattr(target_obj, config_id):
125+
raise AttributeError(
126+
f"Error: Function '{config_id}' not found on the target object "
127+
f"for model '{model_name}'. Target object type: {type(target_obj)}"
128+
)
129+
130+
method_to_call = getattr(target_obj, config_id)
131+
132+
if not callable(method_to_call):
133+
raise TypeError(
134+
f"Error: Attribute '{config_id}' on the target object is not callable."
135+
)
136+
137+
logging.info(
138+
'Attempting to call: %s() on object of type %s',
139+
config_id,
140+
type(target_obj),
141+
)
142+
return method_to_call()
143+
144+
145+
from tunix.models.gemma import model as gemma_lib
146+
from tunix.models.gemma import params as gemma_params_lib
147+
from tunix.models.gemma3 import model as gemma3_lib
148+
from tunix.models.gemma3 import params as gemma3_params_lib
149+
from tunix.models.llama3 import model as llama3_lib
150+
from tunix.models.qwen2 import model as qwen2_lib
151+
from tunix.models.qwen3 import model as qwen3_lib
152+
# Map prefixes to the target object containing the methods.
153+
CONFIG_MAP = {
154+
'gemma': gemma_lib.ModelConfig,
155+
'gemma1.1': gemma_lib.ModelConfig,
156+
'gemma2': gemma_lib.ModelConfig,
157+
'gemma3': gemma3_lib.ModelConfig,
158+
'llama3': llama3_lib.ModelConfig,
159+
'llama3.1': llama3_lib.ModelConfig,
160+
'llama3.2': llama3_lib.ModelConfig,
161+
'qwen2.5': qwen2_lib.ModelConfig,
162+
'qwen3': qwen3_lib.ModelConfig,
163+
}
164+
165+
112166
def _get_version(model_name: str, matched_prefix: str) -> str:
113167
"""Extracts the version string from the model name."""
114168
if not model_name.startswith(matched_prefix):
@@ -126,7 +180,7 @@ def _get_version(model_name: str, matched_prefix: str) -> str:
126180
return suffix.replace('.', '_').replace('-', '_')
127181

128182

129-
def obtain_model_params(model_name: str) -> Any:
183+
def obtain_model_params_original(model_name: str) -> Any:
130184
"""Dynamically calls a configuration function based on the model_string.
131185
132186
The routing to the correct module/class instance is based on the longest
@@ -173,6 +227,13 @@ def obtain_model_params(model_name: str) -> Any:
173227

174228
function_name = f'{family_snake}_{core_version}'
175229

230+
logging.info('[SHADI] function_name: %s', function_name)
231+
logging.info('[SHADI] matched_prefix: %s', matched_prefix)
232+
logging.info('[SHADI] family_snake: %s', family_snake)
233+
logging.info('[SHADI] core_version: %s', core_version)
234+
logging.info('[SHADI] target_obj: %s', type(target_obj))
235+
logging.info('[SHADI] config_id/function_name: %s', function_name)
236+
logging.info('[SHADI] model_name: %s', model_name)
176237
if not hasattr(target_obj, function_name):
177238
raise AttributeError(
178239
f"Error: Function '{function_name}' not found on the target object "
@@ -197,9 +258,12 @@ def obtain_model_params(model_name: str) -> Any:
197258

198259
def _get_base_model(model_config: dict[str, Any], mesh: jax.sharding.Mesh):
199260
"""Get the base model from the intermediate checkpoint."""
200-
model_params = obtain_model_params(model_config['model_name'])
261+
model_params = obtain_model_params_original(model_config['model_name'])
262+
model_lib_module = get_model_module(
263+
model_config['model_name'], ModelModule.MODEL
264+
)
201265
abs_model: nnx.Module = nnx.eval_shape(
202-
lambda: gemma_lib.Transformer(
266+
lambda: model_lib_module.Transformer(
203267
model_params, rngs=nnx.Rngs(model_config.get('rng_seed', 0))
204268
)
205269
)
@@ -363,11 +427,15 @@ def create_model(
363427
tokenizer_path: str = tokenizer_config['tokenizer_path']
364428
model_name = model_config['model_name']
365429
model_source = model_config['model_source']
366-
430+
logging.info(
431+
'[SHADI] model_source: %s, model_name: %s', model_source, model_name
432+
)
433+
logging.info('[SHADI] model_config: %s', model_config)
367434
if model_name.startswith('gemma3') and model_source == 'gcs':
368435

369436
ckpt_path = model_config['model_id']
370-
model_params = obtain_model_params(model_name)
437+
model_params = obtain_model_params_original(model_name)
438+
gemma3_params_lib = get_model_module(model_name, ModelModule.PARAMS)
371439
model = gemma3_params_lib.create_model_from_checkpoint(
372440
ckpt_path, model_params, mesh
373441
)
@@ -394,10 +462,14 @@ def nnx_conversion():
394462
suffix = '-'.join(model_name.split('-')[1:])
395463
params_path = os.path.join(ckpt_path, suffix)
396464

465+
gemma_params_lib = get_model_module(model_name, ModelModule.PARAMS)
397466
params = gemma_params_lib.load_and_format_params(params_path)
398-
model = gemma_lib.Transformer.from_params(
399-
params, version=_get_model_version_suffix(model_name)
400-
)
467+
gemma_lib = get_model_module(model_name, ModelModule.MODEL)
468+
model_family, version = naming.split(model_name)
469+
# TODO(b/451662153): have gemma2 version handling done better in naming.py
470+
if model_family == 'gemma2':
471+
version = f'2-{version}'
472+
model = gemma_lib.Transformer.from_params(params, version=version)
401473
return _gemma_conversion(model_config, model, params, mesh)
402474

403475
if skip_nnx_conversion:
@@ -419,6 +491,7 @@ def nnx_conversion():
419491
elif model_source == 'huggingface':
420492
# for all other model
421493
oss_utils.hf_pipeline(model_config)
494+
logging.info('[SHADI] inside huggingface')
422495
else:
423496
logging.error(
424497
'Unsupported workflow: from %s to download %s.',
@@ -428,9 +501,13 @@ def nnx_conversion():
428501

429502
if not model_params:
430503
# pick corresponding config based on model version
431-
model_params = obtain_model_params(model_name)
504+
logging.info('[SHADI] inside if not model_params')
505+
model_params = obtain_model_params_original(model_name)
432506

433507
with mesh:
508+
logging.info(
509+
'[SHADI] inside create_model_dynamically with mesh, %s', mesh
510+
)
434511
model = create_model_dynamically(
435512
model_name,
436513
model_config['model_download_path'],

0 commit comments

Comments
 (0)