Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Dec 11, 2025

Description

Previously we have gpt-oss, orbax(scan) -> hf: #2647, which is tested for both 20b and 120b.

Fix: b/459541579

  • gpt-oss, hf -> orbax(scan)
  • gpt-oss, hf -> orbax(unscan)
  • gpt-oss, orbax(unscan) -> hf

Fix: b/452392132

  • implement weight splitting for hf->orbax (i.e., hf to many maxtext key)

Fix: b/452391921

  • verify interleaved scan pattern for hf->orbax

What this does

GPT-OSS / weight splitting

to_maxtext.py

  • allow hf to many mt
    • assume mt keys have same shape, hook function return a tensor stacked in last dim
    • accomodate lazy tensor: unoptimized, hf is repeated loaded for each mt
  • allow loading local hf checkpoint
    • the remote hf checkpoint is quantized for some models (e.g., gpt-oss - mxfp4, deepseek - fp8), yet we are using local de-quantized hf version (usually bf16) for conversion
    • accomodate lazy tensor

param_mapping.py, gpt-oss

  • implement interleave function for hf to many mt
  • add unscan version of mapping and hook

Refactor

readme

  • update supported model
  • update arguments

to_maxtext.py

  • refactor condition for single axis stack using config.scan_layers (using num_layer -> config.scan_layers)
  • factor out get_maxtext_dict
  • add time

utils.py

  • refactor condition for single axis stack using config.scan_layers (to_huggingface, using name -> config.scan_layers)
  • move _check_param_map_keys from to_huggingface.py to here , so it can be reused by to_maxtext.py
  • normalize chained hook (remove reversal to avoid confusion)

param_mapping.py

  • normalize chained hook (update hook order and comment, llama3.1, gemma2, mixtral)
  • remove nested hook definition

Tests

All tests are on CPU.

1 HF -> orbax (gpt-oss-20b)

since we made non-trivial changes to lazy tensor implementation, also test lazy mode

HF -> orbax (scan)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2

https://paste.googleplex.com/4888544332087296

3.56 min

CKPT=gs://runner-maxtext-logs/2025-12-26-21-51
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/5272274628378624

HF -> orbax (scan), lazy load

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
--lazy_load_tensors=true

https://paste.googleplex.com/6192468888518656

4.96 min

CKPT=gs://runner-maxtext-logs/2025-12-26-21-58
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/5022180226236416

HF -> orbax (unscan)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2

https://paste.googleplex.com/6000687559344128

CKPT=gs://runner-maxtext-logs/2025-12-26-22-20
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/6128993298415616

2 orbax -> HF (gpt-oss-20b), unscan

ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-20b/unscan-bf16-v2-2025-09-02-01-16-00/0/items \
base_output_directory=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-$ID \
scan_layers=false \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/5483624130543616

HF_PATH=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-2025-12-26-22-56-40
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-20b/unscan-bf16-v2-2025-09-02-01-16-00/0/items \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/4854884102963200

3 orbax -> HF (gpt-oss-120b), unscan

MAXTEXT_CKPT=gs://shuningjin-multipod-dev/gpt-oss-120b/unscan-bf16-v2-2025-09-04-10-55-39/0/items
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=gpt-oss-120b \
load_parameters_path=$MAXTEXT_CKPT \
base_output_directory=/home/shuningjin/gpt-oss-120b/gpt-oss-120b-hf-$ID \
scan_layers=false \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/5823699976585216
/home/shuningjin/gpt-oss-120b/gpt-oss-120b-hf-2026-01-05-23-19-02

MAXTEXT_CKPT=gs://shuningjin-multipod-dev/gpt-oss-120b/unscan-bf16-v2-2025-09-04-10-55-39/0/items
HF_PATH=/home/shuningjin/gpt-oss-120b/gpt-oss-120b-hf-2026-01-05-23-19-02
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-120b \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=openai/gpt-oss-120b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/6546254270169088

4 HF -> orbax (gpt-oss-120b), scan

# sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'
# watch -n 1 free -h
BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-120b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-120b/gpt-oss-120b-bf16-v2
MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-06-20-09/0/items 
HF_PATH=/home/shuningjin/gpt-oss-120b/gpt-oss-120b-bf16-v2
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-120b \
checkpoint_storage_concurrent_gb=1024 \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=openai/gpt-oss-120b tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/6019766257057792

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@shuningjin shuningjin changed the title [WIP] checkpoint util: gpt-oss, orbax -> hf [WIP] checkpoint util: gpt-oss, hf -> orbax Dec 11, 2025
@codecov
Copy link

codecov bot commented Dec 23, 2025

@shuningjin shuningjin changed the title [WIP] checkpoint util: gpt-oss, hf -> orbax Checkpoint utility: gpt-oss, hf to orbax Dec 26, 2025
@shuningjin shuningjin marked this pull request as ready for review December 26, 2025 23:45
@github-actions
Copy link

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces a significant and valuable set of enhancements to the checkpoint conversion utilities. The refactoring improves code modularity by moving shared functions to a central utils.py, and the new functionality, such as support for local Hugging Face models and unscanned GPT-OSS models, greatly increases the flexibility of these tools.

🔍 General Feedback

  • Positive: The addition of more granular timing logs is a great improvement for performance analysis and debugging. The updated documentation in the README provides a much clearer overview of supported models and conversion paths.
  • Good Refactoring: Moving check_param_map_keys to utils.py and introducing get_maxtext_model_info in to_maxtext.py are excellent changes that improve code organization and reusability.

I have left a couple of minor comments, one regarding a bug in the timing calculation and another for a small docstring clarification. Overall, this is a solid contribution.

Copy link
Collaborator

@hengtaoguo hengtaoguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work!

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! If you could wrap your added codes into helper functions among that 8 cases, which will be really appreciated :) Also, let's create a bug to track the rest refactor.

@shuningjin shuningjin changed the title Checkpoint utility: gpt-oss, hf to orbax Checkpoint utility: add gpt-oss to_maxtext & refactor code Jan 5, 2026
@shuningjin
Copy link
Collaborator Author

shuningjin commented Jan 5, 2026

Test llama3.1-8b & llama3.1-70b

Why: (1) Not in readme, unclear if it is tested. (2) I refactored chained hook function, need validate.

Test llama3.1-8b

1 scan to_maxtext

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=llama3.1-8b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product

https://paste.googleplex.com/4998663032143872

gs://runner-maxtext-logs/2026-01-05-07-48/0/items

CKPT=gs://runner-maxtext-logs/2026-01-05-07-48/0/items
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=llama3.1-8b \
load_parameters_path=$CKPT \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
tokenizer_path=meta-llama/Llama-3.1-8B tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True \
--hf_model_path=meta-llama/Llama-3.1-8B

https://paste.googleplex.com/5362702983757824

2 scan to_huggingface

MAXTEXT_SCAN=gs://runner-maxtext-logs/2026-01-05-07-48/0/items
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=llama3.1-8b \
load_parameters_path=$MAXTEXT_SCAN \
base_output_directory=/home/shuningjin/tmp/llama3.1-8b-hf-$ID \
scan_layers=true \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/6262739150438400

/home/shuningjin/tmp/llama3.1-8b-hf-2026-01-05-08-07-40

MAXTEXT_SCAN=gs://runner-maxtext-logs/2026-01-05-07-48/0/items
HF_PATH=/home/shuningjin/tmp/llama3.1-8b-hf-2026-01-05-08-07-40
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=llama3.1-8b \
load_parameters_path=$MAXTEXT_SCAN \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=meta-llama/Llama-3.1-8B tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/4895288806473728

3 unscan to_maxtext

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=llama3.1-8b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product

https://paste.googleplex.com/5925712177528832

gs://runner-maxtext-logs/2026-01-05-08-44/0/items

CKPT=gs://runner-maxtext-logs/2026-01-05-08-44/0/items
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=llama3.1-8b \
load_parameters_path=$CKPT \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
tokenizer_path=meta-llama/Llama-3.1-8B tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True \
--hf_model_path=meta-llama/Llama-3.1-8B

https://paste.googleplex.com/5328587924307968

4 unscan to_huggingface

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-08-44/0/items
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=llama3.1-8b \
load_parameters_path=$MAXTEXT_CKPT \
base_output_directory=/home/shuningjin/tmp/llama3.1-8b-hf-$ID \
scan_layers=false \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/4849285143330816

/home/shuningjin/tmp/llama3.1-8b-hf-2026-01-05-08-50-33

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-08-44/0/items
HF_PATH=/home/shuningjin/tmp/llama3.1-8b-hf-2026-01-05-08-50-33
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=llama3.1-8b \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=meta-llama/Llama-3.1-8B tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/4851388989440000

Test llama3.1-70b

1 scan to_maxtext

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=llama3.1-70b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True

https://paste.googleplex.com/5830490068221952

gs://runner-maxtext-logs/2026-01-05-19-22/0/items

2 scan to_huggingface

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-19-22/0/items 
HF_PATH=meta-llama/Llama-3.1-70B
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=llama3.1-70b \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=meta-llama/Llama-3.1-70B tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/4999300700569600

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-19-22/0/items
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=llama3.1-70b \
load_parameters_path=$MAXTEXT_CKPT \
base_output_directory=/home/shuningjin/tmp/llama3.1-70b-hf-$ID \
scan_layers=true \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/4753248986726400

/home/shuningjin/tmp/llama3.1-70b-hf-2026-01-05-22-42-55

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-19-22/0/items 
HF_PATH=/home/shuningjin/tmp/llama3.1-70b-hf-2026-01-05-22-42-55
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=llama3.1-70b \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=meta-llama/Llama-3.1-70B tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/4647903924715520

@shuningjin
Copy link
Collaborator Author

Test qwen3-30b

Why: (1) validate for readme (2) validate as refactored hf param stacking

1 Scan to_maxtext

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-30b-a3b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product

gs://runner-maxtext-logs/2026-01-05-16-40/0/items

https://paste.googleplex.com/5478434014887936

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-16-40/0/items
HF_PATH=Qwen/Qwen3-30B-A3B-Thinking-2507
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=qwen3-30b-a3b \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.15 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=Qwen/Qwen3-30B-A3B-Thinking-2507 tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/6224942968471552

2 Scan to_huggingface

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-16-40/0/items
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=qwen3-30b-a3b \
load_parameters_path=$MAXTEXT_CKPT \
base_output_directory=/home/shuningjin/tmp/qwen3-30b-a3b-hf-$ID \
scan_layers=true \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/6328627622969344

/home/shuningjin/tmp/qwen3-30b-a3b-hf-2026-01-05-17-16-41

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-16-40/0/items
HF_PATH=/home/shuningjin/tmp/qwen3-30b-a3b-hf-2026-01-05-17-16-41
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=qwen3-30b-a3b \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.15 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=Qwen/Qwen3-30B-A3B-Thinking-2507 tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/6311964089384960

3 Unscan to_maxtext

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-30b-a3b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product

https://paste.googleplex.com/6271451139276800

gs://runner-maxtext-logs/2026-01-05-17-34/0/items

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-17-34/0/items
HF_PATH=Qwen/Qwen3-30B-A3B-Thinking-2507
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=qwen3-30b-a3b \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.15 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=Qwen/Qwen3-30B-A3B-Thinking-2507 tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/6272882302910464

4 Unscan to_huggingface

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-17-34/0/items
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=qwen3-30b-a3b \
load_parameters_path=$MAXTEXT_CKPT \
base_output_directory=/home/shuningjin/tmp/qwen3-30b-a3b-hf-$ID \
scan_layers=false \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/6161672899264512
/home/shuningjin/tmp/qwen3-30b-a3b-hf-2026-01-05-17-52-24

MAXTEXT_CKPT=gs://runner-maxtext-logs/2026-01-05-17-34/0/items
HF_PATH=/home/shuningjin/tmp/qwen3-30b-a3b-hf-2026-01-05-17-52-24
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=qwen3-30b-a3b \
load_parameters_path=$MAXTEXT_CKPT \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.15 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=Qwen/Qwen3-30B-A3B-Thinking-2507 tokenizer_type=huggingface hf_access_token=$HF_TOKEN \
skip_jax_distributed_system=True

https://paste.googleplex.com/6247622660718592

@shuningjin shuningjin force-pushed the shuningjin-ckpt-gpt branch 2 times, most recently from bd2c830 to f8ce9a5 Compare January 5, 2026 23:57
@shuningjin
Copy link
Collaborator Author

shuningjin commented Jan 6, 2026

Thanks! If you could wrap your added codes into helper functions among that 8 cases, which will be really appreciated :) Also, let's create a bug to track the rest refactor.

I did an initial round of refactor, and improved the comment/docstring. This should help clarify maxtext key forms and hf value forms.

I tested llama3.1-8b, llama3.1-70b, and qwen3-30b. This is to ensure the refactor does not break other models.

I tested gpt-oss-120b: to_huggingface works well, test added to description; to_maxtext is very slow, will defer its testing after optimization -- follow up.

Readme is updated to track models supported and distinguish sizes tested/untested.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change!

@shuningjin shuningjin force-pushed the shuningjin-ckpt-gpt branch from 95feb36 to 171a3c6 Compare January 6, 2026 22:42
@copybara-service copybara-service bot merged commit f0bf728 into main Jan 6, 2026
31 of 33 checks passed
@copybara-service copybara-service bot deleted the shuningjin-ckpt-gpt branch January 6, 2026 23:49
@shuningjin
Copy link
Collaborator Author

I tested gpt-oss-120b: to_huggingface works well, test added to description; to_maxtext is very slow, will defer its testing after optimization -- follow up.

I additionally tested gpt-oss-120b to_maxtext, added to description. It takes 120min. Will follow up with optimization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants