Skip to content

Add Packing Support for Context Parallelism (Ring Attention)#2906

Merged
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
kocchop:faysal/add-thd-2-ring-attn
Apr 20, 2026
Merged

Add Packing Support for Context Parallelism (Ring Attention)#2906
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
kocchop:faysal/add-thd-2-ring-attn

Conversation

@kocchop
Copy link
Copy Markdown
Collaborator

@kocchop kocchop commented Dec 31, 2025

Description

Enables sequence packing for context parallelism with ring strategy using TransformerEngine's DotProductAttention. Includes comprehensive GPU tests for ring attention with packing for sm90+.

  • Currently supports packing only for ring attention
  • Replaced local sequence reordering with TE reorder_causal_load_balancing api
  • Currently the load balancing strategy is automatically picked based on the packing config

Current support matrix for context parallelism on GPU:

CP Strategy Packing Load Balance Strategy SWA
all_gather dual_chunk_swap
ring ✅ (with real data) dual_chunk_swap (non-packed) striped (packed) ✅ (with packing)

Tests

Added a GPU integration test that works for sm90+.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 31, 2025

Codecov Report

❌ Patch coverage is 41.46341% with 24 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/train_utils.py 15.38% 10 Missing and 1 partial ⚠️
src/maxtext/utils/max_utils.py 41.66% 5 Missing and 2 partials ⚠️
src/maxtext/layers/attention_op.py 0.00% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a couple of nits but lgtm

Comment thread src/MaxText/layers/attention_op.py Outdated
Comment thread src/MaxText/maxtext_utils.py Outdated
Comment thread src/MaxText/maxtext_utils.py Outdated
Comment thread src/maxtext/utils/max_utils.py Outdated
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tests and great comments illustrating the two reorder strategies!

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just minor comments.

Comment thread src/maxtext/configs/base.yml
Comment thread src/maxtext/utils/train_utils.py
@github-actions
Copy link
Copy Markdown

This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.

@github-actions github-actions Bot added the stale Automatically applied to stale PRs. label Feb 21, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 5, 2026

This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it.

@github-actions github-actions Bot closed this Mar 5, 2026
@kocchop kocchop reopened this Mar 5, 2026
@kocchop kocchop force-pushed the faysal/add-thd-2-ring-attn branch from 9d04e5a to e9d76dc Compare March 8, 2026 11:43
@kocchop kocchop added enhancement New feature or request and removed stale Automatically applied to stale PRs. enhancement New feature or request labels Mar 8, 2026
@kocchop kocchop requested a review from gobbleturk March 8, 2026 11:46
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 7, 2026

This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.

@github-actions github-actions Bot added the stale Automatically applied to stale PRs. label Apr 7, 2026
@github-actions
Copy link
Copy Markdown

This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it.

@github-actions github-actions Bot closed this Apr 14, 2026
@kocchop kocchop reopened this Apr 16, 2026
@kocchop kocchop requested a review from abhinavclemson as a code owner April 16, 2026 02:56
@kocchop kocchop added enhancement New feature or request and removed stale Automatically applied to stale PRs. labels Apr 16, 2026
@kocchop kocchop force-pushed the faysal/add-thd-2-ring-attn branch 2 times, most recently from 2935ff5 to 132bf4c Compare April 16, 2026 23:18
Enable CP + packing for context_parallel_strategy="ring" with load
balancing. On GPU, uses Transformer Engine's striped reorder for
THD-packed sequences. On TPU/CPU, falls back to pure-JAX reorder_sequence
and never imports TE.

Changes:
- common_types: Add ReorderStrategy enum (AUTO, DUAL_CHUNK_SWAP, STRIPED).
- configs: Add context_parallel_reorder_strategy (default "auto"). Reject
    explicit STRIPED on non-GPU at config validation time.
- attention_op: Thread segment_positions through apply_attention,
    cudnn_flash_attention, and __call__. Use segment_positions in TE's
    SequenceDescriptor for packing. Restrict packing+CP to load-balanced
    ring only. Note TE version constraint.
- attentions.py, attention_mla.py, gpt3.py: Pass inputs_positions into
    attention_op calls (None for gpt3).
- max_utils: Hardware-dispatched reorder_causal_load_balanced. GPU uses
    TE's reorder_causal_load_balancing; TPU/CPU uses reorder_sequence.
    TE import is lazy and GPU-only.
- maxtext_utils: Thread reorder_strategy and hardware through
    shard_reorder_causal_load_balanced and get_reorder_callable. Default
    hardware="tpu" never triggers TE import.
- train_utils: Allow ring+packing; forbid all_gather+packing and
    synthetic+packing. Resolve AUTO->STRIPED for packing else
    DUAL_CHUNK_SWAP. Pass config.hardware to reorder callable. Build
    data_loader after reorder wrapper is applied.
- attention_test_util: Pass cfg_cp.hardware so TPU tests use pure-JAX
    reorder. Helper is TPU-oriented and does not model GPU packed behavior.
- tests: Add test_gpu_ring_attention_with_packing (sm90+).

Requires TE with reorder_causal_load_balancing; works with TE <=2.11 or
>=2.14 (incompatible with 2.12 and 2.13 due to a known bug).
@kocchop kocchop force-pushed the faysal/add-thd-2-ring-attn branch from 132bf4c to eb52c0b Compare April 20, 2026 18:55
@copybara-service copybara-service Bot merged commit 9608068 into AI-Hypercomputer:main Apr 20, 2026
22 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request pull ready

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants