-
Notifications
You must be signed in to change notification settings - Fork 614
Add support for SWA (left, right) with FusedAttention #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…IA#1369 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L0 |
Greptile SummaryThis PR adds support for left-right sliding window attention (SWA) to FusedAttention by introducing the Key changes:
Critical issues identified:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant TransformerLayer
participant MHA as MultiheadAttention
participant DPA as DotProductAttention
participant Backend as FusedAttention Backend
participant CPP as C++ Extensions
participant CUDA as CUDA Kernels
User->>TransformerLayer: forward(bottom_right_diagonal)
TransformerLayer->>TransformerLayer: Set defaults based on mask_type
TransformerLayer->>MHA: forward(bottom_right_diagonal)
MHA->>MHA: Apply mask_type logic
MHA->>DPA: forward(bottom_right_diagonal)
DPA->>DPA: Calculate default if None
DPA->>Backend: get_attention_backend(bottom_right_diagonal)
Backend->>Backend: Filter backends by SWA support
Backend-->>DPA: Selected backend
DPA->>CPP: fused_attn_fwd(bottom_right_diagonal)
CPP->>CUDA: nvte_fused_attn_fwd(bottom_right_diagonal)
alt F16 Arbitrary Seqlen
CUDA->>CUDA: fused_attn_arbitrary_seqlen_fwd
CUDA->>CUDA: Set diagonal_alignment (TOP_LEFT/BOTTOM_RIGHT)
CUDA->>CUDA: Set diagonal_band_right_bound if needed
CUDA->>CUDA: Build cuDNN graph with SWA config
else FP8 Path
CUDA->>CUDA: fused_attn_fp8_fwd_impl_v1
Note over CUDA: Hardcoded bottom_right_diagonal=true
CUDA->>CUDA: FADescriptor_v1 with hardcoded value
end
CUDA-->>CPP: Attention output
CPP-->>DPA: Output tensor
DPA-->>MHA: Output tensor
MHA-->>TransformerLayer: Output tensor
TransformerLayer-->>User: Final output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)logic: Trailing comma creates single-element tuple instead of boolean - should this be just
bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]? -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)style: Uses hardcoded mask type check instead of the new
bottom_right_diagonalparameter for ALiBi alignment. Should this usebottom_right_diagonalparameter for consistency?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
15 files reviewed, 8 comments
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
| if self_attn_mask_type in {"causal", "padding_causal"}: | ||
| bottom_right_diagonal = False | ||
| if bottom_right_diagonal is None or self_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?
| if enc_dec_attn_mask_type in {"causal", "padding_causal"}: | ||
| enc_dec_bottom_right_diagonal = False | ||
| if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| enc_dec_bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same logic issue as above - mask type check overrides explicit parameter values
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/pytorch/transformer.py, line 777-783 (link)logic: Conditional logic overwrites explicitly passed parameter. When
self_attn_mask_type in {"causal", "padding_causal"}at line 777, this unconditionally setsbottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.Consider restructuring to only apply defaults when
bottom_right_diagonal is None: -
transformer_engine/pytorch/transformer.py, line 787-793 (link)logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites
enc_dec_bottom_right_diagonal, ignoring explicitly passed values. -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2
15 files reviewed, 4 comments
|
/te-ci pytorch L0 |
1 similar comment
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean
15 files reviewed, 1 comment
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L0 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
a56c4d2 to
a545ebf
Compare
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
17 files reviewed, 4 comments
|
You may need to add the extra argument |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ansformerEngine into swa_padding_brcm_try2
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
20 files reviewed, 4 comments
|
/te-ci L1 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ansformerEngine into swa_padding_brcm_try2
… swa_padding_brcm_try2
|
/te-ci jax L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
20 files reviewed, 2 comments
| "Disabling FlashAttention as it only supports sliding window with bottom right" | ||
| " diagonal alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: wrong variable - should be use_flash_attention_2 = False
this disables all FlashAttention backends when only FlashAttention 2 should be disabled for sliding window with cross-attention
| use_flash_attention = False | |
| use_flash_attention_2 = False |
| "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" | ||
| " alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: wrong variable - should be use_flash_attention_2 = False
this disables all FlashAttention backends when only FlashAttention 2 should be disabled for ALiBi with cross-attention
| use_flash_attention = False | |
| use_flash_attention_2 = False |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci jax L1 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/common/fused_attn/fused_attn_fp8.cu, line 2480-2486 (link)logic: The
bottom_right_diagonalparameter is not being passed tofused_attn_fp8_fwd_impl_v1, even though it was added to the higher-level API. This causes the function to use hardcoded values (line 1709:true) instead of respecting the user-specified configuration.Add
bottom_right_diagonalparameter to the function call and update the function signature to accept it. -
transformer_engine/common/fused_attn/fused_attn_fp8.cu, line 2575-2580 (link)logic: The
bottom_right_diagonalparameter is not being passed tofused_attn_fp8_bwd_impl_v1. Add it to the function call to match the forward pass fix.
20 files reviewed, 6 comments
| "Disabling FlashAttention as it only supports sliding window with bottom right" | ||
| " diagonal alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: wrong variable - should be use_flash_attention_2 = False
| use_flash_attention = False | |
| use_flash_attention_2 = False |
| "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" | ||
| " alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: wrong variable - should be use_flash_attention_2 = False
| use_flash_attention = False | |
| use_flash_attention_2 = False |
| true, | ||
| true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: hardcoded bottom_right_diagonal value (position 9 in FADescriptor_v1 initialization) bypasses the parameter mechanism. Accept bottom_right_diagonal as a function parameter and use it here instead of hardcoding true.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| true, | ||
| false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: hardcoded bottom_right_diagonal value in backward pass. Accept as parameter and pass through instead of hardcoding true.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci L1 |
|
/te-ci L1 |
cyanguwa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks!
* SWA (left, right) with FusedAttention changes cherry-picked from #1369 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test_kv_cache failures Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * remove unnecessary comments Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix some more filter issues, address feedback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix for local test case failures - `bottom_right_diagonal` should be calculated in `fused_attn_fwd` call as well Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * make conditions more accurate Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * add cp tests to test swa (left, right) Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove dead code and make conditions better Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feedback form Charlene Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * small er Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * plumb `bottom_right_diagonal` through jax Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * plumb `bottom_right_diagonal` through jax Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add missing fields Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * use proper mask type in CP Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)
Type of change
Changes
Please list the changes introduced in this PR:
transformer_enginecommonfused_attnfused_attn.cppbottom_right_diagonalparameter to the APIfused_attn_f16_arbitrary_seqlen.cu: addbottom_right_diagonalparameter to the APIfused_attn_fp8.cu: addbottom_right_diagonalparameter to theFADescriptor_v1APIutils.h: addbottom_right_diagonalparameter toFADescriptor_v1APIpytorchtransformer.pybottom_right_diagonalthrough the call stack:TransformerLayer-->SelfAttention/CrossAttentionattentiondot_product_attentionbackends.py:UnfusedDotProductAttentionbottom_right_diagonalparameter to theforwardAPIforward?bottom_right_alignmentis being used in the Alibi call, perhaps this should be correctedFusedAttncustom modulebottom_right_diagonalparameter to theforwardAPIFusedAttentionmodulebottom_right_diagonalthrough the call stackdot_product_attention.pyDotProductAttentionbottom_right_diagonalthrough the call stackbottom_right_diagonalif it'sNoneutils.pyAttentionParamsget_attention_backendmulti_head_attention.pybottom_right_diagonalto forward API and callbottom_right_diagonalif it'sNonecpp_extentionsfused_attn.pybottom_right_diagonalinfused_attn_fwd/fused_attn_bwdcsrcextensionattention.cppbottom_right_diagonalthrough the call stack:fused_attn_fwd-->nvte_fused_attn_fwdextensions.hbottom_right_diagonaltofused_attn_fwdandfused_attn_bwdAPI definitionsChecklist: