Add ModelOpt Triton attention kernels for WAN2.2 diffusion (sparse, skip-softmax, NVFP4)#1190
Add ModelOpt Triton attention kernels for WAN2.2 diffusion (sparse, skip-softmax, NVFP4)#1190yeyu-nvidia wants to merge 52 commits intomainfrom
Conversation
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds paged KV-cache and optional post-softmax quantization to the Triton flash-attention kernel; integrates Triton sparse-attention with Diffusers (WanAttention) and vLLM (paged KV); provides example scripts/workers for Wan2.2 and vLLM sparse-serving; and adds tests, calibration tooling, and plugin/registration plumbing. Changes
Sequence Diagram(s)sequenceDiagram
participant Model as Diffusers Model
participant WanModule as WanAttention
participant SparseModule as WanSparseAttentionModule
participant Processor as ModelOptWanAttnProcessor
participant TritonKernel as Triton FA Kernel
Model->>WanModule: forward()
WanModule->>SparseModule: delegate (installed)
SparseModule->>Processor: prepare Q/K/V, rotary, proj
Processor->>TritonKernel: call attention(paged KV, sparse_kw, quantize_p)
TritonKernel-->>Processor: attention output
Processor-->>SparseModule: return output
SparseModule-->>Model: continue forward
sequenceDiagram
participant Server as vLLM Server
participant Worker as SparseQuantWorker
participant Model as Loaded Model
participant Attn as vLLM Attention Module
participant Impl as ModelOptSparseAttentionImpl
participant TritonKernel as Triton FA Kernel
Server->>Worker: start / load_model()
Worker->>Model: load base model
Worker->>Attn: iterate named_modules()
Worker->>Impl: replace impl with ModelOptSparseAttentionImpl (paged KV)
Worker->>Worker: compile_or_warm_up_model (quant prolog if configured)
Model->>Attn: inference forward(metadata, caches)
Attn->>Impl: forward(...)
Impl->>TritonKernel: call attention(paged KV, block_table, sparse_kw)
TritonKernel-->>Impl: output
Impl-->>Attn: write output
Attn-->>Model: continue
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
142-151: Validatequantize_pcompatibility to avoid silent no-op configs.
quantize_pis documented as diffusers-triton-specific, but invalid combinations are currently accepted. Adding a compatibility validator will prevent confusing runtime behavior.♻️ Proposed validation
class SparseAttentionAttributeConfig(ModeloptBaseConfig): @@ `@model_validator`(mode="after") def validate_sparsity_n_vs_m(self): @@ return self + + `@model_validator`(mode="after") + def validate_quantize_p_compatibility(self): + if self.quantize_p and self.backend != "diffusers_triton": + raise ValueError( + "quantize_p=True requires backend='diffusers_triton'." + ) + if self.quantize_p and self.method not in ( + "triton_sparse_softmax", + "triton_skip_softmax", + ): + raise ValueError( + "quantize_p=True is only supported with " + "'triton_sparse_softmax' or 'triton_skip_softmax'." + ) + return self🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 142 - 151, Add a compatibility validator for the quantize_p ModeloptField so invalid combos are rejected: if quantize_p is True ensure the Triton-specific backend (e.g., backend == "diffusers_triton") is selected and reject/raise an error when used with non-diffusers_triton backends; also validate interactions with related flags (triton_sparse_softmax, triton_skip_softmax) to either allow or explicitly forbid combinations that make quantize_p a no-op, by adding the check in the same config class where quantize_p is defined and raising a clear validation error when the constraints are violated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/diffusers/quantization/wan2_sage_attention.py`:
- Around line 612-666: compute_clip_score defaults to device="cuda" and loads
the CLIP model directly onto that device which can OOM smaller GPUs; wrap the
CLIPProcessor/CLIPModel.from_pretrained calls (and any subsequent .to(device))
in a try/except that catches both OSError and RuntimeError, and on failure retry
loading the model on CPU (set device="cpu" for the retry), logging a warning;
ensure downstream tensors (text_inputs/img_inputs) are moved to the active
device variable and that clip_model is .to(device) only after successful load so
the function can gracefully fall back from GPU to CPU without crashing
(references: compute_clip_score, clip_model_id, device, processor, clip_model,
text_inputs, img_inputs).
- Around line 790-800: The help text for the "--skip-threshold" /
skip_softmax_threshold flag is incorrect: the implementation performs an
additive/log-space comparison using exp(tile_max - running_max) < lambda (i.e.,
compares the softmax ratio via exponent of the logit difference), not a
multiplicative comparison of raw logits; update the help string to describe that
a tile is skipped when exp(tile_max - running_max) is less than LAMBDA (softmax
ratio threshold), and mention this applies to the triton-skip /
triton-skip-nvfp4 kernels so users understand it's a log-space/additive
threshold.
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 72-85: The preset validation is too permissive:
_build_sparse_config currently accepts any module-level preset but
_replace_attention_impl only supports sparsity_* keys and
skip_softmax_threshold, so presets containing other fields (e.g.,
backend="pytorch" or missing a fixed skip_softmax_threshold like
SKIP_SOFTMAX_CALIB) silently degrade to dense attention. Update
_build_sparse_config (and the same logic at the other occurrence) to validate
the chosen preset: ensure the dict contains only supported keys (keys matching
"sparsity_" prefixes and/or "skip_softmax_threshold") and that if a softmax-skip
behavior is required it includes a fixed numeric skip_softmax_threshold; if the
preset includes unsupported keys or lacks the required numeric threshold, raise
ValueError telling the user to pick "default" or a compatible preset so
_replace_attention_impl and VLLMAttention won't be silently ignored.
In `@examples/vllm_serve/vllm_serve_sparse_attn.py`:
- Line 75: The PYTHONPATH assignment currently concatenates with ":" and can
produce a leading empty entry; update the logic that sets
os.environ["PYTHONPATH"] (the existing statement referencing repo_root) to build
the path using os.pathsep and only include non-empty components: read
os.environ.get("PYTHONPATH",""), split/filter out empty strings, append the
repo_root, then join with os.pathsep and assign back to os.environ["PYTHONPATH"]
so no empty/CWD entry is injected and it is platform portable.
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 536-538: The forward-only quantization for numerator p (guarded by
QUANTIZE_P and using _quantize_p_nvfp4) is not mirrored in backward: the
backward path still reconstructs unquantized softmax and ignores ctx.quantize_p,
producing incorrect gradients; fix by either (A) implementing a matching
backward branch that reconstructs/uses the same quantized representation and
dequantization logic when ctx.quantize_p is true (ensure grad path uses the
quantized p->dequant sequence and adjust gradient accumulation consistent with
tl.dot(p.to(v.dtype), v, acc)), or (B) explicitly reject/raise an error when
QUANTIZE_P/ctx.quantize_p is enabled and any of q.requires_grad,
k.requires_grad, or v.requires_grad is true (check these tensors in the forward
function and set ctx.quantize_p accordingly), so the mode is only allowed for
inference. Ensure you reference and update the forward use of
QUANTIZE_P/_quantize_p_nvfp4 and the backward branch that reads ctx.quantize_p.
- Around line 988-991: Only synthesize b_start_loc_k when actually in paged-KV
mode: compute a paged_mode boolean from the real indicators (e.g., presence of
block_table or missing v_cache) and if paged_mode is True set b_start_loc_k =
torch.zeros_like(b_start_loc); otherwise, do not substitute silently—raise an
explicit error (or assert) when b_start_loc_k is None so the contiguous K/V path
fails fast before later stride(...) / shape[...] access. Reference
b_start_loc_k, b_start_loc, k_cache, v_cache, and block_table when making this
conditional.
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 63-66: The early return when "diffusers_triton" is present can
skip `_attn_implementation` setup for other sparse backends; update the check
around the `if "diffusers_triton" in backends: return` (in conversion.py where
`backends` is evaluated and ModelOptWanAttnProcessor is mentioned) to fail fast
if `diffusers_triton` is combined with any other backend (e.g., if
"diffusers_triton" in backends and len(backends) > 1: raise a ValueError with a
clear message) so that mixed configurations are rejected instead of silently
skipping attention registration.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py`:
- Around line 81-132: The Triton path (_wan_forward_triton) currently ignores
the attention_mask causing different semantics than the SDPA path; add an early
guard in _wan_forward_triton that checks if attention_mask is not None and if so
immediately calls and returns self._wan_forward_sdpa(...) with the same
arguments (including attention_mask) to preserve fallback behavior; ensure the
call uses the same parameters and return value so signatures/outputs remain
consistent.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 142-151: Add a compatibility validator for the quantize_p
ModeloptField so invalid combos are rejected: if quantize_p is True ensure the
Triton-specific backend (e.g., backend == "diffusers_triton") is selected and
reject/raise an error when used with non-diffusers_triton backends; also
validate interactions with related flags (triton_sparse_softmax,
triton_skip_softmax) to either allow or explicitly forbid combinations that make
quantize_p a no-op, by adding the check in the same config class where
quantize_p is defined and raising a clear validation error when the constraints
are violated.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a409c49f-7309-45c3-89ed-693c7124c99d
📒 Files selected for processing (13)
examples/diffusers/quantization/wan2_sage_attention.pyexamples/vllm_serve/sparse_attn_worker.pyexamples/vllm_serve/vllm_serve_sparse_attn.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/quantization/plugins/diffusion/diffusers.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/plugins/__init__.pymodelopt/torch/sparsity/attention_sparsity/plugins/diffusers.pymodelopt/torch/sparsity/attention_sparsity/plugins/vllm.pypyproject.tomltests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.pytests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py
modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
Outdated
Show resolved
Hide resolved
|
|
||
| [tool.ruff.lint.isort] | ||
| known-first-party = ["modelopt"] | ||
| known-third-party = ["vllm"] |
There was a problem hiding this comment.
why is this needed?
There was a problem hiding this comment.
vLLM does not follow standard Python package naming conventions — ruff's isort treats it as a first-party import without this entry, which causes import ordering lint failures in the vLLM worker files. Adding it to known-third-party keeps isort from re-classifying it incorrectly.
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") |
There was a problem hiding this comment.
Does our current CICD env skip this test or run it?
There was a problem hiding this comment.
The test lives in tests/gpu/ which runs in the GPU CI environment (python -m pytest tests/gpu). The test itself also has @pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton") as an additional guard, so it will be skipped gracefully on any runner without a CUDA GPU or Triton installed.
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (1)
examples/vllm_serve/sparse_attn_worker.py (1)
146-166:⚠️ Potential issue | 🟠 MajorMake
methoddrivesparse_kw.At Line 148,
methodis only validated and then discarded.ModelOptSparseAttentionImpl.forward()only consumesself.sparse_kwinmodelopt/torch/sparsity/attention_sparsity/plugins/vllm.py:64-66, so atriton_skip_softmaxconfig with no numeric threshold still patches the layer with empty kwargs, and mixed configs can pass bothsparsity_*andskip_softmax_threshold, selecting a different path than the one you just validated. Build kwargs frommethodand fail fast when that method’s required parameters are missing.🛠️ Proposed fix
# Build per-layer sparse kwargs sparse_kw = {} - sparsity_n = layer_cfg.get("sparsity_n", 0) - if sparsity_n > 0: - sparse_kw["sparsity_n"] = sparsity_n - sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) - sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) - sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) - threshold = layer_cfg.get("skip_softmax_threshold") - if threshold: - sparse_kw["skip_softmax_threshold"] = threshold + if method == "triton_sparse_softmax": + sparsity_n = layer_cfg.get("sparsity_n") + if not isinstance(sparsity_n, int) or sparsity_n <= 0: + raise ValueError(f"{name}: triton_sparse_softmax requires a positive sparsity_n") + + sparse_kw["sparsity_n"] = sparsity_n + sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4) + sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0) + sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1) + else: + threshold = layer_cfg.get("skip_softmax_threshold") + if not isinstance(threshold, (int, float)): + raise ValueError( + f"{name}: triton_skip_softmax requires a numeric skip_softmax_threshold" + ) + sparse_kw["skip_softmax_threshold"] = float(threshold)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/sparse_attn_worker.py` around lines 146 - 166, The code currently validates method but doesn't let method drive sparse_kw; update the logic in sparse_attn_worker.py so that method determines which keys are populated and that required params are present: if method == "triton_sparse_softmax" require sparsity_n > 0 and populate only sparsity_n, sparsity_m, num_sink_tokens, dense_window_size into sparse_kw (fail fast with ValueError if missing); if method == "triton_skip_softmax" require skip_softmax_threshold to be set and populate only skip_softmax_threshold into sparse_kw (fail fast if missing); keep references to method, layer_cfg, sparse_kw and ensure compatibility with ModelOptSparseAttentionImpl.forward() which reads self.sparse_kw in modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py.
🧹 Nitpick comments (2)
examples/vllm_serve/sparse_attn_worker.py (1)
72-123: Add concrete types at the worker/config boundary.These helpers still use bare
dictand an untypedworker, so mypy can't validate the config shape or themodel_runner.model/unwrap()contract on this new extension path. A smallTypedDict+Protocolhere would turn several runtime-only failures into lint errors.As per coding guidelines, "Ensure type hints are properly annotated for static type checking with mypy".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/vllm_serve/sparse_attn_worker.py` around lines 72 - 123, Introduce concrete typing at the worker/config boundary: define a TypedDict (e.g., SparseLayerCfg and SparseConfigTypedDict) that describes the expected keys/values returned by _build_sparse_config/_load_sparse_config, and a Protocol (e.g., ModelRunnerProtocol or WorkerWithModel) that specifies the model_runner.model and unwrap() shape used by _replace_attention_impl; then update signatures for _build_sparse_config, _load_sparse_config, _match_sparse_config, and _replace_attention_impl to use these types (replace bare dict and untyped worker), and cast/validate where needed so mypy can statically check config shape and the model.unwrap() contract referenced in _replace_attention_impl.examples/diffusers/quantization/wan2_sage_attention.py (1)
425-431: Complete the public type annotations.
attention_kernel_ctx,load_pipeline, andrun_inferencestill leave important parts of the public surface asAny, which weakens mypy coverage for a new example module. As per coding guidelines, "Ensure type hints are properly annotated for static type checking with mypy".Also applies to: 682-694
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 425 - 431, Annotate the public API precisely: change attention_kernel_ctx to return a ContextManager[None] (import ContextManager from typing) instead of Any; update load_pipeline to have typed params (e.g., ckpt: str, dtype: torch.dtype, device: torch.device | str) and return a DiffusionPipeline (import DiffusionPipeline from diffusers) rather than Any; update run_inference to type its parameters (pipeline: DiffusionPipeline, prompt: str | Sequence[str], num_inference_steps: int, generator: Optional[torch.Generator], etc.) and give it a concrete return type such as Dict[str, Any] or a more specific result type, importing Optional, Sequence, Dict from typing and torch types as needed so mypy can check these public symbols (attention_kernel_ctx, load_pipeline, run_inference).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/diffusers/quantization/wan2_sage_attention.py`:
- Around line 504-531: In apply_triton_sparse_kernel, validate any provided
skip_threshold when the kernel is KERNEL_TRITON_SKIP or
KERNEL_TRITON_SKIP_NVFP4: if skip_threshold is not None and skip_threshold <= 0,
raise a ValueError with a clear message (e.g. "skip_threshold must be > 0")
before modifying config["sparse_cfg"]["*"]["skip_softmax_threshold"]; otherwise
proceed to set the value as currently implemented. Ensure you reference the
function apply_triton_sparse_kernel and the constants KERNEL_TRITON_SKIP /
KERNEL_TRITON_SKIP_NVFP4 so the check is only applied for skip-softmax kernels.
- Around line 756-760: Move the kernel availability check to immediately after
argument parsing to fail fast: after parse_args(), verify args.kernel is present
in AVAILABLE_KERNELS (and/or that Triton/SageAttention and required PyTorch
version are available) and raise/exit with a clear message if not; remove or
keep only redundant checks later. Specifically, perform this validation before
constructing/loading the WAN pipeline (the code that currently loads the large
model) so that calls like apply_triton_sparse_kernel() and
enable_attention_kernel() are only reached when args.kernel is actually
supported; reference the args.kernel variable, AVAILABLE_KERNELS,
apply_triton_sparse_kernel(), and enable_attention_kernel() when making the
change.
- Around line 591-595: The current PSNR calculation in psnr_per_frame uses
np.where with mse_per_frame which causes the divide-by-zero branch to be
evaluated eagerly; fix by clamping mse_per_frame to a small positive floor
before the division (e.g., denom = np.maximum(mse_per_frame, 1e-10)) and compute
the PSNR using that denom, then optionally use np.where to override values for
truly-zero MSE if you need a specific constant; update the code around
psnr_per_frame and any direct uses of mse_per_frame to use the clamped
denominator to avoid warnings.
- Around line 352-385: The _patched_sdpa function is dropping caller kwargs on
both fallback paths; update both calls to _orig_sdpa (the early fallback when
attn_mask/dropout/dtype unsupported and the exception fallback in the except
block) to forward the original **kwargs so caller flags like enable_gqa (and
future SDPA kwargs) are preserved; keep the existing explicit args (query, key,
value, attn_mask, dropout_p, is_causal, scale) and add **kwargs to those
_orig_sdpa invocations to pass through any additional options.
---
Duplicate comments:
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 146-166: The code currently validates method but doesn't let
method drive sparse_kw; update the logic in sparse_attn_worker.py so that method
determines which keys are populated and that required params are present: if
method == "triton_sparse_softmax" require sparsity_n > 0 and populate only
sparsity_n, sparsity_m, num_sink_tokens, dense_window_size into sparse_kw (fail
fast with ValueError if missing); if method == "triton_skip_softmax" require
skip_softmax_threshold to be set and populate only skip_softmax_threshold into
sparse_kw (fail fast if missing); keep references to method, layer_cfg,
sparse_kw and ensure compatibility with ModelOptSparseAttentionImpl.forward()
which reads self.sparse_kw in
modelopt/torch/sparsity/attention_sparsity/plugins/vllm.py.
---
Nitpick comments:
In `@examples/diffusers/quantization/wan2_sage_attention.py`:
- Around line 425-431: Annotate the public API precisely: change
attention_kernel_ctx to return a ContextManager[None] (import ContextManager
from typing) instead of Any; update load_pipeline to have typed params (e.g.,
ckpt: str, dtype: torch.dtype, device: torch.device | str) and return a
DiffusionPipeline (import DiffusionPipeline from diffusers) rather than Any;
update run_inference to type its parameters (pipeline: DiffusionPipeline,
prompt: str | Sequence[str], num_inference_steps: int, generator:
Optional[torch.Generator], etc.) and give it a concrete return type such as
Dict[str, Any] or a more specific result type, importing Optional, Sequence,
Dict from typing and torch types as needed so mypy can check these public
symbols (attention_kernel_ctx, load_pipeline, run_inference).
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 72-123: Introduce concrete typing at the worker/config boundary:
define a TypedDict (e.g., SparseLayerCfg and SparseConfigTypedDict) that
describes the expected keys/values returned by
_build_sparse_config/_load_sparse_config, and a Protocol (e.g.,
ModelRunnerProtocol or WorkerWithModel) that specifies the model_runner.model
and unwrap() shape used by _replace_attention_impl; then update signatures for
_build_sparse_config, _load_sparse_config, _match_sparse_config, and
_replace_attention_impl to use these types (replace bare dict and untyped
worker), and cast/validate where needed so mypy can statically check config
shape and the model.unwrap() contract referenced in _replace_attention_impl.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: bde9eacb-58d5-4815-92e3-46606b7fa475
📒 Files selected for processing (5)
examples/diffusers/quantization/wan2_sage_attention.pyexamples/vllm_serve/sparse_attn_worker.pyexamples/vllm_serve/vllm_serve_sparse_attn.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
✅ Files skipped from review due to trivial changes (1)
- examples/vllm_serve/vllm_serve_sparse_attn.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py
| def _patched_sdpa( | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: torch.Tensor | None = None, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = False, | ||
| scale: float | None = None, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| global _sage_calls, _fallback_calls | ||
| # Fall back to standard SDPA for unsupported cases | ||
| if ( | ||
| attn_mask is not None | ||
| or dropout_p > 0.0 | ||
| or query.dtype not in (torch.float16, torch.bfloat16) | ||
| ): | ||
| _fallback_calls += 1 | ||
| return _orig_sdpa( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask=attn_mask, | ||
| dropout_p=dropout_p, | ||
| is_causal=is_causal, | ||
| scale=scale, | ||
| ) | ||
|
|
||
| _sage_calls += 1 | ||
| try: | ||
| return _run_kernel(query, key, value, is_causal=is_causal, scale=scale) | ||
| except (AssertionError, RuntimeError) as e: | ||
| print(f"[Attention] WARNING: kernel={_active_kernel!r} failed ({e}). Falling back to SDPA.") | ||
| return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n examples/diffusers/quantization/wan2_sage_attention.py | head -400 | tail -100Repository: NVIDIA/Model-Optimizer
Length of output: 3695
🏁 Script executed:
# Check the _orig_sdpa assignment and context
rg "_orig_sdpa|_patched_sdpa" examples/diffusers/quantization/wan2_sage_attention.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 1229
🏁 Script executed:
# Check torch.nn.functional.scaled_dot_product_attention signature
python3 << 'EOF'
import torch
import inspect
sig = inspect.signature(torch.nn.functional.scaled_dot_product_attention)
print("Signature:", sig)
print("\nParameters:")
for name, param in sig.parameters.items():
print(f" {name}: {param.kind.name} = {param.default}")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 183
🏁 Script executed:
# Check PyTorch documentation and what parameters F.scaled_dot_product_attention supports
# by searching for any documentation or comments in the codebase
rg "scaled_dot_product_attention" examples/diffusers/quantization/wan2_sage_attention.py -B 5 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 1390
🏁 Script executed:
# Check if there are any other kwargs passed at call sites
rg "scaled_dot_product_attention\|_patched_sdpa" --type py -r "." | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Search for pytorch version requirements to determine available parameters
cat pyproject.toml | grep -A 5 -B 5 torchRepository: NVIDIA/Model-Optimizer
Length of output: 2204
🏁 Script executed:
# Check if there are any call sites in the example that use _patched_sdpa with extra kwargs
rg "_patched_sdpa\|F\.scaled_dot_product_attention" examples/diffusers/quantization/wan2_sage_attention.py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check the full context around patching to see how it's used
sed -n '388,430p' examples/diffusers/quantization/wan2_sage_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1420
🏁 Script executed:
# Check if there are any test files or usage examples that call SDPA
find . -name "*test*" -type f | xargs grep -l "scaled_dot_product_attention\|_patched_sdpa" 2>/dev/null | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 251
🏁 Script executed:
# Check the test files to understand how scaled_dot_product_attention is called
cat ./tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 3861
🏁 Script executed:
# Check the attention quantization test
cat ./tests/unit/torch/quantization/plugins/test_attention_quant.py | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 5720
🏁 Script executed:
# Look for any documentation or comments about SDPA kwargs in PyTorch
web_search
PyTorch scaled_dot_product_attention signature parameters kwargs torch 2.6Repository: NVIDIA/Model-Optimizer
Length of output: 162
🌐 Web query:
PyTorch torch.nn.functional.scaled_dot_product_attention signature parameters 2.6
💡 Result:
In PyTorch 2.6, the function signature is:
torch.nn.functional.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
enable_gqa=False
) -> TensorParameters (with defaults): query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None (keyword-only), enable_gqa=False. [1]
Source
[1] PyTorch 2.6 docs page for torch.nn.functional.scaled_dot_product_attention https://pytorch.com.tw/docs/2.6/generated/torch.nn.functional.scaled_dot_product_attention.html
🏁 Script executed:
# Check if any code in the examples or tests calls SDPA/patched_sdpa with kwargs like enable_gqa
rg "scaled_dot_product_attention|_patched_sdpa|F\.sdpa" examples/ tests/ -A 2 | grep -E "(enable_gqa|enable_flash_|dropout_p|**kwargs)" | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 145
🏁 Script executed:
# Search more broadly for diffusers usage and how SDPA is called in the pipeline
grep -r "scaled_dot_product_attention" examples/diffusers/ -B 2 -A 2 2>/dev/null || echo "No direct calls found"Repository: NVIDIA/Model-Optimizer
Length of output: 2572
🏁 Script executed:
# Check the full wan2_sage_attention.py file to see if there are any actual kwargs being passed to SDPA anywhere
grep -E "scaled_dot_product_attention|attention\(" examples/diffusers/quantization/wan2_sage_attention.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 383
Forward caller kwargs when SDPA falls back.
_patched_sdpa() accepts **kwargs but both fallback paths (lines 370–378 and 385) drop them before calling _orig_sdpa(). This causes silent semantic changes: if a caller passes enable_gqa=True (or any other supported kwarg like future PyTorch SDPA extensions) and the kernel fails or early fallback triggers, the fallback ignores those parameters instead of honoring them.
Suggested fix
def _patched_sdpa(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@@
if (
attn_mask is not None
or dropout_p > 0.0
or query.dtype not in (torch.float16, torch.bfloat16)
):
_fallback_calls += 1
return _orig_sdpa(
query,
key,
value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
+ **kwargs,
)
@@
except (AssertionError, RuntimeError) as e:
print(f"[Attention] WARNING: kernel={_active_kernel!r} failed ({e}). Falling back to SDPA.")
- return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale)
+ return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 352 -
385, The _patched_sdpa function is dropping caller kwargs on both fallback
paths; update both calls to _orig_sdpa (the early fallback when
attn_mask/dropout/dtype unsupported and the exception fallback in the except
block) to forward the original **kwargs so caller flags like enable_gqa (and
future SDPA kwargs) are preserved; keep the existing explicit args (query, key,
value, attn_mask, dropout_p, is_causal, scale) and add **kwargs to those
_orig_sdpa invocations to pass through any additional options.
| def apply_triton_sparse_kernel( | ||
| transformer: torch.nn.Module, | ||
| kernel: str, | ||
| skip_threshold: float | None = None, | ||
| ) -> None: | ||
| """Apply a ModelOpt Triton sparse attention kernel to the WAN transformer. | ||
|
|
||
| Calls ``mtsa.sparsify()`` with ``backend="diffusers_triton"``, which installs | ||
| a ``ModelOptWanAttnProcessor`` on every ``WanAttention`` module. The NVFP4 | ||
| variants additionally pass ``quantize_p=True`` to the Triton kernel, enabling | ||
| per-tile NVFP4 E2M1 P-matrix quantization in a single fused pass. | ||
|
|
||
| This modifies the model in-place. | ||
|
|
||
| Args: | ||
| transformer: The ``pipe.transformer`` WAN model. | ||
| kernel: One of the ``KERNEL_TRITON_*`` constants. | ||
| skip_threshold: Override ``skip_softmax_threshold`` for skip-softmax kernels. | ||
| ``None`` uses the kernel's built-in default. | ||
| Lower = better quality, less speedup. Typical range: 0.001–0.1. | ||
| """ | ||
| import copy | ||
|
|
||
| import modelopt.torch.sparsity.attention_sparsity as mtsa | ||
|
|
||
| config = copy.deepcopy(_TRITON_KERNEL_CONFIGS[kernel]) | ||
| if skip_threshold is not None and kernel in (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4): | ||
| config["sparse_cfg"]["*"]["skip_softmax_threshold"] = skip_threshold |
There was a problem hiding this comment.
Reject non-positive skip_threshold values inside the helper.
The skip-softmax path now depends on log2(skip_softmax_threshold). Passing 0 or a negative override should fail here with a clear ValueError instead of surfacing later in the Triton path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 504 -
531, In apply_triton_sparse_kernel, validate any provided skip_threshold when
the kernel is KERNEL_TRITON_SKIP or KERNEL_TRITON_SKIP_NVFP4: if skip_threshold
is not None and skip_threshold <= 0, raise a ValueError with a clear message
(e.g. "skip_threshold must be > 0") before modifying
config["sparse_cfg"]["*"]["skip_softmax_threshold"]; otherwise proceed to set
the value as currently implemented. Ensure you reference the function
apply_triton_sparse_kernel and the constants KERNEL_TRITON_SKIP /
KERNEL_TRITON_SKIP_NVFP4 so the check is only applied for skip-softmax kernels.
| psnr_per_frame = np.where( | ||
| mse_per_frame < 1e-10, | ||
| 100.0, | ||
| 10.0 * np.log10(255.0**2 / mse_per_frame), | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find and verify the file exists
fd wan2_sage_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 120
🏁 Script executed:
# Once we know the file exists, let's examine the exact code at lines 591-595
cat -n examples/diffusers/quantization/wan2_sage_attention.py | sed -n '585,600p'Repository: NVIDIA/Model-Optimizer
Length of output: 697
🌐 Web query:
numpy np.where() eager evaluation both branches behavior
💡 Result:
np.where(cond, x, y) is not lazy / short‑circuiting. In typical usage like:
np.where(cond, f(a), g(a))both f(a) and g(a) are computed before np.where runs, because they are evaluated as normal Python function arguments; np.where then just selects elements from the already-computed x and y arrays. This is why you can still get warnings/errors (e.g., divide-by-zero, sqrt of negatives) even where the condition would “exclude” those elements. [1][2][3]
The NumPy docs’ note that, for 1‑D arrays, where is “equivalent to” a Python conditional list-comprehension is about the resulting selection behavior, not about lazily evaluating the expressions you pass in for x and y. [4]
What to do instead (to avoid evaluating the “bad” branch):
- Prefer a ufunc with a
where=mask andout=, so computation is only performed where allowed, e.g.np.divide(a, b, out=out, where=mask)(avoids computing division wheremaskis False). [5] - Similarly, many elementwise operations can be done as ufuncs with
where=(and oftenout=) instead ofnp.where(...). [3][5]
Sources: [1] np.where() eager evaluation discussion (Stack Overflow) [2] np.where computing unused values (Stack Overflow) [3] np.where can’t avoid sqrt(negative) evaluation (Stack Overflow) [4] NumPy numpy.where documentation (NumPy manual) [5] Using np.divide(..., where=...) to avoid divide-by-zero warnings (Stack Overflow)
The PSNR zero-MSE guard still evaluates the divide-by-zero branch.
np.where() eagerly evaluates both the true and false branches before selecting values, so identical frames (where mse_per_frame < 1e-10) still compute 255.0**2 / mse_per_frame, triggering divide-by-zero warnings. Clamp the denominator first to avoid this.
Suggested fix
- psnr_per_frame = np.where(
- mse_per_frame < 1e-10,
- 100.0,
- 10.0 * np.log10(255.0**2 / mse_per_frame),
- )
+ safe_mse = np.maximum(mse_per_frame, 1e-10)
+ psnr_per_frame = 10.0 * np.log10(255.0**2 / safe_mse)
+ psnr_per_frame = np.where(mse_per_frame < 1e-10, 100.0, psnr_per_frame)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 591 -
595, The current PSNR calculation in psnr_per_frame uses np.where with
mse_per_frame which causes the divide-by-zero branch to be evaluated eagerly;
fix by clamping mse_per_frame to a small positive floor before the division
(e.g., denom = np.maximum(mse_per_frame, 1e-10)) and compute the PSNR using that
denom, then optionally use np.where to override values for truly-zero MSE if you
need a specific constant; update the code around psnr_per_frame and any direct
uses of mse_per_frame to use the clamped denominator to avoid warnings.
| "--kernel", | ||
| type=str, | ||
| default=KERNEL_FP8, | ||
| choices=KERNEL_CHOICES, | ||
| help=( |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
wc -l examples/diffusers/quantization/wan2_sage_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 124
🏁 Script executed:
sed -n '750,770p' examples/diffusers/quantization/wan2_sage_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1157
🏁 Script executed:
sed -n '810,825p' examples/diffusers/quantization/wan2_sage_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 501
🏁 Script executed:
# Search for KERNEL_CHOICES and AVAILABLE_KERNELS definitions
rg -n "KERNEL_CHOICES|AVAILABLE_KERNELS" examples/diffusers/quantization/wan2_sage_attention.py -A 3 -B 1Repository: NVIDIA/Model-Optimizer
Length of output: 1818
🏁 Script executed:
# Search for where kernel validation happens
rg -n "kernel" examples/diffusers/quantization/wan2_sage_attention.py | grep -i "valid\|check\|if"Repository: NVIDIA/Model-Optimizer
Length of output: 1115
🏁 Script executed:
sed -n '815,835p' examples/diffusers/quantization/wan2_sage_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 943
🏁 Script executed:
# Find and show apply_triton_sparse_kernel function
rg -n "def apply_triton_sparse_kernel" examples/diffusers/quantization/wan2_sage_attention.py -A 20Repository: NVIDIA/Model-Optimizer
Length of output: 1071
🏁 Script executed:
# Check what _detect_available_kernels does
rg -n "def _detect_available_kernels" examples/diffusers/quantization/wan2_sage_attention.py -A 30Repository: NVIDIA/Model-Optimizer
Length of output: 1302
🏁 Script executed:
sed -n '183,235p' examples/diffusers/quantization/wan2_sage_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1794
🏁 Script executed:
# Check _TRITON_MODELOPT_KERNELS definition
rg -n "_TRITON_MODELOPT_KERNELS" examples/diffusers/quantization/wan2_sage_attention.py -B 2 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 2608
Add kernel availability validation before loading the WAN pipeline.
The --kernel argument accepts any name in KERNEL_CHOICES via argparse, but this only validates the name string—not whether the required dependencies (Triton, SageAttention, PyTorch version) are actually available. The multi-GB model loads at line 820 before kernel validation occurs. The Triton branch (lines 824–825, 898–899) calls apply_triton_sparse_kernel() without checking AVAILABLE_KERNELS, while the non-Triton branch only validates via enable_attention_kernel() after the model is already in memory. On a missing Triton or SageAttention setup, this loads the pipeline unnecessarily before failing.
Move the args.kernel in AVAILABLE_KERNELS check to immediately after parse_args() to fail fast without wasting resources.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 756 -
760, Move the kernel availability check to immediately after argument parsing to
fail fast: after parse_args(), verify args.kernel is present in
AVAILABLE_KERNELS (and/or that Triton/SageAttention and required PyTorch version
are available) and raise/exit with a clear message if not; remove or keep only
redundant checks later. Specifically, perform this validation before
constructing/loading the WAN pipeline (the code that currently loads the large
model) so that calls like apply_triton_sparse_kernel() and
enable_attention_kernel() are only reached when args.kernel is actually
supported; reference the args.kernel variable, AVAILABLE_KERNELS,
apply_triton_sparse_kernel(), and enable_attention_kernel() when making the
change.
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
modelopt/torch/kernels/triton_fa.py (1)
309-315: Update thequantize_pdocs to match the new runtime guard.Both docstrings still describe an STE/backward path, but Lines 1021-1024 now make
quantize_pinference-only.Suggested doc update
- quantize_p: If ``True``, quantize the post-softmax p tile to NVFP4 - E2M1 before the p @ V matmul (per-tile max scaling, STE in - backward). Default ``False``. + quantize_p: If ``True``, quantize the post-softmax p tile to NVFP4 + E2M1 before the p @ V matmul. This mode is inference-only and + raises if autograd is required. Default ``False``.Also applies to: 1294-1296
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 309 - 315, The docstring for quantize_p incorrectly describes a straight-through estimator and backward behavior even though the function is now guarded to be inference-only at runtime; update the quantize_p docstring (and the other similar docstring instance) to remove any mention of STE/backward passthrough and instead state that this implementation performs post-softmax tile quantization for inference only (describe scaling and level mapping briefly and note no gradient/_backward behavior is applied at runtime).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 973-980: The code incorrectly infers paged layout from the
placeholder k tensor and an independent page_size; instead, when is_paged
(k_cache is not None) derive and validate paged metadata from the cache tensors
themselves: read num_kv_heads from k_cache.shape[1] (and page_size / page_length
from the cache dim that represents pages), recompute kv_group_num from
num_q_heads // num_kv_heads, and ignore k.shape[1]/page_size inputs; also verify
k_cache and v_cache shapes match and that computed kv_group_num divides
num_q_heads and matches any b_seq_len-derived batch logic (validate and
raise/error if mismatched). Apply the same change to the other blocks referenced
around the file (the regions near the existing checks at lines ~988-997 and
~1077-1084) so all paged launches use cache-derived layout and perform
consistency checks before kernel launch.
- Around line 979-980: Add a guard that rejects paged-KV when autograd is
enabled: where is_paged = k_cache is not None is computed, check
torch.is_grad_enabled() (or equivalent autograd check) and raise a clear
RuntimeError explaining that paged k/v/block_table is incompatible with autograd
because gradients are rebuilt from contiguous k/v; make this change for both
occurrences around the is_paged checks (the block using
k_cache/v_cache/block_table near the is_paged assignment and the other spot at
the 1021-1024 section) so callers get an explicit error instead of
incorrect/missing gradients.
---
Nitpick comments:
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 309-315: The docstring for quantize_p incorrectly describes a
straight-through estimator and backward behavior even though the function is now
guarded to be inference-only at runtime; update the quantize_p docstring (and
the other similar docstring instance) to remove any mention of STE/backward
passthrough and instead state that this implementation performs post-softmax
tile quantization for inference only (describe scaling and level mapping briefly
and note no gradient/_backward behavior is applied at runtime).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 567259d0-7c54-45c7-8d57-f53403768c13
📒 Files selected for processing (1)
modelopt/torch/kernels/triton_fa.py
| is_paged = k_cache is not None | ||
|
|
There was a problem hiding this comment.
Reject paged-KV when autograd is enabled.
Lines 1109-1236 still rebuild gradients from contiguous k/v and return no gradients for k_cache/v_cache/block_table. In paged mode that makes autograd incorrect, because the forward path no longer reads those contiguous tensors.
Proposed guard
if quantize_p and (q.requires_grad or k.requires_grad or v.requires_grad):
raise NotImplementedError(
"quantize_p supports inference only; backward does not model the quantized P path"
)
+ if is_paged and (
+ q.requires_grad
+ or k.requires_grad
+ or v.requires_grad
+ or k_cache.requires_grad
+ or v_cache.requires_grad
+ ):
+ raise NotImplementedError(
+ "paged KV cache supports inference only; backward does not model k_cache/v_cache"
+ )Also applies to: 1021-1024
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 979 - 980, Add a guard that
rejects paged-KV when autograd is enabled: where is_paged = k_cache is not None
is computed, check torch.is_grad_enabled() (or equivalent autograd check) and
raise a clear RuntimeError explaining that paged k/v/block_table is incompatible
with autograd because gradients are rebuilt from contiguous k/v; make this
change for both occurrences around the is_paged checks (the block using
k_cache/v_cache/block_table near the is_paged assignment and the other spot at
the 1021-1024 section) so callers get an explicit error instead of
incorrect/missing gradients.
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
…ling - Add --clip-model arg for local CLIP model path (avoids HF rate limits) - Read HF_TOKEN / HUGGING_FACE_HUB_TOKEN env vars for authenticated downloads - Catch OSError from rate limits / network errors with a helpful message - Add `import os` for env var access Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Implements Python-level NVFP4 E2M1 quantization (SA3-inspired): - Channel-wise mean smoothing before quantization - Per-token scale to NVFP4 range (max 6.0) - Round to nearest of 8 representable levels via torch.bucketize - No CUDA kernel required, works on any GPU Use --kernel nvfp4 to compare accuracy vs FP8 and baseline. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
The CLIP score range depends on the model/prompt distribution; remove the hardcoded "0.25-0.35" and instead tell users to focus on the baseline vs quantized delta. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Quantizing V to FP4 causes severe color/texture loss because V carries the actual content that gets weighted and summed into the output. SA3 applies NVFP4 to the post-softmax attention probability matrix P, not to V directly. V stays in BF16 for the P*V multiply. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
V quantized to NVFP4 (8 levels) caused visible color/texture degradation. SA3 paper applies NVFP4 only to post-softmax P matrix; V stays in higher precision. Use FP8 E4M3 (448 levels) for V to match SA2-style V handling while keeping NVFP4 for Q/K. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
FP8 V still causes color/texture degradation. V carries actual content that gets weighted and summed — any quantization noise is directly visible. Keep V in BF16, matching SA3 design where only Q/K (routing weights) are quantized to NVFP4. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Previous approach quantized Q and K to NVFP4, which caused visible quality degradation. SA3 (arXiv 2505.11594) applies NVFP4 to the post-softmax P matrix, not to Q/K/V. Implement manual attention (Q@K^T -> scale -> mask -> softmax -> NVFP4(P) -> P@V) with per-row max-based scaling of P to NVFP4 range [0, 6]. Q, K, V stay in BF16. This faithfully simulates SA3 quantization behavior for accuracy measurement. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
NVFP4 (8 levels) caused generation trajectory divergence on WAN2.2 video diffusion due to dense attention patterns. Replace with unsigned INT4 (16 levels, 0..15) per-row scaled quantization of the post-softmax P matrix — same SA3-inspired approach (quantize P not Q/K/V) but with 2x more precision levels to better cover dense probability distributions. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Previous implementations used per-row scaling, causing trajectory divergence in video diffusion: a single row-max outlier dominated the scale, wasting most of the 8 FP4 levels on near-zero values. SA3 quantizes P at flash-attention tile granularity (per-tile, not per-row). Implement per-tile NVFP4 E2M1 of the post-softmax P matrix: reshape P into 64x64 tiles, compute per-tile max scale, quantize to nearest NVFP4 level, dequantize back. Each local tile's value range is fully covered by the 8 levels. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Materialising the full N×N float32 attention matrix for WAN2.2 (~8k tokens, 24 heads) requires ~6 GB, causing OOM on 48 GB GPUs already loaded with the model. Process Q in 512-row chunks: each chunk produces a (512 x N) P slice, which is quantized per-tile to NVFP4 and accumulated into the output. Peak memory for the attention matrix drops from O(N^2) to O(chunk x N). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Smaller tiles give finer-grained scales and better accuracy, at the cost of more overhead (acceptable for Python-level simulation). Add --nvfp4-tile CLI arg in TRxTC format (e.g. '1x64', '32x32', '64x64') to control the tile shape for per-tile NVFP4 quantization of the post-softmax P matrix. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Implements a ModelOpt sparse attention plugin for diffusers WAN models, building on the triton_fa kernel infrastructure from PR #1127. New files: - modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py - ModelOptWanAttnProcessor: replaces WanAttnProcessor, calls triton_fa.attention() directly with BSND->varlen conversion. Supports I2V cross-attention path and N:M/skip-softmax sparsity. - WanSparseAttentionModule: subclasses SparseAttentionModule, installs the Triton processor and syncs enabled state on each forward. - register_wan_sparse_attention(): plugin callback auto-registered in CUSTOM_MODEL_PLUGINS; fires during mtsa.sparsify(). Updated files: - plugins/__init__.py: lazy-import diffusers plugin via import_plugin() - config.py: add "diffusers_triton" to validate_backend whitelist - conversion.py: skip HF attn registration for "diffusers_triton" backend - wan2_sage_attention.py: add triton-sparse and triton-skip kernel options backed by mtsa.sparsify() with diffusers_triton backend Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Covers: - "diffusers_triton" backend config validation - WanAttention registration with SparseAttentionRegistry - Module and processor type after mtsa.sparsify() - Correct sparse_kw populated for triton_sparse_softmax and triton_skip_softmax - Forward output shape (via SDPA fallback on CPU) - enable/disable flag propagation into ModelOptWanAttnProcessor Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
WanSparseAttentionModule.forward() syncs proc._enabled from self.is_enabled on every call, so directly setting proc._enabled was immediately overwritten. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
- Add _QuantWanAttnProcessor: manual Q@K^T->softmax->softmax_quantizer->@v with BSND layout, I2V cross-attention support, and q/k/v_bmm_quantizer hooks - Add _QuantWanAttention(_QuantAttentionModuleMixin): installs the processor in _setup(); forward() bypasses _QuantFunctionalMixin patching context since the processor calls TensorQuantizers directly - Register _QuantWanAttention for WanAttention (replaces generic mixin that went through F.sdpa and never invoked softmax_quantizer) - Add NVFP4_WAN_SOFTMAX_CFG: enables only *softmax_quantizer with dynamic NVFP4 block_size=16; no calibration needed (dynamic scaling) - Add nvfp4-modelopt kernel to wan2_sage_attention.py example via apply_nvfp4_modelopt_kernel() using mtq.quantize(transformer, NVFP4_WAN_SOFTMAX_CFG) - Add unit tests for registry, processor type, quantizer attributes, config structure, and forward shape Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
…float32 matrix WAN2.2 video sequences exceed 8k tokens; materialising [B, S, H, T] in float32 before softmax exhausts GPU memory. Process query rows in chunks of 512 (matching the existing nvfp4 Python kernel) so peak memory stays bounded. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
Extends the ModelOpt Triton flash-attention kernel (triton_fa.py) with per-tile NVFP4 E2M1 P-matrix quantization, enabling fused sparse+quantized attention in a single kernel pass. Key changes: - triton_fa.py: add _quantize_p_nvfp4() Triton helper (per-tile max scaling, boundary-compare to 8 NVFP4 levels); add QUANTIZE_P constexpr to _attn_fwd (applied in both standard and skip-softmax paths, STE for backward); thread quantize_p: bool = False through _Attention and attention() - sparsity/attention_sparsity/config.py: add quantize_p field to SparseAttentionAttributeConfig - sparsity/attention_sparsity/plugins/diffusers.py: extract quantize_p in _build_sparse_kw() so it is forwarded to triton_fa.attention() - examples/diffusers/quantization/wan2_sage_attention.py: remove nvfp4 and nvfp4-modelopt Python-level kernels; add triton-sparse-nvfp4 and triton-skip-nvfp4 kernel options that use the fused Triton path - quantization/plugins/diffusion/diffusers.py: revert WAN-specific _QuantWanAttnProcessor/_QuantWanAttention; restore WanAttention -> _QuantAttentionModuleMixin registration - quantization/config.py: remove NVFP4_WAN_SOFTMAX_CFG - tests: remove test_wan_quant.py (superseded by Triton-based approach) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
WanAttention is a new-style diffusers AttentionModuleMixin that uses dispatch_attention_fn internally and does NOT call self.processor(). The previous forward() delegated to super().forward() (WanAttention), which bypassed our ModelOptWanAttnProcessor entirely, producing output byte-identical to the SDPA baseline. Fix: intercept in WanSparseAttentionModule.forward() and call the processor directly. The processor's __call__ handles both paths: - enabled=True → _wan_forward_triton (ModelOpt Triton kernel) - enabled=False → _wan_forward_sdpa (dispatch_attention_fn fallback) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
…F generic The HuggingFace generic plugin was imported before the diffusers plugin in plugins/__init__.py. This caused register_sparse_attention_on_the_fly to register WanAttention with _GenericSparseAttention first; the diffusers plugin then found WanAttention already registered and skipped, so WanSparseAttentionModule / ModelOptWanAttnProcessor were never installed. Fix: swap import order so the diffusers plugin registers WanAttention with WanSparseAttentionModule before the HF generic plugin runs. Signed-off-by: Ye Yu <yeyu@nvidia.com>
lambda=0.1 (ln=−2.3) is too aggressive for WAN's 8190-token video sequences with relatively diffuse attention: it skips tiles where the max score is >2.3 below the running max, which includes many tiles with genuine contribution to the output, causing PSNR≈11 dB. Lower default to 0.01 (ln=−4.6) for both triton-skip and triton-skip-nvfp4 kernels. Add --skip-threshold CLI arg so the threshold can be swept without editing code: --skip-threshold 0.1 aggressive (high sparsity, lower quality) --skip-threshold 0.01 moderate (default) --skip-threshold 0.001 conservative (near-baseline quality) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
The BLASST (https://arxiv.org/pdf/2512.12087) criterion checks ln(lambda) on the sm_scale-SCALED attention logits a_ij = q·k/sqrt(d). The Triton kernel stores scores as x = a * log2(e), so the correct threshold in kernel (log2) space is log2(lambda), not log2(lambda)*sm_scale. Previous code multiplied by sm_scale (~0.088 for head_dim=128), making every threshold 11× too aggressive. With lambda=0.1 the kernel-space threshold was -0.29 instead of the correct -3.32, skipping most attention tiles and producing garbage output (PSNR~11 dB). Even lambda=0.0001 was still too aggressive (-1.18 vs correct -13.29). Fix: use `log2(lambda)` directly as SKIP_THRESHOLD_LOG2, and restore the default threshold to 0.1 (the standard BLASST value). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
- CLIP scoring: use CPU to avoid OOM with WAN pipeline on GPU; catch RuntimeError in addition to OSError - --skip-threshold help: fix description to match actual exp(tile_max - running_max) < lambda criterion - vLLM worker: reject unsupported sparse presets (non-triton backend or unknown method) with a clear ValueError instead of silently degrading to dense attention - PYTHONPATH construction: use os.pathsep and skip empty entries to avoid CWD injection when PYTHONPATH is unset - diffusers_triton backend: raise ValueError when mixed with other backends instead of silently skipping _attn_implementation setup - _wan_forward_triton: fall back to SDPA when attention_mask is not None to preserve masking semantics Signed-off-by: Ye Yu <yeyu@nvidia.com>
- quantize_p: raise NotImplementedError when any of q/k/v requires_grad since backward does not model the quantized P path (inference-only) - b_start_loc_k: only synthesize dummy tensor in paged mode; raise ValueError in contiguous separate-KV path when b_start_loc_k is None; also validate that v_cache and block_table are provided alongside k_cache Signed-off-by: Ye Yu <yeyu@nvidia.com>
The * sm_scale factor is intentional: it scales the tile-skip threshold relative to head dimension, so larger head_dim (smaller sm_scale) produces more aggressive sparsity for the same lambda value. The previous 'fix' was incorrect. Signed-off-by: Ye Yu <yeyu@nvidia.com>
708f113 to
3f0bfd3
Compare
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1190 +/- ##
==========================================
- Coverage 74.77% 74.27% -0.51%
==========================================
Files 351 359 +8
Lines 40072 43949 +3877
==========================================
+ Hits 29964 32641 +2677
- Misses 10108 11308 +1200
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 9
♻️ Duplicate comments (6)
modelopt/torch/kernels/triton_fa.py (2)
1005-1012:⚠️ Potential issue | 🔴 CriticalReject paged KV when autograd is enabled.
Forward reads K/V from
k_cache/v_cache, but backward still reconstructs gradients only from contiguousk/vand returns nothing for the cache tensors. That is silently wrong once the paged path is used with dummyk/vor differentiable caches, so this mode needs the same inference-only guard asquantize_p.Suggested fix
if quantize_p and (q.requires_grad or k.requires_grad or v.requires_grad): raise NotImplementedError( "quantize_p supports inference only; backward does not model the quantized P path" ) + if is_paged and ( + q.requires_grad + or k.requires_grad + or v.requires_grad + or (k_cache is not None and k_cache.requires_grad) + or (v_cache is not None and v_cache.requires_grad) + ): + raise NotImplementedError( + "paged KV cache supports inference only; backward does not model k_cache/v_cache" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 1005 - 1012, The forward path currently allows using paged KV (k_cache/v_cache) while autograd is enabled, but the backward only reconstructs grads from contiguous k/v and omits cache tensors; add the same inference-only guard used for quantize_p: detect when paged KV is in use (check k_cache or v_cache presence/usage) and if any of q.requires_grad, k.requires_grad, or v.requires_grad are true, raise NotImplementedError with a clear message like "paged KV supports inference only; backward does not model the paged K/V path" (mirror the existing quantize_p check) so the code in the function handling forward (the block containing quantize_p and references to k_cache/v_cache) rejects autograd-enabled paged KV.
973-997:⚠️ Potential issue | 🟠 MajorDerive paged layout from
k_cache/v_cache, not from the placeholderk.In paged mode the caller can pass dummy contiguous
k/v, butnum_kv_heads,kv_group_num, and the page geometry are still taken fromkandpage_size. A mismatch silently remaps Q heads to the wrong KV heads or indexes the cache with the wrong page size. Validate the cache shapes and derive paged metadata from the cache tensors themselves.Suggested fix
HEAD_DIM = q.shape[2] num_q_heads = q.shape[1] - num_kv_heads = k.shape[1] - kv_group_num = num_q_heads // num_kv_heads batch = b_seq_len.shape[0] - is_paged = k_cache is not None + is_paged = any(x is not None for x in (k_cache, v_cache, block_table)) + if is_paged: + if k_cache is None or v_cache is None or block_table is None: + raise ValueError("k_cache, v_cache, and block_table must all be provided together") + if k_cache.shape[1] != v_cache.shape[1] or page_size != k_cache.shape[1]: + raise ValueError("page_size must match k_cache/v_cache.shape[1]") + if k_cache.shape[2:] != v_cache.shape[2:]: + raise ValueError("k_cache and v_cache must share num_kv_heads/head_dim") + num_kv_heads = k_cache.shape[2] + else: + num_kv_heads = k.shape[1] + + if num_q_heads % num_kv_heads != 0: + raise ValueError("num_q_heads must be divisible by num_kv_heads") + kv_group_num = num_q_heads // num_kv_heads🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/triton_fa.py` around lines 973 - 997, When paged mode is detected (k_cache/v_cache provided) don't derive KV head counts or page geometry from the placeholder tensor k; instead read and validate num_kv_heads, kv_group_num and page_size (and any page layout metadata) from k_cache and v_cache and use those values to compute kv_group_num and related variables used later in triton_fa.py (e.g., num_kv_heads, kv_group_num, page_size, b_seq_len_k/b_start_loc_k defaults). Add explicit shape checks that k_cache and v_cache agree with each other and with block_table (and that num_q_heads is divisible by num_kv_heads), and raise a clear ValueError if they mismatch; when b_start_loc_k is needed, construct/derive it from the cache metadata rather than the placeholder k. Ensure all references to num_kv_heads/kv_group_num/page_size later in the function use the cache-derived values.examples/diffusers/quantization/wan2_sage_attention.py (4)
370-385:⚠️ Potential issue | 🟡 MinorCaller kwargs still dropped on SDPA fallback paths.
_patched_sdpa()accepts**kwargs(line 360) but both fallback paths drop them:
- Lines 370-378: Early fallback doesn't forward
**kwargs- Line 385: Exception fallback doesn't forward
**kwargsThis causes silent semantic changes if callers pass
enable_gqa=Trueor other SDPA kwargs.Suggested fix
return _orig_sdpa( query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, + **kwargs, ) @@ - return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale) + return _orig_sdpa(query, key, value, is_causal=is_causal, scale=scale, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 370 - 385, _patched_sdpa currently drops caller kwargs on both fallback paths causing silently changed behavior; update the calls so that any received **kwargs are forwarded: call _run_kernel(query, key, value, is_causal=is_causal, scale=scale, **kwargs) and call _orig_sdpa(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) (and likewise in the exception fallback) so _patched_sdpa, _run_kernel and _orig_sdpa receive the original caller options (references: _patched_sdpa, _run_kernel, _orig_sdpa, _sage_calls, _active_kernel).
504-531:⚠️ Potential issue | 🟡 MinorMissing validation for non-positive
skip_thresholdvalues.The skip-softmax path depends on
log2(skip_softmax_threshold)in the Triton kernel. Passing0or a negative value should fail early with a clear error instead of causing cryptic failures in the kernel.Suggested fix
def apply_triton_sparse_kernel( transformer: torch.nn.Module, kernel: str, skip_threshold: float | None = None, ) -> None: ... config = copy.deepcopy(_TRITON_KERNEL_CONFIGS[kernel]) - if skip_threshold is not None and kernel in (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4): + if skip_threshold is not None: + if kernel in (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4): + if skip_threshold <= 0: + raise ValueError(f"skip_threshold must be > 0, got {skip_threshold}") config["sparse_cfg"]["*"]["skip_softmax_threshold"] = skip_threshold🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 504 - 531, The apply_triton_sparse_kernel function does not validate skip_threshold for the skip-softmax kernels (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4), which can lead to cryptic failures when zero or negative values are used; add an early validation inside apply_triton_sparse_kernel that when kernel is one of the skip-softmax constants and skip_threshold is not None, raise a ValueError if skip_threshold <= 0 with a clear message (e.g., require skip_threshold > 0) before mutating config["sparse_cfg"]["*"]["skip_softmax_threshold"] so invalid values never reach the Triton kernel.
815-817:⚠️ Potential issue | 🟡 MinorPipeline loads before kernel availability validation.
The model loads at line 817 (
load_pipeline) before kernel availability is checked (lines 824, 827, 898, 903). If the requested kernel isn't available, the multi-GB model download/load is wasted.Move the
args.kernel in AVAILABLE_KERNELScheck immediately afterparse_args().Suggested fix
def main() -> None: args = parse_args() + + # Fail fast if kernel is unavailable + if args.kernel not in AVAILABLE_KERNELS and not args.baseline and not args.benchmark: + raise RuntimeError( + f"Kernel {args.kernel!r} is not available. " + f"Available kernels: {AVAILABLE_KERNELS}" + ) + pipe = load_pipeline(args.model)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 815 - 817, The code currently calls load_pipeline(args.model) before validating kernel availability; move the kernel check immediately after parse_args() in main() so you verify args.kernel in AVAILABLE_KERNELS before any heavy model download. Concretely, inside main(), call args = parse_args(), then perform the AVAILABLE_KERNELS membership check (and exit/raise with a clear message if invalid), and only then call load_pipeline(args.model); update/remove the later redundant kernel checks around load_pipeline/usage (referencing main(), parse_args(), load_pipeline, and AVAILABLE_KERNELS).
589-596:⚠️ Potential issue | 🟡 MinorPSNR calculation still triggers divide-by-zero warning due to
np.whereeager evaluation.
np.where()evaluates both branches before selecting, so255.0**2 / mse_per_framecomputes even wheremse_per_frame < 1e-10, triggering divide-by-zero warnings for identical frames.Suggested fix
# PSNR mse_per_frame = ((ref - quant) ** 2).mean(axis=(1, 2, 3)) # (N,) # Avoid log(0) for identical frames - psnr_per_frame = np.where( - mse_per_frame < 1e-10, - 100.0, - 10.0 * np.log10(255.0**2 / mse_per_frame), - ) + safe_mse = np.maximum(mse_per_frame, 1e-10) + psnr_per_frame = 10.0 * np.log10(255.0**2 / safe_mse) + psnr_per_frame = np.where(mse_per_frame < 1e-10, 100.0, psnr_per_frame)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/quantization/wan2_sage_attention.py` around lines 589 - 596, The PSNR code uses np.where which still triggers divide-by-zero because both branches are evaluated; fix by first protecting the denominator (e.g., compute safe_mse = np.maximum(mse_per_frame, 1e-10)) then compute the PSNR values from safe_mse (psnr_vals = 10.0 * np.log10(255.0**2 / safe_mse)) and finally set exact-zero cases to 100.0 with psnr_per_frame = np.where(mse_per_frame < 1e-10, 100.0, psnr_vals) and keep psnr = float(psnr_per_frame.mean()); update the code that references mse_per_frame, psnr_per_frame, and psnr accordingly.
🧹 Nitpick comments (9)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)
387-401: Make the sparse-context setup exception-safe and reentrant.If one of the later
enter_context()calls raises, the thread-local flag stays stuck onTruebecause theExitStackis never unwound. The callback also always restoresFalse, so a nested sparse context would clear an outer scope too early. Build the stack underwith ExitStack() as stack:and restore the previous flag beforereturn stack.pop_all().♻️ Proposed fix
- from ..kernels import set_skip_softmax_context + from ..kernels import get_skip_softmax_context, set_skip_softmax_context - stack = ExitStack() - set_skip_softmax_context(True) - stack.callback(set_skip_softmax_context, False) - - try: - from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backend - - stack.enter_context(get_skip_softmax_attention_backend()) - except (ImportError, RuntimeError): - pass - - stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax)) - return stack + previous_skip_softmax = get_skip_softmax_context() + with ExitStack() as stack: + set_skip_softmax_context(True) + stack.callback(set_skip_softmax_context, previous_skip_softmax) + + try: + from ..kernels.diffusers_eager_attention import ( + get_skip_softmax_attention_backend, + ) + + stack.enter_context(get_skip_softmax_attention_backend()) + except (ImportError, RuntimeError): + pass + + stack.enter_context( + replace_function(torch.nn.functional, "softmax", sparse_softmax) + ) + return stack.pop_all()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py` around lines 387 - 401, The current setup sets a thread-local flag via set_skip_softmax_context(True) and builds an ExitStack but if any enter_context() raises the flag remains True and nested contexts are not reentrant; change the pattern to build the stack inside a with ExitStack() as stack: block so any exception will unwind and restore, capture the previous flag value before setting it, call set_skip_softmax_context(True), and before returning call and return stack.pop_all() while restoring the previous flag via set_skip_softmax_context(previous_value); make the adjustments around the functions get_skip_softmax_attention_backend, replace_function(torch.nn.functional, "softmax", sparse_softmax), and the ExitStack usage so the callback restoring the flag no longer clobbers outer contexts.modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py (1)
50-61: Defer these optional backend imports until registration time.Importing diffusers/LTX backends from
__init__means every consumer ofmodelopt.torch.sparsity.attention_sparsity.kernelspays those import attempts, and real registration failures degrade into a silentNoneexport viasuppress(...). Resolve these symbols lazily when registration is actually needed instead. As per coding guidelines: "Use lazy imports and gate optional dependencies viaimport_plugin()for integrations (HuggingFace, Megatron, etc.); avoid hard imports at module level for optional features."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py` around lines 50 - 61, The module currently does eager imports of optional backends (register_diffusers_eager_attention, register_diffusers_triton_attention, register_ltx_eager_attention, register_ltx_triton_attention) inside __init__.py using contextlib.suppress, which forces all consumers to attempt those imports; instead, remove these top-level imports and make the four registration callables resolved lazily at registration time (e.g., implement an import_plugin() or inline import inside the functions that perform registration), so each of register_diffusers_eager_attention, register_diffusers_triton_attention, register_ltx_eager_attention and register_ltx_triton_attention is imported only when actually invoked and ImportError/RuntimeError are handled there.modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py (2)
72-131: Significant code duplication withdiffusers_triton_attention.py.The
_ltx_triton_attentionfunction (lines 72-131) shares nearly identical logic with_diffusers_triton_attentionindiffusers_triton_attention.py:
- Same varlen metadata construction
- Same calibration mode branching
- Same inference mode threshold handling
Consider extracting the shared core into a common helper to reduce maintenance burden.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py` around lines 72 - 131, The two functions _ltx_triton_attention and _diffusers_triton_attention duplicate varlen metadata construction, calibration/inference branching and Triton call logic; refactor by extracting a shared helper (e.g., _triton_attention_core) that accepts parameters needed to build q_flat/k_flat/v_flat, seq lengths, heads/dim_head, and optional flags (threshold, calibration trials), move the common kw construction and calibration/inference handling (including use of _thread_local, attention_calibrate, attention) into that helper, then have _ltx_triton_attention and _diffusers_triton_attention prepare their layout-specific tensors and call the new helper to return the reshaped output.
147-153:register_ltx_triton_attentiondoesn't track registration state.Unlike
register_diffusers_triton_attentionwhich uses_BACKEND_REGISTEREDto prevent double-registration, this function relies only on theisinstance(fn, _TritonLTXAttentionWrapper)check. While this works, it's inconsistent with the diffusers pattern and scans all modules on every call.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py` around lines 147 - 153, register_ltx_triton_attention currently scans all modules every call and only avoids double-wrapping by checking isinstance(fn, _TritonLTXAttentionWrapper); make it consistent with register_diffusers_triton_attention by adding a module-level flag _BACKEND_REGISTERED (or reusing the same symbol) to short-circuit if already registered. Modify register_ltx_triton_attention to check _BACKEND_REGISTERED at the top and set it to True after the first successful registration pass, while still keeping the existing isinstance(fn, _TritonLTXAttentionWrapper) guard and targeting Attention and _TritonLTXAttentionWrapper as before.modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py (1)
169-179: Fragile internal API manipulation — document diffusers version dependency.The registration directly manipulates
AttentionBackendName._member_map_and_AttentionBackendRegistry._backendsinternals. This works but is fragile if diffusers changes its internal structure.Consider adding a comment noting the diffusers version this was tested against, or wrapping in a try/except to provide a clear error if the internal API changes.
Suggested documentation
def register_diffusers_triton_attention() -> None: """Register ``modelopt_triton`` backend in diffusers. Safe to call multiple times; registration happens only once. + + Note: This manipulates diffusers internal APIs (_member_map_, _backends, etc.) + and was tested with diffusers >= 0.31. May require updates for future versions. """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py` around lines 169 - 179, The code directly mutates diffusers internals (AttentionBackendName._member_map_, AttentionBackendRegistry._backends, etc.), which is fragile; add a brief comment near the block stating the diffusers version(s) this was tested against and expected internal shape, and wrap the registration logic for AttentionBackendName and _AttentionBackendRegistry (the new_member creation, map assignments, and setting of _supported_arg_names via inspect.signature(_diffusers_triton_attention)) in a try/except that catches AttributeError/KeyError and raises a clear, actionable RuntimeError explaining that the diffusers internal API changed and what fields are expected (e.g., _member_map_, _value2member_map_, _backends, _constraints, _supported_arg_names) so maintainers can update accordingly.modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py (2)
172-192: Calibration counter retrieval tries diffusers first, then LTX — order may matter.The code tries to get counters from diffusers backend first (lines 176-181), then falls back to LTX (lines 183-189). If both backends are active in the same forward pass, only one set of counters will be collected. Per context snippets 1-2, each backend maintains independent thread-local storage.
This is likely fine for single-backend usage, but consider documenting that mixed-backend calibration is not supported, or aggregating counters from both backends if both are present.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py` around lines 172 - 192, _collect_calibration_stats currently only uses the first available get_calibration_counters (diffusers_triton_attention) and ignores the other backend, which prevents collecting counters when both backends are active; change the logic in _collect_calibration_stats to attempt to import get_calibration_counters from both ..kernels.diffusers_triton_attention and ..kernels.ltx_triton_attention, call each present function and merge/aggregate their returned counters (e.g., sum corresponding metrics or combine entries) into a single counters object before proceeding with the existing threshold logic (referencing get_calibration_counters and self._threshold_trials to decide whether aggregation is needed), or alternatively add a clear docstring note on mixed-backend non-support if aggregation is not implemented.
110-128: Hardcoded sequence length4224limits calibration accuracy for varying input sizes.The threshold calculation uses a fixed
seqlen = 4224(line 127), which may not match actual runtime sequence lengths. The TODO acknowledges this, but for diffusion models with variable resolution/frame counts, this could lead to suboptimal threshold selection.Consider exposing this as a configurable parameter or computing it from the model/input metadata when available.
Would you like me to suggest an approach for passing actual sequence length at runtime?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py` around lines 110 - 128, The method _get_effective_threshold currently divides scale_factor by a hardcoded seqlen (4224); change it to use a real sequence length by accepting an explicit seqlen (e.g., add a seqlen parameter to _get_effective_threshold) or by reading a sequence-length property from the passed module (e.g., module.seq_len or module.input_shape) and fall back to the previous default when unavailable; update callers that compute thresholds to pass the actual runtime seqlen (or set a configurable class attribute like self.default_seqlen) so calibration_params/target_sparse_ratio are scaled correctly instead of using 4224, and keep returning self.skip_softmax_threshold when no calibration info exists.modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py (2)
339-353:_build_sparse_kwsilently ignores unknown methods.If
_methodis neither"triton_sparse_softmax"nor"triton_skip_softmax", the method returns an empty dict (plus optionalquantize_p). This could lead to silent misconfiguration if a user specifies an invalid method name.Consider logging a warning for unknown methods.
Suggested improvement
def _build_sparse_kw(self) -> dict: """Extract triton_fa kwargs from the current method config.""" cfg = getattr(self, "_method_config", {}) method = getattr(self, "_method", "") kw: dict = {} if method == "triton_sparse_softmax": kw["sparsity_n"] = cfg.get("sparsity_n", 2) kw["sparsity_m"] = cfg.get("sparsity_m", 4) kw["num_sink_tokens"] = cfg.get("num_sink_tokens", 0) kw["dense_window_size"] = cfg.get("dense_window_size", 64) elif method == "triton_skip_softmax": kw["skip_softmax_threshold"] = cfg.get("skip_softmax_threshold", 0.1) + elif method: + import warnings + warnings.warn(f"Unknown sparse method '{method}'; sparse_kw will be empty") if cfg.get("quantize_p", False): kw["quantize_p"] = True return kw🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py` around lines 339 - 353, The _build_sparse_kw function currently ignores unknown _method values silently; update it to detect when method is non-empty and not one of the supported values ("triton_sparse_softmax", "triton_skip_softmax") and emit a warning including the invalid method name and available options to aid debugging. Locate the _build_sparse_kw method and add a membership check (e.g. allowed_methods = {"triton_sparse_softmax","triton_skip_softmax"}); if method and method not in allowed_methods, call a logger on the instance (use getattr(self, "_logger", logging.getLogger(__name__))) to log a warning that includes method and allowed_methods, then continue building kw (still respect quantize_p as currently implemented).
224-229: Hardcoded context length512is fragile if WAN architecture changes.Line 227 uses a hardcoded
512for the text-encoder context length. The comment acknowledges this is a "WAN hardcoded constant", but if the WAN model changes or different variants use different context lengths, this will break silently.Consider extracting this as a class constant or reading it from the attention module's configuration if available.
Suggested improvement
+_WAN_TEXT_CONTEXT_LENGTH = 512 # WAN hardcoded constant for text-encoder context class ModelOptWanAttnProcessor: ... def _wan_forward_triton(self, ...): ... if attn.add_k_proj is not None: - # 512 is the text-encoder context length (WAN hardcoded constant) assert encoder_hidden_states is not None - image_context_length = encoder_hidden_states.shape[1] - 512 + image_context_length = encoder_hidden_states.shape[1] - _WAN_TEXT_CONTEXT_LENGTH🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py` around lines 224 - 229, The code slices encoder_hidden_states using a hardcoded 512 (seen in the block checking attn.add_k_proj and computing image_context_length) which is brittle; replace that magic number with a configurable value read from the attention module or plugin class: attempt to read a context length from attn (e.g., attn.config.text_encoder_context_length or a similar attribute) and fall back to a plugin/class constant like TEXT_ENCODER_CONTEXT_LEN if not present, then use that variable when computing image_context_length and slicing encoder_hidden_states; keep the assert for encoder_hidden_states and add a warning/log if neither config nor constant exists and you must fall back to 512.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 22-24: Update the example header text to reflect that calibration
now uses triton_skip_softmax with the Triton backend (not flash_skip_softmax
with the eager backend); locate the header block at the top of the file and
change any mentions of flash_skip_softmax/eager to triton_skip_softmax/Triton so
it matches the implementation in build_sparse_config(), ensuring documentation
and code paths align for debugging and comparison.
In `@examples/vllm_serve/sparse_attn_worker.py`:
- Around line 164-166: The current check uses a falsy test on threshold and will
skip valid zero values; update the condition around threshold retrieval to test
for None explicitly (e.g., replace "if threshold:" with "if threshold is not
None:") so that when layer_cfg.get("skip_softmax_threshold") returns 0 the key
"skip_softmax_threshold" is still set on sparse_kw; locate the threshold
variable, layer_cfg.get call, and sparse_kw assignment in sparse_attn_worker.py
to apply this change.
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 259-266: The current guard `if forward_loop is None and
(calibrate_prefill or calibrate_decode):` prevents building the tokenizer and
RULER decode dataset whenever a custom `forward_loop` is provided, so built-in
decode calibration never runs when `target_sparse_ratio.decode > 0`; change the
condition to always build the tokenizer/calibration_data when `calibrate_decode`
is true (even if `forward_loop` is provided) while keeping the existing prefill
behavior for `calibrate_prefill`. Specifically, call
`_extract_tokenizer_from_model(model)` and construct the RULER dataset whenever
`calibrate_decode` is enabled (or `target_sparse_ratio.decode > 0`), and only
gate prefill dataset construction on `forward_loop is None` and
`calibrate_prefill`, ensuring `tokenizer`/`calibration_data` are available to
the decode calibration code paths (see `forward_loop`, `calibrate_prefill`,
`calibrate_decode`, `tokenizer`, `calibration_data`, and
`_extract_tokenizer_from_model`).
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 127-156: The diffuses/ltx registration blocks currently swallow
all exceptions; change them to only suppress expected optional-dependency
errors: in the ModelMixin/diffusers block catch only ImportError (remove broad
except (ImportError, Exception)) so that after importing ModelMixin you still
check and call register_diffusers_eager_attention and
register_diffusers_triton_attention but let any runtime errors from those calls
raise; in the ltx block stop using contextlib.suppress(Exception) around
register_ltx_eager_attention/register_ltx_triton_attention and instead wrap the
import in except (ImportError, RuntimeError) as you already do, and for the
actual calls use a narrow try/except that only catches ImportError or
RuntimeError so real bugs in
register_ltx_eager_attention/register_ltx_triton_attention surface.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`:
- Around line 79-81: The code currently always adds attn_mask to scores which is
incorrect when attn_mask is a boolean mask; detect boolean masks before the
additive path and convert them to an additive mask (e.g., torch.where(attn_mask,
neg_inf, 0) or attn_mask.to(scores.dtype) * neg_inf) using scores.device and
scores.dtype to construct neg_inf (like -torch.finfo(scores.dtype).max or -1e9
for float32), then add that converted mask to scores; update the code around the
scores and attn_mask usage to only add when attn_mask is already additive or
after converting when attn_mask.dtype is bool.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 130-143: When calibration_mode (thread-local "calibration_mode")
is enabled but either "threshold_trials" (local variable trials) is
missing/empty or the calibration kernel "attention_calibrate" is None, the
function should not silently fall through to inference; instead detect this case
and raise a clear RuntimeError (or at least log an error and raise) that
includes which condition failed (missing trials vs missing attention_calibrate)
and mentions the function/mode involved (the calibration branch around
attention_calibrate(q, k, v, **kw, threshold_trials=trials)). Keep the existing
logic that accumulates _thread_local.calibration_counters when calibration runs,
but before falling back, explicitly check calib_mode && (not trials or
attention_calibrate is None) and raise with a descriptive message so callers
know calibration was requested but cannot proceed.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py`:
- Around line 103-114: The registration is non-idempotent because
register_ltx_eager_attention only skips when module.attention_function is a
_SkipSoftmaxLTXAttentionWrapper, so if _TritonLTXAttentionWrapper is outermost
it will wrap again; change the logic to always operate on the underlying base
callable: implement a small unwrap helper (e.g., get_base_attention_fn) that,
given a callable, repeatedly drills into known wrapper attributes (the wrappers
used here: _SkipSoftmaxLTXAttentionWrapper and _TritonLTXAttentionWrapper expose
the original as an inner attribute such as .fn or .wrapped_fn) until a
non-wrapper callable is reached, then if the base callable is not already
wrapped by the eager wrapper create a fresh
_SkipSoftmaxLTXAttentionWrapper(base_fn) and assign that to
module.attention_function (or rewrap preserving the other wrapper by
reconstructing outer wrappers around the new eager wrapper), and update
register_ltx_triton_attention similarly so both registrations first unwrap to
the base before reapplying their wrapper.
In `@modelopt/torch/sparsity/attention_sparsity/stats_manager.py`:
- Around line 67-75: The code assumes all incoming sparse_blocks vectors have
the same length and increments total_blocks even when sparse_blocks is missing;
fix by (1) tracking a separate counter like
self.aggregated_stats["sparse_blocks_reports"] (increment only when
stats.get("sparse_blocks") is not None) and use that for denominators, (2) when
receiving incoming = stats.get("sparse_blocks"), ensure the accumulator
self.aggregated_stats["sparse_blocks"] is resized to accommodate varying widths
(extend with zeros if incoming is longer, or pad incoming with zeros if shorter)
before elementwise addition, and (3) do not rely on total_blocks for
sparse_blocks normalization — only count calls that actually reported
sparse_blocks. Ensure all references use self.aggregated_stats, incoming,
stats.get("sparse_blocks"), and "total_blocks" so the changes integrate with
existing logic.
In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 82-96: Tests are reusing cached backend modules across fixture
runs causing stale mocks; ensure you remove cached modules from sys.modules
before reimporting. Inside the patch.dict(sys.modules, _mock_diffusers())
context (and likewise for the triton fixture), pop any existing
"modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention"
and
"modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention"
entries from sys.modules, then import the module and access
_diffusers_eager_attention, register_diffusers_eager_attention,
get_skip_softmax_attention_backend and reset mod._BACKEND_REGISTERED to
guarantee a fresh module and mocks for each fixture run.
---
Duplicate comments:
In `@examples/diffusers/quantization/wan2_sage_attention.py`:
- Around line 370-385: _patched_sdpa currently drops caller kwargs on both
fallback paths causing silently changed behavior; update the calls so that any
received **kwargs are forwarded: call _run_kernel(query, key, value,
is_causal=is_causal, scale=scale, **kwargs) and call _orig_sdpa(query, key,
value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal,
scale=scale, **kwargs) (and likewise in the exception fallback) so
_patched_sdpa, _run_kernel and _orig_sdpa receive the original caller options
(references: _patched_sdpa, _run_kernel, _orig_sdpa, _sage_calls,
_active_kernel).
- Around line 504-531: The apply_triton_sparse_kernel function does not validate
skip_threshold for the skip-softmax kernels (KERNEL_TRITON_SKIP,
KERNEL_TRITON_SKIP_NVFP4), which can lead to cryptic failures when zero or
negative values are used; add an early validation inside
apply_triton_sparse_kernel that when kernel is one of the skip-softmax constants
and skip_threshold is not None, raise a ValueError if skip_threshold <= 0 with a
clear message (e.g., require skip_threshold > 0) before mutating
config["sparse_cfg"]["*"]["skip_softmax_threshold"] so invalid values never
reach the Triton kernel.
- Around line 815-817: The code currently calls load_pipeline(args.model) before
validating kernel availability; move the kernel check immediately after
parse_args() in main() so you verify args.kernel in AVAILABLE_KERNELS before any
heavy model download. Concretely, inside main(), call args = parse_args(), then
perform the AVAILABLE_KERNELS membership check (and exit/raise with a clear
message if invalid), and only then call load_pipeline(args.model); update/remove
the later redundant kernel checks around load_pipeline/usage (referencing
main(), parse_args(), load_pipeline, and AVAILABLE_KERNELS).
- Around line 589-596: The PSNR code uses np.where which still triggers
divide-by-zero because both branches are evaluated; fix by first protecting the
denominator (e.g., compute safe_mse = np.maximum(mse_per_frame, 1e-10)) then
compute the PSNR values from safe_mse (psnr_vals = 10.0 * np.log10(255.0**2 /
safe_mse)) and finally set exact-zero cases to 100.0 with psnr_per_frame =
np.where(mse_per_frame < 1e-10, 100.0, psnr_vals) and keep psnr =
float(psnr_per_frame.mean()); update the code that references mse_per_frame,
psnr_per_frame, and psnr accordingly.
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 1005-1012: The forward path currently allows using paged KV
(k_cache/v_cache) while autograd is enabled, but the backward only reconstructs
grads from contiguous k/v and omits cache tensors; add the same inference-only
guard used for quantize_p: detect when paged KV is in use (check k_cache or
v_cache presence/usage) and if any of q.requires_grad, k.requires_grad, or
v.requires_grad are true, raise NotImplementedError with a clear message like
"paged KV supports inference only; backward does not model the paged K/V path"
(mirror the existing quantize_p check) so the code in the function handling
forward (the block containing quantize_p and references to k_cache/v_cache)
rejects autograd-enabled paged KV.
- Around line 973-997: When paged mode is detected (k_cache/v_cache provided)
don't derive KV head counts or page geometry from the placeholder tensor k;
instead read and validate num_kv_heads, kv_group_num and page_size (and any page
layout metadata) from k_cache and v_cache and use those values to compute
kv_group_num and related variables used later in triton_fa.py (e.g.,
num_kv_heads, kv_group_num, page_size, b_seq_len_k/b_start_loc_k defaults). Add
explicit shape checks that k_cache and v_cache agree with each other and with
block_table (and that num_q_heads is divisible by num_kv_heads), and raise a
clear ValueError if they mismatch; when b_start_loc_k is needed,
construct/derive it from the cache metadata rather than the placeholder k.
Ensure all references to num_kv_heads/kv_group_num/page_size later in the
function use the cache-derived values.
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py`:
- Around line 50-61: The module currently does eager imports of optional
backends (register_diffusers_eager_attention,
register_diffusers_triton_attention, register_ltx_eager_attention,
register_ltx_triton_attention) inside __init__.py using contextlib.suppress,
which forces all consumers to attempt those imports; instead, remove these
top-level imports and make the four registration callables resolved lazily at
registration time (e.g., implement an import_plugin() or inline import inside
the functions that perform registration), so each of
register_diffusers_eager_attention, register_diffusers_triton_attention,
register_ltx_eager_attention and register_ltx_triton_attention is imported only
when actually invoked and ImportError/RuntimeError are handled there.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 169-179: The code directly mutates diffusers internals
(AttentionBackendName._member_map_, AttentionBackendRegistry._backends, etc.),
which is fragile; add a brief comment near the block stating the diffusers
version(s) this was tested against and expected internal shape, and wrap the
registration logic for AttentionBackendName and _AttentionBackendRegistry (the
new_member creation, map assignments, and setting of _supported_arg_names via
inspect.signature(_diffusers_triton_attention)) in a try/except that catches
AttributeError/KeyError and raises a clear, actionable RuntimeError explaining
that the diffusers internal API changed and what fields are expected (e.g.,
_member_map_, _value2member_map_, _backends, _constraints, _supported_arg_names)
so maintainers can update accordingly.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 72-131: The two functions _ltx_triton_attention and
_diffusers_triton_attention duplicate varlen metadata construction,
calibration/inference branching and Triton call logic; refactor by extracting a
shared helper (e.g., _triton_attention_core) that accepts parameters needed to
build q_flat/k_flat/v_flat, seq lengths, heads/dim_head, and optional flags
(threshold, calibration trials), move the common kw construction and
calibration/inference handling (including use of _thread_local,
attention_calibrate, attention) into that helper, then have
_ltx_triton_attention and _diffusers_triton_attention prepare their
layout-specific tensors and call the new helper to return the reshaped output.
- Around line 147-153: register_ltx_triton_attention currently scans all modules
every call and only avoids double-wrapping by checking isinstance(fn,
_TritonLTXAttentionWrapper); make it consistent with
register_diffusers_triton_attention by adding a module-level flag
_BACKEND_REGISTERED (or reusing the same symbol) to short-circuit if already
registered. Modify register_ltx_triton_attention to check _BACKEND_REGISTERED at
the top and set it to True after the first successful registration pass, while
still keeping the existing isinstance(fn, _TritonLTXAttentionWrapper) guard and
targeting Attention and _TritonLTXAttentionWrapper as before.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 387-401: The current setup sets a thread-local flag via
set_skip_softmax_context(True) and builds an ExitStack but if any
enter_context() raises the flag remains True and nested contexts are not
reentrant; change the pattern to build the stack inside a with ExitStack() as
stack: block so any exception will unwind and restore, capture the previous flag
value before setting it, call set_skip_softmax_context(True), and before
returning call and return stack.pop_all() while restoring the previous flag via
set_skip_softmax_context(previous_value); make the adjustments around the
functions get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax), and the
ExitStack usage so the callback restoring the flag no longer clobbers outer
contexts.
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 172-192: _collect_calibration_stats currently only uses the first
available get_calibration_counters (diffusers_triton_attention) and ignores the
other backend, which prevents collecting counters when both backends are active;
change the logic in _collect_calibration_stats to attempt to import
get_calibration_counters from both ..kernels.diffusers_triton_attention and
..kernels.ltx_triton_attention, call each present function and merge/aggregate
their returned counters (e.g., sum corresponding metrics or combine entries)
into a single counters object before proceeding with the existing threshold
logic (referencing get_calibration_counters and self._threshold_trials to decide
whether aggregation is needed), or alternatively add a clear docstring note on
mixed-backend non-support if aggregation is not implemented.
- Around line 110-128: The method _get_effective_threshold currently divides
scale_factor by a hardcoded seqlen (4224); change it to use a real sequence
length by accepting an explicit seqlen (e.g., add a seqlen parameter to
_get_effective_threshold) or by reading a sequence-length property from the
passed module (e.g., module.seq_len or module.input_shape) and fall back to the
previous default when unavailable; update callers that compute thresholds to
pass the actual runtime seqlen (or set a configurable class attribute like
self.default_seqlen) so calibration_params/target_sparse_ratio are scaled
correctly instead of using 4224, and keep returning self.skip_softmax_threshold
when no calibration info exists.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py`:
- Around line 339-353: The _build_sparse_kw function currently ignores unknown
_method values silently; update it to detect when method is non-empty and not
one of the supported values ("triton_sparse_softmax", "triton_skip_softmax") and
emit a warning including the invalid method name and available options to aid
debugging. Locate the _build_sparse_kw method and add a membership check (e.g.
allowed_methods = {"triton_sparse_softmax","triton_skip_softmax"}); if method
and method not in allowed_methods, call a logger on the instance (use
getattr(self, "_logger", logging.getLogger(__name__))) to log a warning that
includes method and allowed_methods, then continue building kw (still respect
quantize_p as currently implemented).
- Around line 224-229: The code slices encoder_hidden_states using a hardcoded
512 (seen in the block checking attn.add_k_proj and computing
image_context_length) which is brittle; replace that magic number with a
configurable value read from the attention module or plugin class: attempt to
read a context length from attn (e.g., attn.config.text_encoder_context_length
or a similar attribute) and fall back to a plugin/class constant like
TEXT_ENCODER_CONTEXT_LEN if not present, then use that variable when computing
image_context_length and slicing encoder_hidden_states; keep the assert for
encoder_hidden_states and add a warning/log if neither config nor constant
exists and you must fall back to 512.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f4a94f9e-f0a7-436f-bbed-e4e2ad821a33
📒 Files selected for processing (30)
examples/diffusers/README.mdexamples/diffusers/quantization/wan2_sage_attention.pyexamples/diffusers/sparsity/README.mdexamples/diffusers/sparsity/wan22_skip_softmax.pyexamples/vllm_serve/sparse_attn_worker.pyexamples/vllm_serve/vllm_serve_sparse_attn.pymodelopt/torch/kernels/__init__.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/quantization/plugins/diffusion/diffusers.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrator.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/kernels/__init__.pymodelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/registry.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/plugins/__init__.pymodelopt/torch/sparsity/attention_sparsity/plugins/diffusers.pymodelopt/torch/sparsity/attention_sparsity/plugins/huggingface.pymodelopt/torch/sparsity/attention_sparsity/plugins/vllm.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pypyproject.tomltests/gpu/torch/sparsity/attention_sparsity/test_triton_fa_paged.pytests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.pytests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
✅ Files skipped from review due to trivial changes (4)
- modelopt/torch/quantization/plugins/diffusion/diffusers.py
- modelopt/torch/sparsity/attention_sparsity/methods/registry.py
- examples/diffusers/README.md
- pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/torch/sparsity/attention_sparsity/plugins/init.py
- modelopt/torch/sparsity/attention_sparsity/config.py
- tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py
| During calibration, ``flash_skip_softmax`` with the eager attention backend | ||
| collects sparsity statistics across multiple threshold trials. The fitted | ||
| exponential model then allows runtime control of the target sparsity ratio |
There was a problem hiding this comment.
Sync the example header with the actual calibration path.
The file header still says calibration uses flash_skip_softmax with the eager backend, but build_sparse_config() now wires triton_skip_softmax with the Triton path. That will send users to the wrong codepath when they debug or compare results.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 22 - 24,
Update the example header text to reflect that calibration now uses
triton_skip_softmax with the Triton backend (not flash_skip_softmax with the
eager backend); locate the header block at the top of the file and change any
mentions of flash_skip_softmax/eager to triton_skip_softmax/Triton so it matches
the implementation in build_sparse_config(), ensuring documentation and code
paths align for debugging and comparison.
| threshold = layer_cfg.get("skip_softmax_threshold") | ||
| if threshold: | ||
| sparse_kw["skip_softmax_threshold"] = threshold |
There was a problem hiding this comment.
Threshold check if threshold: treats 0 as falsy, skipping valid zero threshold.
Line 165 uses if threshold: which will be False when threshold=0. If a user intentionally sets skip_softmax_threshold: 0 to disable skip-softmax, the key won't be added to sparse_kw, potentially causing different behavior than expected.
Consider using if threshold is not None: for explicit None check.
Suggested fix
threshold = layer_cfg.get("skip_softmax_threshold")
- if threshold:
+ if threshold is not None:
sparse_kw["skip_softmax_threshold"] = threshold🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/vllm_serve/sparse_attn_worker.py` around lines 164 - 166, The
current check uses a falsy test on threshold and will skip valid zero values;
update the condition around threshold retrieval to test for None explicitly
(e.g., replace "if threshold:" with "if threshold is not None:") so that when
layer_cfg.get("skip_softmax_threshold") returns 0 the key
"skip_softmax_threshold" is still set on sparse_kw; locate the threshold
variable, layer_cfg.get call, and sparse_kw assignment in sparse_attn_worker.py
to apply this change.
| # Extract tokenizer and build calibration data only if no forward_loop is provided. | ||
| # When the user supplies their own forward_loop (e.g. for diffusion models), | ||
| # RULER dataset generation is skipped entirely. | ||
| tokenizer = None | ||
| calibration_data = None | ||
|
|
||
| if calibrate_prefill or calibrate_decode: | ||
| if forward_loop is None and (calibrate_prefill or calibrate_decode): | ||
| tokenizer = _extract_tokenizer_from_model(model) |
There was a problem hiding this comment.
Still build decode calibration data when forward_loop is provided.
forward_loop is documented as prefill-only, but this guard skips tokenizer/dataset construction for decode as soon as any custom loop is passed. With target_sparse_ratio.decode > 0, the function now always hits Lines 313-314 instead of running the built-in decode calibration. Build the RULER dataset whenever decode calibration is enabled, even if prefill uses a custom loop.
🔧 Proposed fix
- if forward_loop is None and (calibrate_prefill or calibrate_decode):
+ needs_builtin_dataset = calibrate_decode or forward_loop is None
+ if needs_builtin_dataset and (calibrate_prefill or calibrate_decode):
tokenizer = _extract_tokenizer_from_model(model)
builder = RulerDatasetBuilder(
samples=calib_config.samples,
max_seqlen=calib_config.max_seqlen,
tokenizer_name_or_path=tokenizer,Also applies to: 313-314
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 259 - 266, The current guard `if forward_loop is None and
(calibrate_prefill or calibrate_decode):` prevents building the tokenizer and
RULER decode dataset whenever a custom `forward_loop` is provided, so built-in
decode calibration never runs when `target_sparse_ratio.decode > 0`; change the
condition to always build the tokenizer/calibration_data when `calibrate_decode`
is true (even if `forward_loop` is provided) while keeping the existing prefill
behavior for `calibrate_prefill`. Specifically, call
`_extract_tokenizer_from_model(model)` and construct the RULER dataset whenever
`calibrate_decode` is enabled (or `target_sparse_ratio.decode > 0`), and only
gate prefill dataset construction on `forward_loop is None` and
`calibrate_prefill`, ensuring `tokenizer`/`calibration_data` are available to
the decode calibration code paths (see `forward_loop`, `calibrate_prefill`,
`calibrate_decode`, `tokenizer`, `calibration_data`, and
`_extract_tokenizer_from_model`).
| # Apply attention mask if provided | ||
| if attn_mask is not None: | ||
| scores = scores + attn_mask |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /root/repo && find . -name "diffusers_eager_attention.py" -type fRepository: NVIDIA/Model-Optimizer
Length of output: 126
🏁 Script executed:
cd /root/repo && wc -l modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 126
🏁 Script executed:
cd /root/repo && sed -n '60,100p' modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 126
🏁 Script executed:
cd /root/repo && head -100 modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | tail -50Repository: NVIDIA/Model-Optimizer
Length of output: 126
🏁 Script executed:
fd diffusers_eager_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 145
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | head -110Repository: NVIDIA/Model-Optimizer
Length of output: 4474
🏁 Script executed:
rg "torch.nn.functional.scaled_dot_product_attention|torch.ops.aten.scaled_dot_product_attention" -A 5 -B 5Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
rg "attn_mask|attention_mask" modelopt/torch/sparsity/attention_sparsity/kernels/ -A 3 -B 3Repository: NVIDIA/Model-Optimizer
Length of output: 2471
🏁 Script executed:
rg "scaled_dot_product_attention" --type py -A 10 -B 5Repository: NVIDIA/Model-Optimizer
Length of output: 40075
🏁 Script executed:
fd "*.py" modelopt/torch/sparsity/attention_sparsity/ | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 349
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | sed -n '1,150p'Repository: NVIDIA/Model-Optimizer
Length of output: 6853
🏁 Script executed:
rg "attn_mask" modelopt/torch/sparsity/attention_sparsity/ -B 5 -A 5 --type py | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 8090
🏁 Script executed:
fd --glob "*.py" -x grep -l "diffusers.*attn_mask\|attn_mask.*diffusers" {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py | sed -n '1,200p'Repository: NVIDIA/Model-Optimizer
Length of output: 9214
🏁 Script executed:
rg "attention_mask" modelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py -B 10 -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 4438
🏁 Script executed:
web_search
PyTorch scaled_dot_product_attention attn_mask boolean float mask semanticsRepository: NVIDIA/Model-Optimizer
Length of output: 162
🌐 Web query:
PyTorch torch.nn.functional.scaled_dot_product_attention attn_mask parameter documentation
💡 Result:
In torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, ...), the attn_mask parameter is:
- Type: optional
Tensor - Shape: must be broadcastable to the attention-weight shape
(N, ..., L, S), whereL = query.size(-2)andS = key.size(-2). (docs.pytorch.org) - Supported mask kinds:
- Boolean mask (
dtype=torch.bool):Truemeans “allowed to attend / takes part in attention”;Falsepositions are effectively masked out (internally turned into-infbias before softmax). (docs.pytorch.org) - Float/additive mask (same dtype as
query/key/value): values are added to the attention scores (commonly0for keep and-inffor mask). (docs.pytorch.org)
- Boolean mask (
- Interaction with
is_causal: it’s an error to set bothattn_maskandis_causal=True. (docs.pytorch.org)
Note: PyTorch’s Transformer module explicitly warns that its boolean mask semantics are the opposite of SDPA’s attn_mask boolean semantics (Transformer uses True = “not allowed”). (docs.pytorch.org)
Citations:
- 1: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention?utm_source=openai
- 2: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention?utm_source=openai
- 3: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention?utm_source=openai
- 4: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention?utm_source=openai
- 5: https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.transformer.Transformer.html?utm_source=openai
🏁 Script executed:
fd test -type f -name "*.py" modelopt/torch/sparsity/attention_sparsity/ | xargs grep -l "diffusers_eager_attention\|eager.*attn_mask" 2>/dev/null | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 236
🏁 Script executed:
rg "test.*diffusers.*attention|diffusers.*test" modelopt/torch/sparsity/attention_sparsity/ -l --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
rg "attn_mask.*bool|bool.*attn_mask" --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 346
🏁 Script executed:
rg "dispatch_attention_fn" modelopt/torch/sparsity/attention_sparsity/ -B 5 -A 10 --type py | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 8321
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | sed -n '110,150p'Repository: NVIDIA/Model-Optimizer
Length of output: 1760
🏁 Script executed:
rg "register.*backend\|_AttentionBackendRegistry" modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py -A 15Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
rg "dtype.*bool|bool.*dtype" modelopt/torch/sparsity/attention_sparsity/ -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1582
🏁 Script executed:
python3 << 'EOF'
# Verify the masking semantics claimed in the review
import torch
# Simulate attention scores
scores = torch.randn(2, 4, 8, 8)
# Case 1: Boolean mask (SDPA-style)
bool_mask = torch.ones(2, 4, 8, 8, dtype=torch.bool)
bool_mask[0, 0, :, 7] = False # Mask out last position
# What the current code does (WRONG for boolean)
wrong_result = scores + bool_mask.float() # Treats False as 0, True as 1
print("Boolean mask treated as float (WRONG):")
print(f" masked position bias: {wrong_result[0, 0, 0, 7].item():.4f}")
print(f" unmasked position bias: {wrong_result[0, 0, 0, 6].item():.4f}")
# What SDPA expects for boolean mask (right semantics)
correct_result = scores.clone()
correct_result = correct_result.masked_fill(bool_mask.logical_not(), float("-inf"))
print("\nBoolean mask handled correctly (what SDPA expects):")
print(f" masked position bias: {correct_result[0, 0, 0, 7].item():.4f}")
print(f" unmasked position bias: {correct_result[0, 0, 0, 6].item():.4f}")
# Case 2: Float/additive mask (what current code assumes)
float_mask = torch.zeros(2, 4, 8, 8)
float_mask[0, 0, :, 7] = float("-inf") # Mask out last position
proper_add = scores + float_mask
print("\nFloat additive mask (current code works):")
print(f" masked position bias: {proper_add[0, 0, 0, 7].item():.4f}")
print(f" unmasked position bias: {proper_add[0, 0, 0, 6].item():.4f}")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 183
🏁 Script executed:
rg "dispatch_attention_fn" --type py -A 20 | grep -A 20 "def dispatch_attention_fn\|attn_mask"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
rg "dispatch_attention_fn\|_diffusers_eager_attention" --type py -l | head -10Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
find . -name "*test*.py" -type f -exec grep -l "diffusers_eager_attention\|diffusers.*attention" {} \; | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 208
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | sed -n '1,110p' | tail -30Repository: NVIDIA/Model-Optimizer
Length of output: 1162
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 6349
🏁 Script executed:
rg "attn_mask\|attention_mask" tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_diffusers_plugin.py | sed -n '1,100p'Repository: NVIDIA/Model-Optimizer
Length of output: 4134
🏁 Script executed:
rg "attn_mask.*bool\|bool.*attn_mask" tests/unit/torch/sparsity/attention_sparsity/ -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '98,124p'Repository: NVIDIA/Model-Optimizer
Length of output: 1098
🏁 Script executed:
rg "def test.*attn_mask\|def test.*mask" tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
rg "bool.*mask|mask.*bool" modelopt/torch/sparsity/attention_sparsity/ --type py -B 2 -A 2 | grep -E "(eager|triton|diffusers)" -A 2 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 740
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '98,180p'Repository: NVIDIA/Model-Optimizer
Length of output: 3614
🏁 Script executed:
rg "attn_mask\|attention_mask" tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
grep -n "attn_mask" tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
wc -l tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.pyRepository: NVIDIA/Model-Optimizer
Length of output: 138
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | tail -100Repository: NVIDIA/Model-Optimizer
Length of output: 4496
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | sed -n '87,160p'Repository: NVIDIA/Model-Optimizer
Length of output: 3505
🏁 Script executed:
rg "dispatch_attention_fn" --type py -A 30 | grep -A 30 "def dispatch\|attn_mask"Repository: NVIDIA/Model-Optimizer
Length of output: 3283
🏁 Script executed:
rg "def.*eager.*attention\|def.*diffusers.*attention" --type py -lRepository: NVIDIA/Model-Optimizer
Length of output: 48
Handle boolean masks before the additive path.
attn_mask on SDPA-style backends can be boolean or additive. This code always adds it, so a boolean mask becomes 0/1 instead of -inf and returns wrong attention whenever diffusers passes a bool mask.
Suggested fix
# Apply attention mask if provided
if attn_mask is not None:
- scores = scores + attn_mask
+ if attn_mask.dtype == torch.bool:
+ scores = scores.masked_fill(attn_mask.logical_not(), float("-inf"))
+ else:
+ scores = scores + attn_mask🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`
around lines 79 - 81, The code currently always adds attn_mask to scores which
is incorrect when attn_mask is a boolean mask; detect boolean masks before the
additive path and convert them to an additive mask (e.g., torch.where(attn_mask,
neg_inf, 0) or attn_mask.to(scores.dtype) * neg_inf) using scores.device and
scores.dtype to construct neg_inf (like -torch.finfo(scores.dtype).max or -1e9
for float32), then add that converted mask to scores; update the code around the
scores and attn_mask usage to only add when attn_mask is already additive or
after converting when attn_mask.dtype is bool.
| calib_mode = getattr(_thread_local, "calibration_mode", False) | ||
| if calib_mode: | ||
| trials = getattr(_thread_local, "threshold_trials", None) | ||
| if trials and attention_calibrate is not None: | ||
| o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) | ||
|
|
||
| # Accumulate counters across all attention calls in this forward pass | ||
| prev = getattr(_thread_local, "calibration_counters", None) | ||
| if prev is None: | ||
| _thread_local.calibration_counters = counters | ||
| else: | ||
| _thread_local.calibration_counters = prev + counters | ||
|
|
||
| return o.view(batch, seq_q, num_heads_q, head_dim) |
There was a problem hiding this comment.
Calibration mode silently returns early if trials is empty or attention_calibrate is None.
When calibration_mode=True but trials is empty/None or attention_calibrate is unavailable, the code falls through to inference mode instead of raising an error or warning. This could silently produce unexpected results if a user enables calibration but the calibration kernel isn't available.
Suggested improvement
calib_mode = getattr(_thread_local, "calibration_mode", False)
if calib_mode:
trials = getattr(_thread_local, "threshold_trials", None)
- if trials and attention_calibrate is not None:
+ if not trials:
+ raise ValueError("calibration_mode=True but threshold_trials is empty")
+ if attention_calibrate is None:
+ raise RuntimeError("calibration_mode=True but attention_calibrate is unavailable (requires CUDA + triton)")
+ if trials and attention_calibrate is not None:
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 130 - 143, When calibration_mode (thread-local "calibration_mode")
is enabled but either "threshold_trials" (local variable trials) is
missing/empty or the calibration kernel "attention_calibrate" is None, the
function should not silently fall through to inference; instead detect this case
and raise a clear RuntimeError (or at least log an error and raise) that
includes which condition failed (missing trials vs missing attention_calibrate)
and mentions the function/mode involved (the calibration branch around
attention_calibrate(q, k, v, **kw, threshold_trials=trials)). Keep the existing
logic that accumulates _thread_local.calibration_counters when calibration runs,
but before falling back, explicitly check calib_mode && (not trials or
attention_calibrate is None) and raise with a descriptive message so callers
know calibration was requested but cannot proceed.
| def register_ltx_eager_attention(model: torch.nn.Module) -> None: | ||
| """Walk *model* and patch all ``ltx_core.model.transformer.attention.Attention`` modules. | ||
|
|
||
| Patches modules so their ``attention_function`` is routed through the eager wrapper. | ||
| Safe to call multiple times on the same model — already-wrapped modules are | ||
| skipped. | ||
| """ | ||
| for module in model.modules(): | ||
| if isinstance(module, Attention): | ||
| fn = module.attention_function | ||
| if not isinstance(fn, _SkipSoftmaxLTXAttentionWrapper): | ||
| module.attention_function = _SkipSoftmaxLTXAttentionWrapper(fn) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if the file exists and examine the full ltx_eager_attention.py
find . -type f -name "ltx_eager_attention.py" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 141
🏁 Script executed:
#!/bin/bash
# Read the complete ltx_eager_attention.py file
cat modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3926
🏁 Script executed:
#!/bin/bash
# Read the ltx_triton_attention.py file to check wrapper type and guard logic
cat modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 5621
🏁 Script executed:
#!/bin/bash
# Check conversion.py to see how registrations are called
grep -n "register_ltx" modelopt/torch/sparsity/attention_sparsity/conversion.py -A 3 -B 3Repository: NVIDIA/Model-Optimizer
Length of output: 633
This registration becomes non-idempotent after both wrappers are applied.
The guard only recognizes _SkipSoftmaxLTXAttentionWrapper. After register_ltx_triton_attention wraps the same attention_function with _TritonLTXAttentionWrapper, a second call to register_ltx_eager_attention will wrap again because the outer callable is no longer the eager wrapper type. Since conversion.py:151–156 calls both registrations sequentially, repeated conversions will stack nested wrappers instead of being safely reapplied. Both functions should either detect any existing wrapper (not just their own type), or restore to the base callable before rewrapping.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py`
around lines 103 - 114, The registration is non-idempotent because
register_ltx_eager_attention only skips when module.attention_function is a
_SkipSoftmaxLTXAttentionWrapper, so if _TritonLTXAttentionWrapper is outermost
it will wrap again; change the logic to always operate on the underlying base
callable: implement a small unwrap helper (e.g., get_base_attention_fn) that,
given a callable, repeatedly drills into known wrapper attributes (the wrappers
used here: _SkipSoftmaxLTXAttentionWrapper and _TritonLTXAttentionWrapper expose
the original as an inner attribute such as .fn or .wrapped_fn) until a
non-wrapper callable is reached, then if the base callable is not already
wrapped by the eager wrapper create a fresh
_SkipSoftmaxLTXAttentionWrapper(base_fn) and assign that to
module.attention_function (or rewrap preserving the other wrapper by
reconstructing outer wrappers around the new eager wrapper), and update
register_ltx_triton_attention similarly so both registrations first unwrap to
the base before reapplying their wrapper.
| self.aggregated_stats["total_blocks"] += stats.get("total_blocks", 0) | ||
|
|
||
| incoming = stats["sparse_blocks"] | ||
| if "sparse_blocks" not in self.aggregated_stats: | ||
| self.aggregated_stats["sparse_blocks"] = list(incoming) | ||
| else: | ||
| for i, val in enumerate(incoming): | ||
| self.aggregated_stats["sparse_blocks"][i] += val | ||
| incoming = stats.get("sparse_blocks") | ||
| if incoming is not None: | ||
| if "sparse_blocks" not in self.aggregated_stats: | ||
| self.aggregated_stats["sparse_blocks"] = list(incoming) | ||
| else: | ||
| for i, val in enumerate(incoming): | ||
| self.aggregated_stats["sparse_blocks"][i] += val |
There was a problem hiding this comment.
Handle changing sparse_blocks widths before accumulating.
Line 75 assumes every incoming vector has the same length as the first one. That is no longer guaranteed once calibration paths can emit multi-threshold stats while inference emits single-threshold stats, so this can either raise IndexError or mix incompatible totals. Also, total_blocks keeps growing even when sparse_blocks is missing, which will underreport summary sparsity. Reset/resize the accumulator and use a denominator that only counts calls that actually reported block data.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/stats_manager.py` around lines 67
- 75, The code assumes all incoming sparse_blocks vectors have the same length
and increments total_blocks even when sparse_blocks is missing; fix by (1)
tracking a separate counter like self.aggregated_stats["sparse_blocks_reports"]
(increment only when stats.get("sparse_blocks") is not None) and use that for
denominators, (2) when receiving incoming = stats.get("sparse_blocks"), ensure
the accumulator self.aggregated_stats["sparse_blocks"] is resized to accommodate
varying widths (extend with zeros if incoming is longer, or pad incoming with
zeros if shorter) before elementwise addition, and (3) do not rely on
total_blocks for sparse_blocks normalization — only count calls that actually
reported sparse_blocks. Ensure all references use self.aggregated_stats,
incoming, stats.get("sparse_blocks"), and "total_blocks" so the changes
integrate with existing logic.
| with patch.dict(sys.modules, _mock_diffusers()): | ||
| from modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention import ( | ||
| _diffusers_eager_attention, | ||
| get_skip_softmax_attention_backend, | ||
| register_diffusers_eager_attention, | ||
| ) | ||
|
|
||
| self._fn = _diffusers_eager_attention | ||
| self._register = register_diffusers_eager_attention | ||
| self._get_backend = get_skip_softmax_attention_backend | ||
|
|
||
| import modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention as mod | ||
|
|
||
| mod._BACKEND_REGISTERED = False | ||
| yield |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
fd test_kernel_backends.pyRepository: NVIDIA/Model-Optimizer
Length of output: 134
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.pyRepository: NVIDIA/Model-Optimizer
Length of output: 8940
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '98,123p'Repository: NVIDIA/Model-Optimizer
Length of output: 1090
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '160,171p'Repository: NVIDIA/Model-Optimizer
Length of output: 487
🏁 Script executed:
fd diffusers_eager_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 145
🏁 Script executed:
fd diffusers_triton_attention.pyRepository: NVIDIA/Model-Optimizer
Length of output: 146
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 3377
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 3541
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | sed -n '26,40p'Repository: NVIDIA/Model-Optimizer
Length of output: 446
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | sed -n '26,45p'Repository: NVIDIA/Model-Optimizer
Length of output: 632
🏁 Script executed:
python3 << 'EOF'
import sys
import types
from unittest.mock import patch
# Simulate the issue
def _mock_diffusers_v1():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.NAME = "MOCK_V1"
return {"diffusers.models.attention_dispatch": m}
def _mock_diffusers_v2():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.NAME = "MOCK_V2"
return {"diffusers.models.attention_dispatch": m}
# First fixture run
print("=== First fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v1()):
from diffusers.models.attention_dispatch import NAME as NAME1
print(f"First import: NAME = {NAME1}")
print(f"Module in sys.modules: {'diffusers.models.attention_dispatch' in sys.modules}")
print(f"After first fixture, module still in sys.modules: {'diffusers.models.attention_dispatch' in sys.modules}")
# Second fixture run
print("\n=== Second fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v2()):
from diffusers.models.attention_dispatch import NAME as NAME2
print(f"Second import: NAME = {NAME2}")
print(f"sys.modules entry NAME: {sys.modules['diffusers.models.attention_dispatch'].NAME}")
print(f"NAME1 vs NAME2: {NAME1} vs {NAME2} - Same? {NAME1 == NAME2}")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 343
🏁 Script executed:
python3 << 'EOF'
import sys
import types
from unittest.mock import patch
def _mock_diffusers_v1():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.NAME = "MOCK_V1"
m.VALUE = 100
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
def _mock_diffusers_v2():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.NAME = "MOCK_V2"
m.VALUE = 200
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
# Create a backend module that imports from mocked diffusers
backend_code = '''
from diffusers.models.attention_dispatch import NAME, VALUE
imported_name = NAME
imported_value = VALUE
'''
print("=== First fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v1()):
# Simulate backend module being imported
backend_v1 = types.ModuleType("backend")
exec(backend_code, backend_v1.__dict__)
sys.modules["backend"] = backend_v1
print(f"First import: backend.imported_name = {backend_v1.imported_name}")
print(f"First import: backend.imported_value = {backend_v1.imported_value}")
print(f"After first fixture, backend in sys.modules: {'backend' in sys.modules}")
print(f"Cached backend values: name={sys.modules['backend'].imported_name}, value={sys.modules['backend'].imported_value}")
print("\n=== Second fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v2()):
# Try to import backend again (would reuse cached)
from backend import imported_name, imported_value
print(f"Second import: imported_name = {imported_name}")
print(f"Second import: imported_value = {imported_value}")
print(f"Expected V2 values (200, 'MOCK_V2'), got ({imported_value}, '{imported_name}')")
print(f"Mismatch? {imported_value != 200}")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 342
🏁 Script executed:
python3 << 'EOF'
import sys
import types
from unittest.mock import patch
import importlib
def _mock_diffusers_v1():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.NAME = "MOCK_V1"
m.VALUE = 100
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
def _mock_diffusers_v2():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.NAME = "MOCK_V2"
m.VALUE = 200
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
# Create a fake backend module that imports from diffusers
backend_code = '''
from diffusers.models.attention_dispatch import NAME, VALUE
backend_name = NAME
backend_value = VALUE
'''
print("=== First fixture run ===")
with patch.dict(sys.modules, _mock_diffusers_v1()):
backend_v1 = types.ModuleType("fake_backend")
exec(backend_code, backend_v1.__dict__)
sys.modules["fake_backend"] = backend_v1
print(f"First: backend.backend_name = {backend_v1.backend_name}")
print(f"First: backend.backend_value = {backend_v1.backend_value}")
print(f"After patch context, fake_backend still in sys.modules: {'fake_backend' in sys.modules}")
print(f"Cached values: {sys.modules['fake_backend'].backend_name}, {sys.modules['fake_backend'].backend_value}")
print("\n=== Second fixture run (with new mocks) ===")
with patch.dict(sys.modules, _mock_diffusers_v2()):
# Import the cached backend module again
from fake_backend import backend_name, backend_value
print(f"Second: backend_name = {backend_name}")
print(f"Second: backend_value = {backend_value}")
print(f"ISSUE: Got old mock values (MOCK_V1, 100) instead of new (MOCK_V2, 200)")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 342
🏁 Script executed:
python3 << 'EOF'
import sys
import types
from unittest.mock import patch
def _mock_diffusers_v1():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.NAME = "MOCK_V1"
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
def _mock_diffusers_v2():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.NAME = "MOCK_V2"
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
# Test 1: Check if backend modules persist in sys.modules
print("=== Simulating actual test fixture pattern ===")
print("\nFirst fixture run:")
with patch.dict(sys.modules, _mock_diffusers_v1()):
import diffusers.models.attention_dispatch as d1
print(f"Inside patch: diffusers NAME = {d1.NAME}")
print(f"Inside patch, 'diffusers.models.attention_dispatch' in sys.modules: {('diffusers.models.attention_dispatch' in sys.modules)}")
print(f"Outside patch, 'diffusers.models.attention_dispatch' in sys.modules: {('diffusers.models.attention_dispatch' in sys.modules)}")
print("\nSecond fixture run (with different mock):")
with patch.dict(sys.modules, _mock_diffusers_v2()):
import diffusers.models.attention_dispatch as d2
print(f"Inside patch: diffusers NAME = {d2.NAME}")
print(f"Are they the same object? {d1 is d2}")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 433
🏁 Script executed:
python3 << 'EOF'
import sys
import types
from unittest.mock import patch
def _mock_diffusers_v1():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.FakeRegistry = {"backends": {}}
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
def _mock_diffusers_v2():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.FakeRegistry = {"backends": {}}
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
# Simulate backend module
backend_code = '''
from diffusers.models.attention_dispatch import FakeRegistry
_BACKEND_REGISTERED = False
def register_backend():
global _BACKEND_REGISTERED
_BACKEND_REGISTERED = True
FakeRegistry["backends"]["test"] = "value"
'''
print("=== Simulating stale module state across fixtures ===")
print("\nFirst fixture run:")
with patch.dict(sys.modules, _mock_diffusers_v1()):
backend_v1 = types.ModuleType("my_backend")
exec(backend_code, backend_v1.__dict__)
sys.modules["my_backend"] = backend_v1
print(f"First: FakeRegistry id = {id(backend_v1.FakeRegistry)}")
print(f"First: Before register - _BACKEND_REGISTERED = {backend_v1._BACKEND_REGISTERED}")
backend_v1.register_backend()
print(f"First: After register - _BACKEND_REGISTERED = {backend_v1._BACKEND_REGISTERED}")
print(f"First: Registry state = {backend_v1.FakeRegistry}")
print(f"\nAfter first fixture, my_backend in sys.modules: {'my_backend' in sys.modules}")
print("\n--- Second fixture run (simulating next test) ---")
# Don't use patch.dict for second run to simulate what happens when module is already imported
if "my_backend" in sys.modules:
backend_v2 = sys.modules["my_backend"]
print(f"Second: Using cached backend module")
print(f"Second: FakeRegistry id = {id(backend_v2.FakeRegistry)}")
print(f"Second: _BACKEND_REGISTERED before reset = {backend_v2._BACKEND_REGISTERED}")
backend_v2._BACKEND_REGISTERED = False
print(f"Second: _BACKEND_REGISTERED after reset = {backend_v2._BACKEND_REGISTERED}")
print(f"Second: Registry state (STALE!) = {backend_v2.FakeRegistry}")
print(f"ISSUE: Registry still has old backends from first test!")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 446
🏁 Script executed:
python3 << 'EOF'
import sys
import types
from unittest.mock import patch
def _mock_diffusers_v1():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.FakeRegistry = {"backends": {}}
print(f" Created MOCK_V1: {id(m.FakeRegistry)}")
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
def _mock_diffusers_v2():
m = types.ModuleType("diffusers.models.attention_dispatch")
m.FakeRegistry = {"backends": {}}
print(f" Created MOCK_V2: {id(m.FakeRegistry)}")
return {
"diffusers": types.ModuleType("diffusers"),
"diffusers.models": types.ModuleType("diffusers.models"),
"diffusers.models.attention_dispatch": m,
}
# Create a REAL backend module file in memory
import tempfile
import os
import importlib.util
backend_code = '''
from diffusers.models.attention_dispatch import FakeRegistry
_BACKEND_REGISTERED = False
def register():
global _BACKEND_REGISTERED
_BACKEND_REGISTERED = True
FakeRegistry["backends"]["test"] = "value"
'''
# Write to temp file
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, dir='/tmp') as f:
f.write(backend_code)
backend_file = f.name
spec = importlib.util.spec_from_file_location("my_real_backend", backend_file)
my_real_backend = importlib.util.module_from_spec(spec)
sys.modules["my_real_backend"] = my_real_backend
print("=== Real backend module persists across patch contexts ===")
print("\nFirst fixture (patch.dict context):")
with patch.dict(sys.modules, _mock_diffusers_v1()):
spec.loader.exec_module(my_real_backend)
print(f" Imported FakeRegistry: {id(my_real_backend.FakeRegistry)}")
my_real_backend.register()
print(f" After register: {my_real_backend.FakeRegistry}")
print(f"\nAfter first patch context exits:")
print(f" my_real_backend still in sys.modules: {'my_real_backend' in sys.modules}")
print(f" Cached FakeRegistry: {id(my_real_backend.FakeRegistry)}")
print(f" Cached FakeRegistry content: {my_real_backend.FakeRegistry}")
print("\nSecond fixture (NEW patch.dict context with different mock):")
with patch.dict(sys.modules, _mock_diffusers_v2()):
from diffusers.models.attention_dispatch import FakeRegistry as FakeRegistry2
print(f" New mock FakeRegistry: {id(FakeRegistry2)}")
print(f" Backend's OLD FakeRegistry: {id(my_real_backend.FakeRegistry)}")
print(f" MISMATCH: Backend still references old mock!")
print(f" Backend's old registry content: {my_real_backend.FakeRegistry}")
print(f" New mock registry content: {FakeRegistry2}")
os.unlink(backend_file)
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 811
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | sed -n '100,150p'Repository: NVIDIA/Model-Optimizer
Length of output: 2100
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | grep -A 20 "def register_diffusers_eager_attention"Repository: NVIDIA/Model-Optimizer
Length of output: 1161
🏁 Script executed:
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | grep -A 20 "def get_skip_softmax_attention_backend"Repository: NVIDIA/Model-Optimizer
Length of output: 581
🏁 Script executed:
cat -n tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py | sed -n '1,30p'Repository: NVIDIA/Model-Optimizer
Length of output: 1323
Reload backend modules on each fixture to ensure test isolation with fresh mocks.
The patch.dict(sys.modules, ...) context patches the module namespace, but diffusers_eager_attention and diffusers_triton_attention persist in sys.modules after the first import. On subsequent fixture runs, the cached modules retain references to the previous patch's mock objects, so resetting _BACKEND_REGISTERED alone does not provide a clean state. Pop and reimport the backend module before using it in each fixture.
Also applies to: 140-158
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around
lines 82 - 96, Tests are reusing cached backend modules across fixture runs
causing stale mocks; ensure you remove cached modules from sys.modules before
reimporting. Inside the patch.dict(sys.modules, _mock_diffusers()) context (and
likewise for the triton fixture), pop any existing
"modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_eager_attention"
and
"modelopt.torch.sparsity.attention_sparsity.kernels.diffusers_triton_attention"
entries from sys.modules, then import the module and access
_diffusers_eager_attention, register_diffusers_eager_attention,
get_skip_softmax_attention_backend and reset mod._BACKEND_REGISTERED to
guarantee a fresh module and mocks for each fixture run.
Move NVFP4 P-matrix quantization (quantize_p) out of the sparsity module
and into a new modelopt/torch/quantization/sage_attention/ module.
Key changes:
- Add modelopt/torch/quantization/sage_attention/__init__.py with
apply_sage_attention(transformer) API exposed via mtq namespace.
Wraps the transformer forward to activate the modelopt_triton diffusers
backend and set quantize_p=True in thread-local for every call.
- Remove quantize_p from SparseAttentionAttributeConfig (config.py),
TritonSkipSoftmaxMethod, and TritonSparseSoftmaxMethod — sparsity
methods no longer control quantization.
- Split thread-local management in diffusers_triton_attention.py:
* set_triton_skip_softmax_config() no longer accepts quantize_p
* clear_triton_skip_softmax_config() does NOT reset quantize_p
* New set_sage_attention_config() / clear_sage_attention_config()
manage quantize_p independently
This enables transparent composition: apply_sage_attention() sets
quantize_p=True at the outer forward level; per-layer sparsity
contexts clear only their own params without clobbering quantize_p.
- Delete plugins/diffusers.py (WanSparseAttentionModule) — superseded
by PR #1166's diffusers_triton_attention.py backend approach.
- Update wan2_sage_attention.py example: apply_triton_sparse_kernel()
no longer accepts quantize_p; --quantize-p now calls
apply_sage_attention() from modelopt.torch.quantization.
- Update tests to reflect the new API boundaries.
Usage:
from modelopt.torch.quantization import apply_sage_attention
apply_sage_attention(pipe.transformer) # standalone
# combined with N:M sparse softmax:
mtsa.sparsify(transformer, mtsa.SPARSE_SOFTMAX_DEFAULT)
apply_sage_attention(transformer)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Adds --kernel nvfp4 as an explicit entry point for standalone NVFP4
P-matrix quantization via apply_sage_attention(), removing the previous
awkward implicit path through --quantize-p without a triton kernel.
Usage:
python wan2_sage_attention.py --prompt "..." --kernel nvfp4
python wan2_sage_attention.py --prompt "..." --kernel nvfp4 --compare
python wan2_sage_attention.py --prompt "..." --kernel triton-sparse --quantize-p
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Summary
This PR adds four Triton-backed attention kernel modes for WAN2.2 diffusion models, integrated into the
modelopt.torch.sparsity.attention_sparsityframework. It also fixes a registration bug and adds NVFP4 P-matrix quantization to the Triton kernel.New Features
Four kernel modes are available via
mtsa.sparsify()withbackend="diffusers_triton":triton-sparsetriton_sparse_softmaxtriton-skiptriton_skip_softmaxtriton-sparse-nvfp4triton_sparse_softmaxtriton-sparsewith NVFP4 E2M1 per-tile quantization of the post-softmax P matrixtriton-skip-nvfp4triton_skip_softmaxtriton-skipwith NVFP4 E2M1 per-tile quantization of the post-softmax P matrixUsage:
Design note: why
quantize_pis not exposed viamtq.quantize()ModelOpt already has a
softmax_quantizerin_QuantAttention(seemodelopt/torch/quantization/plugins/diffusion/diffusers.py) which conceptually covers the same operation. We considered routing NVFP4 P-matrix quantization throughmtq.quantize()but it is not feasible for three reasons:Per-tile granularity.
TensorQuantizercalibrates a global amax over the full tensor. Our NVFP4 quantization is per-Triton-tile ([BLOCK_M × BLOCK_N]), with the scale computed on-the-fly inside the kernel. There is no PyTorch-level tensor to calibrate against.The P matrix never materializes. In the Triton kernel, P only exists tile-by-tile as a register-level intermediate. There is no
[seq_q × seq_kv]PyTorch tensor formtqto intercept.Inseparable from the Triton path.
quantize_p=Trueis only meaningful when the Triton kernel is already active (viabackend="diffusers_triton").Bug Fix
WanSparseAttentionModulenever executed (PSNR = 100 dB / byte-identical to baseline)plugins/__init__.pyimportedhuggingface.pybeforediffusers.py. The HF generic plugin registeredWanAttentionwith_GenericSparseAttentionfirst; the diffusers-specific plugin then sawWanAttentionalready registered and skipped, soWanSparseAttentionModule/ModelOptWanAttnProcessorwere never installed.plugins/__init__.py—diffusersbeforehuggingface.Files Changed
modelopt/torch/kernels/triton_fa.py— NVFP4 per-tile P-matrix quantizationmodelopt/torch/sparsity/attention_sparsity/plugins/diffusers.py—WanSparseAttentionModule,ModelOptWanAttnProcessor,register_wan_sparse_attentionmodelopt/torch/sparsity/attention_sparsity/plugins/__init__.py— Fix import order (diffusers before huggingface)modelopt/torch/quantization/plugins/diffusion/diffusers.py—_QuantWanAttnProcessorfor NVFP4 quantization pathtests/unit/— Unit tests for the new diffusers WAN sparse attention pluginexamples/diffusers/quantization/wan2_sage_attention.py— Example script with all four kernel modes,--kernel,--compare,--skip-thresholdCLI flagsTest plan
python -m pytest tests/unit -k "wan"— unit tests forWanSparseAttentionModuleandModelOptWanAttnProcessorpython wan2_sage_attention.py --kernel triton-sparse --compare --seed 42— verify PSNR > 30 dB vs baselinepython wan2_sage_attention.py --kernel triton-skip --compare --seed 42— verify PSNR > 30 dB vs baselinepython wan2_sage_attention.py --kernel triton-skip-nvfp4 --compare --seed 42— verify visually acceptable outputpython wan2_sage_attention.py --kernel triton-sparse-nvfp4 --compare --seed 42— verify visually acceptable output🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Documentation
Tests
Chores