Skip to content

[None][perf] CuTeDSL MegaMoE: eliminate per-launch workspace memset o…#15201

Draft
Barry-Delaney wants to merge 1 commit into
NVIDIA:feat/deepseek_v4from
Barry-Delaney:user/jinshik/cutedsl_megamoe_memset_opt
Draft

[None][perf] CuTeDSL MegaMoE: eliminate per-launch workspace memset o…#15201
Barry-Delaney wants to merge 1 commit into
NVIDIA:feat/deepseek_v4from
Barry-Delaney:user/jinshik/cutedsl_megamoe_memset_opt

Conversation

@Barry-Delaney

Copy link
Copy Markdown
Collaborator

…verhead

Two optimizations that reduce CuTeDSL MegaMoE forward latency by up to 2.7x at large token counts (DSv4 shapes, B200 8-GPU EP=8):

  1. Targeted workspace zeroing: the per-launch _zero_local_workspace_preserving_phase was zeroing the entire ~1.2 GB local workspace (minus the 4-byte nvlink_barrier_counter). Only ~0.57 MB of atomic counters and small safety-critical data regions actually need resetting. Atomic counters (l1_arrival_count, expert_send_count, grid_sync_counter, fc1_done_counter, fc2_done_counter, load_balance_counter) must start at zero; l1_topk_weights_buffer and token_src_metadata must be zero for MMA padding rows to prevent stale data from corrupting FC2 → combine scatter. The bulk data regions (l1_token_buffer, l1_sf_buffer, fc1_output, fc1_output_sf) are overwritten by dispatch/FC1 before being read. This eliminates ~3 ms/fwd of memset at 32768T.

  2. Fuse bf16->fp32 cast into sum: replace combine_output.to(fp32).sum(dim=1) with combine_output.sum(dim=1, dtype=fp32), eliminating a full-size bf16->fp32 materialization kernel (~300 us/layer).

  3. Cache get_regions() return value to avoid re-creating symmetric memory tensor views on every forward call, which is incompatible with CUDA-graph capture (cudaErrorStreamCaptureUnsupported).

Measured (min latency, CG, balanced routing, 8×B200):
32768T: 5.78 ms -> 2.15 ms (2.69x)
16384T: 2.98 ms -> 1.16 ms (2.57x)
8192T: 1.56 ms -> 0.65 ms (2.40x)
4096T: 0.85 ms -> 0.40 ms (2.10x)

@coderabbitai summary

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

…verhead

Two optimizations that reduce CuTeDSL MegaMoE forward latency by up to
2.7x at large token counts (DSv4 shapes, B200 8-GPU EP=8):

1. Targeted workspace zeroing: the per-launch
   `_zero_local_workspace_preserving_phase` was zeroing the entire
   ~1.2 GB local workspace (minus the 4-byte nvlink_barrier_counter).
   Only ~0.57 MB of atomic counters and small safety-critical data
   regions actually need resetting.  Atomic counters
   (l1_arrival_count, expert_send_count, grid_sync_counter,
   fc1_done_counter, fc2_done_counter, load_balance_counter) must
   start at zero; l1_topk_weights_buffer and token_src_metadata must
   be zero for MMA padding rows to prevent stale data from corrupting
   FC2 → combine scatter.  The bulk data regions (l1_token_buffer,
   l1_sf_buffer, fc1_output, fc1_output_sf) are overwritten by
   dispatch/FC1 before being read.  This eliminates ~3 ms/fwd of
   memset at 32768T.

2. Fuse bf16->fp32 cast into sum: replace
   `combine_output.to(fp32).sum(dim=1)` with
   `combine_output.sum(dim=1, dtype=fp32)`, eliminating a full-size
   bf16->fp32 materialization kernel (~300 us/layer).

3. Cache `get_regions()` return value to avoid re-creating symmetric
   memory tensor views on every forward call, which is incompatible
   with CUDA-graph capture (cudaErrorStreamCaptureUnsupported).

Measured (min latency, CG, balanced routing, 8×B200):
  32768T: 5.78 ms -> 2.15 ms (2.69x)
  16384T: 2.98 ms -> 1.16 ms (2.57x)
   8192T: 1.56 ms -> 0.65 ms (2.40x)
   4096T: 0.85 ms -> 0.40 ms (2.10x)

Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant