Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
reference_model_config:
model_name: "gemma2-2b-it"
model_id: "google/gemma-2/flax/gemma2-2b-it"
model_source: "kaggle"
mesh:
shape: "(2,4)"
axis_names: "('fsdp','tp')"
rng_seed: 42
actor_model_config:
lora_config:
rank: 64
alpha: 64.0
module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
mesh:
shape: "(2,4)"
axis_names: "('fsdp','tp')"
rollout_model_config:
mesh:
shape: "(2,4)"
axis_names: "('fsdp','tp')"
tokenizer_config:
tokenizer_type: "sentencepiece"
add_bos: False
dataset_name: "gsm8k"
batch_size: 1
num_batches: 3738
num_test_batches: 100
num_train_epochs: 1
rl_training_config:
actor_optimizer_config:
opt_type: "adamw"
peak_value: 3e-6
schedule_type: "warmup_cosine_decay_schedule"
init_value: 0.0
end_value: 0.0
warmup_ratio: 0.1
warmup_steps: 374
decay_steps: 3738
b1: 0.9
b2: 0.99
weight_decay: 0.1
max_grad_norm: 0.1
eval_every_n_steps: 10
max_steps: 3738
metrics_logging_options:
flush_every_n_steps: 20
checkpointing_options:
save_interval_steps: 500
max_to_keep: 4
profiler_options: {}
rollout_config:
total_generation_steps: 768
max_prompt_length: 256
temperature: 0.9
top_p: 1.0
top_k: 50
rollout_engine: "vanilla"
offload_to_cpu: False
grpo_config:
num_generations: 2
num_iterations: 1
beta: 0.08
epsilon: 0.2
reward_functions:
- "tunix/cli/reward_fn/gsm8k.py"
46 changes: 2 additions & 44 deletions examples/rl/grpo/gsm8k/run_gemma2_2b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,58 +40,16 @@ echo "Rounded warmup steps: $warmup_steps"

python3 -m tunix.cli.grpo_main \
base_config.yaml \
reference_model_config.model_name="gemma2-2b-it" \
reference_model_config.model_id="google/gemma-2/flax/gemma2-2b-it" \
reference_model_config.model_source="kaggle" \
override_config_file=examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml \
reference_model_config.model_download_path="/tmp/models/gemma2-2b" \
reference_model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/1" \
reference_model_config.mesh.shape="(2,4)" \
reference_model_config.mesh.axis_names="('fsdp','tp')" \
reference_model_config.rng_seed=42 \
actor_model_config.lora_config.rank=64 \
actor_model_config.lora_config.alpha=64.0 \
actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \
actor_model_config.mesh.shape="(2,4)" \
actor_model_config.mesh.axis_names="('fsdp','tp')" \
rollout_model_config.mesh.shape="(2,4)" \
rollout_model_config.mesh.axis_names="('fsdp','tp')" \
tokenizer_config.tokenizer_path="/tmp/models/gemma2-2b/models/google/gemma-2/flax/gemma2-2b-it/1/tokenizer.model" \
tokenizer_config.tokenizer_type="sentencepiece" \
tokenizer_config.add_bos=false \
dataset_name="gsm8k" \
batch_size=$batch_size \
num_batches=$num_batches \
num_test_batches=100 \
num_train_epochs=$num_train_epochs \
rl_training_config.actor_optimizer_config.opt_type="adamw" \
rl_training_config.actor_optimizer_config.peak_value=3e-6 \
rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \
rl_training_config.actor_optimizer_config.init_value=0.0 \
rl_training_config.actor_optimizer_config.end_value=0.0 \
rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \
rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \
rl_training_config.actor_optimizer_config.decay_steps=$max_steps \
rl_training_config.actor_optimizer_config.b1=0.9 \
rl_training_config.actor_optimizer_config.b2=0.99 \
rl_training_config.actor_optimizer_config.weight_decay=0.1 \
rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \
rl_training_config.eval_every_n_steps=10 \
rl_training_config.max_steps=$max_steps \
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo" \
rl_training_config.metrics_logging_options.flush_every_n_steps=20 \
rl_training_config.checkpointing_options.save_interval_steps=500 \
rl_training_config.checkpointing_options.max_to_keep=4 \
rl_training_config.profiler_options={} \
rollout_config.total_generation_steps=768 \
rollout_config.max_prompt_length=256 \
rollout_config.temperature=0.9 \
rollout_config.top_p=1.0 \
rollout_config.top_k=50 \
rollout_engine="vanilla" \
offload_to_cpu=false \
grpo_config.num_generations=2 \
grpo_config.num_iterations=1 \
grpo_config.beta=0.08 \
grpo_config.epsilon=0.2 \
reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo"

13 changes: 11 additions & 2 deletions tests/cli/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def test_override_training_config_simple(self):
"training_config.eval_every_n_steps=10",
]
hp = config.initialize(argv)

config_dict = cast(Dict[str, Any], hp.config)

self.assertEqual(config_dict["training_config"]["max_steps"], 150)
self.assertEqual(
config_dict["training_config"]["data_sharding_axis"], ["fsdp", "dp"]
Expand Down Expand Up @@ -464,6 +464,15 @@ def test_obtain_reward_fn_relative_path(self):
finally:
os.chdir(original_cwd)

def test_obtain_reward_fn_file_not_found(self):
hp = self.initialize_config(
["reward_functions=['tunix/cli/reward_fn/non_existent.py']"]
)
with self.assertRaisesRegex(
ImportError, "Failed to execute module non_existent"
):
hp.obtain_reward_fn()


if __name__ == "__main__":
if "HF_TOKEN" not in os.environ:
Expand Down
90 changes: 77 additions & 13 deletions tunix/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@

# Define a prefix for environment variables that can override YAML keys
_TUNIX_PREFIX = "T_"
_SUPPORTED_MODEL_SOURCES = [
"kaggle",
"huggingface",
"gcs",
"",
]


def yaml_key_to_env_key(s: str) -> str:
Expand Down Expand Up @@ -81,26 +87,85 @@ def get_project_root() -> Path:
return Path.cwd()


def _dict_to_cli_args(d: Any, parent_key: str = "", sep: str = ".") -> List[str]:
"""Converts a dictionary to a list of CLI arguments."""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, collections.abc.Mapping) or isinstance(
v, omegaconf.DictConfig
):
if v:
items.extend(_dict_to_cli_args(v, new_key, sep=sep))
else:
items.append(f"{new_key}={{}}")
else:
items.append(f"{new_key}={v}")
return items


class HyperParameters:
"""Loads, merges, overrides, validates, and prepares the configuration for pipeline execution."""
"""Loads, merges, overrides, validates, and prepares the configuration for pipeline execution.

Configurations are merged from multiple sources. The following order of
precedence applies, with later sources overriding earlier ones:
1. Base Config File: The first positional argument, path to a YAML config
file.
2. Config File Override: An optional `override_config_file=/path/to/file.yaml`
argument. Values in this file override values in the base config file.
3. CLI Arguments: `key=value` pairs provided as arguments override values
from both the base config and the config file override.

Environment variables prefixed with `T_` can also be used to set parameters,
but it is an error to set a parameter via both an environment variable and
a command-line argument or override file.
"""

def __init__(self, argv: list[str], **kwargs):
# Use omegaconf.OmegaConf.from_cli to capture CLI arguments.

dotenv.load_dotenv()
raw_keys = collections.OrderedDict()
config_name = argv[1]
raw_data_from_yaml = self._load_config_from_yaml(config_name)

if len(argv) < 2 or "=" in argv[1]:
raise ValueError(
"The first argument must be a path to a base config file."
)

base_config_file = argv[1]
# Handle relative paths used in examples
if base_config_file == "base_config.yaml":
base_config_file = pathlib.Path(__file__).parent / base_config_file
raw_data_from_yaml = self._load_config_from_yaml(base_config_file)
self._validate_env_variable(raw_data_from_yaml)

override_config_file = None
overrides = []
for arg in argv[2:]:
if arg.startswith("override_config_file="):
if override_config_file is not None:
raise ValueError("Only one override_config_file argument is allowed.")
override_config_file = arg.split("=", 1)[1]
else:
overrides.append(arg)

self.replace_keys = {
"lora_config",
"training_config",
"optimizer_config",
"profiler_options",
"rl_training_config",
}

file_overrides = []
if override_config_file:
next_conf = self._load_config_from_yaml(override_config_file)
file_overrides.extend(_dict_to_cli_args(next_conf))

overrides = file_overrides + overrides

keys_from_env_and_command_line = self._update_from_env_and_command_line(
raw_keys, raw_data_from_yaml, argv, **kwargs
raw_keys, raw_data_from_yaml, overrides, **kwargs
)
logging.info(
"Updating keys from env and command line: %s",
Expand Down Expand Up @@ -171,10 +236,10 @@ def _validate_model_source(self, raw_keys: collections.OrderedDict[str, Any]):
model_source = model_config.get("model_source")
intermediate_ckpt = model_config.get("intermediate_ckpt_dir")

if model_source not in ["kaggle", "huggingface", "gcs", ""]:
if model_source not in _SUPPORTED_MODEL_SOURCES:
raise ValueError(
f"Invalid model_source: {model_source}. Must be 'kaggle',"
" 'huggingface', 'gcs' or ''."
f"Invalid model_source: {model_source}. Must be one of"
f" {_SUPPORTED_MODEL_SOURCES}."
)

if model_source in ["kaggle", "huggingface"] and not intermediate_ckpt:
Expand Down Expand Up @@ -541,12 +606,12 @@ def _update_from_env_and_command_line(
self,
raw_keys: collections.OrderedDict[str, Any],
raw_data_from_yaml: dict[str, Any],
argv: list[str],
overrides: list[str],
**kwargs,
):
"""Update the configuration from command line."""

cli_cfg = omegaconf.OmegaConf.from_cli(argv[2:])
cli_cfg = omegaconf.OmegaConf.from_cli(overrides)

raw_data_from_cmd_line = omegaconf.OmegaConf.to_container(
cli_cfg, resolve=True
Expand Down Expand Up @@ -694,14 +759,13 @@ def _validate_env_variable(self, raw_data_from_yaml):
f"We received env {environment_var} but it isn't all uppercase."
)

def _load_config_from_yaml(self, config_name: str):
def _load_config_from_yaml(self, config_path: str):
"""Try Loading and validate the configuration from the YAML file."""

path = pathlib.Path(__file__).parent / config_name
try:
config_oconf = omegaconf.OmegaConf.load(path)
config_oconf = omegaconf.OmegaConf.load(config_path)
except FileNotFoundError as e:
raise ValueError(f"Config {config_name} not found.") from e
raise ValueError(f"Config {config_path} not found.") from e

return config_oconf

Expand Down
21 changes: 21 additions & 0 deletions tunix/cli/grpo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

"""Main entry point for GRPO training."""
from absl import app
from absl import flags
from absl import logging
import jax
from tunix.cli import config
from tunix.cli.utils import model as model_lib
from tunix.examples.data import math_dataset as data_lib
Expand All @@ -22,6 +25,10 @@
from tunix.rl.grpo.grpo_learner import GrpoConfig
from tunix.rl.rollout import base_rollout

_PATHWAYS_BNS = flags.DEFINE_string(
"pathways_bns", None, "BNS address of the Pathways server."
)


class GrpoPipeline(config.HyperParameters):
"""Class for running the GRPO trainer."""
Expand Down Expand Up @@ -132,8 +139,22 @@ def run_grpo_trainer(self):
grpo_trainer.train(dataset)


def _setup_jax_pathways(pathways_bns: str):
"""Sets up Jax with Pathways."""
flags.FLAGS.pathways_ifrt = True
jax.config.update("jax_xla_backend", "pathways")
jax.config.update("jax_backend_target", pathways_bns)


def main(argv, **kwargs):
if _PATHWAYS_BNS.value:
_setup_jax_pathways(_PATHWAYS_BNS.value)
pipeline = GrpoPipeline(argv, **kwargs)
logging.info(
"--- Launching GRPO pipeline with following config ---\n"
"%s\n--------------------------",
pipeline.config,
)
pipeline.run_grpo_trainer()


Expand Down