Skip to content

Commit 5d3b620

Browse files
committed
Add torch.cond split-K decode dispatch to Qwen3.5 MoE attention
Runtime dispatch via torch.cond in FullAttention: split-K flash-decoding for decode (L_q==1) and standard tiled SDPA for prefill (L_q>1). Guard sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing with symbolic shapes doesn't trip the L_q==1 check. Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal, enable_gqa) for drop-in use with torch.cond; unsupported args fail with clear messages. End-to-end on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096, 1024 decode tokens, prompt="Hi", temperature=0, 5 runs median): Baseline (tiled) Split-K Speedup Decode tok/s 61.7 89.9 1.46x Prefill tok/s 378.2 378.2 1.00x nsys GPU time 13853 ms 8674 ms 1.60x SDPA kernel 5370 ms (38.8%) 209 ms (2.4%) 25.7x
1 parent e06db27 commit 5d3b620

2 files changed

Lines changed: 46 additions & 13 deletions

File tree

backends/cuda/triton/kernels/sdpa.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,26 +1390,48 @@ def sdpa_decode_splitk(
13901390
key: torch.Tensor,
13911391
value: torch.Tensor,
13921392
attn_mask: Optional[torch.Tensor] = None,
1393+
dropout_p: float = 0.0,
1394+
is_causal: bool = False,
13931395
scale: float = 0.0,
1396+
enable_gqa: bool = False,
13941397
) -> torch.Tensor:
1398+
"""Split-K flash-decoding SDPA for L_q=1 (decode step).
1399+
1400+
Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
1401+
enable_gqa is accepted but ignored — GQA is handled natively via
1402+
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
1403+
"""
1404+
_validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa)
1405+
13951406
B, H_q, L_q, D = query.shape
13961407
_, H_kv, L_kv, _ = key.shape
13971408

1398-
if L_q != 1:
1399-
raise RuntimeError(
1400-
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
1401-
)
1402-
if H_q % H_kv != 0:
1403-
raise RuntimeError(
1404-
f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}"
1405-
)
1406-
if not _is_power_of_2(D):
1409+
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
1410+
1411+
if is_causal:
14071412
raise RuntimeError(
1408-
f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}"
1413+
"sdpa_decode_splitk does not support is_causal=True "
1414+
"(causal masking is a no-op at L_q=1; pass attn_mask instead)"
14091415
)
14101416

1417+
# Validation — only check at runtime (concrete shapes), not during AOTI
1418+
# tracing where shapes are symbolic. torch.cond traces both branches with
1419+
# the same symbolic L_q, so L_q is not necessarily 1 during tracing.
1420+
if isinstance(L_q, int):
1421+
if L_q != 1:
1422+
raise RuntimeError(
1423+
f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}"
1424+
)
1425+
if H_q % H_kv != 0:
1426+
raise RuntimeError(
1427+
f"H_q must be divisible by H_kv; got H_q={H_q}, H_kv={H_kv}"
1428+
)
1429+
if not _is_power_of_2(D):
1430+
raise RuntimeError(
1431+
f"sdpa_decode_splitk requires power-of-2 head dim; got D={D}"
1432+
)
1433+
14111434
num_groups = H_q // H_kv
1412-
out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype)
14131435
sm_scale = 1.0 / math.sqrt(D) if scale == 0.0 else scale
14141436
HAS_MASK, Mask_ptr, stride_mb, stride_mq, stride_mk = _prepare_mask_params(
14151437
attn_mask, B, L_q, L_kv
@@ -1430,7 +1452,10 @@ def _sdpa_decode_splitk_abstract(
14301452
key: torch.Tensor,
14311453
value: torch.Tensor,
14321454
attn_mask: Optional[torch.Tensor] = None,
1455+
dropout_p: float = 0.0,
1456+
is_causal: bool = False,
14331457
scale: float = 0.0,
1458+
enable_gqa: bool = False,
14341459
) -> torch.Tensor:
14351460
assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype"
14361461
B, H_q, L_q, D = query.shape

examples/models/qwen3_5_moe/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import torch.nn as nn
2323
from torch.nn import functional as F
2424

25+
from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk
26+
2527

2628
# ---------------------------------------------------------------------------
2729
# Config
@@ -285,8 +287,14 @@ def forward(self, x, input_pos):
285287
)
286288
else:
287289
k, v = self.kv_cache.update(input_pos, k, v)
288-
y = F.scaled_dot_product_attention(
289-
q, k, v, attn_mask=attn_mask, enable_gqa=True
290+
# Runtime dispatch via torch.cond:
291+
# decode (L_q==1): split-K flash-decoding for high KV occupancy
292+
# prefill (L_q>1): standard tiled SDPA
293+
y = torch.cond(
294+
q.shape[2] == 1,
295+
lambda q, k, v, mask: sdpa_decode_splitk(q, k, v, attn_mask=mask),
296+
lambda q, k, v, mask: sdpa(q, k, v, attn_mask=mask, enable_gqa=True),
297+
[q, k, v, attn_mask],
290298
)
291299

292300
y = y.transpose(1, 2).contiguous().view(B, T, -1)

0 commit comments

Comments
 (0)