Skip to content

support attn_dp_expert parallelism#3141

Open
khatwanimohit wants to merge 1 commit intomainfrom
mohit/attn_expert_submit
Open

support attn_dp_expert parallelism#3141
khatwanimohit wants to merge 1 commit intomainfrom
mohit/attn_expert_submit

Conversation

@khatwanimohit
Copy link
Collaborator

Description

  • Add attn_dp_expert mesh axis for expert parallelism: Introduce a new attn_dp_expert axis in the mesh and logical axis rules to support attention DP-aware expert parallelism, particularly when using
    vllm_rpa attention.
  • Make expert axis name configurable in MoE layer: Replace hardcoded "expert" axis name in RoutedMoE with a configurable _expert_parallelism_name that switches to "attn_dp_expert" when attention ==
    "vllm_rpa", affecting all collective operations (all_gather, psum_scatter, axis_index, ragged_all_to_all).
  • Add rollout_expert_parallelism config for RL training: New config field to specify expert parallelism per replica for rollout, with updated device count validation (tp * dp * ep ==
    num_sampler_devices).
  • Pass expert parallelism settings to vLLM rollout: When rollout_expert_parallelism > 1, pass expert_parallel_size and rollout_vllm_enable_expert_parallelism to vLLM kwargs.
  • Add vllm_config_path config option: New config field (in base.yml, rl.yml, and types.py) to specify the path to a YAML file for loading vLLM config, defaulting to
    src/maxtext/configs/inference/vllm.yml.
  • Update vLLM inference script: Add vllm_swap_space, vllm_async_scheduling, and vllm_config_path flags to vllm_decode.py; remove the hard requirement on hf_config_path.
  • Update vLLM logical axis rules: Add attn_dp_expert to sharding rules for activation axes (activation_q_length, activation_kv_batch, decode_batch, embed, exp, etc.) and add moe_mlp rule in the vLLM
    inference config.

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 13, 2026

Codecov Report

❌ Patch coverage is 22.72727% with 17 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/rl/train_rl.py 0.00% 9 Missing ⚠️
src/MaxText/layers/moe.py 38.46% 6 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

@khatwanimohit khatwanimohit force-pushed the mohit/attn_expert_submit branch from f22b6db to d910ca3 Compare February 13, 2026 23:24

def get_expert_parallelism_size(self):
return self.mesh.shape.get("expert", 1)
if isinstance(self._expert_parallelism_name, tuple):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the tuple case actually supported? Have you tested it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments