SDPA decode perf improvements for qwen-3.5-35B-A3B#18759
Draft
digantdesai wants to merge 5 commits intomainfrom
Draft
SDPA decode perf improvements for qwen-3.5-35B-A3B#18759digantdesai wants to merge 5 commits intomainfrom
digantdesai wants to merge 5 commits intomainfrom
Conversation
Compares ET Triton SDPA (native GQA) against PyTorch Flash/Efficient/Math backends (expanded KV) across Lk=64..16K on A100. Uses triton.testing.do_bench for timing. Standalone script, no changes to the kernel.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18759
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ebe61e8 with merge base b24535b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
2cb04c3 to
febc419
Compare
Add a stats_callback to generate() that prints prefill/decode rates, model load time, TTFT, and sampling time via printf, mirroring the format in extension/llm/runner/stats.h print_report. Uses printf instead of ET_LOG(Info) because the CMake target does not link executorch_no_prim_ops (which provides the PAL logger); adding that dependency pulls in the full runtime and breaks the minimal runner build.
Register `triton::sdpa_decode_splitk` as an independent op so AOTI can trace and compile it without the runtime L_kv conditional that prevents the split-K path from appearing in the standard `sdpa` op. The split-K (flash-decoding) approach partitions the KV sequence across CTAs and reduces partial softmax results in a second kernel. The benchmark script now includes the split-K column for comparison. Standalone kernel benchmark on H100 (Qwen3.5 MoE decode, B=1, H_q=16, H_kv=2, D=256, bf16): Lk ET Tiled (us) ET Split-K (us) Speedup 64 131.8 259.5 0.5x 512 98.9 221.5 0.4x 4096 199.9 214.4 0.9x 8192 392.2 211.3 1.9x 16384 775.3 211.8 3.7x Split-K breaks even around Lk=4096 and dominates at longer sequences where the tiled kernel's single-CTA-per-head bottleneck becomes severe.
The previous example used T=2, which caused AOTI to compile the chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime, prompts longer than 64 tokens (requiring NT>1 chunks) failed with "Error resizing tensor at input 0". Using max_seq_len-1 as the example ensures AOTI generalizes intermediate buffer sizes for the full sequence length range.
Runtime dispatch via torch.cond in FullAttention: split-K flash-decoding for decode (L_q==1) and standard tiled SDPA for prefill (L_q>1). Guard sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing with symbolic shapes doesn't trip the L_q==1 check. Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal, enable_gqa) for drop-in use with torch.cond; unsupported args fail with clear messages.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
WIP