Commit 5d3b620
committed
Add torch.cond split-K decode dispatch to Qwen3.5 MoE attention
Runtime dispatch via torch.cond in FullAttention: split-K flash-decoding
for decode (L_q==1) and standard tiled SDPA for prefill (L_q>1). Guard
sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing
with symbolic shapes doesn't trip the L_q==1 check.
Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal,
enable_gqa) for drop-in use with torch.cond; unsupported args fail
with clear messages.
End-to-end on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096,
1024 decode tokens, prompt="Hi", temperature=0, 5 runs median):
Baseline (tiled) Split-K Speedup
Decode tok/s 61.7 89.9 1.46x
Prefill tok/s 378.2 378.2 1.00x
nsys GPU time 13853 ms 8674 ms 1.60x
SDPA kernel 5370 ms (38.8%) 209 ms (2.4%) 25.7x1 parent e06db27 commit 5d3b620
2 files changed
Lines changed: 46 additions & 13 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1390 | 1390 | | |
1391 | 1391 | | |
1392 | 1392 | | |
| 1393 | + | |
| 1394 | + | |
1393 | 1395 | | |
| 1396 | + | |
1394 | 1397 | | |
| 1398 | + | |
| 1399 | + | |
| 1400 | + | |
| 1401 | + | |
| 1402 | + | |
| 1403 | + | |
| 1404 | + | |
| 1405 | + | |
1395 | 1406 | | |
1396 | 1407 | | |
1397 | 1408 | | |
1398 | | - | |
1399 | | - | |
1400 | | - | |
1401 | | - | |
1402 | | - | |
1403 | | - | |
1404 | | - | |
1405 | | - | |
1406 | | - | |
| 1409 | + | |
| 1410 | + | |
| 1411 | + | |
1407 | 1412 | | |
1408 | | - | |
| 1413 | + | |
| 1414 | + | |
1409 | 1415 | | |
1410 | 1416 | | |
| 1417 | + | |
| 1418 | + | |
| 1419 | + | |
| 1420 | + | |
| 1421 | + | |
| 1422 | + | |
| 1423 | + | |
| 1424 | + | |
| 1425 | + | |
| 1426 | + | |
| 1427 | + | |
| 1428 | + | |
| 1429 | + | |
| 1430 | + | |
| 1431 | + | |
| 1432 | + | |
| 1433 | + | |
1411 | 1434 | | |
1412 | | - | |
1413 | 1435 | | |
1414 | 1436 | | |
1415 | 1437 | | |
| |||
1430 | 1452 | | |
1431 | 1453 | | |
1432 | 1454 | | |
| 1455 | + | |
| 1456 | + | |
1433 | 1457 | | |
| 1458 | + | |
1434 | 1459 | | |
1435 | 1460 | | |
1436 | 1461 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
| 26 | + | |
25 | 27 | | |
26 | 28 | | |
27 | 29 | | |
| |||
285 | 287 | | |
286 | 288 | | |
287 | 289 | | |
288 | | - | |
289 | | - | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
290 | 298 | | |
291 | 299 | | |
292 | 300 | | |
| |||
0 commit comments