Skip to content

Commit bd25164

Browse files
s-noghabiThe tunix Authors
authored andcommitted
[DNS] debug failure
PiperOrigin-RevId: 834954246
1 parent 59bcc89 commit bd25164

File tree

7 files changed

+151
-106
lines changed

7 files changed

+151
-106
lines changed

tests/cli/utils/model_test.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@
7878
testcase_name="gemma3-27b",
7979
model_name="gemma3-27b",
8080
),
81+
dict(
82+
testcase_name="gemma-3-270m",
83+
model_name="gemma-3-270m",
84+
),
85+
dict(
86+
testcase_name="gemma-3-1b",
87+
model_name="gemma-3-1b",
88+
),
89+
dict(
90+
testcase_name="gemma-3-4b",
91+
model_name="gemma-3-4b",
92+
),
93+
dict(
94+
testcase_name="gemma-3-12b",
95+
model_name="gemma-3-12b",
96+
),
97+
dict(
98+
testcase_name="gemma-3-27b",
99+
model_name="gemma-3-27b",
100+
),
81101
dict(
82102
testcase_name="llama3-70b",
83103
model_name="llama3-70b",
@@ -118,11 +138,10 @@
118138
testcase_name="qwen2.5-math-1.5b",
119139
model_name="qwen2.5-math-1.5b",
120140
),
121-
# TODO(b/451662153): support deepseek model name parsing
122-
# dict(
123-
# testcase_name="deepseek-r1-distill-qwen-1.5b",
124-
# model_name="deepseek-r1-distill-qwen-1.5b",
125-
# ),
141+
dict(
142+
testcase_name="deepseek-r1-distill-qwen-1.5b",
143+
model_name="deepseek-r1-distill-qwen-1.5b",
144+
),
126145
dict(
127146
testcase_name="qwen3-0.6b",
128147
model_name="qwen3-0.6b",
@@ -151,10 +170,15 @@ def test_obtain_model_params_valid(self, model_name: str):
151170
model.obtain_model_params(model_name)
152171

153172
def test_create_model_dynamically_routing(self, model_name: str):
154-
model_module = model.get_model_module(model_name)
173+
params_module = model.get_model_module(model_name, model.ModelModule.PARAMS)
155174
if not model_name.startswith("gemma"):
156175
# TODO(b/444572467)
157-
getattr(model_module, "create_model_from_safe_tensors")
176+
getattr(params_module, "create_model_from_safe_tensors")
177+
178+
model_lib_module = model.get_model_module(
179+
model_name, model.ModelModule.MODEL
180+
)
181+
getattr(model_lib_module, "ModelConfig")
158182

159183

160184
if __name__ == "__main__":

tunix/cli/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def create_optimizer(
345345
" https://optax.readthedocs.io/en/latest/api/optimizers.html#optimizers"
346346
) from e
347347

348+
logging.info("[SHADI] optimizer_config: %s", optimizer_config)
348349
# Handle learning rate, potentially creating a schedule
349350
learning_rate_val = self._create_learning_rate(
350351
optimizer_config, config_path_info

tunix/cli/debug.ipynb

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 12,
6+
"metadata": {
7+
"id": "Jylqh_larGRI"
8+
},
9+
"outputs": [],
10+
"source": [
11+
"from etils import ecolab\n",
12+
"from flax import nnx\n",
13+
"import jax\n",
14+
"import antlr4\n",
15+
"\n",
16+
"with ecolab.adhoc():\n",
17+
" # import antlr4\n",
18+
" import omegaconf\n",
19+
" #from tunix.cli import peft_main"
20+
]
21+
},
22+
{
23+
"metadata": {
24+
"id": "WmP5N2WNstIi"
25+
},
26+
"cell_type": "code",
27+
"source": [
28+
"pipelines = peft_main.PeftPipeline(\n",
29+
" argv=[\n",
30+
" 'third_party/py/tunix/cli/peft_main',\n",
31+
" '--config=third_party/py/tunix/cli/configs/peft_config.py:gemini_v3_s_1m_128',\n",
32+
" '--workdir=/tmp/peft_debug',\n",
33+
" ]\n",
34+
")\n",
35+
"pipelines.run_peft_trainer()\n"
36+
],
37+
"outputs": [],
38+
"execution_count": null
39+
}
40+
],
41+
"metadata": {
42+
"colab": {
43+
"private_outputs": true
44+
}
45+
},
46+
"nbformat": 4,
47+
"nbformat_minor": 0
48+
}

tunix/cli/peft_main.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def gen_model_input_fn(x: peft_trainer.TrainingInput):
6666
}
6767

6868
my_gen_model_input_fn = gen_model_input_fn
69-
trainer = trainer.with_gen_model_input_fn(my_gen_model_input_fn)
70-
7169
train_ds, eval_ds = data_lib.create_datasets(
7270
dataset_name=self.config['dataset_name'],
7371
global_batch_size=self.config['batch_size'],
@@ -77,6 +75,14 @@ def gen_model_input_fn(x: peft_trainer.TrainingInput):
7775
)
7876

7977
with mesh:
78+
trainer = peft_trainer.PeftTrainer(
79+
model,
80+
optimizer,
81+
peft_trainer.TrainingConfig(
82+
**self.obtain_training_config_dict('training_config')
83+
),
84+
)
85+
trainer = trainer.with_gen_model_input_fn(my_gen_model_input_fn)
8086
trainer.train(train_ds, eval_ds)
8187

8288

0 commit comments

Comments
 (0)