Skip to content

Commit 1e5ed04

Browse files
s-noghabiThe tunix Authors
authored andcommitted
[DNS] debug failure
PiperOrigin-RevId: 834954246
1 parent 7852510 commit 1e5ed04

File tree

9 files changed

+286
-241
lines changed

9 files changed

+286
-241
lines changed

.github/workflows/cpu-tests.yml

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,54 +26,54 @@ jobs:
2626
# copybara:strip_begin
2727
# LINT.IfChange()
2828
# copybara:strip_end
29-
run:
30-
# copybara:strip_begin
31-
# LINT.ThenChange(Internal path for github_actions, don't change this line.)
32-
# copybara:strip_end
33-
runs-on: ubuntu-latest
34-
steps:
35-
- uses: actions/checkout@v4
36-
- uses: actions/setup-python@v4
37-
with:
38-
python-version: '3.11'
29+
# run:
30+
# # copybara:strip_begin
31+
# # LINT.ThenChange(Internal path for github_actions, don't change this line.)
32+
# # copybara:strip_end
33+
# runs-on: ubuntu-latest
34+
# steps:
35+
# - uses: actions/checkout@v4
36+
# - uses: actions/setup-python@v4
37+
# with:
38+
# python-version: '3.11'
3939

40-
- name: Download the tunix wheel
41-
uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
42-
with:
43-
name: tunix-wheel
40+
# - name: Download the tunix wheel
41+
# uses: actions/download-artifact@634f93cb2916e3fdff6788551b99b062d0335ce0 # v5.0.0
42+
# with:
43+
# name: tunix-wheel
4444

45-
- name: Install the tunix wheel
46-
run: |
47-
python -m pip install --upgrade pip
48-
python -m pip install google_tunix-*-py3-none-any.whl
49-
python -m pip install pytest
45+
# - name: Install the tunix wheel
46+
# run: |
47+
# python -m pip install --upgrade pip
48+
# python -m pip install google_tunix-*-py3-none-any.whl
49+
# python -m pip install pytest
5050

51-
- name: Verify Tunix imports from installed package
52-
run: |
53-
python3 -c "
54-
import tunix
55-
import tunix.models
56-
import tunix.generate
57-
import tunix.sft
58-
import tunix.distillation
59-
import tunix.rl
51+
# - name: Verify Tunix imports from installed package
52+
# run: |
53+
# python3 -c "
54+
# import tunix
55+
# import tunix.models
56+
# import tunix.generate
57+
# import tunix.sft
58+
# import tunix.distillation
59+
# import tunix.rl
6060

61-
assert tunix.__version__ != '0.0.0.dev0', 'Tunix version not set correctly'
62-
print('All tunix modules imported successfully and version is', tunix.__version__)
63-
"
64-
- name: Run agentic RL tests
65-
run: |
66-
python -m pytest tests/rl/agentic/ -v --tb=short
61+
# assert tunix.__version__ != '0.0.0.dev0', 'Tunix version not set correctly'
62+
# print('All tunix modules imported successfully and version is', tunix.__version__)
63+
# "
64+
# - name: Run agentic RL tests
65+
# run: |
66+
# python -m pytest tests/rl/agentic/ -v --tb=short
6767

68-
- name: Run Cli utils tests
69-
run: |
70-
python -m pytest tests/cli/utils/ -v --tb=short
68+
# - name: Run Cli utils tests
69+
# run: |
70+
# python -m pytest tests/cli/utils/ -v --tb=short
7171

72-
- name: Run model alignment tests
73-
run: |
74-
python -m pip install torch
75-
python -m pytest tests/model_alignment/ -v --tb=short
72+
# - name: Run model alignment tests
73+
# run: |
74+
# python -m pip install torch
75+
# python -m pytest tests/model_alignment/ -v --tb=short
7676

77-
- name: Run perf tests
78-
run: |
79-
python -m pytest tests/perf/ -v --tb=short
77+
# - name: Run perf tests
78+
# run: |
79+
# python -m pytest tests/perf/ -v --tb=short

.github/workflows/tpu-tests.yml

Lines changed: 91 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -33,97 +33,97 @@ env:
3333
HF_HUB_ENABLE_HF_TRANSFER: "1"
3434

3535
jobs:
36-
run_prod:
37-
runs-on: [linux-x86-ct5lp-224-8tpu]
38-
environment: testing
39-
container:
40-
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
41-
options: --privileged
42-
env:
43-
CLOUD_TPU_ACCELERATOR: v5e-8
44-
JAX_PLATFORMS: tpu
45-
steps:
46-
47-
# Cache Hugging Face hub
48-
- name: Cache HF hub
49-
uses: actions/cache@v4
50-
with:
51-
path: ~/.cache/huggingface
52-
key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
53-
restore-keys: |
54-
hf-${{ runner.os }}-
55-
56-
- name: Checkout code
57-
uses: actions/checkout@v4
58-
with:
59-
fetch-depth: 0
60-
61-
- name: Install tunix dependencies
62-
run: |
63-
pip install --upgrade pip
64-
pip install -e .[prod] --force-reinstall
65-
pip install pytest pytest-xdist
66-
67-
- name: Verify TPU availability
68-
run: |
69-
python -c "
70-
import jax
71-
print(f'JAX version: {jax.__version__}')
72-
print(f'JAX devices: {jax.devices()}')
73-
74-
# Check if we have TPU devices specifically
75-
devices = jax.devices()
76-
has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices)
77-
print(f'TPU available: {has_tpu}')
78-
79-
if not has_tpu:
80-
print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices])
81-
exit(1)
82-
else:
83-
print(f'SUCCESS: Found {len(devices)} TPU device(s)')
84-
"
85-
86-
- name: Run tunix model tests
87-
run: |
88-
python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
89-
90-
- name: Run tunix generation tests (PASSED only)
91-
run: |
92-
# tokenizer_adapter_test requires access to gated repo
93-
# TODO(b/459824938) Add back test_logprobs_extraction_with_missing_token after fixing the issue
94-
python -m pytest tests/generate/ -v --tb=short \
95-
--ignore=tests/generate/vllm_sampler_test.py \
96-
--ignore=tests/generate/vllm_driver_test.py \
97-
--ignore=tests/generate/tokenizer_adapter_test.py \
98-
--ignore=tests/generate/sglang_jax_sampler_test.py \
99-
--ignore=tests/generate/utils_test.py
100-
101-
python -m pytest tests/generate/utils_test.py -k "not test_logprobs_extraction_with_missing_token"
102-
103-
- name: Run tunix SFT tests
104-
run: |
105-
python -m pytest tests/sft/ -v --tb=short
106-
107-
- name: Run tunix distillation tests
108-
run: |
109-
python -m pytest tests/distillation/ -v --tb=short
110-
111-
- name: Run tunix RL tests
112-
run: |
113-
# RL common tests that passed
114-
# b/448133814: test_grpo_with_lora_model fails
115-
python -m pytest tests/rl/ -v --tb=short -k "not test_grpo_with_lora_model" --ignore=tests/rl/experimental/agentic
116-
117-
- name: Run tunix tests not covered by the above categories
118-
run: |
119-
# This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
120-
python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ || code=$?
121-
if [ "${code:-0}" = "5" ]; then
122-
echo "No tests collected (expected)."
123-
exit 0
124-
else
125-
exit "${code:-0}"
126-
fi
36+
# run_prod:
37+
# runs-on: [linux-x86-ct5lp-224-8tpu]
38+
# environment: testing
39+
# container:
40+
# image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/tpu:latest
41+
# options: --privileged
42+
# env:
43+
# CLOUD_TPU_ACCELERATOR: v5e-8
44+
# JAX_PLATFORMS: tpu
45+
# steps:
46+
47+
# # Cache Hugging Face hub
48+
# - name: Cache HF hub
49+
# uses: actions/cache@v4
50+
# with:
51+
# path: ~/.cache/huggingface
52+
# key: hf-${{ runner.os }}-${{ hashFiles('pyproject.toml', 'requirements*.txt', 'constraints*.txt') }}
53+
# restore-keys: |
54+
# hf-${{ runner.os }}-
55+
56+
# - name: Checkout code
57+
# uses: actions/checkout@v4
58+
# with:
59+
# fetch-depth: 0
60+
61+
# - name: Install tunix dependencies
62+
# run: |
63+
# pip install --upgrade pip
64+
# pip install -e .[prod] --force-reinstall
65+
# pip install pytest pytest-xdist
66+
67+
# - name: Verify TPU availability
68+
# run: |
69+
# python -c "
70+
# import jax
71+
# print(f'JAX version: {jax.__version__}')
72+
# print(f'JAX devices: {jax.devices()}')
73+
74+
# # Check if we have TPU devices specifically
75+
# devices = jax.devices()
76+
# has_tpu = len(devices) > 0 and all(device.platform == 'tpu' for device in devices)
77+
# print(f'TPU available: {has_tpu}')
78+
79+
# if not has_tpu:
80+
# print('ERROR: No TPU devices found! Expected TPU devices but got:', [device.platform for device in devices])
81+
# exit(1)
82+
# else:
83+
# print(f'SUCCESS: Found {len(devices)} TPU device(s)')
84+
# "
85+
86+
# - name: Run tunix model tests
87+
# run: |
88+
# python -m pytest tests/models/ -v --tb=short -m "not cpu_only and not gpu_only"
89+
90+
# - name: Run tunix generation tests (PASSED only)
91+
# run: |
92+
# # tokenizer_adapter_test requires access to gated repo
93+
# # TODO(b/459824938) Add back test_logprobs_extraction_with_missing_token after fixing the issue
94+
# python -m pytest tests/generate/ -v --tb=short \
95+
# --ignore=tests/generate/vllm_sampler_test.py \
96+
# --ignore=tests/generate/vllm_driver_test.py \
97+
# --ignore=tests/generate/tokenizer_adapter_test.py \
98+
# --ignore=tests/generate/sglang_jax_sampler_test.py \
99+
# --ignore=tests/generate/utils_test.py
100+
101+
# python -m pytest tests/generate/utils_test.py -k "not test_logprobs_extraction_with_missing_token"
102+
103+
# - name: Run tunix SFT tests
104+
# run: |
105+
# python -m pytest tests/sft/ -v --tb=short
106+
107+
# - name: Run tunix distillation tests
108+
# run: |
109+
# python -m pytest tests/distillation/ -v --tb=short
110+
111+
# - name: Run tunix RL tests
112+
# run: |
113+
# # RL common tests that passed
114+
# # b/448133814: test_grpo_with_lora_model fails
115+
# python -m pytest tests/rl/ -v --tb=short -k "not test_grpo_with_lora_model" --ignore=tests/rl/experimental/agentic
116+
117+
# - name: Run tunix tests not covered by the above categories
118+
# run: |
119+
# # This category is to catch tests added but not covered by CI yet. Whenever you add new folders under tests/, please add a new category above and skip those tests here.
120+
# python -m pytest tests/ -v --tb=short --ignore=tests/perf/ --ignore=tests/model_alignment/ --ignore=tests/models/ --ignore=tests/cli/ --ignore=tests/generate/ --ignore=tests/sft/ --ignore=tests/distillation/ --ignore=tests/rl/ || code=$?
121+
# if [ "${code:-0}" = "5" ]; then
122+
# echo "No tests collected (expected)."
123+
# exit 0
124+
# else
125+
# exit "${code:-0}"
126+
# fi
127127

128128
run_dev:
129129
if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }}

tests/cli/utils/model_test.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@
7878
testcase_name="gemma3-27b",
7979
model_name="gemma3-27b",
8080
),
81+
dict(
82+
testcase_name="gemma-3-270m",
83+
model_name="gemma-3-270m",
84+
),
85+
dict(
86+
testcase_name="gemma-3-1b",
87+
model_name="gemma-3-1b",
88+
),
89+
dict(
90+
testcase_name="gemma-3-4b",
91+
model_name="gemma-3-4b",
92+
),
93+
dict(
94+
testcase_name="gemma-3-12b",
95+
model_name="gemma-3-12b",
96+
),
97+
dict(
98+
testcase_name="gemma-3-27b",
99+
model_name="gemma-3-27b",
100+
),
81101
dict(
82102
testcase_name="llama3-70b",
83103
model_name="llama3-70b",
@@ -118,11 +138,10 @@
118138
testcase_name="qwen2.5-math-1.5b",
119139
model_name="qwen2.5-math-1.5b",
120140
),
121-
# TODO(b/451662153): support deepseek model name parsing
122-
# dict(
123-
# testcase_name="deepseek-r1-distill-qwen-1.5b",
124-
# model_name="deepseek-r1-distill-qwen-1.5b",
125-
# ),
141+
dict(
142+
testcase_name="deepseek-r1-distill-qwen-1.5b",
143+
model_name="deepseek-r1-distill-qwen-1.5b",
144+
),
126145
dict(
127146
testcase_name="qwen3-0.6b",
128147
model_name="qwen3-0.6b",
@@ -151,10 +170,15 @@ def test_obtain_model_params_valid(self, model_name: str):
151170
model.obtain_model_params(model_name)
152171

153172
def test_create_model_dynamically_routing(self, model_name: str):
154-
model_module = model.get_model_module(model_name)
173+
params_module = model.get_model_module(model_name, model.ModelModule.PARAMS)
155174
if not model_name.startswith("gemma"):
156175
# TODO(b/444572467)
157-
getattr(model_module, "create_model_from_safe_tensors")
176+
getattr(params_module, "create_model_from_safe_tensors")
177+
178+
model_lib_module = model.get_model_module(
179+
model_name, model.ModelModule.MODEL
180+
)
181+
getattr(model_lib_module, "ModelConfig")
158182

159183

160184
if __name__ == "__main__":

tunix/cli/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def create_optimizer(
345345
" https://optax.readthedocs.io/en/latest/api/optimizers.html#optimizers"
346346
) from e
347347

348+
logging.info("[SHADI] optimizer_config: %s", optimizer_config)
348349
# Handle learning rate, potentially creating a schedule
349350
learning_rate_val = self._create_learning_rate(
350351
optimizer_config, config_path_info

0 commit comments

Comments
 (0)