|
33 | 33 | BlockSizes = splash_attention_kernel.BlockSizes |
34 | 34 |
|
35 | 35 | AxisNames = tuple[str, ...] |
36 | | - |
| 36 | +# Physical axis names for device meshes. |
| 37 | +DATA = "data" |
| 38 | +FSDP = "fsdp" |
| 39 | +TENSOR = "tensor" |
| 40 | +# Logical axis names for model parameters and activations. |
37 | 41 | BATCH = "activation_batch" |
38 | 42 | LENGTH = "activation_length" |
39 | 43 | KV_LENGTH = "activation_kv_length" |
|
48 | 52 | WAN2_2 = "wan2.2" |
49 | 53 |
|
50 | 54 | WAN_MODEL = WAN2_1 |
| 55 | + |
| 56 | +# For setting self/cross attention independently in splash kernel |
| 57 | +SELF_ATTN_HEAD = "activation_self_attn_heads" |
| 58 | +SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length" |
| 59 | +SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length" |
| 60 | +CROSS_ATTN_HEAD = "activation_cross_attn_heads" |
| 61 | +CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length" |
| 62 | +CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length" |
| 63 | + |
| 64 | + |
| 65 | +WAN_MODEL = "Wan2.1" |
| 66 | + |
| 67 | +### Common axis rules for ring attention ### |
| 68 | +RING_ATTENTION_AXIS_RULES = [ |
| 69 | + [SELF_ATTN_HEAD, None], |
| 70 | + [SELF_ATTN_Q_LENGTH, FSDP], |
| 71 | + [SELF_ATTN_KV_LENGTH, FSDP], |
| 72 | + [CROSS_ATTN_HEAD, None], |
| 73 | + [CROSS_ATTN_Q_LENGTH, FSDP], |
| 74 | + [CROSS_ATTN_KV_LENGTH, FSDP], |
| 75 | +] |
| 76 | + |
| 77 | +SEQUENCE_PARALLEL_AXIS_RULES = [ |
| 78 | + [SELF_ATTN_HEAD, None], |
| 79 | + [SELF_ATTN_Q_LENGTH, FSDP], |
| 80 | + [SELF_ATTN_KV_LENGTH, None], |
| 81 | + [CROSS_ATTN_HEAD, None], |
| 82 | + [CROSS_ATTN_Q_LENGTH, FSDP], |
| 83 | + [CROSS_ATTN_KV_LENGTH, None], |
| 84 | +] |
0 commit comments