Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

transformer_engine

  • common

    • fused_attn
      • fused_attn.cpp
        • add bottom_right_diagonal parameter to the API
        • Edit the filters to allow sliding window config to pick arbitrary seqlen fused attn backend
      • fused_attn_f16_arbitrary_seqlen.cu: add bottom_right_diagonal parameter to the API
      • fused_attn_fp8.cu: add bottom_right_diagonal parameter to the FADescriptor_v1 API
      • utils.h: add bottom_right_diagonal parameter to FADescriptor_v1 API
  • pytorch

    • transformer.py
      • plumb bottom_right_diagonal through the call stack: TransformerLayer --> SelfAttention/CrossAttention
    • attention
      • dot_product_attention
        • backends.py:
          • UnfusedDotProductAttention
            • add bottom_right_diagonal parameter to the forward API
              • why is it not used in the forward?
                • bottom_right_alignment is being used in the Alibi call, perhaps this should be corrected
          • FusedAttn custom module
            • add bottom_right_diagonal parameter to the forward API
          • FusedAttention module
            • plumb bottom_right_diagonal through the call stack
        • dot_product_attention.py
          • DotProductAttention
            • Plumb bottom_right_diagonal through the call stack
            • Add calculation of bottom_right_diagonal if it's None
        • utils.py
          • AttentionParams
            • [x]
          • get_attention_backend
            • update sliding window filter section
            • update attention bias filter section
      • multi_head_attention.py
        • Add bottom_right_diagonal to forward API and call
        • Add calculation of bottom_right_diagonal if it's None
    • cpp_extentions
      • fused_attn.py
        • plumb bottom_right_diagonal in fused_attn_fwd/fused_attn_bwd
    • csrc
      • extension
        • attention.cpp
          • plumb bottom_right_diagonal through the call stack: fused_attn_fwd --> nvte_fused_attn_fwd
          • same as above for bwd
      • extensions.h
        • add bottom_right_diagonal to fused_attn_fwd and fused_attn_bwd API definitions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…IA#1369

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 4, 2025

Greptile Summary

This PR adds support for left-right sliding window attention (SWA) to FusedAttention by introducing the bottom_right_diagonal parameter throughout the stack. The implementation properly threads this parameter from Python APIs through C++ extensions to CUDA kernels, enabling cuDNN to configure diagonal alignment for SWA.

Key changes:

  • Added bottom_right_diagonal field to FADescriptor_v1 struct
  • Updated backend selection filters to support SWA with arbitrary seqlen
  • Implemented diagonal alignment configuration in cuDNN graphs using set_diagonal_alignment and set_diagonal_band_right_bound
  • Expanded test coverage for SWA with multiple mask types and layouts

Critical issues identified:

  • FP8 path: Functions fused_attn_fp8_fwd_impl_v1 and fused_attn_fp8_bwd_impl_v1 hardcode bottom_right_diagonal=true instead of accepting it as a parameter, breaking configurability for FP8 attention
  • Backend selection bug: Lines 911 and 938 in utils.py incorrectly set use_flash_attention = False instead of use_flash_attention_2 = False, which disables all FlashAttention backends when only FlashAttention 2 should be disabled

Confidence Score: 3/5

  • This PR has correct F16 implementation but contains bugs in FP8 path and backend selection logic
  • The F16 arbitrary seqlen implementation is correct and comprehensive with proper test coverage. However, two critical bugs significantly impact functionality: (1) FP8 attention path hardcodes bottom_right_diagonal=true, preventing users from configuring this for FP8 operations, and (2) variable name typos in backend selection incorrectly disable all FlashAttention variants instead of just v2. These bugs don't affect the main F16 path but create inconsistent behavior across backends.
  • Pay close attention to transformer_engine/common/fused_attn/fused_attn_fp8.cu (FP8 hardcoded values) and transformer_engine/pytorch/attention/dot_product_attention/utils.py (variable name typos)

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Properly threads bottom_right_diagonal parameter through all fused attention APIs and updates backend filters for SWA support
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Correctly implements bottom_right_diagonal parameter in cuDNN graphs using set_diagonal_alignment and set_diagonal_band_right_bound
transformer_engine/common/fused_attn/fused_attn_fp8.cu FP8 functions hardcode bottom_right_diagonal to true instead of accepting it as parameter - breaks configurability
transformer_engine/pytorch/attention/dot_product_attention/utils.py Variable name typos on lines 911 and 938 incorrectly disable all FlashAttention instead of just FlashAttention 2
tests/pytorch/attention/test_attention.py Adds comprehensive test coverage for SWA with multiple mask types and layouts (thd, sbhd)

Sequence Diagram

sequenceDiagram
    participant User
    participant TransformerLayer
    participant MHA as MultiheadAttention
    participant DPA as DotProductAttention
    participant Backend as FusedAttention Backend
    participant CPP as C++ Extensions
    participant CUDA as CUDA Kernels
    
    User->>TransformerLayer: forward(bottom_right_diagonal)
    TransformerLayer->>TransformerLayer: Set defaults based on mask_type
    TransformerLayer->>MHA: forward(bottom_right_diagonal)
    MHA->>MHA: Apply mask_type logic
    MHA->>DPA: forward(bottom_right_diagonal)
    DPA->>DPA: Calculate default if None
    DPA->>Backend: get_attention_backend(bottom_right_diagonal)
    Backend->>Backend: Filter backends by SWA support
    Backend-->>DPA: Selected backend
    DPA->>CPP: fused_attn_fwd(bottom_right_diagonal)
    CPP->>CUDA: nvte_fused_attn_fwd(bottom_right_diagonal)
    
    alt F16 Arbitrary Seqlen
        CUDA->>CUDA: fused_attn_arbitrary_seqlen_fwd
        CUDA->>CUDA: Set diagonal_alignment (TOP_LEFT/BOTTOM_RIGHT)
        CUDA->>CUDA: Set diagonal_band_right_bound if needed
        CUDA->>CUDA: Build cuDNN graph with SWA config
    else FP8 Path
        CUDA->>CUDA: fused_attn_fp8_fwd_impl_v1
        Note over CUDA: Hardcoded bottom_right_diagonal=true
        CUDA->>CUDA: FADescriptor_v1 with hardcoded value
    end
    
    CUDA-->>CPP: Attention output
    CPP-->>DPA: Output tensor
    DPA-->>MHA: Output tensor
    MHA-->>TransformerLayer: Output tensor
    TransformerLayer-->>User: Final output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)

    logic: Trailing comma creates single-element tuple instead of boolean - should this be just bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]?

  2. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)

    style: Uses hardcoded mask type check instead of the new bottom_right_diagonal parameter for ALiBi alignment. Should this use bottom_right_diagonal parameter for consistency?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

15 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +777 to +783
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?

Comment on lines +787 to +793
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same logic issue as above - mask type check overrides explicit parameter values

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/pytorch/transformer.py, line 777-783 (link)

    logic: Conditional logic overwrites explicitly passed parameter. When self_attn_mask_type in {"causal", "padding_causal"} at line 777, this unconditionally sets bottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.

    Consider restructuring to only apply defaults when bottom_right_diagonal is None:

  2. transformer_engine/pytorch/transformer.py, line 787-793 (link)

    logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites enc_dec_bottom_right_diagonal, ignoring explicitly passed values.

  3. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  4. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

15 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

1 similar comment
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)

    syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@cyanguwa
Copy link
Collaborator

You may need to add the extra argument bottom_right_diagonal to the nvte_xxx C API calls in Jax too, otherwise TE-Jax won't build. Have a look at PR 2584 maybe, where the deterministic arg was added, so files such as jax/csrc/extensions/attention.cpp need to be changed as well. Thanks!

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci L1

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci jax L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: wrong variable - should be use_flash_attention_2 = False

this disables all FlashAttention backends when only FlashAttention 2 should be disabled for sliding window with cross-attention

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: wrong variable - should be use_flash_attention_2 = False

this disables all FlashAttention backends when only FlashAttention 2 should be disabled for ALiBi with cross-attention

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci jax L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/common/fused_attn/fused_attn_fp8.cu, line 2480-2486 (link)

    logic: The bottom_right_diagonal parameter is not being passed to fused_attn_fp8_fwd_impl_v1, even though it was added to the higher-level API. This causes the function to use hardcoded values (line 1709: true) instead of respecting the user-specified configuration.

    Add bottom_right_diagonal parameter to the function call and update the function signature to accept it.

  2. transformer_engine/common/fused_attn/fused_attn_fp8.cu, line 2575-2580 (link)

    logic: The bottom_right_diagonal parameter is not being passed to fused_attn_fp8_bwd_impl_v1. Add it to the function call to match the forward pass fix.

20 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: wrong variable - should be use_flash_attention_2 = False

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: wrong variable - should be use_flash_attention_2 = False

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

Comment on lines 1709 to +1710
true,
true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: hardcoded bottom_right_diagonal value (position 9 in FADescriptor_v1 initialization) bypasses the parameter mechanism. Accept bottom_right_diagonal as a function parameter and use it here instead of hardcoding true.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +2039 to 2040
true,
false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: hardcoded bottom_right_diagonal value in backward pass. Accept as parameter and pass through instead of hardcoding true.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci L1

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci L1

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

@cyanguwa cyanguwa merged commit c6a92a4 into NVIDIA:main Jan 22, 2026
45 of 54 checks passed
KshitijLakhani pushed a commit that referenced this pull request Jan 27, 2026
* SWA (left, right) with FusedAttention changes cherry-picked from #1369

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test_kv_cache failures

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* remove unnecessary comments

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix some more filter issues, address feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix for local test case failures - `bottom_right_diagonal` should be calculated in `fused_attn_fwd` call as well

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* make conditions more accurate

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* add cp tests to test swa (left, right)

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove dead code and make conditions better

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feedback form Charlene

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* small er

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* plumb `bottom_right_diagonal` through jax

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* plumb `bottom_right_diagonal` through jax

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add missing fields

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* use proper mask type in CP

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants