11# %%
2+
23# [WIP] Reproduction of [Deepscaler](https://pretty-radio-b75.notion.site/DeepScaleR-Surpassing-O1-Preview-with-a-1-5B-Model-by-Scaling-RL-19681902c1468005bed8ca303013a4e2) with Single-turn Agentic framework.
34
45import contextlib
56import functools
67import json
7- import logging
88import os
9- from pprint import pprint
10- import re
119
12- from etils import ecolab
1310from flax import nnx
1411import grain
15- import humanize
1612import jax
1713from jax import numpy as jnp
1814import optax
19- from orbax import checkpoint as ocp
2015import qwix
2116from tqdm .auto import tqdm
2217
23- from GOOGLE_INTERNAL_PACKAGE_PATH .pyglib import gfile
24- from etils import ecolab
2518import optax
2619from orbax import checkpoint as ocp
2720
4740 from tunix .rl .agentic .rewards import reward
4841 from tunix .rl .agentic .trajectory import trajectory_collect_engine
4942 from tunix .rl .agentic .parser .chat_template_parser import parser
50- from flax import nnx
5143 import jax
5244 import numpy as np
5345 from tunix .rl .experimental .agentic_grpo_learner import GRPOConfig , GRPOLearner
54- from tunix .models .qwen2 import params
55- from tunix .models .qwen2 import model
5646 from tunix .rl import rl_cluster as rl_cluster_lib
57- from tunix .sft import utils
5847 from tunix .rl .rollout import base_rollout
5948 from tunix .sft import metrics_logger
6049 from tunix .sft import utils as sft_utils
6150 from tunix .utils import math_rewards
51+ from tunix .utils import compat
52+
6253# %%
6354# ====== Data ======
6455TRAIN_FRACTION = 1.0
6960# ====== LoRA ======
7061RANK = 64
7162ALPHA = 64.0
63+ TRAIN_WITH_LORA = False
7264
7365# ====== Sharding ======
7466MESH = [(2 , 4 ), ("fsdp" , "tp" )]
8577# The number of times the policy generates multiple responses for a given prompt
8678# within a single training step. This corresponds to `G` in Algorithm 1 in the
8779# paper. The "group" in GRPO comes from here.
88- NUM_GENERATIONS = 1
80+ NUM_GENERATIONS = 2
8981
9082# === other GRPO configs ===
9183# The number of iterations per batch (𝜇 in GRPO algo 1).
9991EPSILON = 0.2
10092
10193# ====== Training ======
102- BATCH_SIZE = 128
103- MINI_BATCH_SIZE = 64
104- ROLLOUT_MICRO_BATCH_SIZE = 8
105- LOGPS_MICRO_BATCH_SIZE = 8
106- NUM_BATCHES = 30
94+ BATCH_SIZE = 32
95+ MINI_BATCH_SIZE = 32
96+ # ROLLOUT_MICRO_BATCH_SIZE = 8
97+ # LOGPS_MICRO_BATCH_SIZE = 8
98+ NUM_BATCHES = 100
10799# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
108100# increased to a max. of 330 (if batch size is 4).
109101NUM_TEST_BATCHES = 50
@@ -264,14 +256,44 @@ def process_item(item):
264256
265257# %%
266258mesh = jax .make_mesh (
267- (1 , 4 ),
268- ("fsdp" , "tp" ),
259+ * MESH ,
269260 axis_types = (jax .sharding .AxisType .Auto ,) * len (("fsdp" , "tp" )),
270261)
271262config = model_lib .ModelConfig .deepseek_r1_distill_qwen_1_5b ()
272- print ("model_path: " , MODEL_PATH )
273- qwen2 = params_lib .create_model_from_safe_tensors (MODEL_PATH , config , mesh , dtype = jnp .float32 )
274- # nnx.display(model)
263+ print ("MODEL_PATH: " , MODEL_PATH )
264+ qwen2_ref = params_lib .create_model_from_safe_tensors (MODEL_PATH , config , mesh , dtype = jnp .float32 )
265+ # nnx.display(qwen2_ref)
266+
267+
268+ # %%
269+ def get_lora_model (base_model , model_mesh ):
270+ lora_provider = qwix .LoraProvider (
271+ module_path = (
272+ ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
273+ ".*attn_vec_einsum"
274+ ),
275+ rank = RANK ,
276+ alpha = ALPHA ,
277+ )
278+
279+ model_input = base_model .get_model_input ()
280+ lora_model = qwix .apply_lora_to_model (
281+ base_model , lora_provider , ** model_input
282+ )
283+
284+ with compat .set_mesh (model_mesh ):
285+ state = nnx .state (lora_model )
286+ pspecs = nnx .get_partition_spec (state )
287+ sharded_state = jax .lax .with_sharding_constraint (state , pspecs )
288+ nnx .update (lora_model , sharded_state )
289+
290+ return lora_model
291+
292+ # %%
293+ if TRAIN_WITH_LORA :
294+ qwen2_actor = get_lora_model (qwen2_ref , mesh )
295+ else :
296+ qwen2_actor = params_lib .create_model_from_safe_tensors (MODEL_PATH , config , mesh , dtype = jnp .float32 )
275297
276298# %%
277299show_hbm_usage ()
@@ -335,6 +357,8 @@ def process_item(item):
335357 actor_optimizer = optimizer ,
336358 eval_every_n_steps = EVAL_EVERY_N_STEPS ,
337359 max_steps = MAX_STEPS ,
360+ mini_batch_size = MINI_BATCH_SIZE ,
361+ train_micro_batch_size = 1 , # larger than 1 will cause OOM on HBM
338362 # metrics logging
339363 metrics_logging_options = metrics_logging_options ,
340364 # checkpoint saving
@@ -353,19 +377,20 @@ def process_item(item):
353377)
354378
355379grpo_config = GRPOConfig (
356- num_generations = 2 ,
380+ num_generations = NUM_GENERATIONS ,
357381 num_iterations = NUM_ITERATIONS ,
358382 beta = BETA ,
359383 epsilon = EPSILON ,
360384 system_prompt = "" ,
385+ max_concurrency = 8 ,
361386)
362387
363388# %%
364389# RL cluster
365- with mesh :
390+ with compat . set_mesh ( mesh ) :
366391 rl_cluster = rl_cluster_lib .RLCluster (
367- actor = qwen2 ,
368- reference = qwen2 ,
392+ actor = qwen2_actor ,
393+ reference = qwen2_ref ,
369394 tokenizer = tokenizer ,
370395 cluster_config = cluster_config ,
371396 )
0 commit comments