Skip to content

Add fused Triton kernel for local-Hessian NVFP4 weight-scale search#1659

Open
Fridah-nv wants to merge 3 commits into
mainfrom
fridah/local-hessian-triton
Open

Add fused Triton kernel for local-Hessian NVFP4 weight-scale search#1659
Fridah-nv wants to merge 3 commits into
mainfrom
fridah/local-hessian-triton

Conversation

@Fridah-nv

@Fridah-nv Fridah-nv commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: new feature

Adds a fused Triton kernel that replaces the 126-step Python reference sweep in local_hessian weight calibration (the Hessian-weighted variant of the NVFP4 FP8 scale search). For each NVFP4 block it minimizes the Hessian-weighted error dwᵀ H dw (dw = w − quant(w)) over the 126 valid FP8-E4M3 candidate scales, using the per-cin-block local Hessian H shared across output rows.

  • Kernel (nvfp4_fp8_scale_sweep_hessian in modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py): one program sweeps a tile of output rows of a single cin-block, loading that block's [16,16] Hessian once and computing the per-candidate quadratic form dwᵀ H dw as a tl.dot tensor-core matmul. Candidate block scales are precomputed on the host via the reference compute_fp4_scales, so the kernel's quantization is bit-identical to the reference fake-quant.
  • Calibrator (NVFP4MSECalibrator, calib/mse.py): gains an optional hessian= fast path; error_func remains the CPU/non-CUDA reference fallback.
  • Plumbing (model_calib.py): _LocalHessianAccumulator.normalized_hessian() exposes a shared normalized Hessian; a hessian_for channel threads it through the
    Works on any CUDA GPU with Triton (no tl.float8e4nv requirement); falls back to the reference sweep when Triton is unavailable, on CPU, or via MODELOPT_NVFP4_TRITON_SWEEP=0.

Usage

import modelopt.torch.quantization as mtq

# NVFP4 W4A4 with STATIC per-block weight scales, searched with local-Hessian.
cfg = mtq.NVFP4_DEFAULT_CFG
for entry in cfg["quant_cfg"]:
    if entry.get("quantizer_name") == "*weight_quantizer" and "cfg" in entry:
        entry["cfg"]["block_sizes"]["type"] = "static"
cfg["algorithm"] = {"method": "local_hessian", "fp8_scale_sweep": True}

# The Triton fast path is used automatically during calibration; no API change.
model = mtq.quantize(model, cfg, forward_loop=forward_loop)

Testing

  • GPU unit tests (tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py): parity vs the reference 126-step sweep — bit-exact for fp32/fp16, bf16 within a tight bound; plus dispatch, input-validation, and a ≥30x speedup assertion. All pass.
  • Single-weight (8192x4096, ~2M NVFP4 blocks): ~34x vs the reference sweep (15.6 ms vs 535 ms).
  • End-to-end PTQ: Qwen3-8B (dense) 9.2x e2e, 0.003% weight-scale mismatch; Qwen3.6-35B-A3B (fused-MoE) sweep itself ~35x (e2e 4.5x, forward-bound). An fp64 ground-truth dissection of the dense mismatches showed 99.98% are equivalent-quality fp32 reduction-order ties (worst loss gap 3e-6) and 0 blocks where the kernel selects a worse scale — confirming no precision/correctness regression.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: N/A

Additional Information

The kernel only accelerates the per-block scale search (phase 3 of local_hessian); on large MoE models the end-to-end time becomes dominated by the calibration forwards (max-calibrate + Hessian capture), so a parallel/tensor-parallel calibration forward is the next lever for further e2e speedup. Separately, ensuring all MoE experts are routed during calibration would shrink the small residual mismatch attributable to never-routed (degenerate) experts.

Summary by CodeRabbit

  • New Features

    • Added a Hessian-weighted NVFP4 FP8 scale sweep with an accelerated GPU fast path; calibrators and calibration flow can accept Hessian data to use this path and will fall back safely when unavailable.
  • Tests

    • Added parity, input-validation, and performance tests (including a benchmark comparing reference vs GPU fast path).
  • Documentation

    • Changelog entry describing the new Hessian-weighted Triton fast path, fallbacks, and measured speedup (~30–34×, bit-exact for fp32/fp16).

@copy-pr-bot

copy-pr-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d085011b-3aa8-4e67-8c18-97b64c001670

📥 Commits

Reviewing files that changed from the base of the PR and between 0975de8 and 70dc7e9.

📒 Files selected for processing (1)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py

📝 Walkthrough

Walkthrough

Adds a Hessian-weighted Triton kernel and public wrapper for NVFP4 FP8 scale sweep, integrates an optional per-block Hessian fast path into NVFP4MSECalibrator and the calibration plumbing, normalizes/caches Hessians via the accumulator, and adds tests for parity, input validation, and performance.

Changes

Hessian-weighted NVFP4 scale sweep

Layer / File(s) Summary
Hessian Triton kernel and shared sweep setup
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
New _fp8_scale_sweep_hessian_kernel evaluates Hessian-weighted quadratic loss across FP8 candidates, with wrapper nvfp4_fp8_scale_sweep_hessian orchestrating validation, candidate generation (compute_fp4_scales), and kernel launch. Shared _prepare_block_sweep centralizes CUDA/block-size validation, flattening, and best_amax allocation for both sweeps.
Hessian dispatch in NVFP4MSECalibrator
modelopt/torch/quantization/calib/mse.py
Constructor accepts optional hessian tensor; adds _triton_sweep_eligible helper; separate _can_use_hessian_fast_path predicate gates Hessian sweep when available; collect() uses nvfp4_fp8_scale_sweep_hessian when eligible and caches per-block amax.
Hessian plumbing through calibration functions
modelopt/torch/quantization/model_calib.py
_make_weight_mse_calibrator accepts optional hessian and forwards it; _mse_calibrate_weights accepts hessian_for mapping and forwards per-quantizer Hessians into the calibrator.
Hessian normalization and accumulator caching
modelopt/torch/quantization/model_calib.py
_LocalHessianAccumulator caches normalized Hessian via normalized_hessian(); build_error_func() reuses the normalized Hessian for both reference and Triton fast paths.
Calibration phase orchestration and cleanup
modelopt/torch/quantization/model_calib.py
local_hessian_calibrate phase 3 builds per-quantizer hessians from accumulators, passes hessian_for to the MSE sweep loop, and performs cleanup clearing error functions, Hessians, and calibrator _hessian state (debug retains accumulators).
Test infrastructure and helpers & test cases
tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
Adds imports for the Hessian wrapper and _LocalHessianAccumulator; helpers to construct accumulators, run Hessian reference and Triton sweeps, and compute total Hessian-weighted loss; and three tests: test_hessian_parity_random_weights, test_hessian_sweep_input_validation, and test_hessian_speedup_report.
sequenceDiagram
  participant Cal as NVFP4MSECalibrator.collect
  participant Wrapper as nvfp4_fp8_scale_sweep_hessian
  participant Prep as _prepare_block_sweep
  participant Scales as compute_fp4_scales
  participant Kernel as _fp8_scale_sweep_hessian_kernel
  participant Acc as _LocalHessianAccumulator
  Cal->>Wrapper: call(x, global_amax, hessian)
  Wrapper->>Prep: validate & flatten x, allocate best_amax
  Wrapper->>Scales: compute candidate_amaxes / scales
  Acc->>Wrapper: provide normalized hessian
  Wrapper->>Kernel: launch with grid(cout, n_cin_blocks), pass flattened hessian
  Kernel->>Kernel: load x tile & hessian tile, compute dwᵀHdw per candidate
  Kernel->>Wrapper: write best_amax per block
  Wrapper->>Cal: return best_amax
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • NVIDIA/Model-Optimizer#1578: Both PRs modify the local-Hessian calibration flow and per-quantizer Hessian plumbing used for NVFP4/Hessian-weighted calibration.

Suggested reviewers

  • realAsma
  • ChenhanYu
🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding a fused Triton kernel for Hessian-weighted NVFP4 weight-scale search, which is the central focus across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 85.19% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no torch.load(weights_only=False), numpy.load(allow_pickle=True), hardcoded trust_remote_code=True, eval/exec, nosec comments, or new pip dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fridah/local-hessian-triton

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1659/

Built to branch gh-pages at 2026-06-10 00:13 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov

codecov Bot commented Jun 9, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 68.75000% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.69%. Comparing base (d3acf45) to head (70dc7e9).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
...torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py 53.70% 25 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1659       +/-   ##
===========================================
+ Coverage   56.59%   76.69%   +20.09%     
===========================================
  Files         507      508        +1     
  Lines       55794    55928      +134     
===========================================
+ Hits        31579    42894    +11315     
+ Misses      24215    13034    -11181     
Flag Coverage Δ
examples 42.28% <15.00%> (+23.72%) ⬆️
gpu 57.86% <68.75%> (+37.32%) ⬆️
regression 14.69% <15.00%> (-0.16%) ⬇️
unit 54.36% <21.25%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Fridah-nv Fridah-nv changed the title Add Hessian-weighted NVFP4 FP8 scale-sweep Triton kernel for local_he… Add fused Triton kernel for local-Hessian NVFP4 weight-scale search Jun 9, 2026
@Fridah-nv Fridah-nv marked this pull request as ready for review June 9, 2026 22:57
@Fridah-nv Fridah-nv requested review from a team as code owners June 9, 2026 22:57
@Fridah-nv

Copy link
Copy Markdown
Contributor Author

/claude review

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py (1)

578-578: 💤 Low value

Hard 30x speedup assertion may be flaky on heterogeneous CI hardware.

The existing test_speedup_report (line 594) explicitly avoids gating on a minimum factor because "kernel timing is noisy on shared CI." This new test takes the opposite approach with a hard 30x requirement.

If this is intentional (the PR objectives mention 30x as a requirement), consider adding a @pytest.mark.timeout() decorator to guard against unexpectedly slow reference runs, and/or documenting that this test may need adjustment for slower GPU SKUs.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py` at line 578, The
hard-coded assertion requiring speedup >= 30.0 is flaky on shared CI; update the
test that contains the line "assert speedup >= 30.0, f\"Hessian fast path
speedup {speedup:.1f}x below 30x target\"" to avoid unconditional failures: make
the minimum factor configurable (read e.g. TEST_MIN_SPEEDUP from env with
default 30.0) and replace the hard assert with a conditional that uses
pytest.fail only if speedup < threshold, and add a pytest.mark.timeout(...)
decorator to the enclosing test function (or use the same non-gating/reporting
approach as test_speedup_report) so slow reference runs are bounded; reference
the assertion string and the existing test_speedup_report behavior when making
the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py`:
- Line 578: The hard-coded assertion requiring speedup >= 30.0 is flaky on
shared CI; update the test that contains the line "assert speedup >= 30.0,
f\"Hessian fast path speedup {speedup:.1f}x below 30x target\"" to avoid
unconditional failures: make the minimum factor configurable (read e.g.
TEST_MIN_SPEEDUP from env with default 30.0) and replace the hard assert with a
conditional that uses pytest.fail only if speedup < threshold, and add a
pytest.mark.timeout(...) decorator to the enclosing test function (or use the
same non-gating/reporting approach as test_speedup_report) so slow reference
runs are bounded; reference the assertion string and the existing
test_speedup_report behavior when making the change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 7ad84528-c5e5-47c0-9eec-26e3544e0ad8

📥 Commits

Reviewing files that changed from the base of the PR and between d87f810 and 27e381e.

📒 Files selected for processing (4)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py

@cjluo-nv cjluo-nv left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Additive Triton sibling kernel for Hessian-weighted NVFP4 weight-scale search. Design review: this is not a new subsystem — it extends the existing nvfp4_fp8_scale_sweep family with a Hessian counterpart, reusing in-repo helpers (compute_fp4_scales, fp4_round_magnitude, fp8_scale_candidates) and factoring shared input validation into _prepare_block_sweep. The reference Python path is preserved as fallback (CPU / no Triton / MODELOPT_NVFP4_TRITON_SWEEP=0), and the kernel uses the same FP8-quantized block scales as the reference for bit-identical residuals. Calibrator and local_hessian_calibrate plumbing is backward-compatible (new optional kwargs default to None). Tests cover parity across dtypes/shapes, input validation, dispatch precedence, and a ≥30x speedup assertion; PR body documents end-to-end Qwen3-8B and MoE validation. ~220 LOC of non-test changes; no licensing changes.

Complex PR: spans 5 directories (≥ 5). Looping in a human for approval.

…ssian

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review summary

A focused, well-tested addition: a Hessian-weighted FP8 scale-sweep Triton kernel for local_hessian calibration. The kernel matches the reference fake-quant bit-exactly on the per-block residual (precomputing scales via compute_fp4_scales is a clean win), the dispatch gating cleanly mirrors the existing fast-path eligibility helper, and tests cover parity, dispatch, validation, and a 30x speedup floor. Algorithm-wise (signed dw, Hessian load + tensor-core tl.dot, fp32 accumulation, scale==0 guard) the math is correct and the cross-term sign handling vs the squared-error sweep is right.

Findings (5 total)

  • CRITICAL: 0
  • IMPORTANT: 1
    • NVFP4MSECalibrator class docstring is stale — still says the Triton fast path is gated on error_func is None, but the new Hessian fast path runs even when error_func is set (and local_hessian_calibrate always sets both). Easy to mislead future callers.
  • SUGGESTION: 4
    • compute_fp4_scales (which dispatches into a CUDA-ext FP8 round-trip on the current device) runs outside the with torch.cuda.device(x.device): block in nvfp4_fp8_scale_sweep_hessian. On a multi-GPU host where x.device != torch.cuda.current_device() this can land on the wrong device. Move the precompute inside the block.
    • Single-config @triton.autotune adds dispatch overhead with no tuning benefit — either add more configs or lift to direct constexpr kwargs.
    • Duplicate p/q aranges in the kernel — one variable + two broadcasts reads cleaner.
    • The ordering of error_funcs = {...} then hessians = {...} in local_hessian_calibrate (phase 3) is load-bearing for memory: build_error_func populates the cache and frees the raw buffer when debug=False. Worth a comment so the order isn't accidentally inverted later.

Risk assessment

Low. Public API unchanged (internal fast-path), reference fallback preserved, MODELOPT_NVFP4_TRITON_SWEEP=0 opt-out kept. The one IMPORTANT finding is a docstring/precedence-clarity issue, not a behavior bug. The multi-device suggestion is the highest-impact item that can become a real bug under specific deployment conditions.

Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Comment thread modelopt/torch/quantization/model_calib.py
@Fridah-nv Fridah-nv force-pushed the fridah/local-hessian-triton branch from cee55a4 to ec3ebec Compare June 9, 2026 23:15
@Fridah-nv Fridah-nv requested a review from meenchen June 9, 2026 23:18
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv force-pushed the fridah/local-hessian-triton branch from ec3ebec to 0975de8 Compare June 9, 2026 23:57
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv

Copy link
Copy Markdown
Contributor Author

/claude review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants