Add fused Triton kernel for local-Hessian NVFP4 weight-scale search#1659
Add fused Triton kernel for local-Hessian NVFP4 weight-scale search#1659Fridah-nv wants to merge 3 commits into
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds a Hessian-weighted Triton kernel and public wrapper for NVFP4 FP8 scale sweep, integrates an optional per-block Hessian fast path into NVFP4MSECalibrator and the calibration plumbing, normalizes/caches Hessians via the accumulator, and adds tests for parity, input validation, and performance. ChangesHessian-weighted NVFP4 scale sweep
sequenceDiagram
participant Cal as NVFP4MSECalibrator.collect
participant Wrapper as nvfp4_fp8_scale_sweep_hessian
participant Prep as _prepare_block_sweep
participant Scales as compute_fp4_scales
participant Kernel as _fp8_scale_sweep_hessian_kernel
participant Acc as _LocalHessianAccumulator
Cal->>Wrapper: call(x, global_amax, hessian)
Wrapper->>Prep: validate & flatten x, allocate best_amax
Wrapper->>Scales: compute candidate_amaxes / scales
Acc->>Wrapper: provide normalized hessian
Wrapper->>Kernel: launch with grid(cout, n_cin_blocks), pass flattened hessian
Kernel->>Kernel: load x tile & hessian tile, compute dwᵀHdw per candidate
Kernel->>Wrapper: write best_amax per block
Wrapper->>Cal: return best_amax
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1659 +/- ##
===========================================
+ Coverage 56.59% 76.69% +20.09%
===========================================
Files 507 508 +1
Lines 55794 55928 +134
===========================================
+ Hits 31579 42894 +11315
+ Misses 24215 13034 -11181
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
|
/claude review |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py (1)
578-578: 💤 Low valueHard 30x speedup assertion may be flaky on heterogeneous CI hardware.
The existing
test_speedup_report(line 594) explicitly avoids gating on a minimum factor because "kernel timing is noisy on shared CI." This new test takes the opposite approach with a hard 30x requirement.If this is intentional (the PR objectives mention 30x as a requirement), consider adding a
@pytest.mark.timeout()decorator to guard against unexpectedly slow reference runs, and/or documenting that this test may need adjustment for slower GPU SKUs.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py` at line 578, The hard-coded assertion requiring speedup >= 30.0 is flaky on shared CI; update the test that contains the line "assert speedup >= 30.0, f\"Hessian fast path speedup {speedup:.1f}x below 30x target\"" to avoid unconditional failures: make the minimum factor configurable (read e.g. TEST_MIN_SPEEDUP from env with default 30.0) and replace the hard assert with a conditional that uses pytest.fail only if speedup < threshold, and add a pytest.mark.timeout(...) decorator to the enclosing test function (or use the same non-gating/reporting approach as test_speedup_report) so slow reference runs are bounded; reference the assertion string and the existing test_speedup_report behavior when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py`:
- Line 578: The hard-coded assertion requiring speedup >= 30.0 is flaky on
shared CI; update the test that contains the line "assert speedup >= 30.0,
f\"Hessian fast path speedup {speedup:.1f}x below 30x target\"" to avoid
unconditional failures: make the minimum factor configurable (read e.g.
TEST_MIN_SPEEDUP from env with default 30.0) and replace the hard assert with a
conditional that uses pytest.fail only if speedup < threshold, and add a
pytest.mark.timeout(...) decorator to the enclosing test function (or use the
same non-gating/reporting approach as test_speedup_report) so slow reference
runs are bounded; reference the assertion string and the existing
test_speedup_report behavior when making the change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 7ad84528-c5e5-47c0-9eec-26e3544e0ad8
📒 Files selected for processing (4)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.pymodelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Additive Triton sibling kernel for Hessian-weighted NVFP4 weight-scale search. Design review: this is not a new subsystem — it extends the existing nvfp4_fp8_scale_sweep family with a Hessian counterpart, reusing in-repo helpers (compute_fp4_scales, fp4_round_magnitude, fp8_scale_candidates) and factoring shared input validation into _prepare_block_sweep. The reference Python path is preserved as fallback (CPU / no Triton / MODELOPT_NVFP4_TRITON_SWEEP=0), and the kernel uses the same FP8-quantized block scales as the reference for bit-identical residuals. Calibrator and local_hessian_calibrate plumbing is backward-compatible (new optional kwargs default to None). Tests cover parity across dtypes/shapes, input validation, dispatch precedence, and a ≥30x speedup assertion; PR body documents end-to-end Qwen3-8B and MoE validation. ~220 LOC of non-test changes; no licensing changes.
Complex PR: spans 5 directories (≥ 5). Looping in a human for approval.
…ssian Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
There was a problem hiding this comment.
Claude review summary
A focused, well-tested addition: a Hessian-weighted FP8 scale-sweep Triton kernel for local_hessian calibration. The kernel matches the reference fake-quant bit-exactly on the per-block residual (precomputing scales via compute_fp4_scales is a clean win), the dispatch gating cleanly mirrors the existing fast-path eligibility helper, and tests cover parity, dispatch, validation, and a 30x speedup floor. Algorithm-wise (signed dw, Hessian load + tensor-core tl.dot, fp32 accumulation, scale==0 guard) the math is correct and the cross-term sign handling vs the squared-error sweep is right.
Findings (5 total)
- CRITICAL: 0
- IMPORTANT: 1
NVFP4MSECalibratorclass docstring is stale — still says the Triton fast path is gated onerror_func is None, but the new Hessian fast path runs even whenerror_funcis set (andlocal_hessian_calibratealways sets both). Easy to mislead future callers.
- SUGGESTION: 4
compute_fp4_scales(which dispatches into a CUDA-ext FP8 round-trip on the current device) runs outside thewith torch.cuda.device(x.device):block innvfp4_fp8_scale_sweep_hessian. On a multi-GPU host wherex.device != torch.cuda.current_device()this can land on the wrong device. Move the precompute inside the block.- Single-config
@triton.autotuneadds dispatch overhead with no tuning benefit — either add more configs or lift to direct constexpr kwargs. - Duplicate
p/qaranges in the kernel — one variable + two broadcasts reads cleaner. - The ordering of
error_funcs = {...}thenhessians = {...}inlocal_hessian_calibrate(phase 3) is load-bearing for memory:build_error_funcpopulates the cache and frees the raw buffer whendebug=False. Worth a comment so the order isn't accidentally inverted later.
Risk assessment
Low. Public API unchanged (internal fast-path), reference fallback preserved, MODELOPT_NVFP4_TRITON_SWEEP=0 opt-out kept. The one IMPORTANT finding is a docstring/precedence-clarity issue, not a behavior bug. The multi-device suggestion is the highest-impact item that can become a real bug under specific deployment conditions.
cee55a4 to
ec3ebec
Compare
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
ec3ebec to
0975de8
Compare
|
/claude review |
What does this PR do?
Type of change: new feature
Adds a fused Triton kernel that replaces the 126-step Python reference sweep in
local_hessianweight calibration (the Hessian-weighted variant of the NVFP4 FP8 scale search). For each NVFP4 block it minimizes the Hessian-weighted errordwᵀ H dw(dw = w − quant(w)) over the 126 valid FP8-E4M3 candidate scales, using the per-cin-block local HessianHshared across output rows.nvfp4_fp8_scale_sweep_hessianinmodelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py): one program sweeps a tile of output rows of a single cin-block, loading that block's[16,16]Hessian once and computing the per-candidate quadratic formdwᵀ H dwas atl.dottensor-core matmul. Candidate block scales are precomputed on the host via the referencecompute_fp4_scales, so the kernel's quantization is bit-identical to the reference fake-quant.NVFP4MSECalibrator,calib/mse.py): gains an optionalhessian=fast path;error_funcremains the CPU/non-CUDA reference fallback.model_calib.py):_LocalHessianAccumulator.normalized_hessian()exposes a shared normalized Hessian; ahessian_forchannel threads it through theWorks on any CUDA GPU with Triton (no
tl.float8e4nvrequirement); falls back to the reference sweep when Triton is unavailable, on CPU, or viaMODELOPT_NVFP4_TRITON_SWEEP=0.Usage
Testing
tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py): parity vs the reference 126-step sweep — bit-exact for fp32/fp16, bf16 within a tight bound; plus dispatch, input-validation, and a ≥30x speedup assertion. All pass.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
The kernel only accelerates the per-block scale search (phase 3 of
local_hessian); on large MoE models the end-to-end time becomes dominated by the calibration forwards (max-calibrate + Hessian capture), so a parallel/tensor-parallel calibration forward is the next lever for further e2e speedup. Separately, ensuring all MoE experts are routed during calibration would shrink the small residual mismatch attributable to never-routed (degenerate) experts.Summary by CodeRabbit
New Features
Tests
Documentation