Skip to content

Add ModelOpt Triton attention kernels for WAN2.2 diffusion (sparse, skip-softmax, NVFP4)#1190

Open
yeyu-nvidia wants to merge 52 commits intomainfrom
yeyu/sage-attention-diffusion
Open

Add ModelOpt Triton attention kernels for WAN2.2 diffusion (sparse, skip-softmax, NVFP4)#1190
yeyu-nvidia wants to merge 52 commits intomainfrom
yeyu/sage-attention-diffusion

Conversation

@yeyu-nvidia
Copy link
Copy Markdown
Contributor

@yeyu-nvidia yeyu-nvidia commented Apr 7, 2026

Summary

This PR adds four Triton-backed attention kernel modes for WAN2.2 diffusion models, integrated into the modelopt.torch.sparsity.attention_sparsity framework. It also fixes a registration bug and adds NVFP4 P-matrix quantization to the Triton kernel.

New Features

Four kernel modes are available via mtsa.sparsify() with backend="diffusers_triton":

Kernel Method Description
triton-sparse triton_sparse_softmax 2:4 N:M sparse softmax — keeps top-2 attention scores per 4 KV positions
triton-skip triton_skip_softmax Skip-softmax tile pruning — skips KV tiles where max score is below threshold
triton-sparse-nvfp4 triton_sparse_softmax Same as triton-sparse with NVFP4 E2M1 per-tile quantization of the post-softmax P matrix
triton-skip-nvfp4 triton_skip_softmax Same as triton-skip with NVFP4 E2M1 per-tile quantization of the post-softmax P matrix

Usage:

import modelopt.torch.sparsity.attention_sparsity as mtsa

# 2:4 N:M sparse softmax
mtsa.sparsify(transformer, {
    "sparse_cfg": {
        "*": {
            "method": "triton_sparse_softmax",
            "sparsity_n": 2,
            "sparsity_m": 4,
            "num_sink_tokens": 0,
            "dense_window_size": 64,
            "backend": "diffusers_triton",
        }
    }
})

# Skip-softmax tile pruning
# skip_softmax_threshold (lambda): skip a KV tile when exp(tile_max - running_max) < lambda.
# Threshold is scaled by sm_scale so sparsity is head-dim-relative.
mtsa.sparsify(transformer, {
    "sparse_cfg": {
        "*": {
            "method": "triton_skip_softmax",
            "skip_softmax_threshold": 0.1,
            "backend": "diffusers_triton",
        }
    }
})

# Add quantize_p=True to either method above to also apply NVFP4 E2M1
# per-tile quantization to the post-softmax attention weights (P matrix).
# This is a Triton kernel option — it only applies within these two methods.
mtsa.sparsify(transformer, {
    "sparse_cfg": {
        "*": {
            "method": "triton_skip_softmax",
            "skip_softmax_threshold": 0.1,
            "backend": "diffusers_triton",
            "quantize_p": True,
        }
    }
})

Design note: why quantize_p is not exposed via mtq.quantize()

ModelOpt already has a softmax_quantizer in _QuantAttention (see modelopt/torch/quantization/plugins/diffusion/diffusers.py) which conceptually covers the same operation. We considered routing NVFP4 P-matrix quantization through mtq.quantize() but it is not feasible for three reasons:

  1. Per-tile granularity. TensorQuantizer calibrates a global amax over the full tensor. Our NVFP4 quantization is per-Triton-tile ([BLOCK_M × BLOCK_N]), with the scale computed on-the-fly inside the kernel. There is no PyTorch-level tensor to calibrate against.

  2. The P matrix never materializes. In the Triton kernel, P only exists tile-by-tile as a register-level intermediate. There is no [seq_q × seq_kv] PyTorch tensor for mtq to intercept.

  3. Inseparable from the Triton path. quantize_p=True is only meaningful when the Triton kernel is already active (via backend="diffusers_triton").

Bug Fix

WanSparseAttentionModule never executed (PSNR = 100 dB / byte-identical to baseline)

  • Root cause: plugins/__init__.py imported huggingface.py before diffusers.py. The HF generic plugin registered WanAttention with _GenericSparseAttention first; the diffusers-specific plugin then saw WanAttention already registered and skipped, so WanSparseAttentionModule / ModelOptWanAttnProcessor were never installed.
  • Fix: Swap import order in plugins/__init__.pydiffusers before huggingface.

Files Changed

  • modelopt/torch/kernels/triton_fa.py — NVFP4 per-tile P-matrix quantization
  • modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.pyWanSparseAttentionModule, ModelOptWanAttnProcessor, register_wan_sparse_attention
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py — Fix import order (diffusers before huggingface)
  • modelopt/torch/quantization/plugins/diffusion/diffusers.py_QuantWanAttnProcessor for NVFP4 quantization path
  • tests/unit/ — Unit tests for the new diffusers WAN sparse attention plugin
  • examples/diffusers/quantization/wan2_sage_attention.py — Example script with all four kernel modes, --kernel, --compare, --skip-threshold CLI flags

Test plan

  • python -m pytest tests/unit -k "wan" — unit tests for WanSparseAttentionModule and ModelOptWanAttnProcessor
  • python wan2_sage_attention.py --kernel triton-sparse --compare --seed 42 — verify PSNR > 30 dB vs baseline
  • python wan2_sage_attention.py --kernel triton-skip --compare --seed 42 — verify PSNR > 30 dB vs baseline
  • python wan2_sage_attention.py --kernel triton-skip-nvfp4 --compare --seed 42 — verify visually acceptable output
  • python wan2_sage_attention.py --kernel triton-sparse-nvfp4 --compare --seed 42 — verify visually acceptable output

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Wan text-to-video examples with selectable attention kernels, FP8/Sage/Triton modes, benchmarking, metrics and MP4 export.
    • vLLM server/workers with optional sparse-attention & quantization support.
    • Triton-backed sparse attention (paged KV-cache, optional post-softmax quantization) and diffusers/vLLM integration with runtime registration.
  • Documentation

    • Added guides for sparse attention, skip-softmax calibration and Wan examples.
  • Tests

    • GPU and unit tests covering paged KV, diffusers plugin and kernel backends.
  • Chores

    • Lint config updated to recognize vllm as third-party.

jingyu-ml and others added 11 commits April 2, 2026 06:02
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds paged KV-cache and optional post-softmax quantization to the Triton flash-attention kernel; integrates Triton sparse-attention with Diffusers (WanAttention) and vLLM (paged KV); provides example scripts/workers for Wan2.2 and vLLM sparse-serving; and adds tests, calibration tooling, and plugin/registration plumbing.

Changes

Cohort / File(s) Summary
Triton FA kernel & tests
modelopt/torch/kernels/triton_fa.py, modelopt/torch/kernels/__init__.py, tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
Added paged-KV load paths, optional post-softmax NVFP4 p-tile quantization (quantize_p), extended attention API and autograd wrapper, added calibration kernel attention_calibrate, and GPU tests validating paged vs contiguous behavior and sparsity scenarios.
Diffusers sparse attention plugin & docs/tests
modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py, modelopt/torch/sparsity/attention_sparsity/config.py, modelopt/torch/sparsity/attention_sparsity/conversion.py, examples/diffusers/README.md, examples/diffusers/sparsity/README.md, tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py, modelopt/torch/kernels/__init__.py
Introduced WanAttention Triton plugin (processor + SparseAttentionModule), added quantize_p config field and diffusers_triton backend handling, auto-registration logic, docs and unit tests, and exported calibration symbol.
vLLM kernel plugin & serving examples
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py, examples/vllm_serve/sparse_attn_worker.py, examples/vllm_serve/vllm_serve_sparse_attn.py, modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
Added ModelOpt vLLM backend ModelOptSparseAttentionImpl using paged KV and Triton kernel; worker classes to replace attention impls (and run quant prolog), server entrypoint that selects custom workers via env vars; guarded plugin imports.
Diffusers eager/Triton kernel backends & LTX hooks
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
Added diffusers eager and Triton backend adapters (including calibration thread-local state), LTX eager/Triton wrappers, registration helpers, and thread-local skip-softmax/calibration context management.
Skip-softmax methods, calibration & stats
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py, modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py, modelopt/torch/sparsity/attention_sparsity/methods/registry.py, modelopt/torch/sparsity/attention_sparsity/calibration/*.py, modelopt/torch/sparsity/attention_sparsity/stats_manager.py
Expanded Triton skip-softmax to support calibration/inference modes and multi-threshold collection, added calibration flow changes (lazy tokenizer, optional forward_loop), updated calibrator threshold writes, added set_calibration_mode API, and adjusted stats aggregation.
ModelOpt plugin plumbing & registration
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py, modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py, modelopt/torch/sparsity/attention_sparsity/conversion.py
Made plugin imports dependency-guarded, deferred transformers/diffusers checks in plugin detection, and added auto-register hook for diffusers/LTX backends during model conversion.
Wan2.2 examples & Sage/FP8 exploration
examples/diffusers/quantization/wan2_sage_attention.py, examples/diffusers/sparsity/wan22_skip_softmax.py, examples/diffusers/README.md, examples/diffusers/sparsity/README.md
New Wan2.2 attention example supporting FP8 quantization, SageAttention modes, ModelOpt Triton sparse/skip kernels, benchmarking, metrics (PSNR/MAE/CLIP), calibration/generation example, and a skip-softmax example script with calibration support.
ModelOpt Diffusers quantization registry tweak
modelopt/torch/quantization/plugins/diffusion/diffusers.py
Reordered QuantModuleRegistry registrations so WanAttention is registered after LTXAttention (no behavioral changes to mixin).
Tests & tooling
tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py, pyproject.toml
Added unit tests for kernel backends, skip-softmax context, and registration behavior; added vllm to isort known-third-party.
New example: vLLM sparse workers
examples/vllm_serve/sparse_attn_worker.py
Workers that read sparse config from env, replace vLLM Attention impls with ModelOpt-backed sparse implementations, and optionally run quantization prolog during warm-up.

Sequence Diagram(s)

sequenceDiagram
    participant Model as Diffusers Model
    participant WanModule as WanAttention
    participant SparseModule as WanSparseAttentionModule
    participant Processor as ModelOptWanAttnProcessor
    participant TritonKernel as Triton FA Kernel

    Model->>WanModule: forward()
    WanModule->>SparseModule: delegate (installed)
    SparseModule->>Processor: prepare Q/K/V, rotary, proj
    Processor->>TritonKernel: call attention(paged KV, sparse_kw, quantize_p)
    TritonKernel-->>Processor: attention output
    Processor-->>SparseModule: return output
    SparseModule-->>Model: continue forward
Loading
sequenceDiagram
    participant Server as vLLM Server
    participant Worker as SparseQuantWorker
    participant Model as Loaded Model
    participant Attn as vLLM Attention Module
    participant Impl as ModelOptSparseAttentionImpl
    participant TritonKernel as Triton FA Kernel

    Server->>Worker: start / load_model()
    Worker->>Model: load base model
    Worker->>Attn: iterate named_modules()
    Worker->>Impl: replace impl with ModelOptSparseAttentionImpl (paged KV)
    Worker->>Worker: compile_or_warm_up_model (quant prolog if configured)
    Model->>Attn: inference forward(metadata, caches)
    Attn->>Impl: forward(...)
    Impl->>TritonKernel: call attention(paged KV, block_table, sparse_kw)
    TritonKernel-->>Impl: output
    Impl-->>Attn: write output
    Attn-->>Model: continue
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • cjluo-nv
  • kevalmorabia97
  • Edwardf0t1
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately and concisely summarizes the primary changes: adding ModelOpt Triton attention kernels for WAN2.2 diffusion with sparse, skip-softmax, and NVFP4 quantization modes. It reflects the main objectives and the most significant additions to the codebase.
Security Anti-Patterns ✅ Passed PR introduces attention kernel implementations and sparse attention plugins without violating critical security anti-patterns; all existing unsafe patterns have proper justification comments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yeyu/sage-attention-diffusion

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/config.py (1)

142-151: Validate quantize_p compatibility to avoid silent no-op configs.

quantize_p is documented as diffusers-triton-specific, but invalid combinations are currently accepted. Adding a compatibility validator will prevent confusing runtime behavior.

♻️ Proposed validation
 class SparseAttentionAttributeConfig(ModeloptBaseConfig):
@@
     `@model_validator`(mode="after")
     def validate_sparsity_n_vs_m(self):
@@
         return self
+
+    `@model_validator`(mode="after")
+    def validate_quantize_p_compatibility(self):
+        if self.quantize_p and self.backend != "diffusers_triton":
+            raise ValueError(
+                "quantize_p=True requires backend='diffusers_triton'."
+            )
+        if self.quantize_p and self.method not in (
+            "triton_sparse_softmax",
+            "triton_skip_softmax",
+        ):
+            raise ValueError(
+                "quantize_p=True is only supported with "
+                "'triton_sparse_softmax' or 'triton_skip_softmax'."
+            )
+        return self
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 142 - 151,
Add a compatibility validator for the quantize_p ModeloptField so invalid combos
are rejected: if quantize_p is True ensure the Triton-specific backend (e.g.,
backend == "diffusers_triton") is selected and reject/raise an error when used
with non-diffusers_triton backends; also validate interactions with related
flags (triton_sparse_softmax, triton_skip_softmax) to either allow or explicitly
forbid combinations that make quantize_p a no-op, by adding the check in the
same config class where quantize_p is defined and raising a clear validation
error when the constraints are violated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/diffusers/quantization/wan2_sage_attention.py`:
- Around line 612-666: compute_clip_score defaults to device="cuda" and loads
the CLIP model directly onto that device which can OOM smaller GPUs; wrap the
CLIPProcessor/CLIPModel.from_pretrained calls (and any subsequent .to(device))
in a try/except that catches both OSError and RuntimeError, and on failure retry
loading the model on CPU (set device="cpu" for the retry), logging a warning;
ensure downstream tensors (text_inputs/img_inputs) are moved to the active
device variable and that clip_model is .to(device) only after successful load so
the function can gracefully fall back from GPU to CPU without crashing
(references: compute_clip_score, clip_model_id, device, processor, clip_model,
text_inputs, img_inputs).
- Around line 790-800: The help text for the "--skip-threshold" /
skip_softmax_threshold flag is incorrect: the implementation performs an
additive/log-space comparison using exp(tile_max - running_max) < lambda (i.e.,
compares the softmax ratio via exponent of the logit difference), not a
multiplicative comparison of raw logits; update the help string to describe that
a tile is skipped when exp(tile_max - running_max) is less than LAMBDA (softmax
ratio threshold), and mention this applies to the triton-skip /
triton-skip-nvfp4 kernels so users understand it's a log-space/additive
threshold.

In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 72-85: The preset validation is too permissive:
_build_sparse_config currently accepts any module-level preset but
_replace_attention_impl only supports sparsity_* keys and
skip_softmax_threshold, so presets containing other fields (e.g.,
backend="pytorch" or missing a fixed skip_softmax_threshold like
SKIP_SOFTMAX_CALIB) silently degrade to dense attention. Update
_build_sparse_config (and the same logic at the other occurrence) to validate
the chosen preset: ensure the dict contains only supported keys (keys matching
"sparsity_" prefixes and/or "skip_softmax_threshold") and that if a softmax-skip
behavior is required it includes a fixed numeric skip_softmax_threshold; if the
preset includes unsupported keys or lacks the required numeric threshold, raise
ValueError telling the user to pick "default" or a compatible preset so
_replace_attention_impl and VLLMAttention won't be silently ignored.

In `@examples/vllm_serve/vllm_serve_sparse_attn.py`:
- Line 75: The PYTHONPATH assignment currently concatenates with ":" and can
produce a leading empty entry; update the logic that sets
os.environ["PYTHONPATH"] (the existing statement referencing repo_root) to build
the path using os.pathsep and only include non-empty components: read
os.environ.get("PYTHONPATH",""), split/filter out empty strings, append the
repo_root, then join with os.pathsep and assign back to os.environ["PYTHONPATH"]
so no empty/CWD entry is injected and it is platform portable.

In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 536-538: The forward-only quantization for numerator p (guarded by
QUANTIZE_P and using _quantize_p_nvfp4) is not mirrored in backward: the
backward path still reconstructs unquantized softmax and ignores ctx.quantize_p,
producing incorrect gradients; fix by either (A) implementing a matching
backward branch that reconstructs/uses the same quantized representation and
dequantization logic when ctx.quantize_p is true (ensure grad path uses the
quantized p->dequant sequence and adjust gradient accumulation consistent with
tl.dot(p.to(v.dtype), v, acc)), or (B) explicitly reject/raise an error when
QUANTIZE_P/ctx.quantize_p is enabled and any of q.requires_grad,
k.requires_grad, or v.requires_grad is true (check these tensors in the forward
function and set ctx.quantize_p accordingly), so the mode is only allowed for
inference. Ensure you reference and update the forward use of
QUANTIZE_P/_quantize_p_nvfp4 and the backward branch that reads ctx.quantize_p.
- Around line 988-991: Only synthesize b_start_loc_k when actually in paged-KV
mode: compute a paged_mode boolean from the real indicators (e.g., presence of
block_table or missing v_cache) and if paged_mode is True set b_start_loc_k =
torch.zeros_like(b_start_loc); otherwise, do not substitute silently—raise an
explicit error (or assert) when b_start_loc_k is None so the contiguous K/V path
fails fast before later stride(...) / shape[...] access. Reference
b_start_loc_k, b_start_loc, k_cache, v_cache, and block_table when making this
conditional.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 63-66: The early return when "diffusers_triton" is present can
skip `_attn_implementation` setup for other sparse backends; update the check
around the `if "diffusers_triton" in backends: return` (in conversion.py where
`backends` is evaluated and ModelOptWanAttnProcessor is mentioned) to fail fast
if `diffusers_triton` is combined with any other backend (e.g., if
"diffusers_triton" in backends and len(backends) > 1: raise a ValueError with a
clear message) so that mixed configurations are rejected instead of silently
skipping attention registration.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py`:
- Around line 81-132: The Triton path (_wan_forward_triton) currently ignores
the attention_mask causing different semantics than the SDPA path; add an early
guard in _wan_forward_triton that checks if attention_mask is not None and if so
immediately calls and returns self._wan_forward_sdpa(...) with the same
arguments (including attention_mask) to preserve fallback behavior; ensure the
call uses the same parameters and return value so signatures/outputs remain
consistent.

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 142-151: Add a compatibility validator for the quantize_p
ModeloptField so invalid combos are rejected: if quantize_p is True ensure the
Triton-specific backend (e.g., backend == "diffusers_triton") is selected and
reject/raise an error when used with non-diffusers_triton backends; also
validate interactions with related flags (triton_sparse_softmax,
triton_skip_softmax) to either allow or explicitly forbid combinations that make
quantize_p a no-op, by adding the check in the same config class where
quantize_p is defined and raising a clear validation error when the constraints
are violated.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a409c49f-7309-45c3-89ed-693c7124c99d

📥 Commits

Reviewing files that changed from the base of the PR and between 80d2f02 and 8996ef1.

📒 Files selected for processing (13)
  • examples/diffusers/quantization/wan2_sage_attention.py
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/quantization/plugins/diffusion/diffusers.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • pyproject.toml
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
  • tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py


[tool.ruff.lint.isort]
known-first-party = ["modelopt"]
known-third-party = ["vllm"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM does not follow standard Python package naming conventions — ruff's isort treats it as a first-party import without this entry, which causes import ordering lint failures in the vLLM worker files. Adding it to known-third-party keeps isort from re-classifying it incorrectly.

# ---------------------------------------------------------------------------


@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton")
Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does our current CICD env skip this test or run it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test lives in tests/gpu/ which runs in the GPU CI environment (python -m pytest tests/gpu). The test itself also has @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") as an additional guard, so it will be skipped gracefully on any runner without a CUDA GPU or Triton installed.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)

146-166: ⚠️ Potential issue | 🟠 Major

Make method drive sparse_kw.

At Line 148, method is only validated and then discarded. ModelOptSparseAttentionImpl.forward() only consumes self.sparse_kw in modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py:64-66, so a triton_skip_softmax config with no numeric threshold still patches the layer with empty kwargs, and mixed configs can pass both sparsity_* and skip_softmax_threshold, selecting a different path than the one you just validated. Build kwargs from method and fail fast when that method’s required parameters are missing.

🛠️ Proposed fix
         # Build per-layer sparse kwargs
         sparse_kw = {}
-        sparsity_n = layer_cfg.get("sparsity_n", 0)
-        if sparsity_n > 0:
-            sparse_kw["sparsity_n"] = sparsity_n
-            sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4)
-            sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0)
-            sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1)
-        threshold = layer_cfg.get("skip_softmax_threshold")
-        if threshold:
-            sparse_kw["skip_softmax_threshold"] = threshold
+        if method == "triton_sparse_softmax":
+            sparsity_n = layer_cfg.get("sparsity_n")
+            if not isinstance(sparsity_n, int) or sparsity_n <= 0:
+                raise ValueError(f"{name}: triton_sparse_softmax requires a positive sparsity_n")
+
+            sparse_kw["sparsity_n"] = sparsity_n
+            sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4)
+            sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0)
+            sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1)
+        else:
+            threshold = layer_cfg.get("skip_softmax_threshold")
+            if not isinstance(threshold, (int, float)):
+                raise ValueError(
+                    f"{name}: triton_skip_softmax requires a numeric skip_softmax_threshold"
+                )
+            sparse_kw["skip_softmax_threshold"] = float(threshold)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 146 - 166, The code
currently validates method but doesn't let method drive sparse_kw; update the
logic in sparse_attn_worker.py so that method determines which keys are
populated and that required params are present: if method ==
"triton_sparse_softmax" require sparsity_n > 0 and populate only sparsity_n,
sparsity_m, num_sink_tokens, dense_window_size into sparse_kw (fail fast with
ValueError if missing); if method == "triton_skip_softmax" require
skip_softmax_threshold to be set and populate only skip_softmax_threshold into
sparse_kw (fail fast if missing); keep references to method, layer_cfg,
sparse_kw and ensure compatibility with ModelOptSparseAttentionImpl.forward()
which reads self.sparse_kw in
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py.
🧹 Nitpick comments (2)
examples/vllm_serve/sparse_attn_worker.py (1)

72-123: Add concrete types at the worker/config boundary.

These helpers still use bare dict and an untyped worker, so mypy can't validate the config shape or the model_runner.model / unwrap() contract on this new extension path. A small TypedDict + Protocol here would turn several runtime-only failures into lint errors.

As per coding guidelines, "Ensure type hints are properly annotated for static type checking with mypy".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 72 - 123, Introduce
concrete typing at the worker/config boundary: define a TypedDict (e.g.,
SparseLayerCfg and SparseConfigTypedDict) that describes the expected
keys/values returned by _build_sparse_config/_load_sparse_config, and a Protocol
(e.g., ModelRunnerProtocol or WorkerWithModel) that specifies the
model_runner.model and unwrap() shape used by _replace_attention_impl; then
update signatures for _build_sparse_config, _load_sparse_config,
_match_sparse_config, and _replace_attention_impl to use these types (replace
bare dict and untyped worker), and cast/validate where needed so mypy can
statically check config shape and the model.unwrap() contract referenced in
_replace_attention_impl.
examples/diffusers/quantization/wan2_sage_attention.py (1)

425-431: Complete the public type annotations.

attention_kernel_ctx, load_pipeline, and run_inference still leave important parts of the public surface as Any, which weakens mypy coverage for a new example module. As per coding guidelines, "Ensure type hints are properly annotated for static type checking with mypy".

Also applies to: 682-694

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 425 -
431, Annotate the public API precisely: change attention_kernel_ctx to return a
ContextManager[None] (import ContextManager from typing) instead of Any; update
load_pipeline to have typed params (e.g., ckpt: str, dtype: torch.dtype, device:
torch.device | str) and return a DiffusionPipeline (import DiffusionPipeline
from diffusers) rather than Any; update run_inference to type its parameters
(pipeline: DiffusionPipeline, prompt: str | Sequence[str], num_inference_steps:
int, generator: Optional[torch.Generator], etc.) and give it a concrete return
type such as Dict[str, Any] or a more specific result type, importing Optional,
Sequence, Dict from typing and torch types as needed so mypy can check these
public symbols (attention_kernel_ctx, load_pipeline, run_inference).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/diffusers/quantization/wan2_sage_attention.py`:
- Around line 504-531: In apply_triton_sparse_kernel, validate any provided
skip_threshold when the kernel is KERNEL_TRITON_SKIP or
KERNEL_TRITON_SKIP_NVFP4: if skip_threshold is not None and skip_threshold <= 0,
raise a ValueError with a clear message (e.g. "skip_threshold must be > 0")
before modifying config["sparse_cfg"]["*"]["skip_softmax_threshold"]; otherwise
proceed to set the value as currently implemented. Ensure you reference the
function apply_triton_sparse_kernel and the constants KERNEL_TRITON_SKIP /
KERNEL_TRITON_SKIP_NVFP4 so the check is only applied for skip-softmax kernels.
- Around line 756-760: Move the kernel availability check to immediately after
argument parsing to fail fast: after parse_args(), verify args.kernel is present
in AVAILABLE_KERNELS (and/or that Triton/SageAttention and required PyTorch
version are available) and raise/exit with a clear message if not; remove or
keep only redundant checks later. Specifically, perform this validation before
constructing/loading the WAN pipeline (the code that currently loads the large
model) so that calls like apply_triton_sparse_kernel() and
enable_attention_kernel() are only reached when args.kernel is actually
supported; reference the args.kernel variable, AVAILABLE_KERNELS,
apply_triton_sparse_kernel(), and enable_attention_kernel() when making the
change.
- Around line 591-595: The current PSNR calculation in psnr_per_frame uses
np.where with mse_per_frame which causes the divide-by-zero branch to be
evaluated eagerly; fix by clamping mse_per_frame to a small positive floor
before the division (e.g., denom = np.maximum(mse_per_frame, 1e-10)) and compute
the PSNR using that denom, then optionally use np.where to override values for
truly-zero MSE if you need a specific constant; update the code around
psnr_per_frame and any direct uses of mse_per_frame to use the clamped
denominator to avoid warnings.
- Around line 352-385: The _patched_sdpa function is dropping caller kwargs on
both fallback paths; update both calls to _orig_sdpa (the early fallback when
attn_mask/dropout/dtype unsupported and the exception fallback in the except
block) to forward the original **kwargs so caller flags like enable_gqa (and
future SDPA kwargs) are preserved; keep the existing explicit args (query, key,
value, attn_mask, dropout_p, is_causal, scale) and add **kwargs to those
_orig_sdpa invocations to pass through any additional options.

---

Duplicate comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 146-166: The code currently validates method but doesn't let
method drive sparse_kw; update the logic in sparse_attn_worker.py so that method
determines which keys are populated and that required params are present: if
method == "triton_sparse_softmax" require sparsity_n > 0 and populate only
sparsity_n, sparsity_m, num_sink_tokens, dense_window_size into sparse_kw (fail
fast with ValueError if missing); if method == "triton_skip_softmax" require
skip_softmax_threshold to be set and populate only skip_softmax_threshold into
sparse_kw (fail fast if missing); keep references to method, layer_cfg,
sparse_kw and ensure compatibility with ModelOptSparseAttentionImpl.forward()
which reads self.sparse_kw in
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py.

---

Nitpick comments:
In `@examples/diffusers/quantization/wan2_sage_attention.py`:
- Around line 425-431: Annotate the public API precisely: change
attention_kernel_ctx to return a ContextManager[None] (import ContextManager
from typing) instead of Any; update load_pipeline to have typed params (e.g.,
ckpt: str, dtype: torch.dtype, device: torch.device | str) and return a
DiffusionPipeline (import DiffusionPipeline from diffusers) rather than Any;
update run_inference to type its parameters (pipeline: DiffusionPipeline,
prompt: str | Sequence[str], num_inference_steps: int, generator:
Optional[torch.Generator], etc.) and give it a concrete return type such as
Dict[str, Any] or a more specific result type, importing Optional, Sequence,
Dict from typing and torch types as needed so mypy can check these public
symbols (attention_kernel_ctx, load_pipeline, run_inference).

In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 72-123: Introduce concrete typing at the worker/config boundary:
define a TypedDict (e.g., SparseLayerCfg and SparseConfigTypedDict) that
describes the expected keys/values returned by
_build_sparse_config/_load_sparse_config, and a Protocol (e.g.,
ModelRunnerProtocol or WorkerWithModel) that specifies the model_runner.model
and unwrap() shape used by _replace_attention_impl; then update signatures for
_build_sparse_config, _load_sparse_config, _match_sparse_config, and
_replace_attention_impl to use these types (replace bare dict and untyped
worker), and cast/validate where needed so mypy can statically check config
shape and the model.unwrap() contract referenced in _replace_attention_impl.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: bde9eacb-58d5-4815-92e3-46606b7fa475

📥 Commits

Reviewing files that changed from the base of the PR and between 8996ef1 and 4248327.

📒 Files selected for processing (5)
  • examples/diffusers/quantization/wan2_sage_attention.py
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
✅ Files skipped from review due to trivial changes (1)
  • examples/vllm_serve/vllm_serve_sparse_attn.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py

Comment on lines +352 to +385
def _patched_sdpa(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
**kwargs,
) -> torch.Tensor:
global _sage_calls, _fallback_calls
# Fall back to standard SDPA for unsupported cases
if (
attn_mask is not None
or dropout_p > 0.0
or query.dtype not in (torch.float16, torch.bfloat16)
):
_fallback_calls += 1
return _orig_sdpa(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)

_sage_calls += 1
try:
return _run_kernel(query, key, value, is_causal=is_causal, scale=scale)
except (AssertionError, RuntimeError) as e:
print(f"[Attention] WARNING: kernel={_active_kernel!r} failed ({e}). Falling back to SDPA.")
return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n examples/diffusers/quantization/wan2_sage_attention.py | head -400 | tail -100

Repository: NVIDIA/Model-Optimizer

Length of output: 3695


🏁 Script executed:

# Check the _orig_sdpa assignment and context
rg "_orig_sdpa|_patched_sdpa" examples/diffusers/quantization/wan2_sage_attention.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1229


🏁 Script executed:

# Check torch.nn.functional.scaled_dot_product_attention signature
python3 << 'EOF'
import torch
import inspect
sig = inspect.signature(torch.nn.functional.scaled_dot_product_attention)
print("Signature:", sig)
print("\nParameters:")
for name, param in sig.parameters.items():
    print(f"  {name}: {param.kind.name} = {param.default}")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 183


🏁 Script executed:

# Check PyTorch documentation and what parameters F.scaled_dot_product_attention supports
# by searching for any documentation or comments in the codebase
rg "scaled_dot_product_attention" examples/diffusers/quantization/wan2_sage_attention.py -B 5 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 1390


🏁 Script executed:

# Check if there are any other kwargs passed at call sites
rg "scaled_dot_product_attention\|_patched_sdpa" --type py -r "." | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for pytorch version requirements to determine available parameters
cat pyproject.toml | grep -A 5 -B 5 torch

Repository: NVIDIA/Model-Optimizer

Length of output: 2204


🏁 Script executed:

# Check if there are any call sites in the example that use _patched_sdpa with extra kwargs
rg "_patched_sdpa\|F\.scaled_dot_product_attention" examples/diffusers/quantization/wan2_sage_attention.py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check the full context around patching to see how it's used
sed -n '388,430p' examples/diffusers/quantization/wan2_sage_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1420


🏁 Script executed:

# Check if there are any test files or usage examples that call SDPA
find . -name "*test*" -type f | xargs grep -l "scaled_dot_product_attention\|_patched_sdpa" 2>/dev/null | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 251


🏁 Script executed:

# Check the test files to understand how scaled_dot_product_attention is called
cat ./tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 3861


🏁 Script executed:

# Check the attention quantization test
cat ./tests/unit/torch/quantization/plugins/test_attention_quant.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 5720


🏁 Script executed:

# Look for any documentation or comments about SDPA kwargs in PyTorch
web_search
PyTorch scaled_dot_product_attention signature parameters kwargs torch 2.6

Repository: NVIDIA/Model-Optimizer

Length of output: 162


🌐 Web query:

PyTorch torch.nn.functional.scaled_dot_product_attention signature parameters 2.6

💡 Result:

In PyTorch 2.6, the function signature is:

torch.nn.functional.scaled_dot_product_attention(
    query, key, value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=False,
    scale=None,
    enable_gqa=False
) -> Tensor

Parameters (with defaults): query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None (keyword-only), enable_gqa=False. [1]

Source
[1] PyTorch 2.6 docs page for torch.nn.functional.scaled_dot_product_attention https://pytorch.com.tw/docs/2.6/generated/torch.nn.functional.scaled_dot_product_attention.html


🏁 Script executed:

# Check if any code in the examples or tests calls SDPA/patched_sdpa with kwargs like enable_gqa
rg "scaled_dot_product_attention|_patched_sdpa|F\.sdpa" examples/ tests/ -A 2 | grep -E "(enable_gqa|enable_flash_|dropout_p|**kwargs)" | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 145


🏁 Script executed:

# Search more broadly for diffusers usage and how SDPA is called in the pipeline
grep -r "scaled_dot_product_attention" examples/diffusers/ -B 2 -A 2 2>/dev/null || echo "No direct calls found"

Repository: NVIDIA/Model-Optimizer

Length of output: 2572


🏁 Script executed:

# Check the full wan2_sage_attention.py file to see if there are any actual kwargs being passed to SDPA anywhere
grep -E "scaled_dot_product_attention|attention\(" examples/diffusers/quantization/wan2_sage_attention.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 383


Forward caller kwargs when SDPA falls back.

_patched_sdpa() accepts **kwargs but both fallback paths (lines 370–378 and 385) drop them before calling _orig_sdpa(). This causes silent semantic changes: if a caller passes enable_gqa=True (or any other supported kwarg like future PyTorch SDPA extensions) and the kernel fails or early fallback triggers, the fallback ignores those parameters instead of honoring them.

Suggested fix
 def _patched_sdpa(
     query: torch.Tensor,
     key: torch.Tensor,
     value: torch.Tensor,
@@
     if (
         attn_mask is not None
         or dropout_p > 0.0
         or query.dtype not in (torch.float16, torch.bfloat16)
     ):
         _fallback_calls += 1
         return _orig_sdpa(
             query,
             key,
             value,
             attn_mask=attn_mask,
             dropout_p=dropout_p,
             is_causal=is_causal,
             scale=scale,
+            **kwargs,
         )
@@
     except (AssertionError, RuntimeError) as e:
         print(f"[Attention] WARNING: kernel={_active_kernel!r} failed ({e}). Falling back to SDPA.")
-        return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale)
+        return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 352 -
385, The _patched_sdpa function is dropping caller kwargs on both fallback
paths; update both calls to _orig_sdpa (the early fallback when
attn_mask/dropout/dtype unsupported and the exception fallback in the except
block) to forward the original **kwargs so caller flags like enable_gqa (and
future SDPA kwargs) are preserved; keep the existing explicit args (query, key,
value, attn_mask, dropout_p, is_causal, scale) and add **kwargs to those
_orig_sdpa invocations to pass through any additional options.

Comment on lines +504 to +531
def apply_triton_sparse_kernel(
transformer: torch.nn.Module,
kernel: str,
skip_threshold: float | None = None,
) -> None:
"""Apply a ModelOpt Triton sparse attention kernel to the WAN transformer.

Calls ``mtsa.sparsify()`` with ``backend="diffusers_triton"``, which installs
a ``ModelOptWanAttnProcessor`` on every ``WanAttention`` module. The NVFP4
variants additionally pass ``quantize_p=True`` to the Triton kernel, enabling
per-tile NVFP4 E2M1 P-matrix quantization in a single fused pass.

This modifies the model in-place.

Args:
transformer: The ``pipe.transformer`` WAN model.
kernel: One of the ``KERNEL_TRITON_*`` constants.
skip_threshold: Override ``skip_softmax_threshold`` for skip-softmax kernels.
``None`` uses the kernel's built-in default.
Lower = better quality, less speedup. Typical range: 0.001–0.1.
"""
import copy

import modelopt.torch.sparsity.attention_sparsity as mtsa

config = copy.deepcopy(_TRITON_KERNEL_CONFIGS[kernel])
if skip_threshold is not None and kernel in (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4):
config["sparse_cfg"]["*"]["skip_softmax_threshold"] = skip_threshold
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Reject non-positive skip_threshold values inside the helper.

The skip-softmax path now depends on log2(skip_softmax_threshold). Passing 0 or a negative override should fail here with a clear ValueError instead of surfacing later in the Triton path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 504 -
531, In apply_triton_sparse_kernel, validate any provided skip_threshold when
the kernel is KERNEL_TRITON_SKIP or KERNEL_TRITON_SKIP_NVFP4: if skip_threshold
is not None and skip_threshold <= 0, raise a ValueError with a clear message
(e.g. "skip_threshold must be > 0") before modifying
config["sparse_cfg"]["*"]["skip_softmax_threshold"]; otherwise proceed to set
the value as currently implemented. Ensure you reference the function
apply_triton_sparse_kernel and the constants KERNEL_TRITON_SKIP /
KERNEL_TRITON_SKIP_NVFP4 so the check is only applied for skip-softmax kernels.

Comment on lines +591 to +595
psnr_per_frame = np.where(
mse_per_frame < 1e-10,
100.0,
10.0 * np.log10(255.0**2 / mse_per_frame),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, find and verify the file exists
fd wan2_sage_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 120


🏁 Script executed:

# Once we know the file exists, let's examine the exact code at lines 591-595
cat -n examples/diffusers/quantization/wan2_sage_attention.py | sed -n '585,600p'

Repository: NVIDIA/Model-Optimizer

Length of output: 697


🌐 Web query:

numpy np.where() eager evaluation both branches behavior

💡 Result:

np.where(cond, x, y) is not lazy / short‑circuiting. In typical usage like:

np.where(cond, f(a), g(a))

both f(a) and g(a) are computed before np.where runs, because they are evaluated as normal Python function arguments; np.where then just selects elements from the already-computed x and y arrays. This is why you can still get warnings/errors (e.g., divide-by-zero, sqrt of negatives) even where the condition would “exclude” those elements. [1][2][3]

The NumPy docs’ note that, for 1‑D arrays, where is “equivalent to” a Python conditional list-comprehension is about the resulting selection behavior, not about lazily evaluating the expressions you pass in for x and y. [4]

What to do instead (to avoid evaluating the “bad” branch):

  • Prefer a ufunc with a where= mask and out=, so computation is only performed where allowed, e.g. np.divide(a, b, out=out, where=mask) (avoids computing division where mask is False). [5]
  • Similarly, many elementwise operations can be done as ufuncs with where= (and often out=) instead of np.where(...). [3][5]

Sources: [1] np.where() eager evaluation discussion (Stack Overflow) [2] np.where computing unused values (Stack Overflow) [3] np.where can’t avoid sqrt(negative) evaluation (Stack Overflow) [4] NumPy numpy.where documentation (NumPy manual) [5] Using np.divide(..., where=...) to avoid divide-by-zero warnings (Stack Overflow)


The PSNR zero-MSE guard still evaluates the divide-by-zero branch.

np.where() eagerly evaluates both the true and false branches before selecting values, so identical frames (where mse_per_frame < 1e-10) still compute 255.0**2 / mse_per_frame, triggering divide-by-zero warnings. Clamp the denominator first to avoid this.

Suggested fix
-    psnr_per_frame = np.where(
-        mse_per_frame < 1e-10,
-        100.0,
-        10.0 * np.log10(255.0**2 / mse_per_frame),
-    )
+    safe_mse = np.maximum(mse_per_frame, 1e-10)
+    psnr_per_frame = 10.0 * np.log10(255.0**2 / safe_mse)
+    psnr_per_frame = np.where(mse_per_frame < 1e-10, 100.0, psnr_per_frame)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 591 -
595, The current PSNR calculation in psnr_per_frame uses np.where with
mse_per_frame which causes the divide-by-zero branch to be evaluated eagerly;
fix by clamping mse_per_frame to a small positive floor before the division
(e.g., denom = np.maximum(mse_per_frame, 1e-10)) and compute the PSNR using that
denom, then optionally use np.where to override values for truly-zero MSE if you
need a specific constant; update the code around psnr_per_frame and any direct
uses of mse_per_frame to use the clamped denominator to avoid warnings.

Comment on lines +756 to +760
"--kernel",
type=str,
default=KERNEL_FP8,
choices=KERNEL_CHOICES,
help=(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l examples/diffusers/quantization/wan2_sage_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 124


🏁 Script executed:

sed -n '750,770p' examples/diffusers/quantization/wan2_sage_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1157


🏁 Script executed:

sed -n '810,825p' examples/diffusers/quantization/wan2_sage_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 501


🏁 Script executed:

# Search for KERNEL_CHOICES and AVAILABLE_KERNELS definitions
rg -n "KERNEL_CHOICES|AVAILABLE_KERNELS" examples/diffusers/quantization/wan2_sage_attention.py -A 3 -B 1

Repository: NVIDIA/Model-Optimizer

Length of output: 1818


🏁 Script executed:

# Search for where kernel validation happens
rg -n "kernel" examples/diffusers/quantization/wan2_sage_attention.py | grep -i "valid\|check\|if"

Repository: NVIDIA/Model-Optimizer

Length of output: 1115


🏁 Script executed:

sed -n '815,835p' examples/diffusers/quantization/wan2_sage_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 943


🏁 Script executed:

# Find and show apply_triton_sparse_kernel function
rg -n "def apply_triton_sparse_kernel" examples/diffusers/quantization/wan2_sage_attention.py -A 20

Repository: NVIDIA/Model-Optimizer

Length of output: 1071


🏁 Script executed:

# Check what _detect_available_kernels does
rg -n "def _detect_available_kernels" examples/diffusers/quantization/wan2_sage_attention.py -A 30

Repository: NVIDIA/Model-Optimizer

Length of output: 1302


🏁 Script executed:

sed -n '183,235p' examples/diffusers/quantization/wan2_sage_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1794


🏁 Script executed:

# Check _TRITON_MODELOPT_KERNELS definition
rg -n "_TRITON_MODELOPT_KERNELS" examples/diffusers/quantization/wan2_sage_attention.py -B 2 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 2608


Add kernel availability validation before loading the WAN pipeline.

The --kernel argument accepts any name in KERNEL_CHOICES via argparse, but this only validates the name string—not whether the required dependencies (Triton, SageAttention, PyTorch version) are actually available. The multi-GB model loads at line 820 before kernel validation occurs. The Triton branch (lines 824–825, 898–899) calls apply_triton_sparse_kernel() without checking AVAILABLE_KERNELS, while the non-Triton branch only validates via enable_attention_kernel() after the model is already in memory. On a missing Triton or SageAttention setup, this loads the pipeline unnecessarily before failing.

Move the args.kernel in AVAILABLE_KERNELS check to immediately after parse_args() to fail fast without wasting resources.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 756 -
760, Move the kernel availability check to immediately after argument parsing to
fail fast: after parse_args(), verify args.kernel is present in
AVAILABLE_KERNELS (and/or that Triton/SageAttention and required PyTorch version
are available) and raise/exit with a clear message if not; remove or keep only
redundant checks later. Specifically, perform this validation before
constructing/loading the WAN pipeline (the code that currently loads the large
model) so that calls like apply_triton_sparse_kernel() and
enable_attention_kernel() are only reached when args.kernel is actually
supported; reference the args.kernel variable, AVAILABLE_KERNELS,
apply_triton_sparse_kernel(), and enable_attention_kernel() when making the
change.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/torch/kernels/triton_fa.py (1)

309-315: Update the quantize_p docs to match the new runtime guard.

Both docstrings still describe an STE/backward path, but Lines 1021-1024 now make quantize_p inference-only.

Suggested doc update
-        quantize_p: If ``True``, quantize the post-softmax p tile to NVFP4
-            E2M1 before the p @ V matmul (per-tile max scaling, STE in
-            backward). Default ``False``.
+        quantize_p: If ``True``, quantize the post-softmax p tile to NVFP4
+            E2M1 before the p @ V matmul. This mode is inference-only and
+            raises if autograd is required. Default ``False``.

Also applies to: 1294-1296

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 309 - 315, The docstring
for quantize_p incorrectly describes a straight-through estimator and backward
behavior even though the function is now guarded to be inference-only at
runtime; update the quantize_p docstring (and the other similar docstring
instance) to remove any mention of STE/backward passthrough and instead state
that this implementation performs post-softmax tile quantization for inference
only (describe scaling and level mapping briefly and note no gradient/_backward
behavior is applied at runtime).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 973-980: The code incorrectly infers paged layout from the
placeholder k tensor and an independent page_size; instead, when is_paged
(k_cache is not None) derive and validate paged metadata from the cache tensors
themselves: read num_kv_heads from k_cache.shape[1] (and page_size / page_length
from the cache dim that represents pages), recompute kv_group_num from
num_q_heads // num_kv_heads, and ignore k.shape[1]/page_size inputs; also verify
k_cache and v_cache shapes match and that computed kv_group_num divides
num_q_heads and matches any b_seq_len-derived batch logic (validate and
raise/error if mismatched). Apply the same change to the other blocks referenced
around the file (the regions near the existing checks at lines ~988-997 and
~1077-1084) so all paged launches use cache-derived layout and perform
consistency checks before kernel launch.
- Around line 979-980: Add a guard that rejects paged-KV when autograd is
enabled: where is_paged = k_cache is not None is computed, check
torch.is_grad_enabled() (or equivalent autograd check) and raise a clear
RuntimeError explaining that paged k/v/block_table is incompatible with autograd
because gradients are rebuilt from contiguous k/v; make this change for both
occurrences around the is_paged checks (the block using
k_cache/v_cache/block_table near the is_paged assignment and the other spot at
the 1021-1024 section) so callers get an explicit error instead of
incorrect/missing gradients.

---

Nitpick comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 309-315: The docstring for quantize_p incorrectly describes a
straight-through estimator and backward behavior even though the function is now
guarded to be inference-only at runtime; update the quantize_p docstring (and
the other similar docstring instance) to remove any mention of STE/backward
passthrough and instead state that this implementation performs post-softmax
tile quantization for inference only (describe scaling and level mapping briefly
and note no gradient/_backward behavior is applied at runtime).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 567259d0-7c54-45c7-8d57-f53403768c13

📥 Commits

Reviewing files that changed from the base of the PR and between 4248327 and c548f6f.

📒 Files selected for processing (1)
  • modelopt/torch/kernels/triton_fa.py

Comment on lines +979 to +980
is_paged = k_cache is not None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject paged-KV when autograd is enabled.

Lines 1109-1236 still rebuild gradients from contiguous k/v and return no gradients for k_cache/v_cache/block_table. In paged mode that makes autograd incorrect, because the forward path no longer reads those contiguous tensors.

Proposed guard
         if quantize_p and (q.requires_grad or k.requires_grad or v.requires_grad):
             raise NotImplementedError(
                 "quantize_p supports inference only; backward does not model the quantized P path"
             )
+        if is_paged and (
+            q.requires_grad
+            or k.requires_grad
+            or v.requires_grad
+            or k_cache.requires_grad
+            or v_cache.requires_grad
+        ):
+            raise NotImplementedError(
+                "paged KV cache supports inference only; backward does not model k_cache/v_cache"
+            )

Also applies to: 1021-1024

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 979 - 980, Add a guard that
rejects paged-KV when autograd is enabled: where is_paged = k_cache is not None
is computed, check torch.is_grad_enabled() (or equivalent autograd check) and
raise a clear RuntimeError explaining that paged k/v/block_table is incompatible
with autograd because gradients are rebuilt from contiguous k/v; make this
change for both occurrences around the is_paged checks (the block using
k_cache/v_cache/block_table near the is_paged assignment and the other spot at
the 1021-1024 section) so callers get an explicit error instead of
incorrect/missing gradients.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
yeyu-nvidia and others added 24 commits April 8, 2026 11:37
…ling

- Add --clip-model arg for local CLIP model path (avoids HF rate limits)
- Read HF_TOKEN / HUGGING_FACE_HUB_TOKEN env vars for authenticated downloads
- Catch OSError from rate limits / network errors with a helpful message
- Add `import os` for env var access

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Implements Python-level NVFP4 E2M1 quantization (SA3-inspired):
- Channel-wise mean smoothing before quantization
- Per-token scale to NVFP4 range (max 6.0)
- Round to nearest of 8 representable levels via torch.bucketize
- No CUDA kernel required, works on any GPU

Use --kernel nvfp4 to compare accuracy vs FP8 and baseline.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
The CLIP score range depends on the model/prompt distribution;
remove the hardcoded "0.25-0.35" and instead tell users to focus
on the baseline vs quantized delta.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Quantizing V to FP4 causes severe color/texture loss because V carries
the actual content that gets weighted and summed into the output.
SA3 applies NVFP4 to the post-softmax attention probability matrix P,
not to V directly. V stays in BF16 for the P*V multiply.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
V quantized to NVFP4 (8 levels) caused visible color/texture degradation.
SA3 paper applies NVFP4 only to post-softmax P matrix; V stays in higher
precision. Use FP8 E4M3 (448 levels) for V to match SA2-style V handling
while keeping NVFP4 for Q/K.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
FP8 V still causes color/texture degradation. V carries actual content
that gets weighted and summed — any quantization noise is directly
visible. Keep V in BF16, matching SA3 design where only Q/K (routing
weights) are quantized to NVFP4.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Previous approach quantized Q and K to NVFP4, which caused visible quality
degradation. SA3 (arXiv 2505.11594) applies NVFP4 to the post-softmax P
matrix, not to Q/K/V.

Implement manual attention (Q@K^T -> scale -> mask -> softmax -> NVFP4(P)
-> P@V) with per-row max-based scaling of P to NVFP4 range [0, 6].
Q, K, V stay in BF16. This faithfully simulates SA3 quantization behavior
for accuracy measurement.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
NVFP4 (8 levels) caused generation trajectory divergence on WAN2.2 video
diffusion due to dense attention patterns. Replace with unsigned INT4
(16 levels, 0..15) per-row scaled quantization of the post-softmax P
matrix — same SA3-inspired approach (quantize P not Q/K/V) but with
2x more precision levels to better cover dense probability distributions.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Previous implementations used per-row scaling, causing trajectory divergence
in video diffusion: a single row-max outlier dominated the scale, wasting
most of the 8 FP4 levels on near-zero values.

SA3 quantizes P at flash-attention tile granularity (per-tile, not per-row).
Implement per-tile NVFP4 E2M1 of the post-softmax P matrix: reshape P into
64x64 tiles, compute per-tile max scale, quantize to nearest NVFP4 level,
dequantize back. Each local tile's value range is fully covered by the 8 levels.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Materialising the full N×N float32 attention matrix for WAN2.2 (~8k tokens,
24 heads) requires ~6 GB, causing OOM on 48 GB GPUs already loaded with the
model. Process Q in 512-row chunks: each chunk produces a (512 x N) P slice,
which is quantized per-tile to NVFP4 and accumulated into the output.

Peak memory for the attention matrix drops from O(N^2) to O(chunk x N).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Smaller tiles give finer-grained scales and better accuracy, at the cost
of more overhead (acceptable for Python-level simulation). Add --nvfp4-tile
CLI arg in TRxTC format (e.g. '1x64', '32x32', '64x64') to control the
tile shape for per-tile NVFP4 quantization of the post-softmax P matrix.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Implements a ModelOpt sparse attention plugin for diffusers WAN models,
building on the triton_fa kernel infrastructure from PR #1127.

New files:
- modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
  - ModelOptWanAttnProcessor: replaces WanAttnProcessor, calls
    triton_fa.attention() directly with BSND->varlen conversion.
    Supports I2V cross-attention path and N:M/skip-softmax sparsity.
  - WanSparseAttentionModule: subclasses SparseAttentionModule, installs
    the Triton processor and syncs enabled state on each forward.
  - register_wan_sparse_attention(): plugin callback auto-registered in
    CUSTOM_MODEL_PLUGINS; fires during mtsa.sparsify().

Updated files:
- plugins/__init__.py: lazy-import diffusers plugin via import_plugin()
- config.py: add "diffusers_triton" to validate_backend whitelist
- conversion.py: skip HF attn registration for "diffusers_triton" backend
- wan2_sage_attention.py: add triton-sparse and triton-skip kernel options
  backed by mtsa.sparsify() with diffusers_triton backend

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Covers:
- "diffusers_triton" backend config validation
- WanAttention registration with SparseAttentionRegistry
- Module and processor type after mtsa.sparsify()
- Correct sparse_kw populated for triton_sparse_softmax and triton_skip_softmax
- Forward output shape (via SDPA fallback on CPU)
- enable/disable flag propagation into ModelOptWanAttnProcessor

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
WanSparseAttentionModule.forward() syncs proc._enabled from self.is_enabled
on every call, so directly setting proc._enabled was immediately overwritten.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- Add _QuantWanAttnProcessor: manual Q@K^T->softmax->softmax_quantizer->@v
  with BSND layout, I2V cross-attention support, and q/k/v_bmm_quantizer hooks
- Add _QuantWanAttention(_QuantAttentionModuleMixin): installs the processor in
  _setup(); forward() bypasses _QuantFunctionalMixin patching context since the
  processor calls TensorQuantizers directly
- Register _QuantWanAttention for WanAttention (replaces generic mixin that
  went through F.sdpa and never invoked softmax_quantizer)
- Add NVFP4_WAN_SOFTMAX_CFG: enables only *softmax_quantizer with dynamic
  NVFP4 block_size=16; no calibration needed (dynamic scaling)
- Add nvfp4-modelopt kernel to wan2_sage_attention.py example via
  apply_nvfp4_modelopt_kernel() using mtq.quantize(transformer, NVFP4_WAN_SOFTMAX_CFG)
- Add unit tests for registry, processor type, quantizer attributes,
  config structure, and forward shape

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…float32 matrix

WAN2.2 video sequences exceed 8k tokens; materialising [B, S, H, T] in float32
before softmax exhausts GPU memory. Process query rows in chunks of 512 (matching
the existing nvfp4 Python kernel) so peak memory stays bounded.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Extends the ModelOpt Triton flash-attention kernel (triton_fa.py) with
per-tile NVFP4 E2M1 P-matrix quantization, enabling fused sparse+quantized
attention in a single kernel pass.

Key changes:
- triton_fa.py: add _quantize_p_nvfp4() Triton helper (per-tile max scaling,
  boundary-compare to 8 NVFP4 levels); add QUANTIZE_P constexpr to _attn_fwd
  (applied in both standard and skip-softmax paths, STE for backward); thread
  quantize_p: bool = False through _Attention and attention()
- sparsity/attention_sparsity/config.py: add quantize_p field to
  SparseAttentionAttributeConfig
- sparsity/attention_sparsity/plugins/diffusers.py: extract quantize_p in
  _build_sparse_kw() so it is forwarded to triton_fa.attention()
- examples/diffusers/quantization/wan2_sage_attention.py: remove nvfp4 and
  nvfp4-modelopt Python-level kernels; add triton-sparse-nvfp4 and
  triton-skip-nvfp4 kernel options that use the fused Triton path
- quantization/plugins/diffusion/diffusers.py: revert WAN-specific
  _QuantWanAttnProcessor/_QuantWanAttention; restore WanAttention ->
  _QuantAttentionModuleMixin registration
- quantization/config.py: remove NVFP4_WAN_SOFTMAX_CFG
- tests: remove test_wan_quant.py (superseded by Triton-based approach)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
WanAttention is a new-style diffusers AttentionModuleMixin that uses
dispatch_attention_fn internally and does NOT call self.processor().
The previous forward() delegated to super().forward() (WanAttention),
which bypassed our ModelOptWanAttnProcessor entirely, producing output
byte-identical to the SDPA baseline.

Fix: intercept in WanSparseAttentionModule.forward() and call the
processor directly. The processor's __call__ handles both paths:
- enabled=True  → _wan_forward_triton (ModelOpt Triton kernel)
- enabled=False → _wan_forward_sdpa   (dispatch_attention_fn fallback)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…F generic

The HuggingFace generic plugin was imported before the diffusers plugin in
plugins/__init__.py. This caused register_sparse_attention_on_the_fly to
register WanAttention with _GenericSparseAttention first; the diffusers plugin
then found WanAttention already registered and skipped, so
WanSparseAttentionModule / ModelOptWanAttnProcessor were never installed.

Fix: swap import order so the diffusers plugin registers WanAttention with
WanSparseAttentionModule before the HF generic plugin runs.

Signed-off-by: Ye Yu <yeyu@nvidia.com>
lambda=0.1 (ln=−2.3) is too aggressive for WAN's 8190-token video
sequences with relatively diffuse attention: it skips tiles where the
max score is >2.3 below the running max, which includes many tiles
with genuine contribution to the output, causing PSNR≈11 dB.

Lower default to 0.01 (ln=−4.6) for both triton-skip and
triton-skip-nvfp4 kernels.  Add --skip-threshold CLI arg so the
threshold can be swept without editing code:

  --skip-threshold 0.1    aggressive (high sparsity, lower quality)
  --skip-threshold 0.01   moderate (default)
  --skip-threshold 0.001  conservative (near-baseline quality)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
The BLASST (https://arxiv.org/pdf/2512.12087) criterion checks
ln(lambda) on the sm_scale-SCALED attention logits a_ij = q·k/sqrt(d).

The Triton kernel stores scores as x = a * log2(e), so the correct
threshold in kernel (log2) space is log2(lambda), not log2(lambda)*sm_scale.

Previous code multiplied by sm_scale (~0.088 for head_dim=128), making
every threshold 11× too aggressive. With lambda=0.1 the kernel-space
threshold was -0.29 instead of the correct -3.32, skipping most attention
tiles and producing garbage output (PSNR~11 dB). Even lambda=0.0001 was
still too aggressive (-1.18 vs correct -13.29).

Fix: use `log2(lambda)` directly as SKIP_THRESHOLD_LOG2, and restore the
default threshold to 0.1 (the standard BLASST value).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
- CLIP scoring: use CPU to avoid OOM with WAN pipeline on GPU; catch
  RuntimeError in addition to OSError
- --skip-threshold help: fix description to match actual exp(tile_max -
  running_max) < lambda criterion
- vLLM worker: reject unsupported sparse presets (non-triton backend or
  unknown method) with a clear ValueError instead of silently degrading
  to dense attention
- PYTHONPATH construction: use os.pathsep and skip empty entries to
  avoid CWD injection when PYTHONPATH is unset
- diffusers_triton backend: raise ValueError when mixed with other
  backends instead of silently skipping _attn_implementation setup
- _wan_forward_triton: fall back to SDPA when attention_mask is not
  None to preserve masking semantics

Signed-off-by: Ye Yu <yeyu@nvidia.com>
- quantize_p: raise NotImplementedError when any of q/k/v requires_grad
  since backward does not model the quantized P path (inference-only)
- b_start_loc_k: only synthesize dummy tensor in paged mode; raise
  ValueError in contiguous separate-KV path when b_start_loc_k is None;
  also validate that v_cache and block_table are provided alongside k_cache

Signed-off-by: Ye Yu <yeyu@nvidia.com>
The * sm_scale factor is intentional: it scales the tile-skip threshold
relative to head dimension, so larger head_dim (smaller sm_scale) produces
more aggressive sparsity for the same lambda value. The previous 'fix' was
incorrect.

Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/sage-attention-diffusion branch from 708f113 to 3f0bfd3 Compare April 8, 2026 18:38
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 8, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1190/

Built to branch gh-pages at 2026-04-08 19:47 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 8, 2026

Codecov Report

❌ Patch coverage is 43.82022% with 350 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.27%. Comparing base (df80a0f) to head (356c517).
⚠️ Report is 20 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/kernels/triton_fa.py 13.63% 114 Missing ⚠️
...attention_sparsity/kernels/ltx_triton_attention.py 5.71% 66 Missing ⚠️
...ion_sparsity/kernels/diffusers_triton_attention.py 49.47% 48 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 53.75% 37 Missing ⚠️
.../attention_sparsity/kernels/ltx_eager_attention.py 11.11% 32 Missing ⚠️
.../torch/sparsity/attention_sparsity/plugins/vllm.py 0.00% 30 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 64.00% 9 Missing ⚠️
...lopt/torch/quantization/sage_attention/__init__.py 70.58% 5 Missing ⚠️
...tion_sparsity/kernels/diffusers_eager_attention.py 93.61% 3 Missing ⚠️
...arsity/attention_sparsity/calibration/calibrate.py 90.90% 1 Missing ⚠️
... and 5 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1190      +/-   ##
==========================================
- Coverage   74.77%   74.27%   -0.51%     
==========================================
  Files         351      359       +8     
  Lines       40072    43949    +3877     
==========================================
+ Hits        29964    32641    +2677     
- Misses      10108    11308    +1200     
Flag Coverage Δ
examples 43.64% <22.31%> (+3.41%) ⬆️
gpu 56.53% <30.49%> (-0.58%) ⬇️
unit 54.74% <30.81%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

♻️ Duplicate comments (6)
modelopt/torch/kernels/triton_fa.py (2)

1005-1012: ⚠️ Potential issue | 🔴 Critical

Reject paged KV when autograd is enabled.

Forward reads K/V from k_cache / v_cache, but backward still reconstructs gradients only from contiguous k / v and returns nothing for the cache tensors. That is silently wrong once the paged path is used with dummy k/v or differentiable caches, so this mode needs the same inference-only guard as quantize_p.

Suggested fix
         if quantize_p and (q.requires_grad or k.requires_grad or v.requires_grad):
             raise NotImplementedError(
                 "quantize_p supports inference only; backward does not model the quantized P path"
             )
+        if is_paged and (
+            q.requires_grad
+            or k.requires_grad
+            or v.requires_grad
+            or (k_cache is not None and k_cache.requires_grad)
+            or (v_cache is not None and v_cache.requires_grad)
+        ):
+            raise NotImplementedError(
+                "paged KV cache supports inference only; backward does not model k_cache/v_cache"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 1005 - 1012, The forward
path currently allows using paged KV (k_cache/v_cache) while autograd is
enabled, but the backward only reconstructs grads from contiguous k/v and omits
cache tensors; add the same inference-only guard used for quantize_p: detect
when paged KV is in use (check k_cache or v_cache presence/usage) and if any of
q.requires_grad, k.requires_grad, or v.requires_grad are true, raise
NotImplementedError with a clear message like "paged KV supports inference only;
backward does not model the paged K/V path" (mirror the existing quantize_p
check) so the code in the function handling forward (the block containing
quantize_p and references to k_cache/v_cache) rejects autograd-enabled paged KV.

973-997: ⚠️ Potential issue | 🟠 Major

Derive paged layout from k_cache / v_cache, not from the placeholder k.

In paged mode the caller can pass dummy contiguous k/v, but num_kv_heads, kv_group_num, and the page geometry are still taken from k and page_size. A mismatch silently remaps Q heads to the wrong KV heads or indexes the cache with the wrong page size. Validate the cache shapes and derive paged metadata from the cache tensors themselves.

Suggested fix
         HEAD_DIM = q.shape[2]
         num_q_heads = q.shape[1]
-        num_kv_heads = k.shape[1]
-        kv_group_num = num_q_heads // num_kv_heads
         batch = b_seq_len.shape[0]
 
-        is_paged = k_cache is not None
+        is_paged = any(x is not None for x in (k_cache, v_cache, block_table))
+        if is_paged:
+            if k_cache is None or v_cache is None or block_table is None:
+                raise ValueError("k_cache, v_cache, and block_table must all be provided together")
+            if k_cache.shape[1] != v_cache.shape[1] or page_size != k_cache.shape[1]:
+                raise ValueError("page_size must match k_cache/v_cache.shape[1]")
+            if k_cache.shape[2:] != v_cache.shape[2:]:
+                raise ValueError("k_cache and v_cache must share num_kv_heads/head_dim")
+            num_kv_heads = k_cache.shape[2]
+        else:
+            num_kv_heads = k.shape[1]
+
+        if num_q_heads % num_kv_heads != 0:
+            raise ValueError("num_q_heads must be divisible by num_kv_heads")
+        kv_group_num = num_q_heads // num_kv_heads
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 973 - 997, When paged mode
is detected (k_cache/v_cache provided) don't derive KV head counts or page
geometry from the placeholder tensor k; instead read and validate num_kv_heads,
kv_group_num and page_size (and any page layout metadata) from k_cache and
v_cache and use those values to compute kv_group_num and related variables used
later in triton_fa.py (e.g., num_kv_heads, kv_group_num, page_size,
b_seq_len_k/b_start_loc_k defaults). Add explicit shape checks that k_cache and
v_cache agree with each other and with block_table (and that num_q_heads is
divisible by num_kv_heads), and raise a clear ValueError if they mismatch; when
b_start_loc_k is needed, construct/derive it from the cache metadata rather than
the placeholder k. Ensure all references to num_kv_heads/kv_group_num/page_size
later in the function use the cache-derived values.
examples/diffusers/quantization/wan2_sage_attention.py (4)

370-385: ⚠️ Potential issue | 🟡 Minor

Caller kwargs still dropped on SDPA fallback paths.

_patched_sdpa() accepts **kwargs (line 360) but both fallback paths drop them:

  • Lines 370-378: Early fallback doesn't forward **kwargs
  • Line 385: Exception fallback doesn't forward **kwargs

This causes silent semantic changes if callers pass enable_gqa=True or other SDPA kwargs.

Suggested fix
         return _orig_sdpa(
             query,
             key,
             value,
             attn_mask=attn_mask,
             dropout_p=dropout_p,
             is_causal=is_causal,
             scale=scale,
+            **kwargs,
         )
@@
-        return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale)
+        return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 370 -
385, _patched_sdpa currently drops caller kwargs on both fallback paths causing
silently changed behavior; update the calls so that any received **kwargs are
forwarded: call _run_kernel(query, key, value, is_causal=is_causal, scale=scale,
**kwargs) and call _orig_sdpa(query, key, value, attn_mask=attn_mask,
dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) (and likewise
in the exception fallback) so _patched_sdpa, _run_kernel and _orig_sdpa receive
the original caller options (references: _patched_sdpa, _run_kernel, _orig_sdpa,
_sage_calls, _active_kernel).

504-531: ⚠️ Potential issue | 🟡 Minor

Missing validation for non-positive skip_threshold values.

The skip-softmax path depends on log2(skip_softmax_threshold) in the Triton kernel. Passing 0 or a negative value should fail early with a clear error instead of causing cryptic failures in the kernel.

Suggested fix
 def apply_triton_sparse_kernel(
     transformer: torch.nn.Module,
     kernel: str,
     skip_threshold: float | None = None,
 ) -> None:
     ...
     config = copy.deepcopy(_TRITON_KERNEL_CONFIGS[kernel])
-    if skip_threshold is not None and kernel in (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4):
+    if skip_threshold is not None:
+        if kernel in (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4):
+            if skip_threshold <= 0:
+                raise ValueError(f"skip_threshold must be > 0, got {skip_threshold}")
         config["sparse_cfg"]["*"]["skip_softmax_threshold"] = skip_threshold
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 504 -
531, The apply_triton_sparse_kernel function does not validate skip_threshold
for the skip-softmax kernels (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4),
which can lead to cryptic failures when zero or negative values are used; add an
early validation inside apply_triton_sparse_kernel that when kernel is one of
the skip-softmax constants and skip_threshold is not None, raise a ValueError if
skip_threshold <= 0 with a clear message (e.g., require skip_threshold > 0)
before mutating config["sparse_cfg"]["*"]["skip_softmax_threshold"] so invalid
values never reach the Triton kernel.

815-817: ⚠️ Potential issue | 🟡 Minor

Pipeline loads before kernel availability validation.

The model loads at line 817 (load_pipeline) before kernel availability is checked (lines 824, 827, 898, 903). If the requested kernel isn't available, the multi-GB model download/load is wasted.

Move the args.kernel in AVAILABLE_KERNELS check immediately after parse_args().

Suggested fix
 def main() -> None:
     args = parse_args()
+
+    # Fail fast if kernel is unavailable
+    if args.kernel not in AVAILABLE_KERNELS and not args.baseline and not args.benchmark:
+        raise RuntimeError(
+            f"Kernel {args.kernel!r} is not available. "
+            f"Available kernels: {AVAILABLE_KERNELS}"
+        )
+
     pipe = load_pipeline(args.model)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 815 -
817, The code currently calls load_pipeline(args.model) before validating kernel
availability; move the kernel check immediately after parse_args() in main() so
you verify args.kernel in AVAILABLE_KERNELS before any heavy model download.
Concretely, inside main(), call args = parse_args(), then perform the
AVAILABLE_KERNELS membership check (and exit/raise with a clear message if
invalid), and only then call load_pipeline(args.model); update/remove the later
redundant kernel checks around load_pipeline/usage (referencing main(),
parse_args(), load_pipeline, and AVAILABLE_KERNELS).

589-596: ⚠️ Potential issue | 🟡 Minor

PSNR calculation still triggers divide-by-zero warning due to np.where eager evaluation.

np.where() evaluates both branches before selecting, so 255.0**2 / mse_per_frame computes even where mse_per_frame < 1e-10, triggering divide-by-zero warnings for identical frames.

Suggested fix
     # PSNR
     mse_per_frame = ((ref - quant) ** 2).mean(axis=(1, 2, 3))  # (N,)
     # Avoid log(0) for identical frames
-    psnr_per_frame = np.where(
-        mse_per_frame < 1e-10,
-        100.0,
-        10.0 * np.log10(255.0**2 / mse_per_frame),
-    )
+    safe_mse = np.maximum(mse_per_frame, 1e-10)
+    psnr_per_frame = 10.0 * np.log10(255.0**2 / safe_mse)
+    psnr_per_frame = np.where(mse_per_frame < 1e-10, 100.0, psnr_per_frame)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 589 -
596, The PSNR code uses np.where which still triggers divide-by-zero because
both branches are evaluated; fix by first protecting the denominator (e.g.,
compute safe_mse = np.maximum(mse_per_frame, 1e-10)) then compute the PSNR
values from safe_mse (psnr_vals = 10.0 * np.log10(255.0**2 / safe_mse)) and
finally set exact-zero cases to 100.0 with psnr_per_frame =
np.where(mse_per_frame < 1e-10, 100.0, psnr_vals) and keep psnr =
float(psnr_per_frame.mean()); update the code that references mse_per_frame,
psnr_per_frame, and psnr accordingly.
🧹 Nitpick comments (9)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)

387-401: Make the sparse-context setup exception-safe and reentrant.

If one of the later enter_context() calls raises, the thread-local flag stays stuck on True because the ExitStack is never unwound. The callback also always restores False, so a nested sparse context would clear an outer scope too early. Build the stack under with ExitStack() as stack: and restore the previous flag before return stack.pop_all().

♻️ Proposed fix
-        from ..kernels import set_skip_softmax_context
+        from ..kernels import get_skip_softmax_context, set_skip_softmax_context

-        stack = ExitStack()
-        set_skip_softmax_context(True)
-        stack.callback(set_skip_softmax_context, False)
-
-        try:
-            from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backend
-
-            stack.enter_context(get_skip_softmax_attention_backend())
-        except (ImportError, RuntimeError):
-            pass
-
-        stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax))
-        return stack
+        previous_skip_softmax = get_skip_softmax_context()
+        with ExitStack() as stack:
+            set_skip_softmax_context(True)
+            stack.callback(set_skip_softmax_context, previous_skip_softmax)
+
+            try:
+                from ..kernels.diffusers_eager_attention import (
+                    get_skip_softmax_attention_backend,
+                )
+
+                stack.enter_context(get_skip_softmax_attention_backend())
+            except (ImportError, RuntimeError):
+                pass
+
+            stack.enter_context(
+                replace_function(torch.nn.functional, "softmax", sparse_softmax)
+            )
+            return stack.pop_all()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 387 - 401, The current setup sets a thread-local flag via
set_skip_softmax_context(True) and builds an ExitStack but if any
enter_context() raises the flag remains True and nested contexts are not
reentrant; change the pattern to build the stack inside a with ExitStack() as
stack: block so any exception will unwind and restore, capture the previous flag
value before setting it, call set_skip_softmax_context(True), and before
returning call and return stack.pop_all() while restoring the previous flag via
set_skip_softmax_context(previous_value); make the adjustments around the
functions get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax), and the
ExitStack usage so the callback restoring the flag no longer clobbers outer
contexts.
modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py (1)

50-61: Defer these optional backend imports until registration time.

Importing diffusers/LTX backends from __init__ means every consumer of modelopt.torch.sparsity.attention_sparsity.kernels pays those import attempts, and real registration failures degrade into a silent None export via suppress(...). Resolve these symbols lazily when registration is actually needed instead. As per coding guidelines: "Use lazy imports and gate optional dependencies via import_plugin() for integrations (HuggingFace, Megatron, etc.); avoid hard imports at module level for optional features."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py` around lines
50 - 61, The module currently does eager imports of optional backends
(register_diffusers_eager_attention, register_diffusers_triton_attention,
register_ltx_eager_attention, register_ltx_triton_attention) inside __init__.py
using contextlib.suppress, which forces all consumers to attempt those imports;
instead, remove these top-level imports and make the four registration callables
resolved lazily at registration time (e.g., implement an import_plugin() or
inline import inside the functions that perform registration), so each of
register_diffusers_eager_attention, register_diffusers_triton_attention,
register_ltx_eager_attention and register_ltx_triton_attention is imported only
when actually invoked and ImportError/RuntimeError are handled there.
modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py (2)

72-131: Significant code duplication with diffusers_triton_attention.py.

The _ltx_triton_attention function (lines 72-131) shares nearly identical logic with _diffusers_triton_attention in diffusers_triton_attention.py:

  • Same varlen metadata construction
  • Same calibration mode branching
  • Same inference mode threshold handling

Consider extracting the shared core into a common helper to reduce maintenance burden.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
around lines 72 - 131, The two functions _ltx_triton_attention and
_diffusers_triton_attention duplicate varlen metadata construction,
calibration/inference branching and Triton call logic; refactor by extracting a
shared helper (e.g., _triton_attention_core) that accepts parameters needed to
build q_flat/k_flat/v_flat, seq lengths, heads/dim_head, and optional flags
(threshold, calibration trials), move the common kw construction and
calibration/inference handling (including use of _thread_local,
attention_calibrate, attention) into that helper, then have
_ltx_triton_attention and _diffusers_triton_attention prepare their
layout-specific tensors and call the new helper to return the reshaped output.

147-153: register_ltx_triton_attention doesn't track registration state.

Unlike register_diffusers_triton_attention which uses _BACKEND_REGISTERED to prevent double-registration, this function relies only on the isinstance(fn, _TritonLTXAttentionWrapper) check. While this works, it's inconsistent with the diffusers pattern and scans all modules on every call.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
around lines 147 - 153, register_ltx_triton_attention currently scans all
modules every call and only avoids double-wrapping by checking isinstance(fn,
_TritonLTXAttentionWrapper); make it consistent with
register_diffusers_triton_attention by adding a module-level flag
_BACKEND_REGISTERED (or reusing the same symbol) to short-circuit if already
registered. Modify register_ltx_triton_attention to check _BACKEND_REGISTERED at
the top and set it to True after the first successful registration pass, while
still keeping the existing isinstance(fn, _TritonLTXAttentionWrapper) guard and
targeting Attention and _TritonLTXAttentionWrapper as before.
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py (1)

169-179: Fragile internal API manipulation — document diffusers version dependency.

The registration directly manipulates AttentionBackendName._member_map_ and _AttentionBackendRegistry._backends internals. This works but is fragile if diffusers changes its internal structure.

Consider adding a comment noting the diffusers version this was tested against, or wrapping in a try/except to provide a clear error if the internal API changes.

Suggested documentation
 def register_diffusers_triton_attention() -> None:
     """Register ``modelopt_triton`` backend in diffusers.
 
     Safe to call multiple times; registration happens only once.
+
+    Note: This manipulates diffusers internal APIs (_member_map_, _backends, etc.)
+    and was tested with diffusers >= 0.31. May require updates for future versions.
     """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 169 - 179, The code directly mutates diffusers internals
(AttentionBackendName._member_map_, AttentionBackendRegistry._backends, etc.),
which is fragile; add a brief comment near the block stating the diffusers
version(s) this was tested against and expected internal shape, and wrap the
registration logic for AttentionBackendName and _AttentionBackendRegistry (the
new_member creation, map assignments, and setting of _supported_arg_names via
inspect.signature(_diffusers_triton_attention)) in a try/except that catches
AttributeError/KeyError and raises a clear, actionable RuntimeError explaining
that the diffusers internal API changed and what fields are expected (e.g.,
_member_map_, _value2member_map_, _backends, _constraints, _supported_arg_names)
so maintainers can update accordingly.
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (2)

172-192: Calibration counter retrieval tries diffusers first, then LTX — order may matter.

The code tries to get counters from diffusers backend first (lines 176-181), then falls back to LTX (lines 183-189). If both backends are active in the same forward pass, only one set of counters will be collected. Per context snippets 1-2, each backend maintains independent thread-local storage.

This is likely fine for single-backend usage, but consider documenting that mixed-backend calibration is not supported, or aggregating counters from both backends if both are present.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 172 - 192, _collect_calibration_stats currently only uses the first
available get_calibration_counters (diffusers_triton_attention) and ignores the
other backend, which prevents collecting counters when both backends are active;
change the logic in _collect_calibration_stats to attempt to import
get_calibration_counters from both ..kernels.diffusers_triton_attention and
..kernels.ltx_triton_attention, call each present function and merge/aggregate
their returned counters (e.g., sum corresponding metrics or combine entries)
into a single counters object before proceeding with the existing threshold
logic (referencing get_calibration_counters and self._threshold_trials to decide
whether aggregation is needed), or alternatively add a clear docstring note on
mixed-backend non-support if aggregation is not implemented.

110-128: Hardcoded sequence length 4224 limits calibration accuracy for varying input sizes.

The threshold calculation uses a fixed seqlen = 4224 (line 127), which may not match actual runtime sequence lengths. The TODO acknowledges this, but for diffusion models with variable resolution/frame counts, this could lead to suboptimal threshold selection.

Consider exposing this as a configurable parameter or computing it from the model/input metadata when available.

Would you like me to suggest an approach for passing actual sequence length at runtime?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 110 - 128, The method _get_effective_threshold currently divides
scale_factor by a hardcoded seqlen (4224); change it to use a real sequence
length by accepting an explicit seqlen (e.g., add a seqlen parameter to
_get_effective_threshold) or by reading a sequence-length property from the
passed module (e.g., module.seq_len or module.input_shape) and fall back to the
previous default when unavailable; update callers that compute thresholds to
pass the actual runtime seqlen (or set a configurable class attribute like
self.default_seqlen) so calibration_params/target_sparse_ratio are scaled
correctly instead of using 4224, and keep returning self.skip_softmax_threshold
when no calibration info exists.
modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py (2)

339-353: _build_sparse_kw silently ignores unknown methods.

If _method is neither "triton_sparse_softmax" nor "triton_skip_softmax", the method returns an empty dict (plus optional quantize_p). This could lead to silent misconfiguration if a user specifies an invalid method name.

Consider logging a warning for unknown methods.

Suggested improvement
     def _build_sparse_kw(self) -> dict:
         """Extract triton_fa kwargs from the current method config."""
         cfg = getattr(self, "_method_config", {})
         method = getattr(self, "_method", "")
         kw: dict = {}
         if method == "triton_sparse_softmax":
             kw["sparsity_n"] = cfg.get("sparsity_n", 2)
             kw["sparsity_m"] = cfg.get("sparsity_m", 4)
             kw["num_sink_tokens"] = cfg.get("num_sink_tokens", 0)
             kw["dense_window_size"] = cfg.get("dense_window_size", 64)
         elif method == "triton_skip_softmax":
             kw["skip_softmax_threshold"] = cfg.get("skip_softmax_threshold", 0.1)
+        elif method:
+            import warnings
+            warnings.warn(f"Unknown sparse method '{method}'; sparse_kw will be empty")
         if cfg.get("quantize_p", False):
             kw["quantize_p"] = True
         return kw
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py` around lines
339 - 353, The _build_sparse_kw function currently ignores unknown _method
values silently; update it to detect when method is non-empty and not one of the
supported values ("triton_sparse_softmax", "triton_skip_softmax") and emit a
warning including the invalid method name and available options to aid
debugging. Locate the _build_sparse_kw method and add a membership check (e.g.
allowed_methods = {"triton_sparse_softmax","triton_skip_softmax"}); if method
and method not in allowed_methods, call a logger on the instance (use
getattr(self, "_logger", logging.getLogger(__name__))) to log a warning that
includes method and allowed_methods, then continue building kw (still respect
quantize_p as currently implemented).

224-229: Hardcoded context length 512 is fragile if WAN architecture changes.

Line 227 uses a hardcoded 512 for the text-encoder context length. The comment acknowledges this is a "WAN hardcoded constant", but if the WAN model changes or different variants use different context lengths, this will break silently.

Consider extracting this as a class constant or reading it from the attention module's configuration if available.

Suggested improvement
+_WAN_TEXT_CONTEXT_LENGTH = 512  # WAN hardcoded constant for text-encoder context

 class ModelOptWanAttnProcessor:
     ...
     def _wan_forward_triton(self, ...):
         ...
         if attn.add_k_proj is not None:
-            # 512 is the text-encoder context length (WAN hardcoded constant)
             assert encoder_hidden_states is not None
-            image_context_length = encoder_hidden_states.shape[1] - 512
+            image_context_length = encoder_hidden_states.shape[1] - _WAN_TEXT_CONTEXT_LENGTH
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py` around lines
224 - 229, The code slices encoder_hidden_states using a hardcoded 512 (seen in
the block checking attn.add_k_proj and computing image_context_length) which is
brittle; replace that magic number with a configurable value read from the
attention module or plugin class: attempt to read a context length from attn
(e.g., attn.config.text_encoder_context_length or a similar attribute) and fall
back to a plugin/class constant like TEXT_ENCODER_CONTEXT_LEN if not present,
then use that variable when computing image_context_length and slicing
encoder_hidden_states; keep the assert for encoder_hidden_states and add a
warning/log if neither config nor constant exists and you must fall back to 512.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 22-24: Update the example header text to reflect that calibration
now uses triton_skip_softmax with the Triton backend (not flash_skip_softmax
with the eager backend); locate the header block at the top of the file and
change any mentions of flash_skip_softmax/eager to triton_skip_softmax/Triton so
it matches the implementation in build_sparse_config(), ensuring documentation
and code paths align for debugging and comparison.

In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 164-166: The current check uses a falsy test on threshold and will
skip valid zero values; update the condition around threshold retrieval to test
for None explicitly (e.g., replace "if threshold:" with "if threshold is not
None:") so that when layer_cfg.get("skip_softmax_threshold") returns 0 the key
"skip_softmax_threshold" is still set on sparse_kw; locate the threshold
variable, layer_cfg.get call, and sparse_kw assignment in sparse_attn_worker.py
to apply this change.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 259-266: The current guard `if forward_loop is None and
(calibrate_prefill or calibrate_decode):` prevents building the tokenizer and
RULER decode dataset whenever a custom `forward_loop` is provided, so built-in
decode calibration never runs when `target_sparse_ratio.decode > 0`; change the
condition to always build the tokenizer/calibration_data when `calibrate_decode`
is true (even if `forward_loop` is provided) while keeping the existing prefill
behavior for `calibrate_prefill`. Specifically, call
`_extract_tokenizer_from_model(model)` and construct the RULER dataset whenever
`calibrate_decode` is enabled (or `target_sparse_ratio.decode > 0`), and only
gate prefill dataset construction on `forward_loop is None` and
`calibrate_prefill`, ensuring `tokenizer`/`calibration_data` are available to
the decode calibration code paths (see `forward_loop`, `calibrate_prefill`,
`calibrate_decode`, `tokenizer`, `calibration_data`, and
`_extract_tokenizer_from_model`).

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 127-156: The diffuses/ltx registration blocks currently swallow
all exceptions; change them to only suppress expected optional-dependency
errors: in the ModelMixin/diffusers block catch only ImportError (remove broad
except (ImportError, Exception)) so that after importing ModelMixin you still
check and call register_diffusers_eager_attention and
register_diffusers_triton_attention but let any runtime errors from those calls
raise; in the ltx block stop using contextlib.suppress(Exception) around
register_ltx_eager_attention/register_ltx_triton_attention and instead wrap the
import in except (ImportError, RuntimeError) as you already do, and for the
actual calls use a narrow try/except that only catches ImportError or
RuntimeError so real bugs in
register_ltx_eager_attention/register_ltx_triton_attention surface.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`:
- Around line 79-81: The code currently always adds attn_mask to scores which is
incorrect when attn_mask is a boolean mask; detect boolean masks before the
additive path and convert them to an additive mask (e.g., torch.where(attn_mask,
neg_inf, 0) or attn_mask.to(scores.dtype) * neg_inf) using scores.device and
scores.dtype to construct neg_inf (like -torch.finfo(scores.dtype).max or -1e9
for float32), then add that converted mask to scores; update the code around the
scores and attn_mask usage to only add when attn_mask is already additive or
after converting when attn_mask.dtype is bool.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 130-143: When calibration_mode (thread-local "calibration_mode")
is enabled but either "threshold_trials" (local variable trials) is
missing/empty or the calibration kernel "attention_calibrate" is None, the
function should not silently fall through to inference; instead detect this case
and raise a clear RuntimeError (or at least log an error and raise) that
includes which condition failed (missing trials vs missing attention_calibrate)
and mentions the function/mode involved (the calibration branch around
attention_calibrate(q, k, v, **kw, threshold_trials=trials)). Keep the existing
logic that accumulates _thread_local.calibration_counters when calibration runs,
but before falling back, explicitly check calib_mode && (not trials or
attention_calibrate is None) and raise with a descriptive message so callers
know calibration was requested but cannot proceed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py`:
- Around line 103-114: The registration is non-idempotent because
register_ltx_eager_attention only skips when module.attention_function is a
_SkipSoftmaxLTXAttentionWrapper, so if _TritonLTXAttentionWrapper is outermost
it will wrap again; change the logic to always operate on the underlying base
callable: implement a small unwrap helper (e.g., get_base_attention_fn) that,
given a callable, repeatedly drills into known wrapper attributes (the wrappers
used here: _SkipSoftmaxLTXAttentionWrapper and _TritonLTXAttentionWrapper expose
the original as an inner attribute such as .fn or .wrapped_fn) until a
non-wrapper callable is reached, then if the base callable is not already
wrapped by the eager wrapper create a fresh
_SkipSoftmaxLTXAttentionWrapper(base_fn) and assign that to
module.attention_function (or rewrap preserving the other wrapper by
reconstructing outer wrappers around the new eager wrapper), and update
register_ltx_triton_attention similarly so both registrations first unwrap to
the base before reapplying their wrapper.

In `@modelopt/torch/sparsity/attention_sparsity/stats_manager.py`:
- Around line 67-75: The code assumes all incoming sparse_blocks vectors have
the same length and increments total_blocks even when sparse_blocks is missing;
fix by (1) tracking a separate counter like
self.aggregated_stats["sparse_blocks_reports"] (increment only when
stats.get("sparse_blocks") is not None) and use that for denominators, (2) when
receiving incoming = stats.get("sparse_blocks"), ensure the accumulator
self.aggregated_stats["sparse_blocks"] is resized to accommodate varying widths
(extend with zeros if incoming is longer, or pad incoming with zeros if shorter)
before elementwise addition, and (3) do not rely on total_blocks for
sparse_blocks normalization — only count calls that actually reported
sparse_blocks. Ensure all references use self.aggregated_stats, incoming,
stats.get("sparse_blocks"), and "total_blocks" so the changes integrate with
existing logic.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 82-96: Tests are reusing cached backend modules across fixture
runs causing stale mocks; ensure you remove cached modules from sys.modules
before reimporting. Inside the patch.dict(sys.modules, _mock_diffusers())
context (and likewise for the triton fixture), pop any existing
"modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention"
and
"modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention"
entries from sys.modules, then import the module and access
_diffusers_eager_attention, register_diffusers_eager_attention,
get_skip_softmax_attention_backend and reset mod._BACKEND_REGISTERED to
guarantee a fresh module and mocks for each fixture run.

---

Duplicate comments:
In `@examples/diffusers/quantization/wan2_sage_attention.py`:
- Around line 370-385: _patched_sdpa currently drops caller kwargs on both
fallback paths causing silently changed behavior; update the calls so that any
received **kwargs are forwarded: call _run_kernel(query, key, value,
is_causal=is_causal, scale=scale, **kwargs) and call _orig_sdpa(query, key,
value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal,
scale=scale, **kwargs) (and likewise in the exception fallback) so
_patched_sdpa, _run_kernel and _orig_sdpa receive the original caller options
(references: _patched_sdpa, _run_kernel, _orig_sdpa, _sage_calls,
_active_kernel).
- Around line 504-531: The apply_triton_sparse_kernel function does not validate
skip_threshold for the skip-softmax kernels (KERNEL_TRITON_SKIP,
KERNEL_TRITON_SKIP_NVFP4), which can lead to cryptic failures when zero or
negative values are used; add an early validation inside
apply_triton_sparse_kernel that when kernel is one of the skip-softmax constants
and skip_threshold is not None, raise a ValueError if skip_threshold <= 0 with a
clear message (e.g., require skip_threshold > 0) before mutating
config["sparse_cfg"]["*"]["skip_softmax_threshold"] so invalid values never
reach the Triton kernel.
- Around line 815-817: The code currently calls load_pipeline(args.model) before
validating kernel availability; move the kernel check immediately after
parse_args() in main() so you verify args.kernel in AVAILABLE_KERNELS before any
heavy model download. Concretely, inside main(), call args = parse_args(), then
perform the AVAILABLE_KERNELS membership check (and exit/raise with a clear
message if invalid), and only then call load_pipeline(args.model); update/remove
the later redundant kernel checks around load_pipeline/usage (referencing
main(), parse_args(), load_pipeline, and AVAILABLE_KERNELS).
- Around line 589-596: The PSNR code uses np.where which still triggers
divide-by-zero because both branches are evaluated; fix by first protecting the
denominator (e.g., compute safe_mse = np.maximum(mse_per_frame, 1e-10)) then
compute the PSNR values from safe_mse (psnr_vals = 10.0 * np.log10(255.0**2 /
safe_mse)) and finally set exact-zero cases to 100.0 with psnr_per_frame =
np.where(mse_per_frame < 1e-10, 100.0, psnr_vals) and keep psnr =
float(psnr_per_frame.mean()); update the code that references mse_per_frame,
psnr_per_frame, and psnr accordingly.

In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 1005-1012: The forward path currently allows using paged KV
(k_cache/v_cache) while autograd is enabled, but the backward only reconstructs
grads from contiguous k/v and omits cache tensors; add the same inference-only
guard used for quantize_p: detect when paged KV is in use (check k_cache or
v_cache presence/usage) and if any of q.requires_grad, k.requires_grad, or
v.requires_grad are true, raise NotImplementedError with a clear message like
"paged KV supports inference only; backward does not model the paged K/V path"
(mirror the existing quantize_p check) so the code in the function handling
forward (the block containing quantize_p and references to k_cache/v_cache)
rejects autograd-enabled paged KV.
- Around line 973-997: When paged mode is detected (k_cache/v_cache provided)
don't derive KV head counts or page geometry from the placeholder tensor k;
instead read and validate num_kv_heads, kv_group_num and page_size (and any page
layout metadata) from k_cache and v_cache and use those values to compute
kv_group_num and related variables used later in triton_fa.py (e.g.,
num_kv_heads, kv_group_num, page_size, b_seq_len_k/b_start_loc_k defaults). Add
explicit shape checks that k_cache and v_cache agree with each other and with
block_table (and that num_q_heads is divisible by num_kv_heads), and raise a
clear ValueError if they mismatch; when b_start_loc_k is needed,
construct/derive it from the cache metadata rather than the placeholder k.
Ensure all references to num_kv_heads/kv_group_num/page_size later in the
function use the cache-derived values.

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py`:
- Around line 50-61: The module currently does eager imports of optional
backends (register_diffusers_eager_attention,
register_diffusers_triton_attention, register_ltx_eager_attention,
register_ltx_triton_attention) inside __init__.py using contextlib.suppress,
which forces all consumers to attempt those imports; instead, remove these
top-level imports and make the four registration callables resolved lazily at
registration time (e.g., implement an import_plugin() or inline import inside
the functions that perform registration), so each of
register_diffusers_eager_attention, register_diffusers_triton_attention,
register_ltx_eager_attention and register_ltx_triton_attention is imported only
when actually invoked and ImportError/RuntimeError are handled there.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 169-179: The code directly mutates diffusers internals
(AttentionBackendName._member_map_, AttentionBackendRegistry._backends, etc.),
which is fragile; add a brief comment near the block stating the diffusers
version(s) this was tested against and expected internal shape, and wrap the
registration logic for AttentionBackendName and _AttentionBackendRegistry (the
new_member creation, map assignments, and setting of _supported_arg_names via
inspect.signature(_diffusers_triton_attention)) in a try/except that catches
AttributeError/KeyError and raises a clear, actionable RuntimeError explaining
that the diffusers internal API changed and what fields are expected (e.g.,
_member_map_, _value2member_map_, _backends, _constraints, _supported_arg_names)
so maintainers can update accordingly.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 72-131: The two functions _ltx_triton_attention and
_diffusers_triton_attention duplicate varlen metadata construction,
calibration/inference branching and Triton call logic; refactor by extracting a
shared helper (e.g., _triton_attention_core) that accepts parameters needed to
build q_flat/k_flat/v_flat, seq lengths, heads/dim_head, and optional flags
(threshold, calibration trials), move the common kw construction and
calibration/inference handling (including use of _thread_local,
attention_calibrate, attention) into that helper, then have
_ltx_triton_attention and _diffusers_triton_attention prepare their
layout-specific tensors and call the new helper to return the reshaped output.
- Around line 147-153: register_ltx_triton_attention currently scans all modules
every call and only avoids double-wrapping by checking isinstance(fn,
_TritonLTXAttentionWrapper); make it consistent with
register_diffusers_triton_attention by adding a module-level flag
_BACKEND_REGISTERED (or reusing the same symbol) to short-circuit if already
registered. Modify register_ltx_triton_attention to check _BACKEND_REGISTERED at
the top and set it to True after the first successful registration pass, while
still keeping the existing isinstance(fn, _TritonLTXAttentionWrapper) guard and
targeting Attention and _TritonLTXAttentionWrapper as before.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 387-401: The current setup sets a thread-local flag via
set_skip_softmax_context(True) and builds an ExitStack but if any
enter_context() raises the flag remains True and nested contexts are not
reentrant; change the pattern to build the stack inside a with ExitStack() as
stack: block so any exception will unwind and restore, capture the previous flag
value before setting it, call set_skip_softmax_context(True), and before
returning call and return stack.pop_all() while restoring the previous flag via
set_skip_softmax_context(previous_value); make the adjustments around the
functions get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax), and the
ExitStack usage so the callback restoring the flag no longer clobbers outer
contexts.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 172-192: _collect_calibration_stats currently only uses the first
available get_calibration_counters (diffusers_triton_attention) and ignores the
other backend, which prevents collecting counters when both backends are active;
change the logic in _collect_calibration_stats to attempt to import
get_calibration_counters from both ..kernels.diffusers_triton_attention and
..kernels.ltx_triton_attention, call each present function and merge/aggregate
their returned counters (e.g., sum corresponding metrics or combine entries)
into a single counters object before proceeding with the existing threshold
logic (referencing get_calibration_counters and self._threshold_trials to decide
whether aggregation is needed), or alternatively add a clear docstring note on
mixed-backend non-support if aggregation is not implemented.
- Around line 110-128: The method _get_effective_threshold currently divides
scale_factor by a hardcoded seqlen (4224); change it to use a real sequence
length by accepting an explicit seqlen (e.g., add a seqlen parameter to
_get_effective_threshold) or by reading a sequence-length property from the
passed module (e.g., module.seq_len or module.input_shape) and fall back to the
previous default when unavailable; update callers that compute thresholds to
pass the actual runtime seqlen (or set a configurable class attribute like
self.default_seqlen) so calibration_params/target_sparse_ratio are scaled
correctly instead of using 4224, and keep returning self.skip_softmax_threshold
when no calibration info exists.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py`:
- Around line 339-353: The _build_sparse_kw function currently ignores unknown
_method values silently; update it to detect when method is non-empty and not
one of the supported values ("triton_sparse_softmax", "triton_skip_softmax") and
emit a warning including the invalid method name and available options to aid
debugging. Locate the _build_sparse_kw method and add a membership check (e.g.
allowed_methods = {"triton_sparse_softmax","triton_skip_softmax"}); if method
and method not in allowed_methods, call a logger on the instance (use
getattr(self, "_logger", logging.getLogger(__name__))) to log a warning that
includes method and allowed_methods, then continue building kw (still respect
quantize_p as currently implemented).
- Around line 224-229: The code slices encoder_hidden_states using a hardcoded
512 (seen in the block checking attn.add_k_proj and computing
image_context_length) which is brittle; replace that magic number with a
configurable value read from the attention module or plugin class: attempt to
read a context length from attn (e.g., attn.config.text_encoder_context_length
or a similar attribute) and fall back to a plugin/class constant like
TEXT_ENCODER_CONTEXT_LEN if not present, then use that variable when computing
image_context_length and slicing encoder_hidden_states; keep the assert for
encoder_hidden_states and add a warning/log if neither config nor constant
exists and you must fall back to 512.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f4a94f9e-f0a7-436f-bbed-e4e2ad821a33

📥 Commits

Reviewing files that changed from the base of the PR and between 708f113 and 3f0bfd3.

📒 Files selected for processing (30)
  • examples/diffusers/README.md
  • examples/diffusers/quantization/wan2_sage_attention.py
  • examples/diffusers/sparsity/README.md
  • examples/diffusers/sparsity/wan22_skip_softmax.py
  • examples/vllm_serve/sparse_attn_worker.py
  • examples/vllm_serve/vllm_serve_sparse_attn.py
  • modelopt/torch/kernels/__init__.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/quantization/plugins/diffusion/diffusers.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • pyproject.toml
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.py
  • tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py
  • tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
✅ Files skipped from review due to trivial changes (4)
  • modelopt/torch/quantization/plugins/diffusion/diffusers.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • examples/diffusers/README.md
  • pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/torch/sparsity/attention_sparsity/plugins/init.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py

Comment on lines +22 to +24
During calibration, ``flash_skip_softmax`` with the eager attention backend
collects sparsity statistics across multiple threshold trials. The fitted
exponential model then allows runtime control of the target sparsity ratio
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Sync the example header with the actual calibration path.

The file header still says calibration uses flash_skip_softmax with the eager backend, but build_sparse_config() now wires triton_skip_softmax with the Triton path. That will send users to the wrong codepath when they debug or compare results.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 22 - 24,
Update the example header text to reflect that calibration now uses
triton_skip_softmax with the Triton backend (not flash_skip_softmax with the
eager backend); locate the header block at the top of the file and change any
mentions of flash_skip_softmax/eager to triton_skip_softmax/Triton so it matches
the implementation in build_sparse_config(), ensuring documentation and code
paths align for debugging and comparison.

Comment on lines +164 to +166
threshold = layer_cfg.get("skip_softmax_threshold")
if threshold:
sparse_kw["skip_softmax_threshold"] = threshold
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Threshold check if threshold: treats 0 as falsy, skipping valid zero threshold.

Line 165 uses if threshold: which will be False when threshold=0. If a user intentionally sets skip_softmax_threshold: 0 to disable skip-softmax, the key won't be added to sparse_kw, potentially causing different behavior than expected.

Consider using if threshold is not None: for explicit None check.

Suggested fix
         threshold = layer_cfg.get("skip_softmax_threshold")
-        if threshold:
+        if threshold is not None:
             sparse_kw["skip_softmax_threshold"] = threshold
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/vllm_serve/sparse_attn_worker.py` around lines 164 - 166, The
current check uses a falsy test on threshold and will skip valid zero values;
update the condition around threshold retrieval to test for None explicitly
(e.g., replace "if threshold:" with "if threshold is not None:") so that when
layer_cfg.get("skip_softmax_threshold") returns 0 the key
"skip_softmax_threshold" is still set on sparse_kw; locate the threshold
variable, layer_cfg.get call, and sparse_kw assignment in sparse_attn_worker.py
to apply this change.

Comment on lines +259 to +266
# Extract tokenizer and build calibration data only if no forward_loop is provided.
# When the user supplies their own forward_loop (e.g. for diffusion models),
# RULER dataset generation is skipped entirely.
tokenizer = None
calibration_data = None

if calibrate_prefill or calibrate_decode:
if forward_loop is None and (calibrate_prefill or calibrate_decode):
tokenizer = _extract_tokenizer_from_model(model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Still build decode calibration data when forward_loop is provided.

forward_loop is documented as prefill-only, but this guard skips tokenizer/dataset construction for decode as soon as any custom loop is passed. With target_sparse_ratio.decode > 0, the function now always hits Lines 313-314 instead of running the built-in decode calibration. Build the RULER dataset whenever decode calibration is enabled, even if prefill uses a custom loop.

🔧 Proposed fix
-    if forward_loop is None and (calibrate_prefill or calibrate_decode):
+    needs_builtin_dataset = calibrate_decode or forward_loop is None
+    if needs_builtin_dataset and (calibrate_prefill or calibrate_decode):
         tokenizer = _extract_tokenizer_from_model(model)
         builder = RulerDatasetBuilder(
             samples=calib_config.samples,
             max_seqlen=calib_config.max_seqlen,
             tokenizer_name_or_path=tokenizer,

Also applies to: 313-314

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 259 - 266, The current guard `if forward_loop is None and
(calibrate_prefill or calibrate_decode):` prevents building the tokenizer and
RULER decode dataset whenever a custom `forward_loop` is provided, so built-in
decode calibration never runs when `target_sparse_ratio.decode > 0`; change the
condition to always build the tokenizer/calibration_data when `calibrate_decode`
is true (even if `forward_loop` is provided) while keeping the existing prefill
behavior for `calibrate_prefill`. Specifically, call
`_extract_tokenizer_from_model(model)` and construct the RULER dataset whenever
`calibrate_decode` is enabled (or `target_sparse_ratio.decode > 0`), and only
gate prefill dataset construction on `forward_loop is None` and
`calibrate_prefill`, ensuring `tokenizer`/`calibration_data` are available to
the decode calibration code paths (see `forward_loop`, `calibrate_prefill`,
`calibrate_decode`, `tokenizer`, `calibration_data`, and
`_extract_tokenizer_from_model`).

Comment on lines +79 to +81
# Apply attention mask if provided
if attn_mask is not None:
scores = scores + attn_mask
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /root/repo && find . -name "diffusers_eager_attention.py" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 126


🏁 Script executed:

cd /root/repo && wc -l modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 126


🏁 Script executed:

cd /root/repo && sed -n '60,100p' modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 126


🏁 Script executed:

cd /root/repo && head -100 modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | tail -50

Repository: NVIDIA/Model-Optimizer

Length of output: 126


🏁 Script executed:

fd diffusers_eager_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 145


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | head -110

Repository: NVIDIA/Model-Optimizer

Length of output: 4474


🏁 Script executed:

rg "torch.nn.functional.scaled_dot_product_attention|torch.ops.aten.scaled_dot_product_attention" -A 5 -B 5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg "attn_mask|attention_mask" modelopt/torch/sparsity/attention_sparsity/kernels/ -A 3 -B 3

Repository: NVIDIA/Model-Optimizer

Length of output: 2471


🏁 Script executed:

rg "scaled_dot_product_attention" --type py -A 10 -B 5

Repository: NVIDIA/Model-Optimizer

Length of output: 40075


🏁 Script executed:

fd "*.py" modelopt/torch/sparsity/attention_sparsity/ | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 349


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | sed -n '1,150p'

Repository: NVIDIA/Model-Optimizer

Length of output: 6853


🏁 Script executed:

rg "attn_mask" modelopt/torch/sparsity/attention_sparsity/ -B 5 -A 5 --type py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 8090


🏁 Script executed:

fd --glob "*.py" -x grep -l "diffusers.*attn_mask\|attn_mask.*diffusers" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py | sed -n '1,200p'

Repository: NVIDIA/Model-Optimizer

Length of output: 9214


🏁 Script executed:

rg "attention_mask" modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py -B 10 -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 4438


🏁 Script executed:

web_search
PyTorch scaled_dot_product_attention attn_mask boolean float mask semantics

Repository: NVIDIA/Model-Optimizer

Length of output: 162


🌐 Web query:

PyTorch torch.nn.functional.scaled_dot_product_attention attn_mask parameter documentation

💡 Result:

In torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, ...), the attn_mask parameter is:

  • Type: optional Tensor
  • Shape: must be broadcastable to the attention-weight shape (N, ..., L, S), where L = query.size(-2) and S = key.size(-2). (docs.pytorch.org)
  • Supported mask kinds:
    1. Boolean mask (dtype=torch.bool): True means “allowed to attend / takes part in attention”; False positions are effectively masked out (internally turned into -inf bias before softmax). (docs.pytorch.org)
    2. Float/additive mask (same dtype as query/key/value): values are added to the attention scores (commonly 0 for keep and -inf for mask). (docs.pytorch.org)
  • Interaction with is_causal: it’s an error to set both attn_mask and is_causal=True. (docs.pytorch.org)

Note: PyTorch’s Transformer module explicitly warns that its boolean mask semantics are the opposite of SDPA’s attn_mask boolean semantics (Transformer uses True = “not allowed”). (docs.pytorch.org)

Citations:


🏁 Script executed:

fd test -type f -name "*.py" modelopt/torch/sparsity/attention_sparsity/ | xargs grep -l "diffusers_eager_attention\|eager.*attn_mask" 2>/dev/null | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 236


🏁 Script executed:

rg "test.*diffusers.*attention|diffusers.*test" modelopt/torch/sparsity/attention_sparsity/ -l --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg "attn_mask.*bool|bool.*attn_mask" --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 346


🏁 Script executed:

rg "dispatch_attention_fn" modelopt/torch/sparsity/attention_sparsity/ -B 5 -A 10 --type py | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 8321


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | sed -n '110,150p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1760


🏁 Script executed:

rg "register.*backend\|_AttentionBackendRegistry" modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py -A 15

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg "dtype.*bool|bool.*dtype" modelopt/torch/sparsity/attention_sparsity/ -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1582


🏁 Script executed:

python3 << 'EOF'
# Verify the masking semantics claimed in the review
import torch

# Simulate attention scores
scores = torch.randn(2, 4, 8, 8)

# Case 1: Boolean mask (SDPA-style)
bool_mask = torch.ones(2, 4, 8, 8, dtype=torch.bool)
bool_mask[0, 0, :, 7] = False  # Mask out last position

# What the current code does (WRONG for boolean)
wrong_result = scores + bool_mask.float()  # Treats False as 0, True as 1
print("Boolean mask treated as float (WRONG):")
print(f"  masked position bias: {wrong_result[0, 0, 0, 7].item():.4f}")
print(f"  unmasked position bias: {wrong_result[0, 0, 0, 6].item():.4f}")

# What SDPA expects for boolean mask (right semantics)
correct_result = scores.clone()
correct_result = correct_result.masked_fill(bool_mask.logical_not(), float("-inf"))
print("\nBoolean mask handled correctly (what SDPA expects):")
print(f"  masked position bias: {correct_result[0, 0, 0, 7].item():.4f}")
print(f"  unmasked position bias: {correct_result[0, 0, 0, 6].item():.4f}")

# Case 2: Float/additive mask (what current code assumes)
float_mask = torch.zeros(2, 4, 8, 8)
float_mask[0, 0, :, 7] = float("-inf")  # Mask out last position

proper_add = scores + float_mask
print("\nFloat additive mask (current code works):")
print(f"  masked position bias: {proper_add[0, 0, 0, 7].item():.4f}")
print(f"  unmasked position bias: {proper_add[0, 0, 0, 6].item():.4f}")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 183


🏁 Script executed:

rg "dispatch_attention_fn" --type py -A 20 | grep -A 20 "def dispatch_attention_fn\|attn_mask"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg "dispatch_attention_fn\|_diffusers_eager_attention" --type py -l | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

find . -name "*test*.py" -type f -exec grep -l "diffusers_eager_attention\|diffusers.*attention" {} \; | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 208


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | sed -n '1,110p' | tail -30

Repository: NVIDIA/Model-Optimizer

Length of output: 1162


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 6349


🏁 Script executed:

rg "attn_mask\|attention_mask" tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py | sed -n '1,100p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4134


🏁 Script executed:

rg "attn_mask.*bool\|bool.*attn_mask" tests/unit/torch/sparsity/attention_sparsity/ -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '98,124p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1098


🏁 Script executed:

rg "def test.*attn_mask\|def test.*mask" tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg "bool.*mask|mask.*bool" modelopt/torch/sparsity/attention_sparsity/ --type py -B 2 -A 2 | grep -E "(eager|triton|diffusers)" -A 2 -B 2

Repository: NVIDIA/Model-Optimizer

Length of output: 740


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '98,180p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3614


🏁 Script executed:

rg "attn_mask\|attention_mask" tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

grep -n "attn_mask" tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

wc -l tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py

Repository: NVIDIA/Model-Optimizer

Length of output: 138


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | tail -100

Repository: NVIDIA/Model-Optimizer

Length of output: 4496


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | sed -n '87,160p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3505


🏁 Script executed:

rg "dispatch_attention_fn" --type py -A 30 | grep -A 30 "def dispatch\|attn_mask"

Repository: NVIDIA/Model-Optimizer

Length of output: 3283


🏁 Script executed:

rg "def.*eager.*attention\|def.*diffusers.*attention" --type py -l

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Handle boolean masks before the additive path.

attn_mask on SDPA-style backends can be boolean or additive. This code always adds it, so a boolean mask becomes 0/1 instead of -inf and returns wrong attention whenever diffusers passes a bool mask.

Suggested fix
     # Apply attention mask if provided
     if attn_mask is not None:
-        scores = scores + attn_mask
+        if attn_mask.dtype == torch.bool:
+            scores = scores.masked_fill(attn_mask.logical_not(), float("-inf"))
+        else:
+            scores = scores + attn_mask
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`
around lines 79 - 81, The code currently always adds attn_mask to scores which
is incorrect when attn_mask is a boolean mask; detect boolean masks before the
additive path and convert them to an additive mask (e.g., torch.where(attn_mask,
neg_inf, 0) or attn_mask.to(scores.dtype) * neg_inf) using scores.device and
scores.dtype to construct neg_inf (like -torch.finfo(scores.dtype).max or -1e9
for float32), then add that converted mask to scores; update the code around the
scores and attn_mask usage to only add when attn_mask is already additive or
after converting when attn_mask.dtype is bool.

Comment on lines +130 to +143
calib_mode = getattr(_thread_local, "calibration_mode", False)
if calib_mode:
trials = getattr(_thread_local, "threshold_trials", None)
if trials and attention_calibrate is not None:
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)

# Accumulate counters across all attention calls in this forward pass
prev = getattr(_thread_local, "calibration_counters", None)
if prev is None:
_thread_local.calibration_counters = counters
else:
_thread_local.calibration_counters = prev + counters

return o.view(batch, seq_q, num_heads_q, head_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Calibration mode silently returns early if trials is empty or attention_calibrate is None.

When calibration_mode=True but trials is empty/None or attention_calibrate is unavailable, the code falls through to inference mode instead of raising an error or warning. This could silently produce unexpected results if a user enables calibration but the calibration kernel isn't available.

Suggested improvement
     calib_mode = getattr(_thread_local, "calibration_mode", False)
     if calib_mode:
         trials = getattr(_thread_local, "threshold_trials", None)
-        if trials and attention_calibrate is not None:
+        if not trials:
+            raise ValueError("calibration_mode=True but threshold_trials is empty")
+        if attention_calibrate is None:
+            raise RuntimeError("calibration_mode=True but attention_calibrate is unavailable (requires CUDA + triton)")
+        if trials and attention_calibrate is not None:
             o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 130 - 143, When calibration_mode (thread-local "calibration_mode")
is enabled but either "threshold_trials" (local variable trials) is
missing/empty or the calibration kernel "attention_calibrate" is None, the
function should not silently fall through to inference; instead detect this case
and raise a clear RuntimeError (or at least log an error and raise) that
includes which condition failed (missing trials vs missing attention_calibrate)
and mentions the function/mode involved (the calibration branch around
attention_calibrate(q, k, v, **kw, threshold_trials=trials)). Keep the existing
logic that accumulates _thread_local.calibration_counters when calibration runs,
but before falling back, explicitly check calib_mode && (not trials or
attention_calibrate is None) and raise with a descriptive message so callers
know calibration was requested but cannot proceed.

Comment on lines +103 to +114
def register_ltx_eager_attention(model: torch.nn.Module) -> None:
"""Walk *model* and patch all ``ltx_core.model.transformer.attention.Attention`` modules.

Patches modules so their ``attention_function`` is routed through the eager wrapper.
Safe to call multiple times on the same model — already-wrapped modules are
skipped.
"""
for module in model.modules():
if isinstance(module, Attention):
fn = module.attention_function
if not isinstance(fn, _SkipSoftmaxLTXAttentionWrapper):
module.attention_function = _SkipSoftmaxLTXAttentionWrapper(fn)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if the file exists and examine the full ltx_eager_attention.py
find . -type f -name "ltx_eager_attention.py" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 141


🏁 Script executed:

#!/bin/bash
# Read the complete ltx_eager_attention.py file
cat modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3926


🏁 Script executed:

#!/bin/bash
# Read the ltx_triton_attention.py file to check wrapper type and guard logic
cat modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 5621


🏁 Script executed:

#!/bin/bash
# Check conversion.py to see how registrations are called
grep -n "register_ltx" modelopt/torch/sparsity/attention_sparsity/conversion.py -A 3 -B 3

Repository: NVIDIA/Model-Optimizer

Length of output: 633


This registration becomes non-idempotent after both wrappers are applied.

The guard only recognizes _SkipSoftmaxLTXAttentionWrapper. After register_ltx_triton_attention wraps the same attention_function with _TritonLTXAttentionWrapper, a second call to register_ltx_eager_attention will wrap again because the outer callable is no longer the eager wrapper type. Since conversion.py:151–156 calls both registrations sequentially, repeated conversions will stack nested wrappers instead of being safely reapplied. Both functions should either detect any existing wrapper (not just their own type), or restore to the base callable before rewrapping.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py`
around lines 103 - 114, The registration is non-idempotent because
register_ltx_eager_attention only skips when module.attention_function is a
_SkipSoftmaxLTXAttentionWrapper, so if _TritonLTXAttentionWrapper is outermost
it will wrap again; change the logic to always operate on the underlying base
callable: implement a small unwrap helper (e.g., get_base_attention_fn) that,
given a callable, repeatedly drills into known wrapper attributes (the wrappers
used here: _SkipSoftmaxLTXAttentionWrapper and _TritonLTXAttentionWrapper expose
the original as an inner attribute such as .fn or .wrapped_fn) until a
non-wrapper callable is reached, then if the base callable is not already
wrapped by the eager wrapper create a fresh
_SkipSoftmaxLTXAttentionWrapper(base_fn) and assign that to
module.attention_function (or rewrap preserving the other wrapper by
reconstructing outer wrappers around the new eager wrapper), and update
register_ltx_triton_attention similarly so both registrations first unwrap to
the base before reapplying their wrapper.

Comment on lines 67 to +75
self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0)

incoming = stats["sparse_blocks"]
if "sparse_blocks" not in self.aggregated_stats:
self.aggregated_stats["sparse_blocks"] = list(incoming)
else:
for i, val in enumerate(incoming):
self.aggregated_stats["sparse_blocks"][i] += val
incoming = stats.get("sparse_blocks")
if incoming is not None:
if "sparse_blocks" not in self.aggregated_stats:
self.aggregated_stats["sparse_blocks"] = list(incoming)
else:
for i, val in enumerate(incoming):
self.aggregated_stats["sparse_blocks"][i] += val
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle changing sparse_blocks widths before accumulating.

Line 75 assumes every incoming vector has the same length as the first one. That is no longer guaranteed once calibration paths can emit multi-threshold stats while inference emits single-threshold stats, so this can either raise IndexError or mix incompatible totals. Also, total_blocks keeps growing even when sparse_blocks is missing, which will underreport summary sparsity. Reset/resize the accumulator and use a denominator that only counts calls that actually reported block data.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/stats_manager.py` around lines 67
- 75, The code assumes all incoming sparse_blocks vectors have the same length
and increments total_blocks even when sparse_blocks is missing; fix by (1)
tracking a separate counter like self.aggregated_stats["sparse_blocks_reports"]
(increment only when stats.get("sparse_blocks") is not None) and use that for
denominators, (2) when receiving incoming = stats.get("sparse_blocks"), ensure
the accumulator self.aggregated_stats["sparse_blocks"] is resized to accommodate
varying widths (extend with zeros if incoming is longer, or pad incoming with
zeros if shorter) before elementwise addition, and (3) do not rely on
total_blocks for sparse_blocks normalization — only count calls that actually
reported sparse_blocks. Ensure all references use self.aggregated_stats,
incoming, stats.get("sparse_blocks"), and "total_blocks" so the changes
integrate with existing logic.

Comment on lines +82 to +96
with patch.dict(sys.modules, _mock_diffusers()):
from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention import (
_diffusers_eager_attention,
get_skip_softmax_attention_backend,
register_diffusers_eager_attention,
)

self._fn = _diffusers_eager_attention
self._register = register_diffusers_eager_attention
self._get_backend = get_skip_softmax_attention_backend

import modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention as mod

mod._BACKEND_REGISTERED = False
yield
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

fd test_kernel_backends.py

Repository: NVIDIA/Model-Optimizer

Length of output: 134


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py

Repository: NVIDIA/Model-Optimizer

Length of output: 8940


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '98,123p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1090


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '160,171p'

Repository: NVIDIA/Model-Optimizer

Length of output: 487


🏁 Script executed:

fd diffusers_eager_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 145


🏁 Script executed:

fd diffusers_triton_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 146


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 3377


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 3541


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | sed -n '26,40p'

Repository: NVIDIA/Model-Optimizer

Length of output: 446


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | sed -n '26,45p'

Repository: NVIDIA/Model-Optimizer

Length of output: 632


🏁 Script executed:

python3 << 'EOF'
import sys
import types
from unittest.mock import patch

# Simulate the issue
def _mock_diffusers_v1():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.NAME = "MOCK_V1"
    return {"diffusers.models.attention_dispatch": m}

def _mock_diffusers_v2():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.NAME = "MOCK_V2"
    return {"diffusers.models.attention_dispatch": m}

# First fixture run
print("=== First fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v1()):
    from diffusers.models.attention_dispatch import NAME as NAME1
    print(f"First import: NAME = {NAME1}")
    print(f"Module in sys.modules: {'diffusers.models.attention_dispatch' in sys.modules}")

print(f"After first fixture, module still in sys.modules: {'diffusers.models.attention_dispatch' in sys.modules}")

# Second fixture run  
print("\n=== Second fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v2()):
    from diffusers.models.attention_dispatch import NAME as NAME2
    print(f"Second import: NAME = {NAME2}")
    print(f"sys.modules entry NAME: {sys.modules['diffusers.models.attention_dispatch'].NAME}")
    print(f"NAME1 vs NAME2: {NAME1} vs {NAME2} - Same? {NAME1 == NAME2}")

EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 343


🏁 Script executed:

python3 << 'EOF'
import sys
import types
from unittest.mock import patch

def _mock_diffusers_v1():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.NAME = "MOCK_V1"
    m.VALUE = 100
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

def _mock_diffusers_v2():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.NAME = "MOCK_V2"
    m.VALUE = 200
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

# Create a backend module that imports from mocked diffusers
backend_code = '''
from diffusers.models.attention_dispatch import NAME, VALUE
imported_name = NAME
imported_value = VALUE
'''

print("=== First fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v1()):
    # Simulate backend module being imported
    backend_v1 = types.ModuleType("backend")
    exec(backend_code, backend_v1.__dict__)
    sys.modules["backend"] = backend_v1
    print(f"First import: backend.imported_name = {backend_v1.imported_name}")
    print(f"First import: backend.imported_value = {backend_v1.imported_value}")

print(f"After first fixture, backend in sys.modules: {'backend' in sys.modules}")
print(f"Cached backend values: name={sys.modules['backend'].imported_name}, value={sys.modules['backend'].imported_value}")

print("\n=== Second fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v2()):
    # Try to import backend again (would reuse cached)
    from backend import imported_name, imported_value
    print(f"Second import: imported_name = {imported_name}")
    print(f"Second import: imported_value = {imported_value}")
    print(f"Expected V2 values (200, 'MOCK_V2'), got ({imported_value}, '{imported_name}')")
    print(f"Mismatch? {imported_value != 200}")

EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 342


🏁 Script executed:

python3 << 'EOF'
import sys
import types
from unittest.mock import patch
import importlib

def _mock_diffusers_v1():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.NAME = "MOCK_V1"
    m.VALUE = 100
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

def _mock_diffusers_v2():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.NAME = "MOCK_V2"
    m.VALUE = 200
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

# Create a fake backend module that imports from diffusers
backend_code = '''
from diffusers.models.attention_dispatch import NAME, VALUE
backend_name = NAME
backend_value = VALUE
'''

print("=== First fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v1()):
    backend_v1 = types.ModuleType("fake_backend")
    exec(backend_code, backend_v1.__dict__)
    sys.modules["fake_backend"] = backend_v1
    print(f"First: backend.backend_name = {backend_v1.backend_name}")
    print(f"First: backend.backend_value = {backend_v1.backend_value}")

print(f"After patch context, fake_backend still in sys.modules: {'fake_backend' in sys.modules}")
print(f"Cached values: {sys.modules['fake_backend'].backend_name}, {sys.modules['fake_backend'].backend_value}")

print("\n=== Second fixture run (with new mocks) ===")
with patch.dict(sys.modules, _mock_diffusers_v2()):
    # Import the cached backend module again
    from fake_backend import backend_name, backend_value
    print(f"Second: backend_name = {backend_name}")
    print(f"Second: backend_value = {backend_value}")
    print(f"ISSUE: Got old mock values (MOCK_V1, 100) instead of new (MOCK_V2, 200)")

EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 342


🏁 Script executed:

python3 << 'EOF'
import sys
import types
from unittest.mock import patch

def _mock_diffusers_v1():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.NAME = "MOCK_V1"
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

def _mock_diffusers_v2():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.NAME = "MOCK_V2"
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

# Test 1: Check if backend modules persist in sys.modules
print("=== Simulating actual test fixture pattern ===")
print("\nFirst fixture run:")
with patch.dict(sys.modules, _mock_diffusers_v1()):
    import diffusers.models.attention_dispatch as d1
    print(f"Inside patch: diffusers NAME = {d1.NAME}")
    print(f"Inside patch, 'diffusers.models.attention_dispatch' in sys.modules: {('diffusers.models.attention_dispatch' in sys.modules)}")

print(f"Outside patch, 'diffusers.models.attention_dispatch' in sys.modules: {('diffusers.models.attention_dispatch' in sys.modules)}")

print("\nSecond fixture run (with different mock):")
with patch.dict(sys.modules, _mock_diffusers_v2()):
    import diffusers.models.attention_dispatch as d2
    print(f"Inside patch: diffusers NAME = {d2.NAME}")
    print(f"Are they the same object? {d1 is d2}")

EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 433


🏁 Script executed:

python3 << 'EOF'
import sys
import types
from unittest.mock import patch

def _mock_diffusers_v1():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.FakeRegistry = {"backends": {}}
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

def _mock_diffusers_v2():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.FakeRegistry = {"backends": {}}
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

# Simulate backend module
backend_code = '''
from diffusers.models.attention_dispatch import FakeRegistry
_BACKEND_REGISTERED = False

def register_backend():
    global _BACKEND_REGISTERED
    _BACKEND_REGISTERED = True
    FakeRegistry["backends"]["test"] = "value"
'''

print("=== Simulating stale module state across fixtures ===")

print("\nFirst fixture run:")
with patch.dict(sys.modules, _mock_diffusers_v1()):
    backend_v1 = types.ModuleType("my_backend")
    exec(backend_code, backend_v1.__dict__)
    sys.modules["my_backend"] = backend_v1
    
    print(f"First: FakeRegistry id = {id(backend_v1.FakeRegistry)}")
    print(f"First: Before register - _BACKEND_REGISTERED = {backend_v1._BACKEND_REGISTERED}")
    backend_v1.register_backend()
    print(f"First: After register - _BACKEND_REGISTERED = {backend_v1._BACKEND_REGISTERED}")
    print(f"First: Registry state = {backend_v1.FakeRegistry}")

print(f"\nAfter first fixture, my_backend in sys.modules: {'my_backend' in sys.modules}")

print("\n--- Second fixture run (simulating next test) ---")
# Don't use patch.dict for second run to simulate what happens when module is already imported
if "my_backend" in sys.modules:
    backend_v2 = sys.modules["my_backend"]
    print(f"Second: Using cached backend module")
    print(f"Second: FakeRegistry id = {id(backend_v2.FakeRegistry)}")
    print(f"Second: _BACKEND_REGISTERED before reset = {backend_v2._BACKEND_REGISTERED}")
    backend_v2._BACKEND_REGISTERED = False
    print(f"Second: _BACKEND_REGISTERED after reset = {backend_v2._BACKEND_REGISTERED}")
    print(f"Second: Registry state (STALE!) = {backend_v2.FakeRegistry}")
    print(f"ISSUE: Registry still has old backends from first test!")

EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 446


🏁 Script executed:

python3 << 'EOF'
import sys
import types
from unittest.mock import patch

def _mock_diffusers_v1():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.FakeRegistry = {"backends": {}}
    print(f"  Created MOCK_V1: {id(m.FakeRegistry)}")
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

def _mock_diffusers_v2():
    m = types.ModuleType("diffusers.models.attention_dispatch")
    m.FakeRegistry = {"backends": {}}
    print(f"  Created MOCK_V2: {id(m.FakeRegistry)}")
    return {
        "diffusers": types.ModuleType("diffusers"),
        "diffusers.models": types.ModuleType("diffusers.models"),
        "diffusers.models.attention_dispatch": m,
    }

# Create a REAL backend module file in memory
import tempfile
import os
import importlib.util

backend_code = '''
from diffusers.models.attention_dispatch import FakeRegistry
_BACKEND_REGISTERED = False

def register():
    global _BACKEND_REGISTERED
    _BACKEND_REGISTERED = True
    FakeRegistry["backends"]["test"] = "value"
'''

# Write to temp file
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, dir='/tmp') as f:
    f.write(backend_code)
    backend_file = f.name

spec = importlib.util.spec_from_file_location("my_real_backend", backend_file)
my_real_backend = importlib.util.module_from_spec(spec)
sys.modules["my_real_backend"] = my_real_backend

print("=== Real backend module persists across patch contexts ===")

print("\nFirst fixture (patch.dict context):")
with patch.dict(sys.modules, _mock_diffusers_v1()):
    spec.loader.exec_module(my_real_backend)
    print(f"  Imported FakeRegistry: {id(my_real_backend.FakeRegistry)}")
    my_real_backend.register()
    print(f"  After register: {my_real_backend.FakeRegistry}")

print(f"\nAfter first patch context exits:")
print(f"  my_real_backend still in sys.modules: {'my_real_backend' in sys.modules}")
print(f"  Cached FakeRegistry: {id(my_real_backend.FakeRegistry)}")
print(f"  Cached FakeRegistry content: {my_real_backend.FakeRegistry}")

print("\nSecond fixture (NEW patch.dict context with different mock):")
with patch.dict(sys.modules, _mock_diffusers_v2()):
    from diffusers.models.attention_dispatch import FakeRegistry as FakeRegistry2
    print(f"  New mock FakeRegistry: {id(FakeRegistry2)}")
    print(f"  Backend's OLD FakeRegistry: {id(my_real_backend.FakeRegistry)}")
    print(f"  MISMATCH: Backend still references old mock!")
    print(f"  Backend's old registry content: {my_real_backend.FakeRegistry}")
    print(f"  New mock registry content: {FakeRegistry2}")

os.unlink(backend_file)

EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 811


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | sed -n '100,150p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2100


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | grep -A 20 "def register_diffusers_eager_attention"

Repository: NVIDIA/Model-Optimizer

Length of output: 1161


🏁 Script executed:

cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | grep -A 20 "def get_skip_softmax_attention_backend"

Repository: NVIDIA/Model-Optimizer

Length of output: 581


🏁 Script executed:

cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '1,30p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1323


Reload backend modules on each fixture to ensure test isolation with fresh mocks.

The patch.dict(sys.modules, ...) context patches the module namespace, but diffusers_eager_attention and diffusers_triton_attention persist in sys.modules after the first import. On subsequent fixture runs, the cached modules retain references to the previous patch's mock objects, so resetting _BACKEND_REGISTERED alone does not provide a clean state. Pop and reimport the backend module before using it in each fixture.

Also applies to: 140-158

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around
lines 82 - 96, Tests are reusing cached backend modules across fixture runs
causing stale mocks; ensure you remove cached modules from sys.modules before
reimporting. Inside the patch.dict(sys.modules, _mock_diffusers()) context (and
likewise for the triton fixture), pop any existing
"modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention"
and
"modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention"
entries from sys.modules, then import the module and access
_diffusers_eager_attention, register_diffusers_eager_attention,
get_skip_softmax_attention_backend and reset mod._BACKEND_REGISTERED to
guarantee a fresh module and mocks for each fixture run.

yeyu-nvidia and others added 2 commits April 8, 2026 12:39
Move NVFP4 P-matrix quantization (quantize_p) out of the sparsity module
and into a new modelopt/torch/quantization/sage_attention/ module.

Key changes:
- Add modelopt/torch/quantization/sage_attention/__init__.py with
  apply_sage_attention(transformer) API exposed via mtq namespace.
  Wraps the transformer forward to activate the modelopt_triton diffusers
  backend and set quantize_p=True in thread-local for every call.

- Remove quantize_p from SparseAttentionAttributeConfig (config.py),
  TritonSkipSoftmaxMethod, and TritonSparseSoftmaxMethod — sparsity
  methods no longer control quantization.

- Split thread-local management in diffusers_triton_attention.py:
  * set_triton_skip_softmax_config() no longer accepts quantize_p
  * clear_triton_skip_softmax_config() does NOT reset quantize_p
  * New set_sage_attention_config() / clear_sage_attention_config()
    manage quantize_p independently
  This enables transparent composition: apply_sage_attention() sets
  quantize_p=True at the outer forward level; per-layer sparsity
  contexts clear only their own params without clobbering quantize_p.

- Delete plugins/diffusers.py (WanSparseAttentionModule) — superseded
  by PR #1166's diffusers_triton_attention.py backend approach.

- Update wan2_sage_attention.py example: apply_triton_sparse_kernel()
  no longer accepts quantize_p; --quantize-p now calls
  apply_sage_attention() from modelopt.torch.quantization.

- Update tests to reflect the new API boundaries.

Usage:
    from modelopt.torch.quantization import apply_sage_attention
    apply_sage_attention(pipe.transformer)  # standalone

    # combined with N:M sparse softmax:
    mtsa.sparsify(transformer, mtsa.SPARSE_SOFTMAX_DEFAULT)
    apply_sage_attention(transformer)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Adds --kernel nvfp4 as an explicit entry point for standalone NVFP4
P-matrix quantization via apply_sage_attention(), removing the previous
awkward implicit path through --quantize-p without a triton kernel.

Usage:
    python wan2_sage_attention.py --prompt "..." --kernel nvfp4
    python wan2_sage_attention.py --prompt "..." --kernel nvfp4 --compare
    python wan2_sage_attention.py --prompt "..." --kernel triton-sparse --quantize-p

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants