Skip to content

Commit 607418a

Browse files
committed
Small updates on MoE configuration
1 parent ac74931 commit 607418a

File tree

3 files changed

+23
-19
lines changed

3 files changed

+23
-19
lines changed

docs/reference/core_concepts/moe_configuration.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Dropping:
5050

5151
`first_num_dense_layers`: The number of initial dense layers before the first MoE layer is introduced.
5252

53-
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability.
53+
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. Recommended specifically when lower precision types cause convergence or quality issues.
5454

5555
### Routing Mechanism
5656
`use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes.
@@ -80,11 +80,11 @@ Dropping:
8080
* Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
8181
* Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.
8282

83-
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul.
83+
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass.
8484

85-
`mlp_bias`: If enabled, add bias terms within the expert MLP layers.
85+
`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.
8686

87-
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications.
87+
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits.
8888

8989
## 2. Sharding
9090
`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include:
@@ -93,9 +93,9 @@ Dropping:
9393

9494
`use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication.
9595

96-
`moe_fsdp_use_two_stage_all_gather`: If enabled, splits the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
96+
`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
9797

98-
`fsdp_shard_on_exp`: If enabled, shard MLP weights on expert dimension instead of embedding dimension during FSDP sharding.
98+
`fsdp_shard_on_exp`: If enabled, shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism.
9999

100100
## 3. Performance Tuning
101101
These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel.

src/MaxText/configs/base.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed
157157
cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly.
158158
float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product
159159
float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax
160-
float32_weight_sum: true # whether to use full fp32 precision for weight_sum during final unpermute in moe
160+
float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability
161161

162162
# multi-token prediction configs
163163
# the number of auxiliary prediction layers to use for mtp.
@@ -179,7 +179,7 @@ sparse_matmul: true
179179
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
180180
load_balance_loss_weight: 0.01 # weight for the load balance loss
181181
use_random_routing: false # whether to use random routing for debug/test purpose
182-
use_custom_sort_vjp: true # whether to use a custom sort vjp for sparse matmul ops
182+
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
183183
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
184184
# tunable tiling dimensions used for mlp gmm
185185
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
@@ -212,7 +212,7 @@ expert_shard_attention_option: "fsdp"
212212

213213
# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
214214
moe_fsdp_use_two_stage_all_gather: false
215-
# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum
215+
# Shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism
216216
fsdp_shard_on_exp: False
217217
# use fsdp and fsdp_transpose axes for sharding the moe weights
218218
use_2d_fsdp_sharding: False
@@ -224,13 +224,12 @@ shared_experts: 1
224224
routed_scaling_factor: 1.0 # scaling factor for routing scores
225225
routed_score_func: "" # scoring function for routing
226226
routed_bias: False # a flag if a learnable bias is added for routing
227-
mlp_bias: False # a flag if a learnable bias is added for MLP matmul
227+
mlp_bias: False # a flag if a learnable bias is added for MLP matmul, and originally implemented to support the GPT-OSS model architecture.
228228
n_routing_groups: -1 # number of groups for routing, disabled by default
229229
topk_routing_group: -1 # number of top groups to route inputs. For EP,
230230
# Splits the batch to allow for better scheduling when using expert parallelism by overlapping the
231231
# all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers.
232-
use_batch_split_schedule: False # whether to use batch split schedule
233-
# sending activations to a maximum of topk_routing_group distinct devices can yield performance benefits.
232+
use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits.
234233

235234
# For complex architectures like llama4 there are repeated sets of
236235
# inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope]

src/MaxText/configs/types.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,9 @@ class MoEGeneral(BaseModel):
553553
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
554554
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
555555
load_balance_loss_weight: NonNegativeFloat = Field(0.01, description="Weight for the load balancing auxiliary loss.")
556-
use_custom_sort_vjp: bool = Field(True, description="Whether to use a custom sort VJP for sparse matmul ops.")
556+
use_custom_sort_vjp: bool = Field(
557+
True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul."
558+
)
557559
use_ring_of_experts: bool = Field(
558560
False,
559561
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
@@ -570,8 +572,8 @@ class MoEGeneral(BaseModel):
570572
)
571573
fsdp_shard_on_exp: bool = Field(
572574
False,
573-
description="Shard the MoE weights on the num_expert dimension. Can be performant when "
574-
"num_experts % fsdp_parallelism != 0.",
575+
description="Shard the expert dimension of the MLP weights on the FSDP axis, "
576+
"and recommended when num_experts is a multiple of fsdp_parallelism",
575577
)
576578
use_2d_fsdp_sharding: bool = Field(
577579
False,
@@ -583,7 +585,7 @@ class MoEGeneral(BaseModel):
583585
)
584586
float32_weight_sum: bool = Field(
585587
True,
586-
description="Whether to use full fp32 precision for weight_sum during final unpermute in MoE.",
588+
description="Whether to use full fp32 precision to sum expert weights for numerical stability.",
587589
)
588590

589591

@@ -639,13 +641,16 @@ class DeepSeekMoE(BaseModel):
639641
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")
640642
routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').")
641643
routed_bias: bool = Field(False, description="Whether to add a bias term for routing.")
642-
mlp_bias: bool = Field(False, description="Whether to add a learnable bias for MLP matmul.")
644+
mlp_bias: bool = Field(
645+
False,
646+
description="Whether to add a learnable bias for MLP matmul, "
647+
"and originally implemented to support the GPT-OSS model architecture",
648+
)
643649
n_routing_groups: int = Field(-1, description="Number of groups for routing, disabled by default.")
644650
topk_routing_group: int = Field(-1, description="Number of top groups to route inputs to.")
645651
use_batch_split_schedule: bool = Field(
646652
False,
647-
description="Splits the batch to allow for better scheduling when using expert parallelism by overlapping all-to-all "
648-
"with compute.",
653+
description="Whether to split batch into micro-batches to hide communications that yields performance benefits.",
649654
)
650655

651656

0 commit comments

Comments
 (0)