Skip to content

SDPA decode perf improvements for qwen-3.5-35B-A3B#18759

Draft
digantdesai wants to merge 5 commits intomainfrom
digantdesai/sdpa-bench-and-perf-stats
Draft

SDPA decode perf improvements for qwen-3.5-35B-A3B#18759
digantdesai wants to merge 5 commits intomainfrom
digantdesai/sdpa-bench-and-perf-stats

Conversation

@digantdesai
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai commented Apr 8, 2026

WIP

Compares ET Triton SDPA (native GQA) against PyTorch Flash/Efficient/Math
backends (expanded KV) across Lk=64..16K on A100. Uses triton.testing.do_bench
for timing. Standalone script, no changes to the kernel.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18759

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ebe61e8 with merge base b24535b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 8, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 8, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@digantdesai digantdesai force-pushed the digantdesai/sdpa-bench-and-perf-stats branch from 2cb04c3 to febc419 Compare April 8, 2026 04:12
digantdesai and others added 4 commits April 8, 2026 21:00
Add a stats_callback to generate() that prints prefill/decode rates,
model load time, TTFT, and sampling time via printf, mirroring the
format in extension/llm/runner/stats.h print_report.

Uses printf instead of ET_LOG(Info) because the CMake target does not
link executorch_no_prim_ops (which provides the PAL logger); adding
that dependency pulls in the full runtime and breaks the minimal
runner build.
Register `triton::sdpa_decode_splitk` as an independent op so AOTI
can trace and compile it without the runtime L_kv conditional that
prevents the split-K path from appearing in the standard `sdpa` op.

The split-K (flash-decoding) approach partitions the KV sequence
across CTAs and reduces partial softmax results in a second kernel.
The benchmark script now includes the split-K column for comparison.

Standalone kernel benchmark on H100 (Qwen3.5 MoE decode, B=1, H_q=16,
H_kv=2, D=256, bf16):

  Lk       ET Tiled (us)  ET Split-K (us)  Speedup
  64            131.8          259.5         0.5x
  512            98.9          221.5         0.4x
  4096          199.9          214.4         0.9x
  8192          392.2          211.3         1.9x
  16384         775.3          211.8         3.7x

Split-K breaks even around Lk=4096 and dominates at longer sequences
where the tiled kernel's single-CTA-per-head bottleneck becomes severe.
The previous example used T=2, which caused AOTI to compile the
chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime,
prompts longer than 64 tokens (requiring NT>1 chunks) failed with
"Error resizing tensor at input 0". Using max_seq_len-1 as the
example ensures AOTI generalizes intermediate buffer sizes for the
full sequence length range.
Runtime dispatch via torch.cond in FullAttention: split-K flash-decoding
for decode (L_q==1) and standard tiled SDPA for prefill (L_q>1). Guard
sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing
with symbolic shapes doesn't trip the L_q==1 check.

Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal,
enable_gqa) for drop-in use with torch.cond; unsupported args fail
with clear messages.
@digantdesai digantdesai changed the title [aoti-cuda] Add SDPA benchmarking script with qwen-3.5-35B-A3B shapes SDPA decode perf improvements for qwen-3.5-35B-A3B Apr 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant