Skip to content

Commit dbed671

Browse files
committed
Merge branch 'shuningjin-xlml-ds' of github.com:AI-Hypercomputer/maxtext into shuningjin-xlml-ds
2 parents 39462d5 + 4f5a1c4 commit dbed671

16 files changed

Lines changed: 246 additions & 116 deletions

end_to_end/tpu/deepseek/Run_DeepSeek.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ python3 -m MaxText.decode src/MaxText/configs/base.yml \
151151
max_target_length=1024 \
152152
tokenizer_type=huggingface \
153153
tokenizer_path=deepseek-ai/DeepSeek-V3 \
154-
attention=flash \
154+
attention=dot_product \
155155
dtype=bfloat16 \
156156
weight_dtype=bfloat16 \
157157
megablox=False \
@@ -197,7 +197,6 @@ python3 -m tests.forward_pass_logit_checker \
197197
model_name=deepseek2-16b \
198198
max_prefill_predict_length=4 \
199199
max_target_length=4 \
200-
dataset_type=synthetic \
201200
scan_layers=false \
202201
sparse_matmul=False \
203202
dtype=float32 \

end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh

Lines changed: 64 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,79 @@
33
# This file is documentation for how to get started with DeepSeek v2-Lite on v5p-8.
44

55
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from HuggingFace to make it compatible with MaxText.
7-
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.
8-
# 3. Run logits check test between Huggingface and MaxText.
9-
# 4. Run pre-training, fine-tuning, and decoding.
6+
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
7+
# Scanned format is better for training; unscanned format is better for decoding.
8+
# 2. Run logit check, pre-training, fine-tuning, and decoding.
9+
10+
# Example Usage: export HF_TOKEN=<huggingface_access_token>; export BASE_OUTPUT_PATH=<GCS_bucket_path>; bash test_deepseek.sh
11+
12+
# The golden logit can be generated by:
13+
# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite --output-path=golden_data_deepseek2-16b.jsonl --prompts='I love to;Today is a;What is the' --hf-model-path=$local_bf16_path --not-trust-remote-code
1014

1115
set -ex
12-
idx=$(date +%Y-%m-%d-%H-%M)
16+
1317
export MODEL_NAME='deepseek2-16b'
1418
export TOKENIZER_PATH='deepseek-ai/DeepSeek-V2-Lite'
1519

16-
# Installing torch for deps in forward_pass_logit_checker.py
20+
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
1721
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
1822

19-
# Step 1:
20-
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
21-
# Non-Googlers please remember to use separate GCS paths for uploading model weights from HuggingFace ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
22-
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
23-
# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite
24-
export CHKPT_BUCKET=gs://maxtext-deepseek/deepseek2-16b/hf
25-
export MODEL_BUCKET=gs://maxtext-deepseek/deepseek2-16b
26-
JAX_PLATFORMS=cpu python3 -m MaxText.convert_deepseek_ckpt --base_model_path ${CHKPT_BUCKET} --maxtext_model_path ${MODEL_BUCKET}/${idx} --model_size ${MODEL_NAME}
23+
# e.g., $HOME/maxtext/src/MaxText
24+
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"
2725

28-
# Step 2:
29-
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
30-
JAX_PLATFORMS=cpu python MaxText.convert_deepseek_unscanned_ckpt --base_model_path ${CHKPT_BUCKET} --maxtext_model_path ${MODEL_BUCKET}/${idx}/unscanned --model_size ${MODEL_NAME}
26+
if [ -z "${BASE_OUTPUT_PATH}" ]; then
27+
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
28+
# this bucket will store all the files generated by MaxText during a run
29+
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
30+
echo "BASE_OUTPUT_PATH is not set"
31+
fi
32+
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
33+
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}
3134

32-
# Step 3:
33-
export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${idx}/unscanned/0/items
34-
python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V2-Lite load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false sparse_matmul=False dtype=float32 activations_in_float32=true matmul_precision=high --max_kl_div=2e-4
35+
# Step 1: Checkpoint conversion
36+
# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite, and dequantize it to bf16
37+
# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET
38+
# Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own
39+
# Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory
40+
if [ -z "${CKPT_DISK_LOCATION}" ]; then
41+
export CKPT_BUCKET=gs://maxtext-deepseek/deepseek2-16b/hf
42+
gcloud storage cp -r ${CKPT_BUCKET} /tmp
43+
export CKPT_DISK_LOCATION=/tmp/hf
44+
fi
3545

36-
# Step 4:
46+
# 1.1 Convert checkpoint to `scanned` format, more suitable for training
47+
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_ckpt --base_model_path ${CKPT_DISK_LOCATION} --maxtext_model_path ${BASE_OUTPUT_PATH}/scanned --model_size ${MODEL_NAME}
48+
49+
# 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding
50+
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_unscanned_ckpt --base_model_path ${CKPT_DISK_LOCATION} --maxtext_model_path ${BASE_OUTPUT_PATH}/unscanned --model_size ${MODEL_NAME}
51+
52+
# Step 2:
53+
# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands
54+
export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items
55+
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
3756
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
3857
export DATASET_PATH=gs://maxtext-dataset
39-
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
40-
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
41-
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands
42-
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${idx}/0/items
43-
44-
# Run pre-training - matmul implementation
45-
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=${MODEL_NAME} ici_fsdp_parallelism=4 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False dataset_type=synthetic
46-
# Run fine-tuning - matmul implementation
47-
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=matmul_fine_tuning per_device_batch_size=4 enable_checkpointing=false model_name=${MODEL_NAME} ici_fsdp_parallelism=4 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False enable_checkpointing=true
48-
# Run supervised fine-tuning - matmul implementation
49-
python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/sft.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=matmul_supervised_fine_tuning per_device_batch_size=4 enable_checkpointing=false model_name=${MODEL_NAME} steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False enable_checkpointing=true ici_expert_parallelism=4 ici_fsdp_parallelism=1 dataset_type=hf
50-
# Run decoding - matmul implementation
51-
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=decode per_device_batch_size=1 enable_checkpointing=false model_name=${MODEL_NAME} max_prefill_predict_length=100 max_target_length=1024 tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False ici_tensor_parallelism=4 ici_fsdp_parallelism=1 prompt="I love to" scan_layers=False
58+
59+
# Test whether the forward pass logits match the golden logits
60+
# default golden_logits_path=/deps/src/MaxText/test_assets/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
61+
GOLDEN_LOGITS_DISK_LOCATION="/deps/src/MaxText/test_assets/golden_data_${MODEL_NAME}.jsonl"
62+
if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
63+
GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl"
64+
GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl
65+
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
66+
fi
67+
68+
python3 -m tests.forward_pass_logit_checker ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1e-4 --rtol=1e-4 --max_kl_div=5e-6
69+
70+
# Run pre-training - megablox implementation
71+
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4
72+
73+
# Run fine-tuning - megablox implementation
74+
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024
75+
76+
# Run supervised fine-tuning - megablox implementation
77+
# python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024
78+
79+
# Run decoding - megablox implementation
80+
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
81+
python3 -m MaxText.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=4 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/bin/bash
2+
3+
# This file is documentation for how to get started with DeepSeek v3.
4+
5+
# This file runs Step 1 on CPU.
6+
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
7+
# Scanned format is better for training; unscanned format is better for decoding.
8+
# 2. Run logit check, pre-training, fine-tuning, and decoding.
9+
10+
set -ex
11+
12+
export MODEL_NAME='deepseek3-671b'
13+
export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3'
14+
15+
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
16+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
17+
18+
if [ -z "${BASE_OUTPUT_PATH}" ]; then
19+
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
20+
# this bucket will store all the files generated by MaxText during a run
21+
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
22+
echo "BASE_OUTPUT_PATH is not set"
23+
fi
24+
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
25+
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}
26+
27+
# Step 1: Checkpoint conversion
28+
# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V3, and dequantize it to bf16
29+
# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET
30+
# Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own
31+
# Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory
32+
if [ -z "${CKPT_DISK_LOCATION}" ]; then
33+
export CKPT_BUCKET=gs://maxtext-deepseek/deepseek3-671b/hf
34+
gcloud storage cp -r ${CKPT_BUCKET} /tmp
35+
export CKPT_DISK_LOCATION=/tmp/hf
36+
fi
37+
38+
# 1.1 Convert checkpoint to `scanned` format, more suitable for training
39+
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_ckpt --base_model_path ${CKPT_DISK_LOCATION} --maxtext_model_path ${BASE_OUTPUT_PATH}/scanned --model_size ${MODEL_NAME}
40+
41+
# 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding
42+
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_unscanned_ckpt --base_model_path ${CKPT_DISK_LOCATION} --maxtext_model_path ${BASE_OUTPUT_PATH}/unscanned --model_size ${MODEL_NAME}

0 commit comments

Comments
 (0)