Skip to content

DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621

Open
yeyu-nvidia wants to merge 10 commits into
mainfrom
yeyu/dflash-auto-mask-token
Open

DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621
yeyu-nvidia wants to merge 10 commits into
mainfrom
yeyu/dflash-auto-mask-token

Conversation

@yeyu-nvidia

@yeyu-nvidia yeyu-nvidia commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

What

Brings up DFlash block-diffusion speculative decoding for large MoE targets (MiniMax-M2.7, 229B) trained under accelerate FSDP2, and fixes the regressions that broke checkpoint resume and per-checkpoint draft export.

Commits

  • auto-add mask token for DFlash when the tokenizer lacks one (resize embeddings, restore dtype).
  • requeue support in build_slurm_executor + FSDP2 cpu_ram_efficient_loading for 229B on multi-node.
  • FSDP2 buffer patch (fsdp2_buffer_patch.py): handle non-DTensor buffers in fsdp2_load_full_state_dict, broadcast dtype codes from rank 0, and an FSDP2-safe clip_grad_norm_. Required because MiniMax-M2.7 pins transformers 4.57.x (no native ParallelismConfig).
  • dtype fix: use the broadcast dtype (rank 0) rather than the local meta-device param dtype, so non-leader ranks don't cast bf16 back to fp32 on resume.
  • restore DFlashExportCallback (this PR's headline): the Pydantic-recipe refactor (7038dec) dropped the callback that exported the draft submodule after each checkpoint save, leaving a stale "export happens during training via DFlashExportCallback" comment with no callback. FSDP2 SHARDED_STATE_DICT checkpoints carry no model.safetensors, so without it there is nothing for vLLM / acceptance-length eval to load. The callback gathers only the ~328 MB draft submodule across shards via get_model_state_dict(..., submodules={dflash_module}, full_state_dict=True, cpu_offload=True) — works under SHARDED_STATE_DICT without materializing the 229B base — and writes exported-checkpoint-{step}/.

Testing

  • Resume from FSDP2 sharded checkpoints verified end-to-end (loss/AR continuity).
  • Draft export validated against vLLM: exported drafts load and produce acceptance-length metrics on MT-Bench across the full checkpoint sweep.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Export draft-submodule weights to dedicated exported checkpoints during training.
    • FSDP2 buffer compatibility and DTensor-aware gradient clipping for safer distributed loading/training.
    • Detect HF-format checkpoints for smarter resume/load behavior.
    • Auto-add and handle a mask special token for draft workflows.
    • vLLM: disable prefix caching to preserve full prompt hidden states.
    • Add CLI option for answer-only loss and save aligned loss masks.
    • Add a SPEED-Bench config for DFLASH/vLLM benchmarking.
  • Chores

    • Robust package version fallback to avoid import failures.

@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner June 3, 2026 18:42
@yeyu-nvidia yeyu-nvidia requested a review from h-guo18 June 3, 2026 18:42
@coderabbitai

coderabbitai Bot commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a DFlash draft-weight export callback, an FSDP2 buffer/DTensor monkey-patch (with DTensor-safe grad clipping), integrates both into the speculative-decoding training script (checkpoint-format detection and mask-token init), disables vLLM prefix caching for hidden-state dumps, and forwards Slurm requeue settings to the launcher executor.

Changes

Speculative Decoding Training Enhancements

Layer / File(s) Summary
DFlash export callback
examples/speculative_decoding/eagle_utils.py
DFlashExportCallback gathers model state (distributed-aware with fallbacks), filters dflash_module.* (excludes rotary_emb), and writes model.safetensors and config.json on master rank; skips empty exports and logs failures.
FSDP2 buffer & DTensor patching
examples/speculative_decoding/fsdp2_buffer_patch.py
Adds monkey-patch to accelerate.utils.fsdp_utils.fsdp2_load_full_state_dict to synchronize dtypes, broadcast non-DTensor buffers from rank 0, reconstruct DTensors via distribute_tensor(), and provides DTensor-aware _clip_grad_norm plus patch_accelerator.
Main script integration and checkpoint/token handling
examples/speculative_decoding/main.py
Imports/conditionally applies fsdp2_buffer_patch, detects HF-formatted checkpoints vs. sharded checkpoints and falls back to base-model loading, ensures dflash_mask_token_id exists (derives or adds `<
vLLM hidden-state config and loss-mask capture
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
Adds --answer-only-loss arg, allows chat-template overrides/cleanup, tokenizes with aligned loss_mask via tokenize_with_loss_mask, disables enable_prefix_caching for full-prompt hidden states, and includes loss_mask in saved .pt aligned to hidden-state length.`

Launcher Slurm Requeue Configuration

Layer / File(s) Summary
Slurm requeue and retries
tools/launcher/core.py
build_slurm_executor forwards additional_parameters to run.SlurmExecutor, sets requeue=True and ensures retries>=3 when slurm_config.requeue is enabled.

Package initialization

Layer / File(s) Summary
modelopt version fallback
modelopt/__init__.py
Wraps importlib.metadata.version in try/except PackageNotFoundError, falling back to __version__ = "0.0.0+unknown" when distribution metadata is absent.

Benchmark spec

Layer / File(s) Summary
MiniMax DFlash SPEED-Bench spec
tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml
Adds YAML configuration for running MiniMax-M2.7 DFLASH benchmarks with vLLM, two tasks, and SLURM/container settings pointing to the exported draft checkpoint.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • ChenhanYu
  • shengliangxu
  • kevalmorabia97
  • h-guo18
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.85% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and comprehensively summarizes the three main changes: auto mask-token support, FSDP2 resume fixes, and per-checkpoint draft export, all in context of DFlash speculative decoding for MiniMax-M2.7.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no unsafe torch.load/numpy.load, trust_remote_code is configurable (not hardcoded), no eval/exec of untrusted input, and no new unsafe dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yeyu/dflash-auto-mask-token

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/eagle_utils.py (1)

53-53: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Update __all__ to include DFlashExportCallback.

The coding guidelines require defining the public API with __all__. Since DFlashExportCallback is imported by main.py (line 39), it should be exported.

-__all__ = ["EagleOfflineDataCollator", "OfflineSupervisedDataset"]
+__all__ = ["DFlashExportCallback", "EagleOfflineDataCollator", "OfflineSupervisedDataset"]

As per coding guidelines: "Define the public API with __all__ at the top of each Python module."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/eagle_utils.py` at line 53, The module's public
API list __all__ is missing DFlashExportCallback; update the __all__ declaration
(which currently lists "EagleOfflineDataCollator" and
"OfflineSupervisedDataset") to also include "DFlashExportCallback" so the symbol
is exported for consumers like main.py that import it.
🧹 Nitpick comments (5)
examples/speculative_decoding/fsdp2_buffer_patch.py (4)

238-240: ⚡ Quick win

Use print_rank_0 to avoid noisy logs in multi-rank environments.

These print statements execute on every rank, which can produce excessive output on large clusters. Consider using print_rank_0 from modelopt.torch.utils or guarding with a rank check.

+from modelopt.torch.utils import print_rank_0
+
 # In apply() function:
-        print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility")
+        print_rank_0("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility")
     except Exception as e:
-        print(f"[fsdp2_buffer_patch] Patch skipped: {e}")
+        print_rank_0(f"[fsdp2_buffer_patch] Patch skipped: {e}")

As per coding guidelines: "use print_rank_0 or warn_rank_0 to avoid noisy logs."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 238 - 240,
The two print calls in fsdp2_buffer_patch (the success message and the except
message around fsdp2_load_full_state_dict) should be replaced with rank-safe
logging: import and call print_rank_0 from modelopt.torch.utils (or guard with a
rank check) so messages only appear on rank 0; update the success print and the
exception print to use print_rank_0 and include the exception variable in the
error message (e) while keeping the same text context.

1-3: 💤 Low value

Add __all__ to define the public API.

Per coding guidelines, each Python module should define __all__ to make the public API explicit.

+__all__ = ["apply", "patch_accelerator"]
+
 """Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling.

As per coding guidelines: "Define the public API with __all__ at the top of each Python module."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 1 - 3, This
module is missing an explicit public API; add a module-level __all__ declaration
at the top (after the SPDX headers) listing the public names exported by this
file (e.g. __all__ = ["Name1", "function_name", "CLASS_NAME"]), ensuring each
symbol included matches the actual top-level functions/classes/variables defined
later in the file; place the __all__ immediately below the license lines to
satisfy the coding guideline.

322-326: ⚡ Quick win

Use print_rank_0 here as well.

 def patch_accelerator(accelerator):
     """Replace accelerator's clip_grad_norm_ with FSDP2-safe version."""
     accelerator.clip_grad_norm_ = _clip_grad_norm
-    print("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ "
-          "for FSDP2 DTensor compatibility")
+    from modelopt.torch.utils import print_rank_0
+    print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ "
+                 "for FSDP2 DTensor compatibility")

As per coding guidelines: "use print_rank_0 or warn_rank_0 to avoid noisy logs."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 322 - 326,
The patch_accelerator function currently uses print to log the patch; replace
that call with print_rank_0 to follow logging guidelines. Update the function to
call print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for
FSDP2 DTensor compatibility") and ensure print_rank_0 is imported at top of the
module (the same utility used elsewhere), leaving accelerator.clip_grad_norm_ =
_clip_grad_norm unchanged.

269-270: 💤 Low value

Return value should be on the same device as gradients.

When there are no gradients, the function returns a CPU tensor. For consistency with the non-empty case (which returns total_norm on device), consider returning on the same device.

     if len(grads) == 0:
-        return torch.tensor(0.0)
+        device = parameters[0].device if parameters else "cpu"
+        return torch.tensor(0.0, device=device)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 269 - 270,
The early-return creates a CPU tensor when grads is empty; change it to return a
zero tensor on the same device as the gradients by selecting the first available
grad device (e.g. device = next((g.device for g in grads if g is not None),
torch.device('cpu'))) and then return torch.tensor(0.0, device=device) so the
empty-case matches the device of the non-empty case that returns total_norm.
examples/speculative_decoding/main.py (1)

297-303: ⚡ Quick win

Consider gating debug output or removing before merge.

This debug print executes on every rank and will produce verbose output on large clusters. If this is temporary debugging code, consider removing it or guarding with a debug flag.

-    rank = int(os.environ.get("RANK", 0))
-    dtypes = {}
-    for name, p in trainer.model.named_parameters():
-        dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype)
-        dtypes.setdefault(dt_key, []).append(name)
-    for dt, names in dtypes.items():
-        print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")
+    if os.environ.get("DEBUG_DTYPES"):
+        rank = int(os.environ.get("RANK", 0))
+        dtypes = {}
+        for name, p in trainer.model.named_parameters():
+            dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype)
+            dtypes.setdefault(dt_key, []).append(name)
+        for dt, names in dtypes.items():
+            print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")

As per coding guidelines: "use print_rank_0 or warn_rank_0 to avoid noisy logs."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/main.py` around lines 297 - 303, The debug loop
printing per-rank dtype info (uses rank, dtypes, iterating
trainer.model.named_parameters()) should be gated or replaced to avoid noisy
logs: either remove the prints or wrap them so only rank 0 logs (use existing
print_rank_0 or warn_rank_0 utility) and/or guard with a debug flag (e.g., if
DEBUG:). Update the block that builds dtypes and the final print to call
print_rank_0 (or warn_rank_0) with the formatted message so only the main
process emits the output, or conditionally execute the entire loop behind a
debug configuration toggle.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tools/launcher/core.py`:
- Around line 280-287: The code assumes slurm_config.additional_parameters is a
mutable dict and mutates it directly, which can cause shared-state bugs; before
assigning to executor.additional_parameters (and before mutating it to set
"requeue"), validate and normalize slurm_config.additional_parameters to a plain
dict (e.g., treat None, mappings, or other types safely), create a shallow copy
for executor.additional_parameters, and then mutate that copy; also ensure
executor.retries is updated via executor.retries = max(executor.retries, 3) as
shown. Reference: slurm_config.additional_parameters,
executor.additional_parameters, and executor.retries.

---

Outside diff comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Line 53: The module's public API list __all__ is missing DFlashExportCallback;
update the __all__ declaration (which currently lists "EagleOfflineDataCollator"
and "OfflineSupervisedDataset") to also include "DFlashExportCallback" so the
symbol is exported for consumers like main.py that import it.

---

Nitpick comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 238-240: The two print calls in fsdp2_buffer_patch (the success
message and the except message around fsdp2_load_full_state_dict) should be
replaced with rank-safe logging: import and call print_rank_0 from
modelopt.torch.utils (or guard with a rank check) so messages only appear on
rank 0; update the success print and the exception print to use print_rank_0 and
include the exception variable in the error message (e) while keeping the same
text context.
- Around line 1-3: This module is missing an explicit public API; add a
module-level __all__ declaration at the top (after the SPDX headers) listing the
public names exported by this file (e.g. __all__ = ["Name1", "function_name",
"CLASS_NAME"]), ensuring each symbol included matches the actual top-level
functions/classes/variables defined later in the file; place the __all__
immediately below the license lines to satisfy the coding guideline.
- Around line 322-326: The patch_accelerator function currently uses print to
log the patch; replace that call with print_rank_0 to follow logging guidelines.
Update the function to call print_rank_0("[fsdp2_buffer_patch] Patched
accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility") and ensure
print_rank_0 is imported at top of the module (the same utility used elsewhere),
leaving accelerator.clip_grad_norm_ = _clip_grad_norm unchanged.
- Around line 269-270: The early-return creates a CPU tensor when grads is
empty; change it to return a zero tensor on the same device as the gradients by
selecting the first available grad device (e.g. device = next((g.device for g in
grads if g is not None), torch.device('cpu'))) and then return torch.tensor(0.0,
device=device) so the empty-case matches the device of the non-empty case that
returns total_norm.

In `@examples/speculative_decoding/main.py`:
- Around line 297-303: The debug loop printing per-rank dtype info (uses rank,
dtypes, iterating trainer.model.named_parameters()) should be gated or replaced
to avoid noisy logs: either remove the prints or wrap them so only rank 0 logs
(use existing print_rank_0 or warn_rank_0 utility) and/or guard with a debug
flag (e.g., if DEBUG:). Update the block that builds dtypes and the final print
to call print_rank_0 (or warn_rank_0) with the formatted message so only the
main process emits the output, or conditionally execute the entire loop behind a
debug configuration toggle.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: ec55dcce-a920-44ca-8e39-ee3167ca3eeb

📥 Commits

Reviewing files that changed from the base of the PR and between 88fd7ff and d2d0558.

📒 Files selected for processing (4)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/fsdp2_buffer_patch.py
  • examples/speculative_decoding/main.py
  • tools/launcher/core.py

Comment thread tools/launcher/core.py
Comment on lines +280 to +287
additional_parameters=getattr(slurm_config, "additional_parameters", None) or {},
)
if getattr(slurm_config, "requeue", False):
executor.additional_parameters["requeue"] = True
# The nemo-run sbatch wrapper only calls `scontrol requeue` when
# TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default)
# disables this, so bump it when requeue is requested.
executor.retries = max(executor.retries, 3)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Harden additional_parameters normalization before mutation.

At Line 280 and Line 283, this assumes additional_parameters is always a mutable mapping. Since these attrs are externally supplied, normalize/validate to dict before assignment and mutate a local copy to avoid shared-state side effects.

Proposed fix
-    executor = run.SlurmExecutor(
+    raw_additional_parameters = getattr(slurm_config, "additional_parameters", None)
+    additional_parameters = {}
+    if raw_additional_parameters is not None:
+        if not isinstance(raw_additional_parameters, dict):
+            raise TypeError("slurm_config.additional_parameters must be a dict")
+        additional_parameters = dict(raw_additional_parameters)
+
+    executor = run.SlurmExecutor(
         account=slurm_config.account,
         partition=slurm_config.partition,
         qos=slurm_config.qos,
@@
-        additional_parameters=getattr(slurm_config, "additional_parameters", None) or {},
+        additional_parameters=additional_parameters,
     )

As per coding guidelines: "Validate external input once at the interface boundary; internal code can trust those checks and avoid redundant assertions".

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
additional_parameters=getattr(slurm_config, "additional_parameters", None) or {},
)
if getattr(slurm_config, "requeue", False):
executor.additional_parameters["requeue"] = True
# The nemo-run sbatch wrapper only calls `scontrol requeue` when
# TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default)
# disables this, so bump it when requeue is requested.
executor.retries = max(executor.retries, 3)
raw_additional_parameters = getattr(slurm_config, "additional_parameters", None)
additional_parameters = {}
if raw_additional_parameters is not None:
if not isinstance(raw_additional_parameters, dict):
raise TypeError("slurm_config.additional_parameters must be a dict")
additional_parameters = dict(raw_additional_parameters)
executor = run.SlurmExecutor(
account=slurm_config.account,
partition=slurm_config.partition,
qos=slurm_config.qos,
additional_parameters=additional_parameters,
)
if getattr(slurm_config, "requeue", False):
executor.additional_parameters["requeue"] = True
# The nemo-run sbatch wrapper only calls `scontrol requeue` when
# TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default)
# disables this, so bump it when requeue is requested.
executor.retries = max(executor.retries, 3)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tools/launcher/core.py` around lines 280 - 287, The code assumes
slurm_config.additional_parameters is a mutable dict and mutates it directly,
which can cause shared-state bugs; before assigning to
executor.additional_parameters (and before mutating it to set "requeue"),
validate and normalize slurm_config.additional_parameters to a plain dict (e.g.,
treat None, mappings, or other types safely), create a shallow copy for
executor.additional_parameters, and then mutate that copy; also ensure
executor.retries is updated via executor.retries = max(executor.retries, 3) as
shown. Reference: slurm_config.additional_parameters,
executor.additional_parameters, and executor.retries.

@codecov

codecov Bot commented Jun 3, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 60.00000% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.71%. Comparing base (d3acf45) to head (f7844eb).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/__init__.py 60.00% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1621       +/-   ##
===========================================
+ Coverage   56.59%   76.71%   +20.11%     
===========================================
  Files         507      508        +1     
  Lines       55794    55866       +72     
===========================================
+ Hits        31579    42855    +11276     
+ Misses      24215    13011    -11204     
Flag Coverage Δ
examples 42.32% <60.00%> (+23.76%) ⬆️
gpu 57.85% <60.00%> (+37.31%) ⬆️
regression 14.74% <60.00%> (-0.11%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

yeyu-nvidia and others added 7 commits June 9, 2026 10:01
Models without a <|mask|> token (e.g., MiniMax-M2.7) would fail with
ValueError during DFlash training. Instead of requiring the user to
manually set dflash_mask_token_id, add the token to the tokenizer
and resize model embeddings automatically.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When slurm_config.requeue is True, set additional_parameters["requeue"]
= True so nemo-run emits #SBATCH --requeue in the sbatch script.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1. main.py: When FSDP2 cpu_ram_efficient_loading is active, only rank 0
   loads real weights on CPU; other ranks use meta device. FSDP2
   distributes from rank 0. Also adds dp_replicate_size auto-computation
   so dp_replicate * dp_shard * cp == world_size.

2. core.py: Set retries=3 when requeue is requested. The nemo-run sbatch
   wrapper only calls scontrol requeue when TORCHX_MAX_RETRIES >
   SLURM_RESTART_COUNT — retries=0 (the default) disabled requeue.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…heckpoint resume

The Pydantic recipe refactor dropped fsdp2_buffer_patch.apply() and
patch_accelerator() calls and added a buffer-to-CUDA block that moved
DFlash buffers before FSDP wrapping. With cpu_ram_efficient_loading,
non-rank-0 processes have meta-device params, causing
_infer_parameter_dtype() to return fp32 instead of bf16 on resume.

Also detects FSDP distributed checkpoints (no HF model files) and
loads the base model instead of trying from_pretrained on them.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… patch

_infer_parameter_dtype() reads the model's current param dtype to cast
the broadcasted tensor. With cpu_ram_efficient_loading, non-rank-0
processes have fp32 meta-device params for DFlash, so
_infer_parameter_dtype returns fp32 and _finish() casts the correctly-
broadcasted bf16 tensor back to fp32. Use bcast_dtype (from rank 0)
instead.

Also prints dtype_check on all ranks to verify consistency.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The Pydantic-recipe refactor (7038dec) dropped DFlashExportCallback,
which had exported the DFlash draft submodule after every checkpoint save.
Without it, FSDP2 sharded checkpoints (pytorch_model_fsdp_0/, no
model.safetensors) get no exported-checkpoint-{step}/, so downstream
vLLM deployment / acceptance-length eval has nothing to load. The
verify-only comment 'export happens during training via DFlashExportCallback'
was left behind but the callback itself was gone.

Restore the callback (gathers only the ~328MB draft submodule across
shards via get_model_state_dict, so it works under SHARDED_STATE_DICT
without materializing the full base model) and wire it into main.py for
DFlash recipes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…c_bench)

Adds a self-contained launcher example so MiniMax-M2.7 (229B MoE) DFlash
training is reproducible end-to-end, plus two common-script enablers it needs:

- tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/:
  - hf_online_dflash.yaml    (train -> vLLM smoke -> AR eval; 8-node FSDP2)
  - hf_offline_dflash.yaml   (vLLM hidden-state dump -> offline train)
  - specdec_bench.yaml       (qualitative + throughput_32k, DFLASH)
  - accelerate_fsdp2_hybrid.yaml, chat_template_train.jinja
- common/specdec/dflash_online_training.sh: honor OVERRIDE_TRANSFORMERS,
  ACCELERATE_CONFIG, and MIXED_PRECISION so trust_remote_code MoE models that
  need FSDP2 via an accelerate config (MiniMax-M2.7 on transformers 4.57.x) work.
- collect_hidden_states/compute_hidden_states_vllm.py: disable prefix caching.
  With it on, vLLM serves shared prefixes from cache in block chunks and the
  hidden-state connector emits only the fresh suffix, so dumped hidden_states
  came out short by N*block_size vs input_ids/loss_mask (observed gaps 0/16/32).

Validated: all three YAMLs resolve under launch.py --dryrun; online training and
the vLLM hidden-state dump (aligned output) were exercised on CW-DFW.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/dflash-auto-mask-token branch from d2d0558 to 5496efc Compare June 9, 2026 17:26

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 269-270: The early-return creates a CPU tensor when no grads
exist, causing device mismatch with the normal path which returns total_norm on
the GPU (the variable device is set around line 274); update the early-return to
produce a tensor on the same device as total_norm by constructing the zero
tensor on the same device (e.g., using the device variable or
total_norm/new_tensor style) so the returned tensor's device matches the normal
path (refer to grads, total_norm, and device to locate the code).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b39e1775-7aa8-482b-84ba-391c5f4eaef7

📥 Commits

Reviewing files that changed from the base of the PR and between d2d0558 and 5496efc.

📒 Files selected for processing (4)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/fsdp2_buffer_patch.py
  • examples/speculative_decoding/main.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py

Comment on lines +269 to +270
if len(grads) == 0:
return torch.tensor(0.0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Return tensor device inconsistency when no gradients exist.

When there are no gradients, this returns a CPU tensor, but the normal path (line 319) returns total_norm which resides on the GPU device determined at line 274. This inconsistency could cause device mismatch errors if callers operate on the returned tensor.

Suggested fix
     grads = [p.grad for p in parameters if p.grad is not None]
     if len(grads) == 0:
-        return torch.tensor(0.0)
+        # Return on CPU; caller should handle empty-grad case gracefully
+        return torch.tensor(0.0, device="cpu")

Or, if you want consistency with the normal path's device:

     grads = [p.grad for p in parameters if p.grad is not None]
     if len(grads) == 0:
+        # Determine device from parameters if possible, else CPU
+        device = next((p.device for p in parameters if p.device.type != "meta"), "cpu")
+        if hasattr(next(iter(parameters), None), "_local_tensor"):
+            device = next(iter(parameters))._local_tensor.device
-        return torch.tensor(0.0)
+        return torch.tensor(0.0, device=device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if len(grads) == 0:
return torch.tensor(0.0)
if len(grads) == 0:
# Return on CPU; caller should handle empty-grad case gracefully
return torch.tensor(0.0, device="cpu")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 269 - 270,
The early-return creates a CPU tensor when no grads exist, causing device
mismatch with the normal path which returns total_norm on the GPU (the variable
device is set around line 274); update the early-return to produce a tensor on
the same device as total_norm by constructing the zero tensor on the same device
(e.g., using the device variable or total_norm/new_tensor style) so the returned
tensor's device matches the normal path (refer to grads, total_norm, and device
to locate the code).

The vLLM hidden-state dump rejected --answer-only-loss / --chat-template and
emitted no loss_mask, so DFlash offline training with answer-only loss could not
use it (the HF dump supports this but is impractical for 229B on one GPU). Mirror
the HF dump: register add_answer_only_loss_args, apply an optional override chat
template, verify {% generation %} tags, and tokenize via tokenize_with_loss_mask so
each .pt carries an aligned loss_mask. Prefix caching is already disabled, so the
dumped hidden states line up 1:1 with input_ids/loss_mask.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py (1)

246-246: ⚡ Quick win

Add defensive length check before truncating loss_mask.

If output_hidden_states is longer than loss_mask, Python slice semantics return the full (too-short) loss_mask, creating a mismatch with the saved input_ids length. The downstream OfflineSupervisedDataset loader does not validate shape alignment, risking silent training errors. With enable_prefix_caching=False (line 194), lengths should match exactly—consider adding a warning to catch violations:

🛡️ Suggested defensive check
+        expected_len = output_hidden_states.shape[0]
+        if loss_mask.shape[0] != expected_len:
+            import warnings
+            warnings.warn(
+                f"Conversation {conv_id}: loss_mask length {loss_mask.shape[0]} != "
+                f"hidden_states length {expected_len}; may indicate tokenization mismatch."
+            )
+
         output_file = output_dir / f"{conv_id}.pt"
         with open(output_file, "wb") as f:
             torch.save(
                 {
                     "input_ids": token_ids.cpu(),
                     "hidden_states": output_hidden_states,
                     "aux_hidden_states": aux_hidden_states,
-                    "loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(),
+                    "loss_mask": loss_mask[:expected_len].cpu(),
                     "conversation_id": conv_id,
                 },
                 f,
             )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`
at line 246, Before truncating loss_mask for storage, ensure its length matches
output_hidden_states.shape[0]; check if output_hidden_states.shape[0] >
loss_mask.shape[0] and if so log a warning (or raise an error) calling out the
mismatch between loss_mask and output_hidden_states lengths (mention
enable_prefix_caching if relevant), otherwise perform the slice as before:
"loss_mask = loss_mask[: output_hidden_states.shape[0]].cpu()". Reference the
variables loss_mask and output_hidden_states and the OfflineSupervisedDataset
consumer when adding this defensive check.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Line 246: Before truncating loss_mask for storage, ensure its length matches
output_hidden_states.shape[0]; check if output_hidden_states.shape[0] >
loss_mask.shape[0] and if so log a warning (or raise an error) calling out the
mismatch between loss_mask and output_hidden_states lengths (mention
enable_prefix_caching if relevant), otherwise perform the slice as before:
"loss_mask = loss_mask[: output_hidden_states.shape[0]].cpu()". Reference the
variables loss_mask and output_hidden_states and the OfflineSupervisedDataset
consumer when adding this defensive check.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 898bf778-31d1-4a22-a5f1-2672bcf3b208

📥 Commits

Reviewing files that changed from the base of the PR and between 5496efc and d5ea663.

📒 Files selected for processing (1)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py

modelopt/__init__.py did __version__ = importlib.metadata.version('nvidia-modelopt')
with no guard, so importing modelopt crashes with PackageNotFoundError whenever the
source tree is on the path without dist metadata — e.g. the launcher mounts the
modelopt source into a vLLM/TRT-LLM container's site-packages rather than pip-installing
it. That broke any modelopt import in those containers: collect_hidden_states'
resolve_aux_layers (DFlash/EAGLE presets) and specdec_bench (whose guard only caught
ModuleNotFoundError, not PackageNotFoundError). Fall back to a sentinel version instead.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1621/

Built to branch gh-pages at 2026-06-09 19:30 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

…FLASH bench

Two fixes for the DFlash launcher paths in a stock vLLM container:
- compute_hidden_states_vllm.py: resolve_aux_layers('dflash') imports
  modelopt.torch.speculative.plugins, which drags in the full modelopt.torch init
  chain (omegaconf, ...) that the vLLM container lacks. Resolve the 'dflash' preset /
  explicit layer list inline so the dump needs no modelopt at all.
- examples/.../specdec_bench.yaml: pass --block_size 8. run.py maps
  speculative_num_draft_tokens=args.block_size; for DFLASH this must be set or
  num_speculative_tokens is None and vLLM's max_num_seqs=concurrency*None crashes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Around line 47-73: The _resolve_aux_layers_standalone function currently
parses comma-separated aux layer IDs but doesn't validate they lie in [0,
num_hidden_layers); update it to check each parsed id (from aux_layers.split)
against 0 <= id < num_hidden_layers and raise a ValueError with a clear message
(mirroring resolve_aux_layers) if any id is out of range or negative; keep the
existing behavior for the "dflash" preset and ensure the error references the
original aux_layers input and num_hidden_layers for clarity.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b02075cf-1305-4259-b54f-425362cba18c

📥 Commits

Reviewing files that changed from the base of the PR and between e6c552f and f7844eb.

📒 Files selected for processing (2)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
  • tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml

Comment on lines +47 to +73
def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]:
"""Resolve aux-layer ids without importing modelopt.

This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the
'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which
pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM
container does not have, so the import fails. Resolve the 'dflash' preset inline
(mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe
default) and accept an explicit comma-separated int list. Keep in sync with modelopt.
"""
spec = aux_layers.strip().lower()
if spec == "dflash":
num_draft = 5
if num_draft == 1:
return [num_hidden_layers // 2]
start = min(1, num_hidden_layers - 1)
end = max(start, num_hidden_layers - 3)
span = end - start
return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)})
ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()})
if not ids:
raise ValueError(
f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the "
"'dflash' preset or an explicit comma-separated layer-id list are supported."
)
return ids

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add bounds validation for comma-separated layer IDs.

The standalone aux-layer resolver doesn't validate that parsed layer IDs fall within [0, num_hidden_layers), unlike the shared resolve_aux_layers helper (see context snippet 1). If a user passes out-of-range indices (e.g., --aux-layers=999 for a 40-layer model), the error will surface later when vLLM tries to extract hidden states, producing a less clear failure message.

🛡️ Add validation to match the shared helper's contract
 def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]:
     """Resolve aux-layer ids without importing modelopt.
 
     This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the
     'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which
     pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM
     container does not have, so the import fails. Resolve the 'dflash' preset inline
     (mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe
     default) and accept an explicit comma-separated int list. Keep in sync with modelopt.
     """
     spec = aux_layers.strip().lower()
     if spec == "dflash":
         num_draft = 5
         if num_draft == 1:
             return [num_hidden_layers // 2]
         start = min(1, num_hidden_layers - 1)
         end = max(start, num_hidden_layers - 3)
         span = end - start
         return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)})
     ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()})
     if not ids:
         raise ValueError(
             f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the "
             "'dflash' preset or an explicit comma-separated layer-id list are supported."
         )
+    for i in ids:
+        if not 0 <= i < num_hidden_layers:
+            raise ValueError(f"--aux-layers index {i} out of range [0, {num_hidden_layers})")
     return ids
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]:
"""Resolve aux-layer ids without importing modelopt.
This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the
'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins``which
pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM
container does not have, so the import fails. Resolve the 'dflash' preset inline
(mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe
default) and accept an explicit comma-separated int list. Keep in sync with modelopt.
"""
spec = aux_layers.strip().lower()
if spec == "dflash":
num_draft = 5
if num_draft == 1:
return [num_hidden_layers // 2]
start = min(1, num_hidden_layers - 1)
end = max(start, num_hidden_layers - 3)
span = end - start
return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)})
ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()})
if not ids:
raise ValueError(
f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the "
"'dflash' preset or an explicit comma-separated layer-id list are supported."
)
return ids
def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]:
"""Resolve aux-layer ids without importing modelopt.
This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the
'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins``which
pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM
container does not have, so the import fails. Resolve the 'dflash' preset inline
(mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe
default) and accept an explicit comma-separated int list. Keep in sync with modelopt.
"""
spec = aux_layers.strip().lower()
if spec == "dflash":
num_draft = 5
if num_draft == 1:
return [num_hidden_layers // 2]
start = min(1, num_hidden_layers - 1)
end = max(start, num_hidden_layers - 3)
span = end - start
return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)})
ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()})
if not ids:
raise ValueError(
f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the "
"'dflash' preset or an explicit comma-separated layer-id list are supported."
)
for i in ids:
if not 0 <= i < num_hidden_layers:
raise ValueError(f"--aux-layers index {i} out of range [0, {num_hidden_layers})")
return ids
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`
around lines 47 - 73, The _resolve_aux_layers_standalone function currently
parses comma-separated aux layer IDs but doesn't validate they lie in [0,
num_hidden_layers); update it to check each parsed id (from aux_layers.split)
against 0 <= id < num_hidden_layers and raise a ValueError with a clear message
(mirroring resolve_aux_layers) if any id is out of range or negative; keep the
existing behavior for the "dflash" preset and ensure the error references the
original aux_layers input and num_hidden_layers for clarity.

This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the
'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which
pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM
container does not have, so the import fails. Resolve the 'dflash' preset inline

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as a workaround for now, but eventually I think it's beneficial to clean up common.resolve_aux_layers from unnecessary dependencies and reuse it. Could you add a TODO here in comment?

Comment thread modelopt/__init__.py
try:
__version__ = _version("nvidia-modelopt")
except PackageNotFoundError:
# No dist metadata — e.g. the modelopt source tree is mounted directly into a

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the impact of this change. cc @kevalmorabia97 Could you please take a look? Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants