diff --git a/backends/cuda/benchmarks/benchmark_sdpa.py b/backends/cuda/benchmarks/benchmark_sdpa.py new file mode 100644 index 00000000000..47af8f3b34e --- /dev/null +++ b/backends/cuda/benchmarks/benchmark_sdpa.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Benchmark the Triton SDPA kernel against PyTorch SDPA backends. + +Measures latency across decode shapes matching the Qwen3.5 MoE model +(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA +(2 KV heads), while Flash/Efficient/Math require pre-expanded KV +(16 heads) since they lack native GQA support. + +""" + +import argparse +import warnings +from functools import partial + +import torch +import torch.nn.functional as F + +from executorch.backends.cuda.triton.kernels.sdpa import ( + sdpa as triton_sdpa, + sdpa_decode_splitk as triton_splitk, +) +from torch.nn.attention import sdpa_kernel, SDPBackend +from triton.testing import do_bench + + +# PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly. +# We expand KV heads via repeat_interleave so they can run, matching what +# the test reference does. This is fair: it measures the kernel itself, not +# the GQA dispatch overhead. + + +def _expand_kv(k, v, num_groups): + if num_groups > 1: + k = k.repeat_interleave(num_groups, dim=1) + v = v.repeat_interleave(num_groups, dim=1) + return k, v + + +def _expand_mask(mask, H_q): + if mask is not None and mask.shape[1] == 1 and H_q > 1: + mask = mask.expand(-1, H_q, -1, -1) + return mask + + +def _run_triton(q, k, v, attn_mask, enable_gqa): + return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + + +def _run_splitk(q, k, v, attn_mask, enable_gqa): + return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + + +def _run_pytorch_default(q, k, v, attn_mask, enable_gqa): + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + enable_gqa=enable_gqa, + ) + + +def _make_pytorch_runner(backend: SDPBackend): + def run(q, k, v, attn_mask, enable_gqa): + with sdpa_kernel(backend): + return F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + + return run + + +# Flash doesn't support attn_mask at all, only is_causal. +# Our benchmark mask is all-ones, so no mask is equivalent. +def _run_flash(q, k, v, attn_mask, enable_gqa): + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + return F.scaled_dot_product_attention(q, k, v) + + +BACKENDS = { + "triton": ("ET Triton (GQA)", _run_triton), + "splitk": ("ET Split-K (GQA)", _run_splitk), + "pytorch": ("PyTorch", _run_pytorch_default), + "flash": ("Flash (expanded KV)", _run_flash), + "efficient": ( + "Efficient (expanded KV)", + _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION), + ), + "math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)), +} + +# Backends that need KV heads expanded before calling (no native GQA support) +_NEEDS_KV_EXPAND = {"flash", "efficient", "math"} + +# -- Shapes ------------------------------------------------------------------ + +# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256 +QWEN35_BASE = dict(B=1, H_q=16, H_kv=2, D=256) + +DECODE_SHAPES = [ + dict(**QWEN35_BASE, Lq=1, Lk=64), + dict(**QWEN35_BASE, Lq=1, Lk=128), + dict(**QWEN35_BASE, Lq=1, Lk=256), + dict(**QWEN35_BASE, Lq=1, Lk=512), + dict(**QWEN35_BASE, Lq=1, Lk=1024), + dict(**QWEN35_BASE, Lq=1, Lk=2048), + dict(**QWEN35_BASE, Lq=1, Lk=4096), + dict(**QWEN35_BASE, Lq=1, Lk=8192), + dict(**QWEN35_BASE, Lq=1, Lk=16384), +] + +SCENARIOS = { + "decode": DECODE_SHAPES, +} + +# -- Helpers ----------------------------------------------------------------- + + +def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): + q = torch.randn(B, H_q, Lq, D, device=device, dtype=dtype) + k = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype) + v = torch.randn(B, H_kv, Lk, D, device=device, dtype=dtype) + mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device) + enable_gqa = H_q != H_kv + num_groups = H_q // H_kv + # Pre-expanded versions for backends without native GQA + k_exp, v_exp = _expand_kv(k, v, num_groups) + mask_exp = _expand_mask(mask, H_q) + return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa + + +def _max_abs_error(out, ref): + return (out.float() - ref.float()).abs().max().item() + + +def _bench_us(fn, num_warmup, num_iters): + """Return median latency in microseconds using triton.testing.do_bench.""" + ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") + return ms * 1000.0 + + +def _try_run(run_fn, q, k, v, mask, enable_gqa): + """Run a backend, returning output or None on failure.""" + try: + return run_fn(q, k, v, mask, enable_gqa) + except RuntimeError: + return None + + +def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters): + """Benchmark a backend, returning median us or None on failure.""" + fn = partial(run_fn, q, k, v, mask, enable_gqa) + try: + run_fn(q, k, v, mask, enable_gqa) + return _bench_us(fn, num_warmup, num_iters) + except RuntimeError: + return None + + +# -- Main -------------------------------------------------------------------- + + +def _shape_label(shape): + return ( + f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} " + f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}" + ) + + +def _short_label(shape, scenario="decode"): + return f"Lq={shape['Lq']},Lk={shape['Lk']}" + + +@torch.inference_mode() +def run_benchmark( + scenario: str = "decode", + num_warmup: int = 25, + num_iters: int = 100, +): + shapes = SCENARIOS[scenario] + backends = [(name, *BACKENDS[name]) for name in BACKENDS] + + device_name = torch.cuda.get_device_name() + print() + print("=" * 100) + print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") + print(f" Device: {device_name}") + print(f" Warmup: {num_warmup}, Iters: {num_iters}") + print(f" Backends: {', '.join(label for _, label, _ in backends)}") + print("=" * 100) + + # Build column specs: (header_text, unit_text, min_width) + # Each column gets width = max(len(header), len(unit), min_width) + max_label = max(len(_short_label(s, scenario)) for s in shapes) + col_specs = [("Shape", "", max(8, max_label))] + for _, label, _ in backends: + col_specs.append((label, "(us)", 8)) + + col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] + + header = " | ".join( + f"{h:<{w}}" if i == 0 else f"{h:>{w}}" + for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) + ) + units = " | ".join( + f"{'':>{w}}" if i == 0 else f"{u:>{w}}" + for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) + ) + print(header) + print(units) + print("-" * len(header)) + + for shape in shapes: + q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # Validate outputs across backends before benchmarking + outputs = {} + for name, _label, run_fn in backends: + if name in _NEEDS_KV_EXPAND: + bk, bv, bmask = k_exp, v_exp, mask_exp + else: + bk, bv, bmask = k, v, mask + outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) + + ref_name, ref_out = None, None + for name, _, _ in backends: + if outputs[name] is not None: + ref_name, ref_out = name, outputs[name] + break + + if ref_out is not None: + for name, label, _ in backends: + if name == ref_name or outputs[name] is None: + continue + err = _max_abs_error(outputs[name], ref_out) + assert err < 1e-2, ( + f"Output mismatch for {_shape_label(shape)}: " + f"{label} vs {BACKENDS[ref_name][0]}, " + f"max abs error {err:.3e} >= 1e-2" + ) + del outputs + + # Benchmark all backends + times = {} + for name, _label, run_fn in backends: + if name in _NEEDS_KV_EXPAND: + bk, bv, bmask = k_exp, v_exp, mask_exp + else: + bk, bv, bmask = k, v, mask + times[name] = _try_bench( + run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters + ) + + # Format row using col_widths + ci = 0 + row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"] + ci += 1 + for name, _, _ in backends: + t = times[name] + w = col_widths[ci] + row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}") + ci += 1 + print(" | ".join(row_parts)) + + del q, k, v, k_exp, v_exp, mask, mask_exp + torch.cuda.empty_cache() + + print("-" * len(header)) + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark Triton SDPA vs PyTorch backends" + ) + parser.add_argument( + "--scenario", + choices=list(SCENARIOS.keys()) + ["all"], + default="all", + help="Which shape set to benchmark (default: all)", + ) + parser.add_argument("--num_warmup", type=int, default=25) + parser.add_argument("--num_iters", type=int, default=100) + args = parser.parse_args() + + scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario] + for s in scenarios: + run_benchmark( + scenario=s, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + ) + + +if __name__ == "__main__": + main() diff --git a/backends/cuda/tests/test_triton_sdpa_splitk.py b/backends/cuda/tests/test_triton_sdpa_splitk.py new file mode 100644 index 00000000000..8b7f0bd867b --- /dev/null +++ b/backends/cuda/tests/test_triton_sdpa_splitk.py @@ -0,0 +1,305 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the split-K decode SDPA kernel (sdpa_decode_splitk). + +Mirrors test_triton_sdpa.py structure. Reference outputs use torch SDPA with +expanded KV heads in float32. +""" + +import itertools +import unittest + +import torch +import torch.nn.functional as F + + +def _skip_if_no_cuda(): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available") + if not torch.cuda.is_bf16_supported(): + raise unittest.SkipTest("BF16 not supported on this GPU") + + +def _import_splitk(): + from executorch.backends.cuda.triton.kernels.sdpa import sdpa_decode_splitk + + return sdpa_decode_splitk + + +def _import_sdpa(): + from executorch.backends.cuda.triton.kernels.sdpa import sdpa + + return sdpa + + +def _reference_sdpa(q, k, v, attn_mask=None, scale=None): + """Compute reference SDPA in float32 with expanded KV heads for GQA.""" + H_q = q.shape[1] + H_kv = k.shape[1] + num_groups = H_q // H_kv + + if num_groups > 1: + k = k.repeat_interleave(num_groups, dim=1) + v = v.repeat_interleave(num_groups, dim=1) + + if attn_mask is not None and attn_mask.shape[1] == 1 and H_q > 1: + attn_mask = attn_mask.expand(-1, H_q, -1, -1) + + return F.scaled_dot_product_attention( + q.float(), + k.float(), + v.float(), + attn_mask=attn_mask, + scale=scale, + ) + + +def _max_abs_error(out, ref): + return (out.float() - ref.float()).abs().max().item() + + +HEAD_DIMS_POW2 = [64, 128, 256] + +GQA_CONFIGS = [ + (6, 3, "gqa_2x"), + (8, 2, "gqa_4x"), + (16, 2, "gqa_8x"), + (6, 1, "mqa"), +] + +LK_LENGTHS = [64, 128, 512, 1024, 4096] + + +class TestTritonSdpaSplitK(unittest.TestCase): + """Test split-K decode SDPA kernel correctness against PyTorch reference.""" + + @classmethod + def setUpClass(cls): + _skip_if_no_cuda() + cls.splitk = _import_splitk() + cls.sdpa = _import_sdpa() + + # ------------------------------------------------------------------ + # Correctness + # ------------------------------------------------------------------ + + def test_decode_basic(self): + """GQA decode across head configs, head dims, and KV lengths.""" + for (H_q, H_kv, label), D, Lk in itertools.product( + GQA_CONFIGS, + HEAD_DIMS_POW2, + LK_LENGTHS, + ): + with self.subTest(label=label, D=D, Lk=Lk): + B, Lq = 1, 1 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.splitk(q, k, v) + ref = _reference_sdpa(q, k, v) + + self.assertEqual(out.shape, (B, H_q, Lq, D)) + self.assertFalse(torch.isnan(out).any(), "NaN in output") + self.assertLess( + _max_abs_error(out, ref), + 0.05, + f"{label} D={D} Lk={Lk}", + ) + + def test_decode_with_mask(self): + """Decode with bool mask (KV cache style: first N positions valid).""" + for H_q, H_kv, label in GQA_CONFIGS: + with self.subTest(label=label): + B, Lq, Lk, D = 1, 1, 512, 128 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + mask = torch.zeros(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + mask[:, :, :, :200] = True + + out = self.splitk(q, k, v, attn_mask=mask) + ref = _reference_sdpa(q, k, v, attn_mask=mask) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_decode_mha(self): + """MHA (H_q==H_kv, num_groups=1) should work with split-K.""" + for D, Lk in itertools.product([64, 128], [128, 512]): + with self.subTest(D=D, Lk=Lk): + B, H, Lq = 1, 4, 1 + torch.manual_seed(42) + q = torch.randn(B, H, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.splitk(q, k, v) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_qwen35_config(self): + """Exact Qwen3.5 MoE config: H_q=16, H_kv=2, D=256.""" + B, H_q, H_kv, D = 1, 16, 2, 256 + for Lk in [128, 512, 1024, 4096]: + with self.subTest(Lk=Lk): + Lq = 1 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + + out = self.splitk(q, k, v, attn_mask=mask) + ref = _reference_sdpa(q, k, v, attn_mask=mask) + + self.assertEqual(out.shape, (B, H_q, Lq, D)) + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_custom_scale(self): + """Non-default attention scale.""" + B, H_q, H_kv, Lq, Lk, D = 1, 8, 2, 1, 256, 128 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + scale = 0.05 + out = self.splitk(q, k, v, scale=scale) + ref = _reference_sdpa(q, k, v, scale=scale) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_cross_validate_with_sdpa(self): + """Split-K output matches tiled sdpa output for decode shapes.""" + B, H_q, H_kv, D = 1, 8, 2, 128 + for Lk in [128, 512, 1024]: + with self.subTest(Lk=Lk): + Lq = 1 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + + out_splitk = self.splitk(q, k, v, attn_mask=mask) + out_tiled = self.sdpa(q, k, v, attn_mask=mask, enable_gqa=True) + + self.assertLess( + _max_abs_error(out_splitk, out_tiled), + 0.05, + f"Split-K vs tiled mismatch at Lk={Lk}", + ) + + # ------------------------------------------------------------------ + # Edge cases + # ------------------------------------------------------------------ + + def test_all_masked(self): + """All-False mask should produce zeros, not NaN.""" + B, H_q, H_kv, Lq, Lk, D = 1, 8, 2, 1, 128, 64 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + mask = torch.zeros(B, 1, Lq, Lk, dtype=torch.bool, device="cuda") + out = self.splitk(q, k, v, attn_mask=mask) + + self.assertFalse(torch.isnan(out).any(), "All-masked should not NaN") + self.assertFalse(torch.isinf(out).any(), "All-masked should not Inf") + + def test_lk_1(self): + """Degenerate single KV position (num_splits=1).""" + B, H_q, H_kv, Lq, Lk, D = 1, 4, 2, 1, 1, 64 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.splitk(q, k, v) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + def test_batch_size(self): + """Batch size > 1.""" + for B in [2, 4]: + with self.subTest(B=B): + H_q, H_kv, Lq, Lk, D = 8, 2, 1, 256, 128 + torch.manual_seed(42) + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + out = self.splitk(q, k, v) + ref = _reference_sdpa(q, k, v) + + self.assertFalse(torch.isnan(out).any()) + self.assertLess(_max_abs_error(out, ref), 0.05) + + # ------------------------------------------------------------------ + # Validation errors + # ------------------------------------------------------------------ + + def test_lq_not_1_rejected(self): + """L_q != 1 should raise RuntimeError.""" + B, H_q, H_kv, D = 1, 8, 2, 64 + q = torch.randn(B, H_q, 4, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.splitk(q, k, v) + + def test_dropout_rejected(self): + """dropout_p != 0 should raise RuntimeError.""" + B, H_q, H_kv, D = 1, 8, 2, 64 + q = torch.randn(B, H_q, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.splitk(q, k, v, dropout_p=0.1) + + def test_is_causal_rejected(self): + """is_causal=True should raise RuntimeError.""" + B, H_q, H_kv, D = 1, 8, 2, 64 + q = torch.randn(B, H_q, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.splitk(q, k, v, is_causal=True) + + def test_hq_not_divisible_rejected(self): + """H_q % H_kv != 0 should raise RuntimeError.""" + B, D = 1, 64 + q = torch.randn(B, 5, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, 3, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, 3, 64, D, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.splitk(q, k, v) + + def test_non_pow2_d_rejected(self): + """Non-power-of-2 D should raise RuntimeError.""" + B, H_q, H_kv, D = 1, 8, 2, 96 + q = torch.randn(B, H_q, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.splitk(q, k, v) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py index e7af2bdaf84..c7589a48fdf 100644 --- a/backends/cuda/triton/kernels/__init__.py +++ b/backends/cuda/triton/kernels/__init__.py @@ -5,12 +5,13 @@ # LICENSE file in the root directory of this source tree. from executorch.backends.cuda.triton.kernels.fused_moe import fused_moe -from executorch.backends.cuda.triton.kernels.sdpa import sdpa +from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk from executorch.backends.cuda.triton.kernels.topk import topk __all__ = [ "fused_moe", "sdpa", + "sdpa_decode_splitk", "topk", ] diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index d83f8e0557a..0adb998fbd8 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -46,7 +46,11 @@ def _is_power_of_2(n: int) -> bool: def _next_power_of_2(x: int) -> int: - """Get the next power of 2 >= x, clamped to [16, 256].""" + """Get the next power of 2 >= x, clamped to [16, 256]. + + Used for HEAD_DIM tiling where tile sizes below 16 waste warps + and head dims above 256 are unsupported. + """ if x <= 16: return 16 if x <= 32: @@ -58,6 +62,17 @@ def _next_power_of_2(x: int) -> int: return 256 +def _next_power_of_2_unclamped(x: int) -> int: + """Get the next power of 2 >= x (no clamping). + + Used for GQA group-count tiling where num_groups can be small (1, 2, ...) + and should not be inflated to 16. + """ + if x <= 0: + return 1 + return 1 << (x - 1).bit_length() + + def _should_pack_gqa(L_q: int, num_groups: int, block_m: int) -> bool: """Decide whether to use pack GQA based on tile utilization. @@ -1032,3 +1047,428 @@ def _sdpa_abstract( B, H_q, _H_kv, L_q, _, D_q, _ = _validate_qkv_shapes(query, key, value, enable_gqa) return torch.empty(B, H_q, L_q, D_q, dtype=query.dtype, device=query.device) + + +# ============================================================================== +# Split-K decode kernel (flash-decoding) +# ============================================================================== +# When L_q == 1 with GQA, the standard kernel launches only +# ceil(num_groups / BLOCK_M) * B * H_kv CTAs (e.g. 2 for Qwen3.5 MoE). +# Split-K partitions the KV sequence across many CTAs for better occupancy, +# then reduces partial results in a second kernel. + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 32}, num_warps=2, num_stages=1), + triton.Config({"BLOCK_N": 32}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_N": 64}, num_warps=2, num_stages=1), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_N": 256}, num_warps=8, num_stages=2), + ], + key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK"], +) +@triton.jit +def _sdpa_decode_splitk_kernel( + Q_ptr, + K_ptr, + V_ptr, + O_partial_ptr, + M_partial_ptr, + L_partial_ptr, + Mask_ptr, + B, + H_kv, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_op_s, + stride_op_b, + stride_op_h, + stride_op_d, + stride_mp_s, + stride_mp_b, + stride_mp_h, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + chunk_size, + HAS_MASK: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM: tl.constexpr, + NUM_GROUPS: tl.constexpr, + BLOCK_G: tl.constexpr, +): + split_id = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + b = pid_bh // H_kv + h_kv = pid_bh % H_kv + + start_n = split_id * chunk_size + end_n = tl.minimum(start_n + chunk_size, Lk) + + offs_d = tl.arange(0, HEAD_DIM) + offs_g = tl.arange(0, BLOCK_G) + g_valid = offs_g < NUM_GROUPS + h_q_heads = h_kv * NUM_GROUPS + offs_g # [BLOCK_G] + + # Load Q for all heads in this group: [BLOCK_G, HEAD_DIM] + q_ptrs = Q_ptr + ( + b * stride_qb + + h_q_heads[:, None] * stride_qh + + 0 * stride_qm + + offs_d[None, :] * stride_qd + ) + q = tl.load(q_ptrs, mask=g_valid[:, None], other=0.0).to(tl.bfloat16) + + m_i = tl.full([BLOCK_G], -float("inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_G], dtype=tl.float32) + acc = tl.zeros([BLOCK_G, HEAD_DIM], dtype=tl.float32) + + offs_n_init = tl.arange(0, BLOCK_N) + + for tile_start in tl.range(start_n, end_n, BLOCK_N): + offs_n = tile_start + offs_n_init + n_valid = offs_n < end_n + + k_ptrs = K_ptr + ( + b * stride_kb + + h_kv * stride_kh + + offs_n[:, None] * stride_kn + + offs_d[None, :] * stride_kd + ) + k = tl.load(k_ptrs, mask=n_valid[:, None], other=0.0).to(tl.bfloat16) + + # QK: [BLOCK_G, BLOCK_N] + qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32) + + # Mask out-of-bounds KV positions + qk = tl.where( + n_valid[None, :], + qk, + tl.full(qk.shape, -float("inf"), dtype=tl.float32), + ) + + if HAS_MASK: + mask_ptrs = Mask_ptr + ( + b * stride_mb + 0 * stride_mq + offs_n[None, :] * stride_mk + ) + mask_block = tl.load(mask_ptrs, mask=n_valid[None, :], other=False) + qk = tl.where( + mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32) + ) + + # Online softmax update + m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32)) + safe_diff = tl.where( + m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") + ) + p_f32 = tl.exp(safe_diff).to(tl.float32) + l_ij = tl.sum(p_f32, axis=1).to(tl.float32) + safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) + alpha = tl.exp(safe_alpha_diff).to(tl.float32) + + v_ptrs = V_ptr + ( + b * stride_vb + + h_kv * stride_vh + + offs_n[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) + v = tl.load(v_ptrs, mask=n_valid[:, None], other=0.0).to(tl.bfloat16) + + p_bf16 = p_f32.to(tl.bfloat16) + acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32) + l_i = (l_i * alpha + l_ij).to(tl.float32) + m_i = m_ij + + # Store partial results for valid groups only + h_q_all = h_kv * NUM_GROUPS + offs_g # [BLOCK_G] + o_ptrs = O_partial_ptr + ( + split_id * stride_op_s + + b * stride_op_b + + h_q_all[:, None] * stride_op_h + + offs_d[None, :] * stride_op_d + ) + tl.store(o_ptrs, acc, mask=g_valid[:, None]) + + ml_ptrs = M_partial_ptr + ( + split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h + ) + tl.store(ml_ptrs, m_i, mask=g_valid) + + ll_ptrs = L_partial_ptr + ( + split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h + ) + tl.store(ll_ptrs, l_i, mask=g_valid) + + +@triton.jit +def _sdpa_decode_reduce_kernel( + O_partial_ptr, + M_partial_ptr, + L_partial_ptr, + O_ptr, + num_splits, + stride_op_s, + stride_op_b, + stride_op_h, + stride_op_d, + stride_mp_s, + stride_mp_b, + stride_mp_h, + stride_ob, + stride_oh, + stride_om, + stride_od, + HEAD_DIM: tl.constexpr, +): + pid = tl.program_id(axis=0) + offs_d = tl.arange(0, HEAD_DIM) + + # pid indexes into flattened (B, H_q). Partial buffers are allocated + # contiguous in _launch_decode_splitk, so pid * stride_*_h is valid. + # Find global max across all splits + m_global = tl.full([1], -float("inf"), dtype=tl.float32) + for s in tl.range(0, num_splits): + m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h + m_s = tl.load(m_ptr) + m_global = tl.maximum(m_global, m_s) + + # Accumulate rescaled outputs + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + l_global = tl.zeros([1], dtype=tl.float32) + for s in tl.range(0, num_splits): + m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h + l_ptr = L_partial_ptr + s * stride_mp_s + pid * stride_mp_h + o_ptrs = O_partial_ptr + ( + s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d + ) + + m_s = tl.load(m_ptr) + l_s = tl.load(l_ptr) + o_s = tl.load(o_ptrs) + + safe_diff = tl.where(m_global > -float("inf"), m_s - m_global, 0.0) + scale = tl.exp(safe_diff).to(tl.float32) + acc += o_s * scale + l_global += l_s * scale + + inv_l = tl.where(l_global > 0, 1.0 / l_global, 0.0) + acc = acc * inv_l + + # pid = b*H_q + h_q. For contiguous output [B, H_q, 1, D] with L_q=1, + # stride_ob == H_q * stride_oh, so pid * stride_oh is correct. + # This relies on `out` being freshly allocated and contiguous. + o_out_ptrs = O_ptr + pid * stride_oh + offs_d * stride_od + tl.store(o_out_ptrs, acc.to(tl.bfloat16)) + + +def _launch_decode_splitk( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + B: int, + H_q: int, + H_kv: int, + L_kv: int, + D: int, + sm_scale: float, + HAS_MASK: bool, + Mask_ptr, + stride_mb: int, + stride_mq: int, + stride_mk: int, + num_groups: int, +) -> None: + num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128) + chunk_size = triton.cdiv(L_kv, num_splits) + + O_partial = torch.empty( + (num_splits, B, H_q, D), device=query.device, dtype=torch.float32 + ) + M_partial = torch.full( + (num_splits, B, H_q), -float("inf"), device=query.device, dtype=torch.float32 + ) + L_partial = torch.zeros( + (num_splits, B, H_q), device=query.device, dtype=torch.float32 + ) + + stride_qb, stride_qh, stride_qm, stride_qd = query.stride() + stride_kb, stride_kh, stride_kn, stride_kd = key.stride() + stride_vb, stride_vh, stride_vn, stride_vd = value.stride() + stride_ob, stride_oh, stride_om, stride_od = out.stride() + stride_op_s, stride_op_b, stride_op_h, stride_op_d = O_partial.stride() + stride_mp_s, stride_mp_b, stride_mp_h = M_partial.stride() + + grid_split = (num_splits, B * H_kv) + wrap_triton(_sdpa_decode_splitk_kernel)[grid_split]( + query, + key, + value, + O_partial, + M_partial, + L_partial, + Mask_ptr if HAS_MASK else 0, + B, + H_kv, + L_kv, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_op_s, + stride_op_b, + stride_op_h, + stride_op_d, + stride_mp_s, + stride_mp_b, + stride_mp_h, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + chunk_size, + HAS_MASK=HAS_MASK, + HEAD_DIM=D, + NUM_GROUPS=num_groups, + BLOCK_G=_next_power_of_2_unclamped(num_groups), + ) + + grid_reduce = (B * H_q,) + wrap_triton(_sdpa_decode_reduce_kernel)[grid_reduce]( + O_partial, + M_partial, + L_partial, + out, + num_splits, + stride_op_s, + stride_op_b, + stride_op_h, + stride_op_d, + stride_mp_s, + stride_mp_b, + stride_mp_h, + stride_ob, + stride_oh, + stride_om, + stride_od, + HEAD_DIM=D, + num_warps=4, + num_stages=1, + ) + + +@triton_op("triton::sdpa_decode_splitk", mutates_args={}) +def sdpa_decode_splitk( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, +) -> torch.Tensor: + """Split-K flash-decoding SDPA for L_q=1 (decode step). + + Signature mirrors sdpa() for drop-in use with torch.cond dispatch. + enable_gqa is accepted but ignored — GQA is handled natively via + H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. + """ + _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) + + B, H_q, L_q, D = query.shape + _, H_kv, L_kv, _ = key.shape + + out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) + + if is_causal: + raise RuntimeError( + "sdpa_decode_splitk does not support is_causal=True " + "(causal masking is a no-op at L_q=1; pass attn_mask instead)" + ) + + # Validation — only check at runtime (concrete shapes), not during AOTI + # tracing where shapes are symbolic. torch.cond traces both branches with + # the same symbolic L_q, so L_q is not necessarily 1 during tracing. + if isinstance(L_q, int): + if L_q != 1: + raise RuntimeError( + f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}" + ) + if H_q % H_kv != 0: + raise RuntimeError( + f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}" + ) + if not _is_power_of_2(D): + raise RuntimeError( + f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}" + ) + + num_groups = H_q // H_kv + sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale + HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params( + attn_mask, B, L_q, L_kv + ) + + _launch_decode_splitk( + query, + key, + value, + out, + B, + H_q, + H_kv, + L_kv, + D, + sm_scale, + HAS_MASK, + Mask_ptr, + stride_mb, + stride_mq, + stride_mk, + num_groups, + ) + return out + + +@sdpa_decode_splitk.register_fake +def _sdpa_decode_splitk_abstract( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: float = 0.0, + enable_gqa: bool = False, +) -> torch.Tensor: + assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" + B, H_q, L_q, D = query.shape + return torch.empty(B, H_q, L_q, D, dtype=query.dtype, device=query.device) diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 7437bc5f461..d885b74024c 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -398,9 +398,13 @@ def export_and_lower(model, config, args): # -O0 compiles ~8x faster than -O1 with no measurable runtime impact. inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" - # Dynamic shapes - example_tokens = torch.tensor([[0, 1]], dtype=torch.long) - example_input_pos = torch.tensor([0, 1], dtype=torch.long) + # Dynamic shapes — example T must equal max_seq_len-1 so AOTI compiles + # kernels (especially chunk_gated_delta_rule with CHUNK_SIZE=64) for the + # full range of sequence lengths. Smaller examples cause AOTI to bake in + # intermediate buffer sizes that reject longer prompts at runtime. + example_seq_len = config.max_seq_len - 1 + example_tokens = torch.zeros((1, example_seq_len), dtype=torch.long) + example_input_pos = torch.arange(example_seq_len, dtype=torch.long) seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1) dynamic_shapes = ({1: seq_dim}, {0: seq_dim}) diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 266d0e65419..4e334b60135 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -13,6 +13,8 @@ #include #include +#include +#include #include #include @@ -67,7 +69,49 @@ int main(int argc, char** argv) { config.temperature = FLAGS_temperature; config.max_new_tokens = FLAGS_max_new_tokens; - auto error = runner->generate(FLAGS_prompt.c_str(), config); + auto error = runner->generate( + FLAGS_prompt.c_str(), + config, + /*token_callback=*/{}, + [](const llm::Stats& stats) { + double scale = stats.SCALING_FACTOR_UNITS_PER_SECOND; + double model_load_s = + (stats.model_load_end_ms - stats.model_load_start_ms) / scale; + double inference_s = + (stats.inference_end_ms - stats.inference_start_ms) / scale; + double prefill_s = + (stats.prompt_eval_end_ms - stats.inference_start_ms) / scale; + double decode_s = + (stats.inference_end_ms - stats.prompt_eval_end_ms) / scale; + double ttft_s = + (stats.first_token_ms - stats.inference_start_ms) / scale; + double sampling_s = stats.aggregate_sampling_time_ms / scale; + + printf( + "\n\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, + stats.num_prompt_tokens, + stats.num_generated_tokens); + printf("\n\tModel Load Time:\t\t%f (seconds)", model_load_s); + printf( + "\n\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + inference_s, + inference_s > 0 ? stats.num_generated_tokens / inference_s : 0.0); + printf( + "\n\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + prefill_s, + prefill_s > 0 ? stats.num_prompt_tokens / prefill_s : 0.0); + printf( + "\n\t\tGenerated %" PRIu64 + " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + stats.num_generated_tokens, + decode_s, + decode_s > 0 ? stats.num_generated_tokens / decode_s : 0.0); + printf("\n\tTime to first generated token:\t%f (seconds)", ttft_s); + printf( + "\n\tSampling time over %" PRIu64 " tokens:\t%f (seconds)\n", + stats.num_prompt_tokens + stats.num_generated_tokens, + sampling_s); + }); if (error != executorch::runtime::Error::Ok) { ET_LOG(Error, "Generation failed"); return 1; diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index d9f127d9ed1..5f0aa047286 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -20,6 +20,8 @@ import torch import torch.nn as nn + +from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk from torch.nn import functional as F @@ -285,8 +287,14 @@ def forward(self, x, input_pos): ) else: k, v = self.kv_cache.update(input_pos, k, v) - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, enable_gqa=True + # Runtime dispatch via torch.cond: + # decode (L_q==1): split-K flash-decoding for high KV occupancy + # prefill (L_q>1): standard tiled SDPA + y = torch.cond( + q.shape[2] == 1, + lambda q, k, v, mask: sdpa_decode_splitk(q, k, v, attn_mask=mask), + lambda q, k, v, mask: sdpa(q, k, v, attn_mask=mask, enable_gqa=True), + [q, k, v, attn_mask], ) y = y.transpose(1, 2).contiguous().view(B, T, -1)