Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions docs/guides/model_bringup.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
<!--
Copyright 2023–2026 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# MaxText Model Bringup: Community Contributor Guide

This documentation acts as the primary resource for efficiently integrating new models into the MaxText ecosystem. It provides the community with a standardized workflow covering architectural analysis, feature implementation, and correctness verification.

## 1. Architecture Analysis

The first phase involves determining how the new model's architecture aligns with MaxText's existing capabilities.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to how you mentioned for the data pipeline, maybe provide a link to the sources of those files, or a README that mentions or has the latest state of available features?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

**Input Data Pipeline**: MaxText supports HuggingFace, Grain, and TFDS pipelines ([details](data_input_pipeline)). While synthetic data is typically used for initial performance benchmarks, the framework supports multiple modalities including text and image (audio and video - work in progress).

**Tokenizer**: Supported options include `TikTokenTokenizer` (Llama3), `SentencePieceTokenizer` (Gemma, Llama2, Mistral), and `HFTokenizer` (DeepSeek and other newer models).

**Self-Attention & RoPE**: Available mechanisms include optimized Flash Attention (supporting MHA, GQA, and MQA), Multi-head Latent Attention (MLA), and Gated DeltaNet. MaxText also supports Regular, Llama, and Yarn variations of Rotary Positional Embeddings (RoPE).

**Multi-Layer Perceptron (MLP)**: The framework supports both traditional dense models and Mixture of Experts (MoE) architectures, including configurations for routed and shared experts.

**Normalization**: Most implementations utilize RMSNorm, which can be configured before or after attention/MLP layers.

**Decoder Layers**: Models can have multiple decoder layers with varying structures. The trend has evolved from entirely dense layers to purely MoE layers, and now towards a mix of both.

## 2. (Optional) Feature Implementation

This step can be bypassed if the current MaxText codebase already supports all components required for your model architecture. However, if your model introduces unique logic or requires specific code refactoring, these modifications should be completed before you begin converting checkpoints.

**Sharding**: MaxText supports both auto and explicit sharding modes and provides dedicated sharding functions. We recommend developers use MaxText-specific sharding functions, such as `MaxText.sharding.maybe_shard_with_name`, instead of default JAX sharding hint like `jax.lax.with_sharding_constraint` for better performance.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a different section?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I slightly prefer to keep this as a note for now, or simply remove it to keep this bringup focused on functionality. We have this as a side note in the internal guide. HDYT?


## 3. Checkpoint Conversion

While most open-source models are distributed in Safetensors or PyTorch formats, MaxText requires conversion to the [Orbax](https://orbax.readthedocs.io/en/latest/) format.

There are [two primary formats](../reference/core_concepts/checkpoints.md) for Orbax checkpoints within MaxText, and while both are technically compatible with training and inference, we recommend following these performance-optimized guidelines:

- **Scanned Format**: Recommended for **training** as it stacks layers for efficient processing via `jax.lax.scan`. To enable this, set `scan_layers=True`.
- **Unscanned Format**: Recommended for **inference** to simplify loading individual layer parameters. To enable this, set `scan_layers=False`.

### 3.1 Create Mapping

Success starts with a clear map. You must align the parameter names from your source checkpoints (Safetensors/PyTorch) with the corresponding MaxText internal names.

- You can use provided [script](../../src/MaxText/utils/ckpt_conversion/inspect_checkpoint.py) to print out the keys and shapes of your original `.safetensors` or `.pth` files.
- To see the target structure, you can initiate a pre-training run to save a randomly initialized checkpoint for inspection, and use the same script above.

### 3.2 Write Script

Use existing model scripts within the repository as templates to tailor the conversion logic for your specific architecture. We strongly recommended to use the [checkpoint conversion utility](../../src/MaxText/utils/ckpt_conversion) rather than [standalone scripts](../../src/MaxText/utils/ckpt_scripts).

### 3.3 Verify Compatibility

You can now fine-tune in MaxText using your converted scanned checkpoint, or perform decoding with your unscanned checkpoint. This assumes no compiler or shape errors are encountered.

## 4. Unit Tests

Before verifying the entire model, it is highly recommended to perform unit tests on new layers or features. This approach helps isolate potential bugs early by comparing the MaxText JAX implementation against a reference PyTorch implementation—typically from HuggingFace or the original model release. More examples can be found [here](https://github.com/search?q=repo%3AAI-Hypercomputer%2Fmaxtext+vs_reference&type=code).

Core Strategy:

- **Instantiate Layers**: Create instances of both the MaxText JAX layer and the corresponding PyTorch layer.
- **Copy Weights**: Initialize the PyTorch layer and copy its exact weights to the JAX layer instance. This ensures both start from the same state. Pay close attention to weight naming and potential shape differences (e.g., transposing Linear/MLP weights).
- **Prepare Input**: Create identical input data (e.g., random tensors) for both frameworks. Use a helper to convert PyTorch tensors to JAX arrays.
- **Forward Pass**: Run the input through both layers. Remember to set the PyTorch model to evaluation mode (`model_pt.eval()`) to disable dropout etc.
- **Compare Outputs**: Convert the PyTorch output to a JAX array (or NumPy array) and use `numpy.testing.assert_allclose` to check if the outputs are numerically close within a specified tolerance (atol, rtol).

## 5. End-to-end correctness

This verification process can vary in duration. If you're working with a small model, you're fortunate as it allows for rapid iteration on your development machine. To verify a model's correctness, we could leverage two strategies below - comparing logits and evaluation.

### 5.1 Compare Forward Logits

This is the primary verification for training, using a small set of input prompts. Typically, we obtain logits from both a reference implementation and MaxText, then compare their divergence to assess consistency.

When running the comparison script, using the flags `dtype=float32 activations_in_float32=true matmul_precision=high` should result in a smaller divergence. Ideally, this will meet the criteria of `--max_kl_div=1e-4` or combination of `--atol=1e-02` and `--rtol=1e-02`.

If your model is already supported by HuggingFace, you can run this [forward_pass_logit_checker.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/utils/forward_pass_logit_checker.py) script to compare logits.

### 5.2 Eval Benchmark

MaxText integrates with benchmark libraries like [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) and [evalchemy](https://github.com/mlfoundations/evalchemy) to facilitate rapid verification of common inference scores ([guide](../../benchmarks/api_server)). This is particularly useful for validating decoding outputs or assessing model performance when logits deviate slightly from reference values.

## 6. Completion Checklist

Please ensure all items on the following checklist are completed before finalizing your MaxText model bring-up support:

1. Core Functionality

- [ ] Implement all new required features for functionality, and have unit tests against the reference implementation.
- [ ] Update FLOP calculations if new components are added.

2. Checkpoint Conversion

- [ ] Complete the checkpoint conversion from source to Orbax for both: scanned version & unscanned version.

3. Model & Integration Verification

- [ ] Verify model forward logits using common prompts from the repository: scanned version & unscanned version.
- [ ] Perform common downstream benchmark scores (e.g., MMLU) against the reference.

4. Documentation

- [ ] Create a user guide and post an announcement in the MaxText repo.

## Community Q&A (FAQ)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also mention how to compile the model (which can be done on a CPU) to avoid errors, such as those related to sharding, later on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me add it in both internal and external user guide for alignment.


**Q: How do I debug code inside a JAX JIT function?**

**A:** Standard debuggers like VSCode cannot step through `jitted` functions. Use `jax.debug.print("{item}", item=...)` to inspect values during execution.

**Q: How to debug the correctness of logits?**
If you run the `forward_pass_logit_checker.py` to compare reference logits with your implementation and find divergence, we highly recommend checking the logits in the first layer using the same prompt. A few things to start with:

- Weight loading logic: The sequence of operations during weight loading is critical; for instance, you may need to split or reshape components before you concatenate them to ensure the final matrix matches the reference structure.
- Naming conflicts: Ensure that internal variables and functions have unique names; simple naming collisions can cause issues to correctly pass RoPE or QK-norm intervals to the attention mechanism.
- Query scaling: To avoid scaling issues, query scaling was moved from the checkpoint conversion phase to the forward pass using the `query_pre_attn_scalar` argument.

**Q: My model is too large for my development machine. What should I do?**

**A:** You can create a smaller version of the model (fewer layers or smaller hidden dimensions) to iterate on your local box before moving to a larger cluster.

**Q: How to store logits from a JAX JIT function?**

We generally use jax.debug.print() to print out results; however, those values are often truncated, so storing them directly is recommended to compare the final logits, especially running on GKE clusters.

```
def save_with_jit(x):
jnp.save("your_file_name.npy", x)
jax.debug.callback(save_with_jit, your_variable)
```

**Q: Do I need to create a scanned checkpoint before an unscanned one?**

**A:** No, you can directly generate an unscanned checkpoint from the source weights using the conversion utility.
Loading