Skip to content

Conversation

@cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Jan 12, 2026

Description

This PR enables determinism for FP16/BF16 attention on Blackwell. It requires cuDNN >= 9.18.1.

To run determinism, please set export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0.

Support matrix for FP16/BF16 on Blackwell:

  • cuDNN 9.7.0-9.18.0: non-determinism, dbias without dropout
  • cuDNN 9.18.1+: non-determinism, dbias without dropout; determinism, no dbias or dropout

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please see Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa cyanguwa changed the title [Common] Enable determinism for SDPA on Blackwell [Common] Enable determinism for cuDNN >= 9.18 on Blackwell Jan 12, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Summary

This PR enables deterministic FP16/BF16 attention on Blackwell GPUs (sm100+) with cuDNN >= 9.18.1 by threading a deterministic parameter through the backend selection logic.

Key changes:

  • Added deterministic parameter to nvte_get_fused_attn_backend() and C++ extension wrappers
  • Updated backend selection logic in fused_attn.cpp to handle Blackwell determinism requirements (lines 444-452)
  • Forward passes use deterministic=false while backward passes use the actual user setting from NVTE_ALLOW_NONDETERMINISTIC_ALGO
  • JAX and PyTorch integration layers now pass deterministic flag to backend selection
  • Updated test assertions to reflect support matrix: non-deterministic mode (cuDNN 9.7+) requires no bias OR no dropout; deterministic mode (cuDNN 9.18.1+) requires no bias AND no dropout
  • Added separate CI test runs with NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 to validate deterministic path

The implementation correctly handles the asymmetry where forward passes are always deterministic while backward passes support both modes.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The implementation is clean and well-tested. The changes thread a deterministic parameter through the call stack correctly, the logic in fused_attn.cpp properly handles Blackwell-specific requirements, comprehensive tests cover both deterministic and non-deterministic modes, and CI validates both paths separately. All previous review comments have been addressed by the developer.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp adds deterministic parameter to backend selection for Blackwell (sm100+) support with cuDNN 9.18.1+
transformer_engine/jax/cpp_extensions/attention.py updates Blackwell assertions for non-deterministic (cuDNN 9.7+) and deterministic (cuDNN 9.18.1+) modes
tests/jax/test_fused_attn.py adds determinism-specific test class and updates Blackwell skip conditions for bias/dropout support
tests/pytorch/attention/test_attention.py adds determinism flag to all backend availability checks and sets NVTE_UNFUSED_ATTN environment variable

Sequence Diagram

sequenceDiagram
    participant User
    participant Python as Python Layer<br/>(JAX/PyTorch)
    participant CPP as C++ Extension<br/>(attention.cpp)
    participant Backend as Backend Selection<br/>(fused_attn.cpp)
    participant cuDNN

    Note over User,cuDNN: Backward Pass (Training)
    User->>Python: Set NVTE_ALLOW_NONDETERMINISTIC_ALGO
    Python->>Python: Read env var & determine<br/>deterministic flag
    Python->>CPP: Call fused_attn_bwd with<br/>deterministic parameter
    CPP->>Backend: nvte_get_fused_attn_backend(...,<br/>deterministic)
    
    alt Blackwell (sm100+) && training
        alt deterministic=false (non-deterministic)
            Backend->>Backend: Check cuDNN >= 9.7.0
            Backend->>Backend: Require (dropout=0 OR bias=NO_BIAS)
            Backend->>Backend: Enable arbitrary_seqlen backend
        else deterministic=true
            Backend->>Backend: Check cuDNN >= 9.18.1
            Backend->>Backend: Require (dropout=0 AND bias=NO_BIAS)
            Backend->>Backend: Enable arbitrary_seqlen backend
        end
    else Non-Blackwell or inference
        Backend->>Backend: Standard backend selection
    end
    
    Backend-->>CPP: Return selected backend
    CPP->>cuDNN: Execute attention backward
    cuDNN-->>CPP: Gradients
    CPP-->>Python: Return gradients
    
    Note over User,cuDNN: Forward Pass (Always Deterministic)
    Python->>CPP: Call fused_attn_fwd
    CPP->>Backend: nvte_get_fused_attn_backend(...,<br/>deterministic=false)
    Backend->>Backend: Forward is always deterministic<br/>(hardcoded false)
    Backend-->>CPP: Return selected backend
    CPP->>cuDNN: Execute attention forward
    cuDNN-->>CPP: Output + aux tensors
    CPP-->>Python: Return output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Overview

Greptile Summary

Overview

This PR enables determinism for FusedAttention on Blackwell GPUs (SM 100) with cuDNN version 9.18.0 or higher. The implementation moves determinism checking logic from Python to the C++ backend selection layer.

Key Changes

  1. Backend Selection Logic: Added a new condition in nvte_get_fused_attn_backend() that disables the arbitrary sequence length backend for Blackwell when:

    • Training mode is enabled
    • Determinism is required
    • Any of: cuDNN < 9.18.0, bias is used, or dropout > 0
  2. API Updates: Added deterministic parameter to the backend selection function across Python, C++, and JAX interfaces. Forward passes hardcode deterministic=true while backward passes accept it as a parameter.

  3. Code Migration: Moved Blackwell determinism checks from Python (utils.py) to C++ backend selection, consolidating version, bias, and dropout checks in one place.

  4. Test Infrastructure: Added environment variable NVTE_ALLOW_NONDETERMINISTIC_ALGO to control determinism in tests, and added explicit NVTE_UNFUSED_ATTN=0 settings to ensure proper backend isolation.

  5. Dependency Update: Updated cudnn-frontend submodule to version 1.17 to support the new determinism features.

Architecture

The change follows a layered approach:

  • User API Level: Python tests set deterministic flag via environment variable or torch settings
  • Python Layer: Extracts deterministic flag and passes to C++ extension
  • C++ Backend Selection: Evaluates hardware, cuDNN version, bias, and dropout to determine if deterministic FusedAttention is supported
  • Execution: If requirements aren't met, falls back to other backends (FlashAttention or UnfusedDotProductAttention)

The implementation correctly restricts deterministic FusedAttention to cases where cuDNN guarantees deterministic behavior, avoiding silent non-determinism.

Confidence Score: 4/5

  • This PR is safe to merge with minor issues that should be addressed
  • The implementation is sound and correctly adds determinism support for Blackwell GPUs. The core logic properly checks cuDNN version, bias, and dropout constraints. However, two issues lower the confidence: (1) inconsistent tab/space indentation in the critical condition on line 444 of fused_attn.cpp, and (2) duplicate XML output file in test.sh causing test results to be overwritten. Both are non-critical but should be fixed before merge.
  • Pay attention to transformer_engine/common/fused_attn/fused_attn.cpp (line 444 indentation) and qa/L0_pytorch_unittest/test.sh (line 48 XML filename collision)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/fused_attn.cpp 4/5 Added determinism check for Blackwell (sm100) to disable FusedAttention when cuDNN < 9.18.0 or bias/dropout are used. Contains tab indentation inconsistency on line 444.
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Removed Python-side Blackwell determinism check, now handled in C++. Added deterministic parameter to backend selection call.
tests/pytorch/attention/test_attention.py 5/5 Added deterministic flag from environment variable and torch settings. Updated tests to explicitly set NVTE_UNFUSED_ATTN=0 to ensure correct backend isolation.
qa/L0_pytorch_unittest/test.sh 3/5 Added deterministic test run with NVTE_ALLOW_NONDETERMINISTIC_ALGO=0. Both test runs write to same XML file causing results to be overwritten.

Sequence Diagram

sequenceDiagram
    participant User as User/Test
    participant PyAPI as Python API
    participant Utils as utils.py
    participant CppExt as C++ Extensions
    participant Backend as Backend Selection
    participant cuDNN as cuDNN Library

    User->>PyAPI: Call attention with deterministic=True
    PyAPI->>Utils: get_attention_backend(params)
    Utils->>Utils: Extract deterministic from params
    Utils->>CppExt: get_fused_attn_backend(..., deterministic)
    CppExt->>Backend: nvte_get_fused_attn_backend(..., deterministic)
    
    alt Blackwell (sm_arch >= 100) & Training & Deterministic
        Backend->>Backend: Check cuDNN version >= 9.18.0
        Backend->>Backend: Check bias_type == NO_BIAS
        Backend->>Backend: Check dropout == 0.0
        alt All checks pass
            Backend-->>CppExt: F16_arbitrary_seqlen backend
        else Any check fails
            Backend-->>CppExt: No_Backend (disabled)
        end
    else Other architectures or inference
        Backend->>Backend: Apply standard backend selection
        Backend-->>CppExt: Selected backend
    end
    
    CppExt-->>Utils: Backend choice
    Utils-->>PyAPI: Backend configuration
    
    alt Forward Pass
        PyAPI->>CppExt: nvte_fused_attn_fwd(..., deterministic=true)
        Note over PyAPI,CppExt: Forward always uses deterministic=true
    else Backward Pass
        PyAPI->>CppExt: nvte_fused_attn_bwd(..., deterministic)
        Note over PyAPI,CppExt: Backward respects user's deterministic flag
    end
    
    CppExt->>cuDNN: Execute attention operation
    cuDNN-->>CppExt: Results
    CppExt-->>PyAPI: Output tensors
    PyAPI-->>User: Attention output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

make .xml file specific to deterministic tests in qa/

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 3 commits January 13, 2026 06:00
fix typo

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
fix indentation

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@cyanguwa
Copy link
Collaborator Author

/te-ci jax L0

@cyanguwa
Copy link
Collaborator Author

/te-ci L0

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 3 commits January 15, 2026 06:57
Signed-off-by: Charlene Yang <[email protected]>
fix and/or logic

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@liayan
Copy link

liayan commented Jan 16, 2026

Cool, we are currently suffering from this issue.
Do we have a rough timeline for when it could be merged?
Let me know if there is anything I can do, such as a test. Would like to help.

KshitijLakhani
KshitijLakhani previously approved these changes Jan 16, 2026
Copy link
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments - some suggested changes and some questions.
Looks good to me, otherwise. Approving to not block from merge, if urgent.

It would be helpful, if you have a table for what's supported for <cuDNN9.18, >=cuDNN9.18, <sm100, sm100+, drop, dbias, etc. in the PR description.

I would also suggest to look into the number of tests being run and the timing (you can compare your PRs L0 jax and L0 pyt timings to the timings in TE 2.11 or in TE main CI - we would not want to go overboard with our timing budget, for sure. If you can report the timing in the PR, it would be helpful as well.
Worst case, if urgent, we can merge this PR and address the QA bit (which runs in the CI) in a separate PR subsequently .

Lastly, this might be some effort but would ensure correctness. As the code for skipping the tests in TE JAX tests has been modified, it would be good to check the test count before and after this PR to check if tests that should not be skipped are incorrectly being skipped

mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this will first run the non-deterministic fused attn tests as part of L31, which runs all non distributed tests, followed by running the fused attn deterministic tests as part of L32.
Is that the intention ? - to run fused attn 2x - with and without determinism ?

That will greatly increase our test time and might be unnecessary. The last pipeline launched was for L1 so I am unsure that I can track the effect this change will have on timing as this is an L0 change. Could you report that in the PR please ?
Thanks !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could come with an approach that runs half the fused attn tests deterministically and the other half non-deterministically ?
Or run all deterministically only ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this extra line tests test_fused_attn.py with determinism, while the line before tests everything with non-determinism. The extra test_fused_attn.py test takes ~20mins on Blackwell:

================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_backward                                                | 5040x | 1336.28s | avg:   0.27s
================================================================================
TOTAL RUNTIME                                                |      | 1336.28s |
================================================================================

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now with cd5bcf3, the extra determinism tests should really take no time at all (there are only 20 tests added).

float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt there was a bit of distinction when I was implementing it: is_training is a description of the state we are in while deterministic is more of a request from the user (that they want to run it in deterministic mode). Not a lot of difference, to be honest - just a feel of the words. I kind of did this when I introduced deterministic as a parameter in AttentionParams so just followed along with it in this PR. Any strong objections?

@KshitijLakhani
Copy link
Collaborator

/te-ci L0 L1

@cyanguwa cyanguwa changed the title [Common] Enable determinism for cuDNN >= 9.18 on Blackwell [Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell Jan 19, 2026
@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@cyanguwa
Copy link
Collaborator Author

Pipeline 42017245 for CI with updated cuDNN.

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@cyanguwa
Copy link
Collaborator Author

Pipeline 42067766 for 9.18.1 tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants