DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621
DFlash speculative decoding for MiniMax-M2.7 (FSDP2): auto mask-token, FSDP2 resume fixes, per-checkpoint draft export#1621yeyu-nvidia wants to merge 10 commits into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a DFlash draft-weight export callback, an FSDP2 buffer/DTensor monkey-patch (with DTensor-safe grad clipping), integrates both into the speculative-decoding training script (checkpoint-format detection and mask-token init), disables vLLM prefix caching for hidden-state dumps, and forwards Slurm requeue settings to the launcher executor. ChangesSpeculative Decoding Training Enhancements
Launcher Slurm Requeue Configuration
Package initialization
Benchmark spec
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/eagle_utils.py (1)
53-53: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick winUpdate
__all__to includeDFlashExportCallback.The coding guidelines require defining the public API with
__all__. SinceDFlashExportCallbackis imported bymain.py(line 39), it should be exported.-__all__ = ["EagleOfflineDataCollator", "OfflineSupervisedDataset"] +__all__ = ["DFlashExportCallback", "EagleOfflineDataCollator", "OfflineSupervisedDataset"]As per coding guidelines: "Define the public API with
__all__at the top of each Python module."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/eagle_utils.py` at line 53, The module's public API list __all__ is missing DFlashExportCallback; update the __all__ declaration (which currently lists "EagleOfflineDataCollator" and "OfflineSupervisedDataset") to also include "DFlashExportCallback" so the symbol is exported for consumers like main.py that import it.
🧹 Nitpick comments (5)
examples/speculative_decoding/fsdp2_buffer_patch.py (4)
238-240: ⚡ Quick winUse
print_rank_0to avoid noisy logs in multi-rank environments.These print statements execute on every rank, which can produce excessive output on large clusters. Consider using
print_rank_0frommodelopt.torch.utilsor guarding with a rank check.+from modelopt.torch.utils import print_rank_0 + # In apply() function: - print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility") + print_rank_0("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility") except Exception as e: - print(f"[fsdp2_buffer_patch] Patch skipped: {e}") + print_rank_0(f"[fsdp2_buffer_patch] Patch skipped: {e}")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 238 - 240, The two print calls in fsdp2_buffer_patch (the success message and the except message around fsdp2_load_full_state_dict) should be replaced with rank-safe logging: import and call print_rank_0 from modelopt.torch.utils (or guard with a rank check) so messages only appear on rank 0; update the success print and the exception print to use print_rank_0 and include the exception variable in the error message (e) while keeping the same text context.
1-3: 💤 Low valueAdd
__all__to define the public API.Per coding guidelines, each Python module should define
__all__to make the public API explicit.+__all__ = ["apply", "patch_accelerator"] + """Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling.As per coding guidelines: "Define the public API with
__all__at the top of each Python module."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 1 - 3, This module is missing an explicit public API; add a module-level __all__ declaration at the top (after the SPDX headers) listing the public names exported by this file (e.g. __all__ = ["Name1", "function_name", "CLASS_NAME"]), ensuring each symbol included matches the actual top-level functions/classes/variables defined later in the file; place the __all__ immediately below the license lines to satisfy the coding guideline.
322-326: ⚡ Quick winUse
print_rank_0here as well.def patch_accelerator(accelerator): """Replace accelerator's clip_grad_norm_ with FSDP2-safe version.""" accelerator.clip_grad_norm_ = _clip_grad_norm - print("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ " - "for FSDP2 DTensor compatibility") + from modelopt.torch.utils import print_rank_0 + print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ " + "for FSDP2 DTensor compatibility")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 322 - 326, The patch_accelerator function currently uses print to log the patch; replace that call with print_rank_0 to follow logging guidelines. Update the function to call print_rank_0("[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility") and ensure print_rank_0 is imported at top of the module (the same utility used elsewhere), leaving accelerator.clip_grad_norm_ = _clip_grad_norm unchanged.
269-270: 💤 Low valueReturn value should be on the same device as gradients.
When there are no gradients, the function returns a CPU tensor. For consistency with the non-empty case (which returns
total_normon device), consider returning on the same device.if len(grads) == 0: - return torch.tensor(0.0) + device = parameters[0].device if parameters else "cpu" + return torch.tensor(0.0, device=device)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 269 - 270, The early-return creates a CPU tensor when grads is empty; change it to return a zero tensor on the same device as the gradients by selecting the first available grad device (e.g. device = next((g.device for g in grads if g is not None), torch.device('cpu'))) and then return torch.tensor(0.0, device=device) so the empty-case matches the device of the non-empty case that returns total_norm.examples/speculative_decoding/main.py (1)
297-303: ⚡ Quick winConsider gating debug output or removing before merge.
This debug print executes on every rank and will produce verbose output on large clusters. If this is temporary debugging code, consider removing it or guarding with a debug flag.
- rank = int(os.environ.get("RANK", 0)) - dtypes = {} - for name, p in trainer.model.named_parameters(): - dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) - dtypes.setdefault(dt_key, []).append(name) - for dt, names in dtypes.items(): - print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})") + if os.environ.get("DEBUG_DTYPES"): + rank = int(os.environ.get("RANK", 0)) + dtypes = {} + for name, p in trainer.model.named_parameters(): + dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype) + dtypes.setdefault(dt_key, []).append(name) + for dt, names in dtypes.items(): + print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")As per coding guidelines: "use
print_rank_0orwarn_rank_0to avoid noisy logs."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/main.py` around lines 297 - 303, The debug loop printing per-rank dtype info (uses rank, dtypes, iterating trainer.model.named_parameters()) should be gated or replaced to avoid noisy logs: either remove the prints or wrap them so only rank 0 logs (use existing print_rank_0 or warn_rank_0 utility) and/or guard with a debug flag (e.g., if DEBUG:). Update the block that builds dtypes and the final print to call print_rank_0 (or warn_rank_0) with the formatted message so only the main process emits the output, or conditionally execute the entire loop behind a debug configuration toggle.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tools/launcher/core.py`:
- Around line 280-287: The code assumes slurm_config.additional_parameters is a
mutable dict and mutates it directly, which can cause shared-state bugs; before
assigning to executor.additional_parameters (and before mutating it to set
"requeue"), validate and normalize slurm_config.additional_parameters to a plain
dict (e.g., treat None, mappings, or other types safely), create a shallow copy
for executor.additional_parameters, and then mutate that copy; also ensure
executor.retries is updated via executor.retries = max(executor.retries, 3) as
shown. Reference: slurm_config.additional_parameters,
executor.additional_parameters, and executor.retries.
---
Outside diff comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Line 53: The module's public API list __all__ is missing DFlashExportCallback;
update the __all__ declaration (which currently lists "EagleOfflineDataCollator"
and "OfflineSupervisedDataset") to also include "DFlashExportCallback" so the
symbol is exported for consumers like main.py that import it.
---
Nitpick comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 238-240: The two print calls in fsdp2_buffer_patch (the success
message and the except message around fsdp2_load_full_state_dict) should be
replaced with rank-safe logging: import and call print_rank_0 from
modelopt.torch.utils (or guard with a rank check) so messages only appear on
rank 0; update the success print and the exception print to use print_rank_0 and
include the exception variable in the error message (e) while keeping the same
text context.
- Around line 1-3: This module is missing an explicit public API; add a
module-level __all__ declaration at the top (after the SPDX headers) listing the
public names exported by this file (e.g. __all__ = ["Name1", "function_name",
"CLASS_NAME"]), ensuring each symbol included matches the actual top-level
functions/classes/variables defined later in the file; place the __all__
immediately below the license lines to satisfy the coding guideline.
- Around line 322-326: The patch_accelerator function currently uses print to
log the patch; replace that call with print_rank_0 to follow logging guidelines.
Update the function to call print_rank_0("[fsdp2_buffer_patch] Patched
accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility") and ensure
print_rank_0 is imported at top of the module (the same utility used elsewhere),
leaving accelerator.clip_grad_norm_ = _clip_grad_norm unchanged.
- Around line 269-270: The early-return creates a CPU tensor when grads is
empty; change it to return a zero tensor on the same device as the gradients by
selecting the first available grad device (e.g. device = next((g.device for g in
grads if g is not None), torch.device('cpu'))) and then return torch.tensor(0.0,
device=device) so the empty-case matches the device of the non-empty case that
returns total_norm.
In `@examples/speculative_decoding/main.py`:
- Around line 297-303: The debug loop printing per-rank dtype info (uses rank,
dtypes, iterating trainer.model.named_parameters()) should be gated or replaced
to avoid noisy logs: either remove the prints or wrap them so only rank 0 logs
(use existing print_rank_0 or warn_rank_0 utility) and/or guard with a debug
flag (e.g., if DEBUG:). Update the block that builds dtypes and the final print
to call print_rank_0 (or warn_rank_0) with the formatted message so only the
main process emits the output, or conditionally execute the entire loop behind a
debug configuration toggle.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: ec55dcce-a920-44ca-8e39-ee3167ca3eeb
📒 Files selected for processing (4)
examples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/fsdp2_buffer_patch.pyexamples/speculative_decoding/main.pytools/launcher/core.py
| additional_parameters=getattr(slurm_config, "additional_parameters", None) or {}, | ||
| ) | ||
| if getattr(slurm_config, "requeue", False): | ||
| executor.additional_parameters["requeue"] = True | ||
| # The nemo-run sbatch wrapper only calls `scontrol requeue` when | ||
| # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default) | ||
| # disables this, so bump it when requeue is requested. | ||
| executor.retries = max(executor.retries, 3) |
There was a problem hiding this comment.
Harden additional_parameters normalization before mutation.
At Line 280 and Line 283, this assumes additional_parameters is always a mutable mapping. Since these attrs are externally supplied, normalize/validate to dict before assignment and mutate a local copy to avoid shared-state side effects.
Proposed fix
- executor = run.SlurmExecutor(
+ raw_additional_parameters = getattr(slurm_config, "additional_parameters", None)
+ additional_parameters = {}
+ if raw_additional_parameters is not None:
+ if not isinstance(raw_additional_parameters, dict):
+ raise TypeError("slurm_config.additional_parameters must be a dict")
+ additional_parameters = dict(raw_additional_parameters)
+
+ executor = run.SlurmExecutor(
account=slurm_config.account,
partition=slurm_config.partition,
qos=slurm_config.qos,
@@
- additional_parameters=getattr(slurm_config, "additional_parameters", None) or {},
+ additional_parameters=additional_parameters,
)As per coding guidelines: "Validate external input once at the interface boundary; internal code can trust those checks and avoid redundant assertions".
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| additional_parameters=getattr(slurm_config, "additional_parameters", None) or {}, | |
| ) | |
| if getattr(slurm_config, "requeue", False): | |
| executor.additional_parameters["requeue"] = True | |
| # The nemo-run sbatch wrapper only calls `scontrol requeue` when | |
| # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default) | |
| # disables this, so bump it when requeue is requested. | |
| executor.retries = max(executor.retries, 3) | |
| raw_additional_parameters = getattr(slurm_config, "additional_parameters", None) | |
| additional_parameters = {} | |
| if raw_additional_parameters is not None: | |
| if not isinstance(raw_additional_parameters, dict): | |
| raise TypeError("slurm_config.additional_parameters must be a dict") | |
| additional_parameters = dict(raw_additional_parameters) | |
| executor = run.SlurmExecutor( | |
| account=slurm_config.account, | |
| partition=slurm_config.partition, | |
| qos=slurm_config.qos, | |
| additional_parameters=additional_parameters, | |
| ) | |
| if getattr(slurm_config, "requeue", False): | |
| executor.additional_parameters["requeue"] = True | |
| # The nemo-run sbatch wrapper only calls `scontrol requeue` when | |
| # TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT. retries=0 (the default) | |
| # disables this, so bump it when requeue is requested. | |
| executor.retries = max(executor.retries, 3) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tools/launcher/core.py` around lines 280 - 287, The code assumes
slurm_config.additional_parameters is a mutable dict and mutates it directly,
which can cause shared-state bugs; before assigning to
executor.additional_parameters (and before mutating it to set "requeue"),
validate and normalize slurm_config.additional_parameters to a plain dict (e.g.,
treat None, mappings, or other types safely), create a shallow copy for
executor.additional_parameters, and then mutate that copy; also ensure
executor.retries is updated via executor.retries = max(executor.retries, 3) as
shown. Reference: slurm_config.additional_parameters,
executor.additional_parameters, and executor.retries.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1621 +/- ##
===========================================
+ Coverage 56.59% 76.71% +20.11%
===========================================
Files 507 508 +1
Lines 55794 55866 +72
===========================================
+ Hits 31579 42855 +11276
+ Misses 24215 13011 -11204
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Models without a <|mask|> token (e.g., MiniMax-M2.7) would fail with ValueError during DFlash training. Instead of requiring the user to manually set dflash_mask_token_id, add the token to the tokenizer and resize model embeddings automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
When slurm_config.requeue is True, set additional_parameters["requeue"] = True so nemo-run emits #SBATCH --requeue in the sbatch script. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1. main.py: When FSDP2 cpu_ram_efficient_loading is active, only rank 0 loads real weights on CPU; other ranks use meta device. FSDP2 distributes from rank 0. Also adds dp_replicate_size auto-computation so dp_replicate * dp_shard * cp == world_size. 2. core.py: Set retries=3 when requeue is requested. The nemo-run sbatch wrapper only calls scontrol requeue when TORCHX_MAX_RETRIES > SLURM_RESTART_COUNT — retries=0 (the default) disabled requeue. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…heckpoint resume The Pydantic recipe refactor dropped fsdp2_buffer_patch.apply() and patch_accelerator() calls and added a buffer-to-CUDA block that moved DFlash buffers before FSDP wrapping. With cpu_ram_efficient_loading, non-rank-0 processes have meta-device params, causing _infer_parameter_dtype() to return fp32 instead of bf16 on resume. Also detects FSDP distributed checkpoints (no HF model files) and loads the base model instead of trying from_pretrained on them. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… patch _infer_parameter_dtype() reads the model's current param dtype to cast the broadcasted tensor. With cpu_ram_efficient_loading, non-rank-0 processes have fp32 meta-device params for DFlash, so _infer_parameter_dtype returns fp32 and _finish() casts the correctly- broadcasted bf16 tensor back to fp32. Use bcast_dtype (from rank 0) instead. Also prints dtype_check on all ranks to verify consistency. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The Pydantic-recipe refactor (7038dec) dropped DFlashExportCallback, which had exported the DFlash draft submodule after every checkpoint save. Without it, FSDP2 sharded checkpoints (pytorch_model_fsdp_0/, no model.safetensors) get no exported-checkpoint-{step}/, so downstream vLLM deployment / acceptance-length eval has nothing to load. The verify-only comment 'export happens during training via DFlashExportCallback' was left behind but the callback itself was gone. Restore the callback (gathers only the ~328MB draft submodule across shards via get_model_state_dict, so it works under SHARDED_STATE_DICT without materializing the full base model) and wire it into main.py for DFlash recipes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…c_bench) Adds a self-contained launcher example so MiniMax-M2.7 (229B MoE) DFlash training is reproducible end-to-end, plus two common-script enablers it needs: - tools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/: - hf_online_dflash.yaml (train -> vLLM smoke -> AR eval; 8-node FSDP2) - hf_offline_dflash.yaml (vLLM hidden-state dump -> offline train) - specdec_bench.yaml (qualitative + throughput_32k, DFLASH) - accelerate_fsdp2_hybrid.yaml, chat_template_train.jinja - common/specdec/dflash_online_training.sh: honor OVERRIDE_TRANSFORMERS, ACCELERATE_CONFIG, and MIXED_PRECISION so trust_remote_code MoE models that need FSDP2 via an accelerate config (MiniMax-M2.7 on transformers 4.57.x) work. - collect_hidden_states/compute_hidden_states_vllm.py: disable prefix caching. With it on, vLLM serves shared prefixes from cache in block chunks and the hidden-state connector emits only the fresh suffix, so dumped hidden_states came out short by N*block_size vs input_ids/loss_mask (observed gaps 0/16/32). Validated: all three YAMLs resolve under launch.py --dryrun; online training and the vLLM hidden-state dump (aligned output) were exercised on CW-DFW. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
d2d0558 to
5496efc
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/speculative_decoding/fsdp2_buffer_patch.py`:
- Around line 269-270: The early-return creates a CPU tensor when no grads
exist, causing device mismatch with the normal path which returns total_norm on
the GPU (the variable device is set around line 274); update the early-return to
produce a tensor on the same device as total_norm by constructing the zero
tensor on the same device (e.g., using the device variable or
total_norm/new_tensor style) so the returned tensor's device matches the normal
path (refer to grads, total_norm, and device to locate the code).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b39e1775-7aa8-482b-84ba-391c5f4eaef7
📒 Files selected for processing (4)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.pyexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/fsdp2_buffer_patch.pyexamples/speculative_decoding/main.py
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/speculative_decoding/eagle_utils.py
- examples/speculative_decoding/main.py
| if len(grads) == 0: | ||
| return torch.tensor(0.0) |
There was a problem hiding this comment.
Return tensor device inconsistency when no gradients exist.
When there are no gradients, this returns a CPU tensor, but the normal path (line 319) returns total_norm which resides on the GPU device determined at line 274. This inconsistency could cause device mismatch errors if callers operate on the returned tensor.
Suggested fix
grads = [p.grad for p in parameters if p.grad is not None]
if len(grads) == 0:
- return torch.tensor(0.0)
+ # Return on CPU; caller should handle empty-grad case gracefully
+ return torch.tensor(0.0, device="cpu")Or, if you want consistency with the normal path's device:
grads = [p.grad for p in parameters if p.grad is not None]
if len(grads) == 0:
+ # Determine device from parameters if possible, else CPU
+ device = next((p.device for p in parameters if p.device.type != "meta"), "cpu")
+ if hasattr(next(iter(parameters), None), "_local_tensor"):
+ device = next(iter(parameters))._local_tensor.device
- return torch.tensor(0.0)
+ return torch.tensor(0.0, device=device)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if len(grads) == 0: | |
| return torch.tensor(0.0) | |
| if len(grads) == 0: | |
| # Return on CPU; caller should handle empty-grad case gracefully | |
| return torch.tensor(0.0, device="cpu") |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/speculative_decoding/fsdp2_buffer_patch.py` around lines 269 - 270,
The early-return creates a CPU tensor when no grads exist, causing device
mismatch with the normal path which returns total_norm on the GPU (the variable
device is set around line 274); update the early-return to produce a tensor on
the same device as total_norm by constructing the zero tensor on the same device
(e.g., using the device variable or total_norm/new_tensor style) so the returned
tensor's device matches the normal path (refer to grads, total_norm, and device
to locate the code).
The vLLM hidden-state dump rejected --answer-only-loss / --chat-template and
emitted no loss_mask, so DFlash offline training with answer-only loss could not
use it (the HF dump supports this but is impractical for 229B on one GPU). Mirror
the HF dump: register add_answer_only_loss_args, apply an optional override chat
template, verify {% generation %} tags, and tokenize via tokenize_with_loss_mask so
each .pt carries an aligned loss_mask. Prefix caching is already disabled, so the
dumped hidden states line up 1:1 with input_ids/loss_mask.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py (1)
246-246: ⚡ Quick winAdd defensive length check before truncating loss_mask.
If
output_hidden_statesis longer thanloss_mask, Python slice semantics return the full (too-short)loss_mask, creating a mismatch with the savedinput_idslength. The downstreamOfflineSupervisedDatasetloader does not validate shape alignment, risking silent training errors. Withenable_prefix_caching=False(line 194), lengths should match exactly—consider adding a warning to catch violations:🛡️ Suggested defensive check
+ expected_len = output_hidden_states.shape[0] + if loss_mask.shape[0] != expected_len: + import warnings + warnings.warn( + f"Conversation {conv_id}: loss_mask length {loss_mask.shape[0]} != " + f"hidden_states length {expected_len}; may indicate tokenization mismatch." + ) + output_file = output_dir / f"{conv_id}.pt" with open(output_file, "wb") as f: torch.save( { "input_ids": token_ids.cpu(), "hidden_states": output_hidden_states, "aux_hidden_states": aux_hidden_states, - "loss_mask": loss_mask[: output_hidden_states.shape[0]].cpu(), + "loss_mask": loss_mask[:expected_len].cpu(), "conversation_id": conv_id, }, f, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py` at line 246, Before truncating loss_mask for storage, ensure its length matches output_hidden_states.shape[0]; check if output_hidden_states.shape[0] > loss_mask.shape[0] and if so log a warning (or raise an error) calling out the mismatch between loss_mask and output_hidden_states lengths (mention enable_prefix_caching if relevant), otherwise perform the slice as before: "loss_mask = loss_mask[: output_hidden_states.shape[0]].cpu()". Reference the variables loss_mask and output_hidden_states and the OfflineSupervisedDataset consumer when adding this defensive check.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Line 246: Before truncating loss_mask for storage, ensure its length matches
output_hidden_states.shape[0]; check if output_hidden_states.shape[0] >
loss_mask.shape[0] and if so log a warning (or raise an error) calling out the
mismatch between loss_mask and output_hidden_states lengths (mention
enable_prefix_caching if relevant), otherwise perform the slice as before:
"loss_mask = loss_mask[: output_hidden_states.shape[0]].cpu()". Reference the
variables loss_mask and output_hidden_states and the OfflineSupervisedDataset
consumer when adding this defensive check.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 898bf778-31d1-4a22-a5f1-2672bcf3b208
📒 Files selected for processing (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
modelopt/__init__.py did __version__ = importlib.metadata.version('nvidia-modelopt')
with no guard, so importing modelopt crashes with PackageNotFoundError whenever the
source tree is on the path without dist metadata — e.g. the launcher mounts the
modelopt source into a vLLM/TRT-LLM container's site-packages rather than pip-installing
it. That broke any modelopt import in those containers: collect_hidden_states'
resolve_aux_layers (DFlash/EAGLE presets) and specdec_bench (whose guard only caught
ModuleNotFoundError, not PackageNotFoundError). Fall back to a sentinel version instead.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
…FLASH bench
Two fixes for the DFlash launcher paths in a stock vLLM container:
- compute_hidden_states_vllm.py: resolve_aux_layers('dflash') imports
modelopt.torch.speculative.plugins, which drags in the full modelopt.torch init
chain (omegaconf, ...) that the vLLM container lacks. Resolve the 'dflash' preset /
explicit layer list inline so the dump needs no modelopt at all.
- examples/.../specdec_bench.yaml: pass --block_size 8. run.py maps
speculative_num_draft_tokens=args.block_size; for DFLASH this must be set or
num_speculative_tokens is None and vLLM's max_num_seqs=concurrency*None crashes.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Around line 47-73: The _resolve_aux_layers_standalone function currently
parses comma-separated aux layer IDs but doesn't validate they lie in [0,
num_hidden_layers); update it to check each parsed id (from aux_layers.split)
against 0 <= id < num_hidden_layers and raise a ValueError with a clear message
(mirroring resolve_aux_layers) if any id is out of range or negative; keep the
existing behavior for the "dflash" preset and ensure the error references the
original aux_layers input and num_hidden_layers for clarity.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b02075cf-1305-4259-b54f-425362cba18c
📒 Files selected for processing (2)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.pytools/launcher/examples/MiniMax/MiniMax-M2.7-DFlash/specdec_bench.yaml
| def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]: | ||
| """Resolve aux-layer ids without importing modelopt. | ||
|
|
||
| This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the | ||
| 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which | ||
| pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM | ||
| container does not have, so the import fails. Resolve the 'dflash' preset inline | ||
| (mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe | ||
| default) and accept an explicit comma-separated int list. Keep in sync with modelopt. | ||
| """ | ||
| spec = aux_layers.strip().lower() | ||
| if spec == "dflash": | ||
| num_draft = 5 | ||
| if num_draft == 1: | ||
| return [num_hidden_layers // 2] | ||
| start = min(1, num_hidden_layers - 1) | ||
| end = max(start, num_hidden_layers - 3) | ||
| span = end - start | ||
| return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)}) | ||
| ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()}) | ||
| if not ids: | ||
| raise ValueError( | ||
| f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the " | ||
| "'dflash' preset or an explicit comma-separated layer-id list are supported." | ||
| ) | ||
| return ids | ||
|
|
There was a problem hiding this comment.
Add bounds validation for comma-separated layer IDs.
The standalone aux-layer resolver doesn't validate that parsed layer IDs fall within [0, num_hidden_layers), unlike the shared resolve_aux_layers helper (see context snippet 1). If a user passes out-of-range indices (e.g., --aux-layers=999 for a 40-layer model), the error will surface later when vLLM tries to extract hidden states, producing a less clear failure message.
🛡️ Add validation to match the shared helper's contract
def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]:
"""Resolve aux-layer ids without importing modelopt.
This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the
'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which
pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM
container does not have, so the import fails. Resolve the 'dflash' preset inline
(mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe
default) and accept an explicit comma-separated int list. Keep in sync with modelopt.
"""
spec = aux_layers.strip().lower()
if spec == "dflash":
num_draft = 5
if num_draft == 1:
return [num_hidden_layers // 2]
start = min(1, num_hidden_layers - 1)
end = max(start, num_hidden_layers - 3)
span = end - start
return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)})
ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()})
if not ids:
raise ValueError(
f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the "
"'dflash' preset or an explicit comma-separated layer-id list are supported."
)
+ for i in ids:
+ if not 0 <= i < num_hidden_layers:
+ raise ValueError(f"--aux-layers index {i} out of range [0, {num_hidden_layers})")
return ids📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]: | |
| """Resolve aux-layer ids without importing modelopt. | |
| This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the | |
| 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which | |
| pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM | |
| container does not have, so the import fails. Resolve the 'dflash' preset inline | |
| (mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe | |
| default) and accept an explicit comma-separated int list. Keep in sync with modelopt. | |
| """ | |
| spec = aux_layers.strip().lower() | |
| if spec == "dflash": | |
| num_draft = 5 | |
| if num_draft == 1: | |
| return [num_hidden_layers // 2] | |
| start = min(1, num_hidden_layers - 1) | |
| end = max(start, num_hidden_layers - 3) | |
| span = end - start | |
| return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)}) | |
| ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()}) | |
| if not ids: | |
| raise ValueError( | |
| f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the " | |
| "'dflash' preset or an explicit comma-separated layer-id list are supported." | |
| ) | |
| return ids | |
| def _resolve_aux_layers_standalone(aux_layers: str, num_hidden_layers: int) -> list[int]: | |
| """Resolve aux-layer ids without importing modelopt. | |
| This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the | |
| 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which | |
| pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM | |
| container does not have, so the import fails. Resolve the 'dflash' preset inline | |
| (mirroring ``modeling_dflash.build_target_layer_ids`` with num_draft=5, the recipe | |
| default) and accept an explicit comma-separated int list. Keep in sync with modelopt. | |
| """ | |
| spec = aux_layers.strip().lower() | |
| if spec == "dflash": | |
| num_draft = 5 | |
| if num_draft == 1: | |
| return [num_hidden_layers // 2] | |
| start = min(1, num_hidden_layers - 1) | |
| end = max(start, num_hidden_layers - 3) | |
| span = end - start | |
| return sorted({round(start + (i * span) / (num_draft - 1)) for i in range(num_draft)}) | |
| ids = sorted({int(t) for t in aux_layers.split(",") if t.strip()}) | |
| if not ids: | |
| raise ValueError( | |
| f"--aux-layers={aux_layers!r}: in the stock vLLM container (no modelopt) only the " | |
| "'dflash' preset or an explicit comma-separated layer-id list are supported." | |
| ) | |
| for i in ids: | |
| if not 0 <= i < num_hidden_layers: | |
| raise ValueError(f"--aux-layers index {i} out of range [0, {num_hidden_layers})") | |
| return ids |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`
around lines 47 - 73, The _resolve_aux_layers_standalone function currently
parses comma-separated aux layer IDs but doesn't validate they lie in [0,
num_hidden_layers); update it to check each parsed id (from aux_layers.split)
against 0 <= id < num_hidden_layers and raise a ValueError with a clear message
(mirroring resolve_aux_layers) if any id is out of range or negative; keep the
existing behavior for the "dflash" preset and ensure the error references the
original aux_layers input and num_hidden_layers for clarity.
| This dump runs in a stock vLLM container. ``common.resolve_aux_layers`` resolves the | ||
| 'dflash'/'eagle' presets by importing ``modelopt.torch.speculative.plugins`` — which | ||
| pulls in the full ``modelopt.torch`` init chain (omegaconf, etc.) that the vLLM | ||
| container does not have, so the import fails. Resolve the 'dflash' preset inline |
There was a problem hiding this comment.
LGTM as a workaround for now, but eventually I think it's beneficial to clean up common.resolve_aux_layers from unnecessary dependencies and reuse it. Could you add a TODO here in comment?
| try: | ||
| __version__ = _version("nvidia-modelopt") | ||
| except PackageNotFoundError: | ||
| # No dist metadata — e.g. the modelopt source tree is mounted directly into a |
There was a problem hiding this comment.
Not sure the impact of this change. cc @kevalmorabia97 Could you please take a look? Thanks
What
Brings up DFlash block-diffusion speculative decoding for large MoE targets (MiniMax-M2.7, 229B) trained under accelerate FSDP2, and fixes the regressions that broke checkpoint resume and per-checkpoint draft export.
Commits
build_slurm_executor+ FSDP2 cpu_ram_efficient_loading for 229B on multi-node.fsdp2_buffer_patch.py): handle non-DTensor buffers infsdp2_load_full_state_dict, broadcast dtype codes from rank 0, and an FSDP2-safeclip_grad_norm_. Required because MiniMax-M2.7 pins transformers 4.57.x (no nativeParallelismConfig).DFlashExportCallback(this PR's headline): the Pydantic-recipe refactor (7038dec) dropped the callback that exported the draft submodule after each checkpoint save, leaving a stale "export happens during training via DFlashExportCallback" comment with no callback. FSDP2 SHARDED_STATE_DICT checkpoints carry nomodel.safetensors, so without it there is nothing for vLLM / acceptance-length eval to load. The callback gathers only the ~328 MB draft submodule across shards viaget_model_state_dict(..., submodules={dflash_module}, full_state_dict=True, cpu_offload=True)— works under SHARDED_STATE_DICT without materializing the 229B base — and writesexported-checkpoint-{step}/.Testing
🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Chores