Skip to content

Commit a6716c3

Browse files
lc5211The tunix Authors
authored andcommitted
[Tunix] Update Dockerfile and deepscaler trainer script to seperate trainer model and ref model.
PiperOrigin-RevId: 835298189
1 parent 61cd709 commit a6716c3

File tree

2 files changed

+63
-40
lines changed

2 files changed

+63
-40
lines changed

Dockerfile

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
1-
# Start FROM a base image
2-
FROM ubuntu:22.04
1+
# Base image with Python 3.12
2+
FROM python:3.12-slim
33

44
# Set environment variables to non-interactive to avoid prompts during installation
55
ENV DEBIAN_FRONTEND=noninteractive
66
ENV TZ=Etc/UTC
77

8-
# Use the official Python 3.12 image based on Debian Stable (Bookworm)
9-
FROM python:3.12-slim
8+
# Install system dependencies, including Python 3 and pip
9+
RUN apt-get update && \
10+
apt-get install -y git python3 python3-pip && \
11+
rm -rf /var/lib/apt/lists/*
1012

11-
# Python 3.12, pip, and venv are already included.
1213
# Upgrade pip
1314
RUN python3 -m pip install --upgrade pip
1415

15-
# Install other system dependencies
16-
RUN apt-get update && apt-get install -y --no-install-recommends \
17-
git \
18-
openjdk-21-jdk \
19-
maven \
20-
&& rm -rf /var/lib/apt/lists/*
21-
22-
2316
# Create a virtual environment
2417
RUN python3.12 -m venv /opt/venv
2518
ENV PATH="/opt/venv/bin:$PATH"
2619

20+
# Upgrade pip
21+
RUN pip install --upgrade pip
2722

2823
RUN pip install git+https://github.com/ayaka14732/jax-smi.git
2924
RUN pip install git+https://github.com/AI-Hypercomputer/pathways-utils.git
25+
# If you encounter a checkpoint issue, try using following old version of pathways-utils.
26+
# RUN pip install git+https://github.com/AI-Hypercomputer/pathways-utils.git@b72729bb152b7b3426299405950b3af300d765a9#egg=pathwaysutils
3027
RUN pip install gcsfs
3128
RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3229

30+
RUN pip install --upgrade wandb
3331

3432
# Set the working directory
3533
WORKDIR /app

examples/deepscaler/train_deepscaler_nb.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,20 @@
11
# %%
2+
23
# [WIP] Reproduction of [Deepscaler](https://pretty-radio-b75.notion.site/DeepScaleR-Surpassing-O1-Preview-with-a-1-5B-Model-by-Scaling-RL-19681902c1468005bed8ca303013a4e2) with Single-turn Agentic framework.
34

45
import contextlib
56
import functools
67
import json
7-
import logging
88
import os
9-
from pprint import pprint
10-
import re
119

12-
from etils import ecolab
1310
from flax import nnx
1411
import grain
15-
import humanize
1612
import jax
1713
from jax import numpy as jnp
1814
import optax
19-
from orbax import checkpoint as ocp
2015
import qwix
2116
from tqdm.auto import tqdm
2217

23-
from GOOGLE_INTERNAL_PACKAGE_PATH.pyglib import gfile
24-
from etils import ecolab
2518
import optax
2619
from orbax import checkpoint as ocp
2720

@@ -47,18 +40,16 @@
4740
from tunix.rl.agentic.rewards import reward
4841
from tunix.rl.agentic.trajectory import trajectory_collect_engine
4942
from tunix.rl.agentic.parser.chat_template_parser import parser
50-
from flax import nnx
5143
import jax
5244
import numpy as np
5345
from tunix.rl.experimental.agentic_grpo_learner import GRPOConfig, GRPOLearner
54-
from tunix.models.qwen2 import params
55-
from tunix.models.qwen2 import model
5646
from tunix.rl import rl_cluster as rl_cluster_lib
57-
from tunix.sft import utils
5847
from tunix.rl.rollout import base_rollout
5948
from tunix.sft import metrics_logger
6049
from tunix.sft import utils as sft_utils
6150
from tunix.utils import math_rewards
51+
from tunix.utils import compat
52+
6253
# %%
6354
# ====== Data ======
6455
TRAIN_FRACTION = 1.0
@@ -69,6 +60,7 @@
6960
# ====== LoRA ======
7061
RANK = 64
7162
ALPHA = 64.0
63+
TRAIN_WITH_LORA = False
7264

7365
# ====== Sharding ======
7466
MESH = [(2, 4), ("fsdp", "tp")]
@@ -85,7 +77,7 @@
8577
# The number of times the policy generates multiple responses for a given prompt
8678
# within a single training step. This corresponds to `G` in Algorithm 1 in the
8779
# paper. The "group" in GRPO comes from here.
88-
NUM_GENERATIONS = 1
80+
NUM_GENERATIONS = 2
8981

9082
# === other GRPO configs ===
9183
# The number of iterations per batch (𝜇 in GRPO algo 1).
@@ -99,11 +91,11 @@
9991
EPSILON = 0.2
10092

10193
# ====== Training ======
102-
BATCH_SIZE = 128
103-
MINI_BATCH_SIZE = 64
104-
ROLLOUT_MICRO_BATCH_SIZE = 8
105-
LOGPS_MICRO_BATCH_SIZE = 8
106-
NUM_BATCHES = 30
94+
BATCH_SIZE = 32
95+
MINI_BATCH_SIZE = 32
96+
# ROLLOUT_MICRO_BATCH_SIZE = 8
97+
# LOGPS_MICRO_BATCH_SIZE = 8
98+
NUM_BATCHES = 100
10799
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
108100
# increased to a max. of 330 (if batch size is 4).
109101
NUM_TEST_BATCHES = 50
@@ -264,14 +256,44 @@ def process_item(item):
264256

265257
# %%
266258
mesh = jax.make_mesh(
267-
(1, 4),
268-
("fsdp", "tp"),
259+
*MESH,
269260
axis_types=(jax.sharding.AxisType.Auto,) * len(("fsdp", "tp")),
270261
)
271262
config = model_lib.ModelConfig.deepseek_r1_distill_qwen_1_5b()
272-
print("model_path: ", MODEL_PATH)
273-
qwen2 = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, mesh, dtype=jnp.float32)
274-
# nnx.display(model)
263+
print("MODEL_PATH: ", MODEL_PATH)
264+
qwen2_ref = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, mesh, dtype=jnp.float32)
265+
# nnx.display(qwen2_ref)
266+
267+
268+
# %%
269+
def get_lora_model(base_model, model_mesh):
270+
lora_provider = qwix.LoraProvider(
271+
module_path=(
272+
".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
273+
".*attn_vec_einsum"
274+
),
275+
rank=RANK,
276+
alpha=ALPHA,
277+
)
278+
279+
model_input = base_model.get_model_input()
280+
lora_model = qwix.apply_lora_to_model(
281+
base_model, lora_provider, **model_input
282+
)
283+
284+
with compat.set_mesh(model_mesh):
285+
state = nnx.state(lora_model)
286+
pspecs = nnx.get_partition_spec(state)
287+
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
288+
nnx.update(lora_model, sharded_state)
289+
290+
return lora_model
291+
292+
# %%
293+
if TRAIN_WITH_LORA:
294+
qwen2_actor = get_lora_model(qwen2_ref, mesh)
295+
else:
296+
qwen2_actor = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, mesh, dtype=jnp.float32)
275297

276298
# %%
277299
show_hbm_usage()
@@ -335,6 +357,8 @@ def process_item(item):
335357
actor_optimizer=optimizer,
336358
eval_every_n_steps=EVAL_EVERY_N_STEPS,
337359
max_steps=MAX_STEPS,
360+
mini_batch_size=MINI_BATCH_SIZE,
361+
train_micro_batch_size = 1, # larger than 1 will cause OOM on HBM
338362
# metrics logging
339363
metrics_logging_options=metrics_logging_options,
340364
# checkpoint saving
@@ -353,19 +377,20 @@ def process_item(item):
353377
)
354378

355379
grpo_config = GRPOConfig(
356-
num_generations=2,
380+
num_generations=NUM_GENERATIONS,
357381
num_iterations=NUM_ITERATIONS,
358382
beta=BETA,
359383
epsilon=EPSILON,
360384
system_prompt="",
385+
max_concurrency=8,
361386
)
362387

363388
# %%
364389
# RL cluster
365-
with mesh:
390+
with compat.set_mesh(mesh):
366391
rl_cluster = rl_cluster_lib.RLCluster(
367-
actor=qwen2,
368-
reference=qwen2,
392+
actor=qwen2_actor,
393+
reference=qwen2_ref,
369394
tokenizer=tokenizer,
370395
cluster_config=cluster_config,
371396
)

0 commit comments

Comments
 (0)