You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/reference/core_concepts/moe_configuration.md
+6-6Lines changed: 6 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -50,7 +50,7 @@ Dropping:
50
50
51
51
`first_num_dense_layers`: The number of initial dense layers before the first MoE layer is introduced.
52
52
53
-
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability.
53
+
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. Recommended specifically when lower precision types cause convergence or quality issues.
54
54
55
55
### Routing Mechanism
56
56
`use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes.
@@ -80,11 +80,11 @@ Dropping:
80
80
* Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
81
81
* Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.
82
82
83
-
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul.
83
+
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass.
84
84
85
-
`mlp_bias`: If enabled, add bias terms within the expert MLP layers.
85
+
`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.
86
86
87
-
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications.
87
+
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits.
88
88
89
89
## 2. Sharding
90
90
`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include:
@@ -93,9 +93,9 @@ Dropping:
93
93
94
94
`use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication.
95
95
96
-
`moe_fsdp_use_two_stage_all_gather`: If enabled, splits the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
96
+
`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
97
97
98
-
`fsdp_shard_on_exp`: If enabled, shard MLP weights on expert dimension instead of embedding dimension during FSDP sharding.
98
+
`fsdp_shard_on_exp`: If enabled, shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism.
99
99
100
100
## 3. Performance Tuning
101
101
These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel.
0 commit comments