diff --git a/examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml b/examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml new file mode 100644 index 00000000..372c1198 --- /dev/null +++ b/examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml @@ -0,0 +1,65 @@ +reference_model_config: + model_name: "gemma2-2b-it" + model_id: "google/gemma-2/flax/gemma2-2b-it" + model_source: "kaggle" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" + rng_seed: 42 +actor_model_config: + lora_config: + rank: 64 + alpha: 64.0 + module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" +rollout_model_config: + mesh: + shape: "(2,4)" + axis_names: "('fsdp','tp')" +tokenizer_config: + tokenizer_type: "sentencepiece" + add_bos: False +dataset_name: "gsm8k" +batch_size: 1 +num_batches: 3738 +num_test_batches: 100 +num_train_epochs: 1 +rl_training_config: + actor_optimizer_config: + opt_type: "adamw" + peak_value: 3e-6 + schedule_type: "warmup_cosine_decay_schedule" + init_value: 0.0 + end_value: 0.0 + warmup_ratio: 0.1 + warmup_steps: 374 + decay_steps: 3738 + b1: 0.9 + b2: 0.99 + weight_decay: 0.1 + max_grad_norm: 0.1 + eval_every_n_steps: 10 + max_steps: 3738 + metrics_logging_options: + flush_every_n_steps: 20 + checkpointing_options: + save_interval_steps: 500 + max_to_keep: 4 + profiler_options: {} +rollout_config: + total_generation_steps: 768 + max_prompt_length: 256 + temperature: 0.9 + top_p: 1.0 + top_k: 50 +rollout_engine: "vanilla" +offload_to_cpu: False +grpo_config: + num_generations: 2 + num_iterations: 1 + beta: 0.08 + epsilon: 0.2 +reward_functions: + - "tunix/cli/reward_fn/gsm8k.py" diff --git a/examples/rl/grpo/gsm8k/run_gemma2_2b.sh b/examples/rl/grpo/gsm8k/run_gemma2_2b.sh index 3230de29..6e1b1ce8 100755 --- a/examples/rl/grpo/gsm8k/run_gemma2_2b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma2_2b.sh @@ -40,58 +40,16 @@ echo "Rounded warmup steps: $warmup_steps" python3 -m tunix.cli.grpo_main \ base_config.yaml \ - reference_model_config.model_name="gemma2-2b-it" \ - reference_model_config.model_id="google/gemma-2/flax/gemma2-2b-it" \ - reference_model_config.model_source="kaggle" \ + override_config_file=examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml \ reference_model_config.model_download_path="/tmp/models/gemma2-2b" \ reference_model_config.intermediate_ckpt_dir="/tmp/intermediate_ckpt/1" \ - reference_model_config.mesh.shape="(2,4)" \ - reference_model_config.mesh.axis_names="('fsdp','tp')" \ - reference_model_config.rng_seed=42 \ - actor_model_config.lora_config.rank=64 \ - actor_model_config.lora_config.alpha=64.0 \ - actor_model_config.lora_config.module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum" \ - actor_model_config.mesh.shape="(2,4)" \ - actor_model_config.mesh.axis_names="('fsdp','tp')" \ - rollout_model_config.mesh.shape="(2,4)" \ - rollout_model_config.mesh.axis_names="('fsdp','tp')" \ tokenizer_config.tokenizer_path="/tmp/models/gemma2-2b/models/google/gemma-2/flax/gemma2-2b-it/1/tokenizer.model" \ - tokenizer_config.tokenizer_type="sentencepiece" \ - tokenizer_config.add_bos=false \ - dataset_name="gsm8k" \ batch_size=$batch_size \ num_batches=$num_batches \ - num_test_batches=100 \ num_train_epochs=$num_train_epochs \ - rl_training_config.actor_optimizer_config.opt_type="adamw" \ - rl_training_config.actor_optimizer_config.peak_value=3e-6 \ - rl_training_config.actor_optimizer_config.schedule_type="warmup_cosine_decay_schedule" \ - rl_training_config.actor_optimizer_config.init_value=0.0 \ - rl_training_config.actor_optimizer_config.end_value=0.0 \ rl_training_config.actor_optimizer_config.warmup_ratio=$warmup_ratio \ rl_training_config.actor_optimizer_config.warmup_steps=$warmup_steps \ rl_training_config.actor_optimizer_config.decay_steps=$max_steps \ - rl_training_config.actor_optimizer_config.b1=0.9 \ - rl_training_config.actor_optimizer_config.b2=0.99 \ - rl_training_config.actor_optimizer_config.weight_decay=0.1 \ - rl_training_config.actor_optimizer_config.max_grad_norm=0.1 \ - rl_training_config.eval_every_n_steps=10 \ rl_training_config.max_steps=$max_steps \ - rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo" \ - rl_training_config.metrics_logging_options.flush_every_n_steps=20 \ - rl_training_config.checkpointing_options.save_interval_steps=500 \ - rl_training_config.checkpointing_options.max_to_keep=4 \ - rl_training_config.profiler_options={} \ - rollout_config.total_generation_steps=768 \ - rollout_config.max_prompt_length=256 \ - rollout_config.temperature=0.9 \ - rollout_config.top_p=1.0 \ - rollout_config.top_k=50 \ - rollout_engine="vanilla" \ - offload_to_cpu=false \ - grpo_config.num_generations=2 \ - grpo_config.num_iterations=1 \ - grpo_config.beta=0.08 \ - grpo_config.epsilon=0.2 \ - reward_functions="['tunix/cli/reward_fn/gsm8k.py']" + rl_training_config.metrics_logging_options.log_dir="/tmp/tensorboard/grpo" diff --git a/tests/cli/config_test.py b/tests/cli/config_test.py index d8dbb1c3..aa21b1ee 100644 --- a/tests/cli/config_test.py +++ b/tests/cli/config_test.py @@ -108,9 +108,9 @@ def test_override_training_config_simple(self): "training_config.eval_every_n_steps=10", ] hp = config.initialize(argv) - + config_dict = cast(Dict[str, Any], hp.config) - + self.assertEqual(config_dict["training_config"]["max_steps"], 150) self.assertEqual( config_dict["training_config"]["data_sharding_axis"], ["fsdp", "dp"] @@ -464,6 +464,15 @@ def test_obtain_reward_fn_relative_path(self): finally: os.chdir(original_cwd) + def test_obtain_reward_fn_file_not_found(self): + hp = self.initialize_config( + ["reward_functions=['tunix/cli/reward_fn/non_existent.py']"] + ) + with self.assertRaisesRegex( + ImportError, "Failed to execute module non_existent" + ): + hp.obtain_reward_fn() + if __name__ == "__main__": if "HF_TOKEN" not in os.environ: diff --git a/tunix/cli/config.py b/tunix/cli/config.py index cf1a8f31..760c95c1 100644 --- a/tunix/cli/config.py +++ b/tunix/cli/config.py @@ -36,6 +36,12 @@ # Define a prefix for environment variables that can override YAML keys _TUNIX_PREFIX = "T_" +_SUPPORTED_MODEL_SOURCES = [ + "kaggle", + "huggingface", + "gcs", + "", +] def yaml_key_to_env_key(s: str) -> str: @@ -81,17 +87,68 @@ def get_project_root() -> Path: return Path.cwd() +def _dict_to_cli_args(d: Any, parent_key: str = "", sep: str = ".") -> List[str]: + """Converts a dictionary to a list of CLI arguments.""" + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, collections.abc.Mapping) or isinstance( + v, omegaconf.DictConfig + ): + if v: + items.extend(_dict_to_cli_args(v, new_key, sep=sep)) + else: + items.append(f"{new_key}={{}}") + else: + items.append(f"{new_key}={v}") + return items + + class HyperParameters: - """Loads, merges, overrides, validates, and prepares the configuration for pipeline execution.""" + """Loads, merges, overrides, validates, and prepares the configuration for pipeline execution. + + Configurations are merged from multiple sources. The following order of + precedence applies, with later sources overriding earlier ones: + 1. Base Config File: The first positional argument, path to a YAML config + file. + 2. Config File Override: An optional `override_config_file=/path/to/file.yaml` + argument. Values in this file override values in the base config file. + 3. CLI Arguments: `key=value` pairs provided as arguments override values + from both the base config and the config file override. + + Environment variables prefixed with `T_` can also be used to set parameters, + but it is an error to set a parameter via both an environment variable and + a command-line argument or override file. + """ def __init__(self, argv: list[str], **kwargs): # Use omegaconf.OmegaConf.from_cli to capture CLI arguments. dotenv.load_dotenv() raw_keys = collections.OrderedDict() - config_name = argv[1] - raw_data_from_yaml = self._load_config_from_yaml(config_name) + + if len(argv) < 2 or "=" in argv[1]: + raise ValueError( + "The first argument must be a path to a base config file." + ) + + base_config_file = argv[1] + # Handle relative paths used in examples + if base_config_file == "base_config.yaml": + base_config_file = pathlib.Path(__file__).parent / base_config_file + raw_data_from_yaml = self._load_config_from_yaml(base_config_file) self._validate_env_variable(raw_data_from_yaml) + + override_config_file = None + overrides = [] + for arg in argv[2:]: + if arg.startswith("override_config_file="): + if override_config_file is not None: + raise ValueError("Only one override_config_file argument is allowed.") + override_config_file = arg.split("=", 1)[1] + else: + overrides.append(arg) + self.replace_keys = { "lora_config", "training_config", @@ -99,8 +156,16 @@ def __init__(self, argv: list[str], **kwargs): "profiler_options", "rl_training_config", } + + file_overrides = [] + if override_config_file: + next_conf = self._load_config_from_yaml(override_config_file) + file_overrides.extend(_dict_to_cli_args(next_conf)) + + overrides = file_overrides + overrides + keys_from_env_and_command_line = self._update_from_env_and_command_line( - raw_keys, raw_data_from_yaml, argv, **kwargs + raw_keys, raw_data_from_yaml, overrides, **kwargs ) logging.info( "Updating keys from env and command line: %s", @@ -171,10 +236,10 @@ def _validate_model_source(self, raw_keys: collections.OrderedDict[str, Any]): model_source = model_config.get("model_source") intermediate_ckpt = model_config.get("intermediate_ckpt_dir") - if model_source not in ["kaggle", "huggingface", "gcs", ""]: + if model_source not in _SUPPORTED_MODEL_SOURCES: raise ValueError( - f"Invalid model_source: {model_source}. Must be 'kaggle'," - " 'huggingface', 'gcs' or ''." + f"Invalid model_source: {model_source}. Must be one of" + f" {_SUPPORTED_MODEL_SOURCES}." ) if model_source in ["kaggle", "huggingface"] and not intermediate_ckpt: @@ -541,12 +606,12 @@ def _update_from_env_and_command_line( self, raw_keys: collections.OrderedDict[str, Any], raw_data_from_yaml: dict[str, Any], - argv: list[str], + overrides: list[str], **kwargs, ): """Update the configuration from command line.""" - cli_cfg = omegaconf.OmegaConf.from_cli(argv[2:]) + cli_cfg = omegaconf.OmegaConf.from_cli(overrides) raw_data_from_cmd_line = omegaconf.OmegaConf.to_container( cli_cfg, resolve=True @@ -694,14 +759,13 @@ def _validate_env_variable(self, raw_data_from_yaml): f"We received env {environment_var} but it isn't all uppercase." ) - def _load_config_from_yaml(self, config_name: str): + def _load_config_from_yaml(self, config_path: str): """Try Loading and validate the configuration from the YAML file.""" - path = pathlib.Path(__file__).parent / config_name try: - config_oconf = omegaconf.OmegaConf.load(path) + config_oconf = omegaconf.OmegaConf.load(config_path) except FileNotFoundError as e: - raise ValueError(f"Config {config_name} not found.") from e + raise ValueError(f"Config {config_path} not found.") from e return config_oconf diff --git a/tunix/cli/grpo_main.py b/tunix/cli/grpo_main.py index 4b527be4..5f578570 100644 --- a/tunix/cli/grpo_main.py +++ b/tunix/cli/grpo_main.py @@ -14,6 +14,9 @@ """Main entry point for GRPO training.""" from absl import app +from absl import flags +from absl import logging +import jax from tunix.cli import config from tunix.cli.utils import model as model_lib from tunix.examples.data import math_dataset as data_lib @@ -22,6 +25,10 @@ from tunix.rl.grpo.grpo_learner import GrpoConfig from tunix.rl.rollout import base_rollout +_PATHWAYS_BNS = flags.DEFINE_string( + "pathways_bns", None, "BNS address of the Pathways server." +) + class GrpoPipeline(config.HyperParameters): """Class for running the GRPO trainer.""" @@ -132,8 +139,22 @@ def run_grpo_trainer(self): grpo_trainer.train(dataset) +def _setup_jax_pathways(pathways_bns: str): + """Sets up Jax with Pathways.""" + flags.FLAGS.pathways_ifrt = True + jax.config.update("jax_xla_backend", "pathways") + jax.config.update("jax_backend_target", pathways_bns) + + def main(argv, **kwargs): + if _PATHWAYS_BNS.value: + _setup_jax_pathways(_PATHWAYS_BNS.value) pipeline = GrpoPipeline(argv, **kwargs) + logging.info( + "--- Launching GRPO pipeline with following config ---\n" + "%s\n--------------------------", + pipeline.config, + ) pipeline.run_grpo_trainer()