Skip to content

Commit a34d0aa

Browse files
s-noghabiThe tunix Authors
authored andcommitted
update gemma2 run script to use the config yaml
PiperOrigin-RevId: 840872977
1 parent 3e605d8 commit a34d0aa

File tree

5 files changed

+111
-49
lines changed

5 files changed

+111
-49
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
reference_model_config:
2+
model_name: "gemma2-2b-it"
3+
model_id: "google/gemma-2/flax/gemma2-2b-it"
4+
model_source: "kaggle"
5+
mesh:
6+
shape: "(2,4)"
7+
axis_names: "('fsdp','tp')"
8+
rng_seed: 42
9+
actor_model_config:
10+
lora_config:
11+
rank: 64
12+
alpha: 64.0
13+
module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
14+
mesh:
15+
shape: "(2,4)"
16+
axis_names: "('fsdp','tp')"
17+
rollout_model_config:
18+
mesh:
19+
shape: "(2,4)"
20+
axis_names: "('fsdp','tp')"
21+
tokenizer_config:
22+
tokenizer_type: "sentencepiece"
23+
add_bos: False
24+
dataset_name: "gsm8k"
25+
tfds_download: False
26+
batch_size: 1
27+
num_batches: 3738
28+
num_train_epochs: 1
29+
train_fraction: 1.0
30+
num_test_batches: 100
31+
rl_training_config:
32+
actor_optimizer_config:
33+
opt_type: "adamw"
34+
peak_value: 3e-6
35+
schedule_type: "warmup_cosine_decay_schedule"
36+
init_value: 0.0
37+
end_value: 0.0
38+
b1: 0.9
39+
b2: 0.99
40+
weight_decay: 0.1
41+
max_grad_norm: 0.1
42+
warmup_ratio: 0.1
43+
warmup_steps: 374
44+
decay_steps: 3738
45+
eval_every_n_steps: 10
46+
metrics_logging_options:
47+
flush_every_n_steps: 20
48+
checkpointing_options:
49+
save_interval_steps: 500
50+
max_to_keep: 4
51+
profiler_options: {}
52+
max_steps: 3738
53+
rollout_config:
54+
total_generation_steps: 768
55+
max_prompt_length: 256
56+
temperature: 0.9
57+
top_p: 1.0
58+
top_k: 50
59+
rollout_engine: "vanilla"
60+
offload_to_cpu: False
61+
grpo_config:
62+
num_generations: 2
63+
num_iterations: 1
64+
beta: 0.08
65+
epsilon: 0.2
66+
reward_functions:
67+
- "tunix/cli/reward_fn/gsm8k.py"

examples/rl/grpo/gsm8k/run_gemma2_2b.sh

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -40,58 +40,16 @@ echo "Rounded warmup steps: $warmup_steps"
4040

4141
python3 -m tunix.cli.grpo_main \
4242
base_config.yaml \
43-
reference_model_config.model_name="gemma2-2b-it" \
44-
reference_model_config.model_id="google/gemma-2/flax/gemma2-2b-it" \
45-
reference_model_config.model_source="kaggle" \
43+
configs/gemma2_2b.yaml \
4644
reference_model_config.model_download_path="/tmp/models/gemma2-2b" \
4745
reference_model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/1" \
48-
reference_model_config.mesh.shape="(2,4)" \
49-
reference_model_config.mesh.axis_names="('fsdp','tp')" \
50-
reference_model_config.rng_seed=42 \
51-
actor_model_config.lora_config.rank=64 \
52-
actor_model_config.lora_config.alpha=64.0 \
53-
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
54-
actor_model_config.mesh.shape="(2,4)" \
55-
actor_model_config.mesh.axis_names="('fsdp','tp')" \
56-
rollout_model_config.mesh.shape="(2,4)" \
57-
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
5846
tokenizer_config.tokenizer_path="/tmp/models/gemma2-2b/models/google/gemma-2/flax/gemma2-2b-it/1/tokenizer.model" \
59-
tokenizer_config.tokenizer_type="sentencepiece" \
60-
tokenizer_config.add_bos=false \
61-
dataset_name="gsm8k" \
6247
batch_size=$batch_size \
6348
num_batches=$num_batches \
64-
num_test_batches=100 \
6549
num_train_epochs=$num_train_epochs \
66-
rl_training_config.actor_optimizer_config.opt_type="adamw" \
67-
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
68-
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
69-
rl_training_config.actor_optimizer_config.init_value=0.0 \
70-
rl_training_config.actor_optimizer_config.end_value=0.0 \
7150
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
7251
rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
7352
rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
74-
rl_training_config.actor_optimizer_config.b1=0.9 \
75-
rl_training_config.actor_optimizer_config.b2=0.99 \
76-
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
77-
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
78-
rl_training_config.eval_every_n_steps=10 \
7953
rl_training_config.max_steps=$max_steps \
80-
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo" \
81-
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
82-
rl_training_config.checkpointing_options.save_interval_steps=500 \
83-
rl_training_config.checkpointing_options.max_to_keep=4 \
84-
rl_training_config.profiler_options={} \
85-
rollout_config.total_generation_steps=768 \
86-
rollout_config.max_prompt_length=256 \
87-
rollout_config.temperature=0.9 \
88-
rollout_config.top_p=1.0 \
89-
rollout_config.top_k=50 \
90-
rollout_engine="vanilla" \
91-
offload_to_cpu=false \
92-
grpo_config.num_generations=2 \
93-
grpo_config.num_iterations=1 \
94-
grpo_config.beta=0.08 \
95-
grpo_config.epsilon=0.2 \
96-
reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
54+
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo"
9755

tests/cli/config_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def test_override_training_config_simple(self):
108108
"training_config.eval_every_n_steps=10",
109109
]
110110
hp = config.initialize(argv)
111-
111+
112112
config_dict = cast(Dict[str, Any], hp.config)
113-
113+
114114
self.assertEqual(config_dict["training_config"]["max_steps"], 150)
115115
self.assertEqual(
116116
config_dict["training_config"]["data_sharding_axis"], ["fsdp", "dp"]
@@ -464,6 +464,15 @@ def test_obtain_reward_fn_relative_path(self):
464464
finally:
465465
os.chdir(original_cwd)
466466

467+
def test_obtain_reward_fn_file_not_found(self):
468+
hp = self.initialize_config(
469+
["reward_functions=['tunix/cli/reward_fn/non_existent.py']"]
470+
)
471+
with self.assertRaisesRegex(
472+
ImportError, "Failed to execute module non_existent"
473+
):
474+
hp.obtain_reward_fn()
475+
467476

468477
if __name__ == "__main__":
469478
if "HF_TOKEN" not in os.environ:

tunix/cli/config.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636

3737
# Define a prefix for environment variables that can override YAML keys
3838
_TUNIX_PREFIX = "T_"
39+
_SUPPORTED_MODEL_SOURCES = [
40+
"kaggle",
41+
"huggingface",
42+
"gcs",
43+
"",
44+
]
3945

4046

4147
def yaml_key_to_env_key(s: str) -> str:
@@ -171,10 +177,10 @@ def _validate_model_source(self, raw_keys: collections.OrderedDict[str, Any]):
171177
model_source = model_config.get("model_source")
172178
intermediate_ckpt = model_config.get("intermediate_ckpt_dir")
173179

174-
if model_source not in ["kaggle", "huggingface", "gcs", ""]:
180+
if model_source not in _SUPPORTED_MODEL_SOURCES:
175181
raise ValueError(
176-
f"Invalid model_source: {model_source}. Must be 'kaggle',"
177-
" 'huggingface', 'gcs' or ''."
182+
f"Invalid model_source: {model_source}. Must be one of"
183+
f" {_SUPPORTED_MODEL_SOURCES}."
178184
)
179185

180186
if model_source in ["kaggle", "huggingface"] and not intermediate_ckpt:

tunix/cli/grpo_main.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
"""Main entry point for GRPO training."""
1616
from absl import app
17+
from absl import flags
18+
from absl import logging
19+
import jax
20+
import omegaconf
1721
from tunix.cli import config
1822
from tunix.cli.utils import model as model_lib
1923
from tunix.examples.data import math_dataset as data_lib
@@ -22,6 +26,10 @@
2226
from tunix.rl.grpo.grpo_learner import GrpoConfig
2327
from tunix.rl.rollout import base_rollout
2428

29+
_PATHWAYS_BNS = flags.DEFINE_string(
30+
"pathways_bns", None, "BNS address of the Pathways server."
31+
)
32+
2533

2634
class GrpoPipeline(config.HyperParameters):
2735
"""Class for running the GRPO trainer."""
@@ -132,8 +140,22 @@ def run_grpo_trainer(self):
132140
grpo_trainer.train(dataset)
133141

134142

143+
def _setup_jax_pathways(pathways_bns: str):
144+
"""Sets up Jax with Pathways."""
145+
flags.FLAGS.pathways_ifrt = True
146+
jax.config.update("jax_xla_backend", "pathways")
147+
jax.config.update("jax_backend_target", pathways_bns)
148+
149+
135150
def main(argv, **kwargs):
151+
if _PATHWAYS_BNS.value:
152+
_setup_jax_pathways(_PATHWAYS_BNS.value)
136153
pipeline = GrpoPipeline(argv, **kwargs)
154+
logging.info(
155+
"--- Launching GRPO pipeline with following config ---\n"
156+
"%s\n--------------------------",
157+
omegaconf.OmegaConf.to_yaml(omegaconf.OmegaConf.create(pipeline.config)),
158+
)
137159
pipeline.run_grpo_trainer()
138160

139161

0 commit comments

Comments
 (0)