Skip to content

Commit 8597e7d

Browse files
committed
Small updates on MoE configuration
1 parent ac74931 commit 8597e7d

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

docs/reference/core_concepts/moe_configuration.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Dropping:
5050

5151
`first_num_dense_layers`: The number of initial dense layers before the first MoE layer is introduced.
5252

53-
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability.
53+
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. Recommended specifically when lower precision types cause convergence or quality issues.
5454

5555
### Routing Mechanism
5656
`use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes.
@@ -80,9 +80,9 @@ Dropping:
8080
* Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
8181
* Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.
8282

83-
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul.
83+
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass.
8484

85-
`mlp_bias`: If enabled, add bias terms within the expert MLP layers.
85+
`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.
8686

8787
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications.
8888

src/MaxText/configs/base.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed
157157
cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly.
158158
float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product
159159
float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax
160-
float32_weight_sum: true # whether to use full fp32 precision for weight_sum during final unpermute in moe
160+
float32_weight_sum: true # # whether to use full fp32 precision to sum expert weights for numerical stability
161161

162162
# multi-token prediction configs
163163
# the number of auxiliary prediction layers to use for mtp.
@@ -179,7 +179,7 @@ sparse_matmul: true
179179
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
180180
load_balance_loss_weight: 0.01 # weight for the load balance loss
181181
use_random_routing: false # whether to use random routing for debug/test purpose
182-
use_custom_sort_vjp: true # whether to use a custom sort vjp for sparse matmul ops
182+
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
183183
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
184184
# tunable tiling dimensions used for mlp gmm
185185
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
@@ -212,7 +212,7 @@ expert_shard_attention_option: "fsdp"
212212

213213
# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
214214
moe_fsdp_use_two_stage_all_gather: false
215-
# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum
215+
# Shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism
216216
fsdp_shard_on_exp: False
217217
# use fsdp and fsdp_transpose axes for sharding the moe weights
218218
use_2d_fsdp_sharding: False

src/MaxText/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ class MoEGeneral(BaseModel):
583583
)
584584
float32_weight_sum: bool = Field(
585585
True,
586-
description="Whether to use full fp32 precision for weight_sum during final unpermute in MoE.",
586+
description="Whether to use full fp32 precision to sum expert weights for numerical stability.",
587587
)
588588

589589

0 commit comments

Comments
 (0)