Skip to content

Added support for MoE for vllm >= 0.14.0rc1#1162

Open
kinjalpatel27 wants to merge 6 commits intomainfrom
kinjal/vllm_super_nano_support
Open

Added support for MoE for vllm >= 0.14.0rc1#1162
kinjalpatel27 wants to merge 6 commits intomainfrom
kinjal/vllm_super_nano_support

Conversation

@kinjalpatel27
Copy link
Copy Markdown
Contributor

@kinjalpatel27 kinjalpatel27 commented Apr 1, 2026

What does this PR do?

Type of change: Bug fix

_QuantFusedMoEBase.forward() previously replaced vllm_fused_moe_package.invoke_fused_moe_kernel, which was renamed to dispatch_fused_moe_kernel starting in vLLM v0.14.0rc1. This caused an AttributeError / assertion failure for any MoE model quantized with vLLM ≥ v0.14.0rc1.

The fix refactors the kernel-patching logic into a _patch_moe_kernel() context manager that probes for both attribute names (the two names are mutually exclusive across vLLM versions — confirmed by inspecting every release from v0.10.0 to v0.18.1).

Usage

NA

Testing

docker run --gpus all -it --shm-size=160GB --network host --rm -v <modelopt path>:/home/modelopt \
vllm/vllm-openai:v0.15.0 bash -c "cd /home/modelopt && pip install . && pip install datasets && \
  QUANT_CFG=NVFP4_DEFAULT_CFG python3 /home/modelopt/examples/vllm_serve/vllm_serve_fakequant.py \
nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 -tp 1 --served-model-name NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ 
  --host 0.0.0.0 --port 8001 --trust-remote-code --enforce-eager --disable-custom-all-reduce \
--gpu-memory-utilization 0.8" 

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: N/A
  • Did you update Changelog?: N/A

Additional Information

Summary by CodeRabbit

  • Refactor
    • Ensures quantized expert weights are correctly used by the fused-MoE execution path so inference uses the intended quantized tensors.
    • Replaces fragile manual swapping of the runtime kernel with a safer, context-managed swap that reliably caches and restores the original.
    • Adds runtime detection and selection among available fused-MoE kernel entrypoints to support multiple variants.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 1, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 1, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Resolve and cache the runtime vLLM fused-MoE kernel entrypoint name, rebind quantized expert weights as the kernel B argument when quantization is enabled, and swap the kernel entrypoint during forward() using replace_function(...) instead of direct monkey-patching.

Changes

Cohort / File(s) Summary
vLLM MoE Quantization Integration
modelopt/torch/quantization/plugins/vllm.py
Resolve and cache available fused-MoE entrypoint name into self.invoke_fused_moe_kernel_func in _setup. invoke_fused_moe_quantized() quantizes expert weights when enabled, rebinds the kernel B argument to the quantized tensor for the call, and restores originals afterwards. forward() uses replace_function(...) to swap the resolved entrypoint with the quantized wrapper instead of try/finally monkey-patching. Introduces internal attribute _QuantFusedMoEBase.invoke_fused_moe_kernel_func.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant QuantMoE as _QuantFusedMoEBase
    participant Quantizer
    participant vLLM as vllm_fused_moe_package

    Caller->>QuantMoE: forward(...)
    Note right of QuantMoE: ensure entrypoint resolved and cached\n(self.invoke_fused_moe_kernel_func)
    QuantMoE->>QuantMoE: enter replace_function -> swap entrypoint to invoke_fused_moe_quantized
    alt weight quantizer enabled
        QuantMoE->>Quantizer: quantize(expert_weight)
        Quantizer-->>QuantMoE: quantized_weight
        QuantMoE->>vLLM: invoke_fused_moe_kernel(..., B=quantized_weight)
    else quantizer disabled
        QuantMoE->>vLLM: invoke_fused_moe_kernel(..., B=original_weight)
    end
    vLLM-->>QuantMoE: kernel result
    QuantMoE->>QuantMoE: exit replace_function -> restore original entrypoint
    QuantMoE-->>Caller: return result
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding MoE support for vLLM version 0.14.0rc1 and later, which aligns with the PR's objective to fix the AttributeError caused by vLLM's function rename.
Security Anti-Patterns ✅ Passed No forbidden security patterns detected in modelopt/torch/quantization/plugins/vllm.py
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kinjal/vllm_super_nano_support

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 1, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1162/

Built to branch gh-pages at 2026-04-08 22:33 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/vllm.py (1)

387-388: Please add regression tests for both symbol paths and patch restore behavior.

Given the compatibility branch and runtime patching, add tests that cover: (1) invoke_fused_moe_kernel, (2) invoke_fused_moe_triton_kernel, and (3) restoration on exceptions during super().forward(...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 387 - 388, Add
regression tests that exercise both symbol paths and verify the runtime patching
in _patch_moe_kernel is applied and always restored: write tests that (1)
trigger the branch where invoke_fused_moe_kernel is used, (2) trigger the branch
where invoke_fused_moe_triton_kernel is used, and (3) simulate an exception
raised during super().forward(...) to assert the original symbols are restored
after the exception. Locate and call the class/method that uses
_patch_moe_kernel and forward to run these cases, patch or monkeypatch the
target symbols to observable fakes, and assert pre-/post-conditions on the
original functions to confirm restore behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 369-383: The current _patch_moe_kernel contextmanager mutates
vllm_fused_moe_package globals unsafely; make it concurrency-safe by serializing
patch/unpatch with a module-level threading.RLock and per-attribute reference
counting (or a reentrancy counter) so nested/concurrent entries don't clobber
originals: on entry, acquire the lock, for each attr
("invoke_fused_moe_kernel","invoke_fused_moe_triton_kernel") save the original
into a local map only if not already saved and replace the attr with
self.invoke_fused_moe_quantized while incrementing a refcount; yield; in
finally, decrement the refcount and only when it reaches zero restore the
original to vllm_fused_moe_package[attr] and remove the saved original, then
release the lock—use the symbols _patch_moe_kernel, vllm_fused_moe_package,
invoke_fused_moe_kernel, invoke_fused_moe_triton_kernel,
_invoke_fused_moe_kernel and invoke_fused_moe_quantized to locate and implement
this change.
- Around line 349-351: The kernel is currently called with the original weight
because B was bound before swapping, and restoration isn't in a finally block so
exceptions leave the quantized tensor in place; fix by assigning orig =
self.w13_weight, replacing self.w13_weight with
self.w13_weight_quantizer(self.w13_weight), then invoke
vllm_fused_moe_package._invoke_fused_moe_kernel using the swapped
self.w13_weight (not the previously bound B), and always restore self.w13_weight
= orig in a finally block; apply the identical pattern to the other symmetric
block (the one that swaps/restores the other weight at lines 359-361).

---

Nitpick comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 387-388: Add regression tests that exercise both symbol paths and
verify the runtime patching in _patch_moe_kernel is applied and always restored:
write tests that (1) trigger the branch where invoke_fused_moe_kernel is used,
(2) trigger the branch where invoke_fused_moe_triton_kernel is used, and (3)
simulate an exception raised during super().forward(...) to assert the original
symbols are restored after the exception. Locate and call the class/method that
uses _patch_moe_kernel and forward to run these cases, patch or monkeypatch the
target symbols to observable fakes, and assert pre-/post-conditions on the
original functions to confirm restore behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 476612da-4375-40a3-b6c6-f1dfd76df7b5

📥 Commits

Reviewing files that changed from the base of the PR and between de55e8a and 122b935.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/vllm.py

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 1, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.12%. Comparing base (abf4558) to head (8f1b28c).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1162      +/-   ##
==========================================
+ Coverage   71.86%   76.12%   +4.25%     
==========================================
  Files         352      353       +1     
  Lines       40351    41169     +818     
==========================================
+ Hits        28999    31338    +2339     
+ Misses      11352     9831    -1521     
Flag Coverage Δ
examples 44.35% <ø> (+1.15%) ⬆️
gpu 56.90% <ø> (+9.47%) ⬆️
unit 55.02% <ø> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/vllm_super_nano_support branch 3 times, most recently from a99f27f to 435caab Compare April 7, 2026 23:34
@kinjalpatel27 kinjalpatel27 marked this pull request as ready for review April 7, 2026 23:35
@kinjalpatel27 kinjalpatel27 requested a review from a team as a code owner April 7, 2026 23:35
@kinjalpatel27 kinjalpatel27 requested a review from ajrasane April 7, 2026 23:35
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/vllm.py (1)

364-373: ⚠️ Potential issue | 🔴 Critical

Restore swapped expert weights in finally blocks.

On Line 371 and Line 387, an exception in the kernel call will skip restoration at Line 372 and Line 388, leaving self.w13_weight/self.w2_weight permanently swapped for the rest of execution.

Proposed fix
         if B is self.w13_weight:
             # First layer of expert
             A = self.w13_input_quantizer(A)  # noqa: N806
             if self.w13_weight_quantizer.is_enabled:
                 original_weight, self.w13_weight = (
                     self.w13_weight,
                     self.w13_weight_quantizer(self.w13_weight),
                 )
                 # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
                 # quantized weight to the kernel.
                 B = self.w13_weight  # noqa: N806
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w13_weight = original_weight
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
+                finally:
+                    self.w13_weight = original_weight
             else:
                 vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
@@
         elif B is self.w2_weight:
             A = self.w2_input_quantizer(A)  # noqa: N806
             if self.w2_weight_quantizer.is_enabled:
                 original_weight, self.w2_weight = (
                     self.w2_weight,
                     self.w2_weight_quantizer(self.w2_weight),
                 )
                 # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
                 # quantized weight to the kernel.
                 B = self.w2_weight  # noqa: N806
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w2_weight = original_weight
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
+                finally:
+                    self.w2_weight = original_weight
             else:
                 vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)

Also applies to: 380-389

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 364 - 373, The
kernel invocation swaps in quantized expert weights (e.g., setting
original_weight, self.w13_weight = (self.w13_weight,
self.w13_weight_quantizer(self.w13_weight)) and similarly for self.w2_weight)
but restores them only after the kernel call, so exceptions can leave weights
swapped; wrap the kernel call and the restore in a try/finally so that
self.w13_weight and self.w2_weight are always set back to original_weight even
if vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
raises—i.e., perform the swap, call kernel inside try, and put self.w13_weight =
original_weight (and self.w2_weight = original_weight for the other block) in
the finally block.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 364-373: The kernel invocation swaps in quantized expert weights
(e.g., setting original_weight, self.w13_weight = (self.w13_weight,
self.w13_weight_quantizer(self.w13_weight)) and similarly for self.w2_weight)
but restores them only after the kernel call, so exceptions can leave weights
swapped; wrap the kernel call and the restore in a try/finally so that
self.w13_weight and self.w2_weight are always set back to original_weight even
if vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
raises—i.e., perform the swap, call kernel inside try, and put self.w13_weight =
original_weight (and self.w2_weight = original_weight for the other block) in
the finally block.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6d0f9a81-1879-41e8-89c7-1365685a7041

📥 Commits

Reviewing files that changed from the base of the PR and between 122b935 and 435caab.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/vllm.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
modelopt/torch/quantization/plugins/vllm.py (2)

358-367: ⚠️ Potential issue | 🟠 Major

Restore the temporary expert weights in a finally.

A kernel/CUDA failure here leaves self.w13_weight / self.w2_weight permanently swapped to the quantized tensor, which can poison later forwards on the same module.

🩹 Proposed fix
             if self.w13_weight_quantizer.is_enabled:  # pragma: no cover
                 original_weight, self.w13_weight = (
                     self.w13_weight,
                     self.w13_weight_quantizer(self.w13_weight),
                 )
                 # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
                 # quantized weight to the kernel.
                 B = self.w13_weight  # noqa: N806
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w13_weight = original_weight
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
+                finally:
+                    self.w13_weight = original_weight
@@
             if self.w2_weight_quantizer.is_enabled:  # pragma: no cover
                 original_weight, self.w2_weight = (
                     self.w2_weight,
                     self.w2_weight_quantizer(self.w2_weight),
                 )
                 # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
                 # quantized weight to the kernel.
                 B = self.w2_weight  # noqa: N806
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w2_weight = original_weight
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
+                finally:
+                    self.w2_weight = original_weight

Also applies to: 374-383

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 358 - 367, The
temporary swap of expert weights with their quantized version (e.g., when
self.w13_weight_quantizer.is_enabled and similarly for self.w2_weight) must be
restored in a finally block so kernel/CUDA exceptions don't leave
self.w13_weight/self.w2_weight permanently replaced; update the blocks around
self.w13_weight = self.w13_weight_quantizer(...)/B = self.w13_weight and the
call to vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, ...) so that
original_weight is assigned back inside a finally (use try: set quantized, call
kernel; finally: self.w13_weight = original_weight) and apply the same change to
the corresponding w2_weight section.

392-398: ⚠️ Potential issue | 🔴 Critical

replace_function() is not safe enough for this module-global kernel swap.

replace_function in modelopt/torch/quantization/utils/core_utils.py restores only after yield, with no try/finally or synchronization. A failing or overlapping forward can therefore leave vllm_fused_moe_package patched to the wrong bound method and break later MoE calls. Please harden the helper or use a dedicated locked try/finally patch on this path.

Run this read-only check to confirm the helper is exception-unsafe and non-reentrant:

#!/bin/bash
set -euo pipefail

sed -n '326,336p' modelopt/torch/quantization/utils/core_utils.py

python - <<'PY'
from contextlib import contextmanager

`@contextmanager`
def replace(obj, name, new, cache="_cached"):
    old = getattr(obj, name)
    setattr(obj, name, new)
    setattr(obj, cache, old)
    yield
    setattr(obj, name, old)
    delattr(obj, cache)

class Obj:
    pass

o = Obj()
o.f = "orig"

try:
    with replace(o, "f", "A"):
        with replace(o, "f", "B"):
            pass
except Exception as exc:
    print(type(exc).__name__, exc)

events = []

`@contextmanager`
def cm():
    events.append("enter")
    yield
    events.append("cleanup")

try:
    with cm():
        raise RuntimeError("boom")
except RuntimeError:
    pass

print(events)
PY

Expected results:

  • The helper source shows restore statements after yield, but no try/finally or locking.
  • The nested toy example raises during teardown, showing the pattern is not re-entrant.
  • The last line prints ['enter'], showing code after yield is skipped on exceptions.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 392 - 398, The
current use of replace_function for swapping the module-global
vllm_fused_moe_package binding in the forward override is exception-unsafe and
non-reentrant; instead, wrap the attribute swap around a locked try/finally in
the forward implementation: acquire a module-level threading.RLock, save the
original bound method from vllm_fused_moe_package (the symbol cached under
"_invoke_fused_moe_kernel"), set the attribute to
self.invoke_fused_moe_quantized, call super().forward(hidden_states,
router_logits), and in a finally block restore the original attribute and remove
any cache attribute; alternatively harden replace_function itself to perform the
same lock + try/finally restore behavior to protect invoke_fused_moe_kernel_func
/ invoke_fused_moe_quantized swaps from exceptions and nested calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 358-367: The temporary swap of expert weights with their quantized
version (e.g., when self.w13_weight_quantizer.is_enabled and similarly for
self.w2_weight) must be restored in a finally block so kernel/CUDA exceptions
don't leave self.w13_weight/self.w2_weight permanently replaced; update the
blocks around self.w13_weight = self.w13_weight_quantizer(...)/B =
self.w13_weight and the call to
vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, ...) so that
original_weight is assigned back inside a finally (use try: set quantized, call
kernel; finally: self.w13_weight = original_weight) and apply the same change to
the corresponding w2_weight section.
- Around line 392-398: The current use of replace_function for swapping the
module-global vllm_fused_moe_package binding in the forward override is
exception-unsafe and non-reentrant; instead, wrap the attribute swap around a
locked try/finally in the forward implementation: acquire a module-level
threading.RLock, save the original bound method from vllm_fused_moe_package (the
symbol cached under "_invoke_fused_moe_kernel"), set the attribute to
self.invoke_fused_moe_quantized, call super().forward(hidden_states,
router_logits), and in a finally block restore the original attribute and remove
any cache attribute; alternatively harden replace_function itself to perform the
same lock + try/finally restore behavior to protect invoke_fused_moe_kernel_func
/ invoke_fused_moe_quantized swaps from exceptions and nested calls.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 12eed472-d865-4c37-bdc1-2eaa551309b2

📥 Commits

Reviewing files that changed from the base of the PR and between 435caab and 96d081a.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/vllm.py

Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/vllm_super_nano_support branch from 96d081a to 80c400e Compare April 8, 2026 16:39
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/vllm.py (1)

358-367: ⚠️ Potential issue | 🟠 Major

Restore the temporary expert weights in finally.

These branches still leave self.w13_weight / self.w2_weight swapped to the temporary quantized tensor if _invoke_fused_moe_kernel(...) throws, so later fake-quant calls run with corrupted layer state.

♻️ Proposed fix
             if self.w13_weight_quantizer.is_enabled:  # pragma: no cover
                 original_weight, self.w13_weight = (
                     self.w13_weight,
                     self.w13_weight_quantizer(self.w13_weight),
                 )
                 # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
                 # quantized weight to the kernel.
                 B = self.w13_weight  # noqa: N806
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w13_weight = original_weight
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
+                finally:
+                    self.w13_weight = original_weight
@@
             if self.w2_weight_quantizer.is_enabled:  # pragma: no cover
                 original_weight, self.w2_weight = (
                     self.w2_weight,
                     self.w2_weight_quantizer(self.w2_weight),
                 )
                 # In case the weight quantizer isn't folded yet in vllm_serve_fakequant, pass the
                 # quantized weight to the kernel.
                 B = self.w2_weight  # noqa: N806
-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
-                self.w2_weight = original_weight
+                try:
+                    vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
+                finally:
+                    self.w2_weight = original_weight

Also applies to: 374-383

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/vllm.py` around lines 358 - 367, The
temporary assignment of quantized expert weights (e.g., swapping self.w13_weight
and self.w2_weight before calling
vllm_fused_moe_package._invoke_fused_moe_kernel) can leave the model in a
corrupted state if the kernel throws; wrap the kernel invocation in a
try/finally so the original weight is always restored (store original_weight =
self.w13_weight / self.w2_weight, assign the quantized tensor, call
vllm_fused_moe_package._invoke_fused_moe_kernel(...) inside try, and restore
self.w13_weight / self.w2_weight in finally). Apply the same try/finally pattern
to both the w13 branch (uses self.w13_weight_quantizer and B variable) and the
w2 branch to ensure consistent restoration.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 392-398: The context manager replace_function currently yields
without guaranteeing cleanup on exceptions, which can leave
vllm_fused_moe_package permanently patched and the _invoke_fused_moe_kernel
cache attribute lingering; change replace_function to wrap the yield in a
try/finally so the patch/unpatch and any attribute restore/deletion (the same
cleanup logic already present after the yield) always runs even if code inside
the context (e.g., super().forward in the caller that uses
invoke_fused_moe_kernel_func/invoke_fused_moe_quantized) raises an exception.

---

Duplicate comments:
In `@modelopt/torch/quantization/plugins/vllm.py`:
- Around line 358-367: The temporary assignment of quantized expert weights
(e.g., swapping self.w13_weight and self.w2_weight before calling
vllm_fused_moe_package._invoke_fused_moe_kernel) can leave the model in a
corrupted state if the kernel throws; wrap the kernel invocation in a
try/finally so the original weight is always restored (store original_weight =
self.w13_weight / self.w2_weight, assign the quantized tensor, call
vllm_fused_moe_package._invoke_fused_moe_kernel(...) inside try, and restore
self.w13_weight / self.w2_weight in finally). Apply the same try/finally pattern
to both the w13 branch (uses self.w13_weight_quantizer and B variable) and the
w2 branch to ensure consistent restoration.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 96b8d5b6-4e3c-4e72-a2b8-32129861fd90

📥 Commits

Reviewing files that changed from the base of the PR and between 96d081a and 80c400e.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/vllm.py

@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/vllm_super_nano_support branch from 31be877 to c42e844 Compare April 8, 2026 18:09
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
@kinjalpatel27 kinjalpatel27 force-pushed the kinjal/vllm_super_nano_support branch from c42e844 to 037f9a9 Compare April 8, 2026 18:17
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant