Skip to content

add: DFlash block diffusion speculative decoding#1211

Open
ChenhanYu wants to merge 6 commits intomainfrom
chenhany/dflash-v2
Open

add: DFlash block diffusion speculative decoding#1211
ChenhanYu wants to merge 6 commits intomainfrom
chenhany/dflash-v2

Conversation

@ChenhanYu
Copy link
Copy Markdown
Collaborator

@ChenhanYu ChenhanYu commented Apr 8, 2026

DFlash (Block Diffusion for Flash Speculative Decoding) predicts an entire block of tokens in a single forward pass using masked parallel prediction with KV injection from the target model's hidden states.

Key features:

  • Feature fusion (multi-layer hidden states -> FC + RMSNorm)
  • KV injection (fused features as K/V in every draft layer with QK-norm)
  • Random anchor sampling with bidirectional intra-block attention
  • Logit distillation with exponential loss decay (gamma weighting)
  • Multi-node DDP training with checkpoint resume
  • Export to z-lab compatible HF format
  • Online validation (context-dependent ground truth)

Training recipe: modelopt_recipes/general/speculative_decoding/dflash.yaml
Results: examples/speculative_decoding/doc/dflash_results.md

ModelOpt Eval (online validation, osl=512)

Dataset z-lab ModelOpt (306K) Diff
gsm8k 4.10 5.19 +1.09
MT-Bench 3.58 4.36 +0.78

z-lab Official Eval (dflash.benchmark, osl=512)

Dataset z-lab ModelOpt (306K) Diff
gsm8k 5.00 4.08 -0.92
MT-Bench 3.28 2.99 -0.29

z-lab model trained with block_size=16. ModelOpt trained with block_size=8.

Evaluation Method Impact (gsm8k)

Eval Method z-lab checkpoint ModelOpt (306K)
Fixed GT (ModelOpt eval) 2.95 4.23
Online GT (ModelOpt eval) 4.10 5.19
z-lab official eval 5.00 4.08

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added DFlash speculative decoding mode with parallel block prediction support.
    • Included training launchers and MT-Bench evaluation scripts for DFlash models.
    • Added online acceptance rate validation for improved inference verification.
  • Documentation

    • DFlash quick start guide with configuration parameters and training examples.
    • Performance results and benchmarks for DFlash-trained models.

@ChenhanYu ChenhanYu requested review from a team as code owners April 8, 2026 20:12
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR introduces comprehensive DFlash (Block Diffusion for Speculative Decoding) support to ModelOpt, including configuration layers, model conversion and plugins for HuggingFace, training infrastructure updates, export utilities, validation methods, and end-to-end recipe and launcher tooling for training and evaluation.

Changes

Cohort / File(s) Summary
Core Configuration & Mode Registry
modelopt/torch/speculative/config.py, modelopt/torch/speculative/mode.py
Added DFlashConfig class with 8 configurable fields, DFLASH_DEFAULT_CFG constant, and _get_dflash_default_config() helper. Registered DFlashModeDescriptor in mode registry with conversion/restore routing to dflash-specific entrypoints.
DFlash Module Package
modelopt/torch/speculative/dflash/__init__.py, modelopt/torch/speculative/dflash/dflash_model.py, modelopt/torch/speculative/dflash/conversion.py, modelopt/torch/speculative/dflash/default_config.py
Established new DFlash subpackage with base DFlashModel class (extending DynamicModule), conversion/restoration functions managing model normalization and config merging, and default architecture parameters (layer count, norm epsilon, attention settings).
HuggingFace DFlash Integration
modelopt/torch/speculative/plugins/hf_dflash.py, modelopt/torch/speculative/plugins/__init__.py
Introduced HFDFlashModel with 897 lines implementing block-wise parallel draft generation, feature fusion, multi-layer draft decoder with KV injection, anchor sampling, hard/soft loss computation, training accuracy reporting, and pseudo-speculative generation. Updated plugin initializer for conditional import.
Export Infrastructure
modelopt/torch/export/plugins/hf_spec_export.py
Added DFlashExporter class extracting dflash-prefixed state dict entries, generating config.json with DFlashDraftModel architecture and dflash-specific metadata, and writing safetensors checkpoint. Expanded ALL_SPEC_MODES constant to include "dflash".
Training Script Updates
examples/speculative_decoding/main.py
Extended TrainingArguments.mode to support "dflash". Modified _load_config to parse dflash YAML section and return triple instead of pair. Added dflash model conversion path and set answer_only_loss=True when building dflash datasets.
Training Utilities & Data Handling
examples/speculative_decoding/eagle_utils.py, modelopt/torch/utils/plugins/transformers_dataset.py
Added answer_only_loss parameter to make_eagle_supervised_data_module. Enhanced EagleTrainingPlot with per-step accuracy logging, single-GPU validation enforcement, raw-model unwrapping, and improved AR validation flow. Implemented auto-detection of chat template style, generation tag injection, and assistant-only masking in LanguageDataCollator.
Validation & Evaluation
modelopt/torch/speculative/utils.py
Added AcceptanceRateValidation.validate_online(...) method implementing online AR validation by comparing draft tokens against target-model posterior predictions without ground-truth dataset dependency.
Export Script Updates
examples/speculative_decoding/scripts/export_hf_checkpoint.py
Removed --trust_remote_code CLI flag and corresponding argument from load_vlm_or_llm call. Reformatted export_speculative_decoding invocation to multi-line format.
Unit & GPU Tests
tests/unit/torch/speculative/plugins/test_hf_dflash.py, tests/gpu/torch/speculative/plugins/test_hf_dflash.py
Added comprehensive CPU unit tests for conversion end-to-end flow, parameter freezing, save/restore round-trips, and mask/layer-id helper functions. Added GPU tests validating draft-module output shape, determinism in eval mode, training forward pass with loss/accuracy, and gradient connectivity.
Training Recipes & Launchers
modelopt_recipes/general/speculative_decoding/dflash.yaml, tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml, tools/launcher/common/dflash/online_training.sh, tools/launcher/common/dflash/ar_eval_mtbench.sh
Introduced base DFlash recipe YAML with model, data, training, and dflash configuration sections. Created Qwen3-8B launcher pipeline with two-task training and MT-Bench evaluation. Added training bootstrap script with multi-node distributed setup and auto-IP detection. Implemented MT-Bench AR evaluation script with online validation and per-category reporting.
Documentation
examples/speculative_decoding/README.md, examples/speculative_decoding/doc/dflash_results.md
Appended DFlash introduction section to README with quick-start, configuration parameters, export command, and results reference. Created results document with Qwen3-8B training metrics, online validation acceptance rates, and comparison against baselines with methodology notes.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'add: DFlash block diffusion speculative decoding' directly and clearly describes the main feature being added—DFlash block diffusion speculative decoding—which aligns with the comprehensive changeset introducing the DFlash optimization method across configuration, model implementation, export, and training pipeline components.
Docstring Coverage ✅ Passed Docstring coverage is 84.69% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed PR does not introduce torch.load weights_only vulnerabilities, numpy.load allow_pickle issues, hardcoded trust_remote_code=True, eval/exec on untrusted input, or # nosec bypasses.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenhany/dflash-v2

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 8, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1211/

Built to branch gh-pages at 2026-04-09 03:26 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@ChenhanYu ChenhanYu force-pushed the chenhany/dflash-v2 branch from f990e5a to e45cc37 Compare April 8, 2026 20:18
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

🧹 Nitpick comments (1)
examples/speculative_decoding/doc/dflash_results.md (1)

5-85: Add reproducibility metadata alongside reported metrics.

Please include the exact eval command(s), seed(s), and checkpoint artifact identifier(s) used for these tables so others can reproduce the numbers without guessing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/doc/dflash_results.md` around lines 5 - 85, Add
reproducibility metadata to dflash_results.md by appending exact evaluation
commands, random seed values, and checkpoint artifact identifiers used to
produce the reported tables (e.g., the Key Metrics, MT-Bench Per-Category AR,
Comparison with z-lab, and Evaluation Method Impact sections). For each
table/experiment (such as the gsm8k and MT-Bench runs and the ModelOpt 306K
checkpoint), include the full CLI or python invocation (including flags like
block_size, osl, sequence length, draft layers, anchors per sample), the seed(s)
used, and the storage/registry identifiers or S3/GS/artifact names for the
specific checkpoint(s) (e.g., the 306K checkpoint), plus the environment (GPU
count/nodes) and any non-default preprocessing or eval-mode choices (Fixed GT vs
Online GT). Place this reproducibility block near the top or directly under "Key
Metrics" so readers can immediately reproduce the results.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 274-279: The call to load_dataset in the validate_ar invocation
hard-codes a local mirror path and should be made configurable or
fault-tolerant; update the validate_ar call site (where model=raw_model,
tokenizer=kwargs["processing_class"], ds=load_dataset(...),
device=next(raw_model.parameters()).device, num_samples=8) to derive the dataset
source from a new parameter or env var (e.g. mt_bench_dataset) with a default of
"HuggingFaceH4/mt_bench_prompts", and wrap the load_dataset call in a try/except
that falls back to the default Hub ID if the local path load fails. Ensure the
new parameter is plumbed from kwargs or config so single-GPU / public setups use
the Hub dataset by default.

In `@examples/speculative_decoding/main.py`:
- Around line 317-334: The fallback currently loads trainer_state.json into
trainer.state but calls trainer.train() without resume_from_checkpoint, so the
dataloader/loop restarts at step 0; update the fallback to load the JSON into
trainer.state (using trainer.state.load_from_json(state_file)) and then call
trainer.train(resume_from_checkpoint=checkpoint) so Hugging Face Trainer
receives the checkpoint and correctly resumes the dataloader and training loop;
ensure you still print the resumed step/max_steps using resumed_step and
resumed_max_steps as before.

In `@examples/speculative_decoding/README.md`:
- Around line 369-378: The README table lists
dflash.dflash_architecture_config.mask_token_id as having default "auto" but the
recipe doesn't define it; either add a mask_token_id key with value "auto" to
the dflash.yaml recipe that backs speculative decoding (so the recipe explicitly
documents the default), or update the README table entry to indicate
mask_token_id is optional/auto-inferred (e.g., mark as "auto
(optional/inferred)") so the documentation matches the actual recipe; locate the
mask_token_id entry under the dflash architecture config in the spec/config for
speculative decoding and make the corresponding change.

In `@examples/speculative_decoding/scripts/export_hf_checkpoint.py`:
- Line 41: The script calls load_vlm_or_llm(args.model_path, torch_dtype="auto")
without exposing trust_remote_code; add a CLI flag/argument (e.g.,
args.trust_remote_code defaulting to False) and pass it through to
load_vlm_or_llm as load_vlm_or_llm(args.model_path, torch_dtype="auto",
trust_remote_code=args.trust_remote_code) so callers can opt into remote code
when needed; update the argument parser to document the flag and set the default
to False.

In `@examples/speculative_decoding/train_dflash.py`:
- Around line 292-293: The code currently hardcodes
load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts") which requires an
internal mirror; change the load to accept an external dataset identifier via an
argument or environment variable and fall back to the public
"HuggingFaceH4/mt_bench_prompts" if none is provided. Update where ds is created
(the load_dataset call near validator = HFARValidation(raw_model, tokenizer)) to
read a CLI flag or os.environ key (e.g., MT_BENCH_DATASET) and pass that value
into load_dataset, defaulting to "HuggingFaceH4/mt_bench_prompts" so the script
works standalone while still allowing an internal path when supplied.
- Around line 150-153: Add a new CLI flag (e.g., --trust-remote-code) that
defaults to False (use action='store_true') and expose it as
args.trust_remote_code; then remove the hardcoded True and pass
args.trust_remote_code into both AutoModelForCausalLM.from_pretrained(...) and
AutoTokenizer.from_pretrained(...). Update any argument parsing logic where args
is created so the new flag is available to the model/tokenizer loading calls.

In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 272-316: The _export_config method currently hardcodes
"torch_dtype": "bfloat16" causing a mismatch when export(dtype=...) saves
model.safetensors in a different dtype; update the export flow so export(...)
passes the chosen dtype into _export_config (add a dtype parameter to
_export_config) and have _export_config serialize that dtype value for the
"torch_dtype" field instead of the literal "bfloat16"; update all call sites
(e.g., wherever export calls _export_config) and ensure the same change is
applied for the similar block around lines 328-343 so the config.json always
matches the actual exported tensor dtype.

In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 83-85: The module-level assignment of _MLP_CLS, _NORM_CLS,
_ROTARY_CLS, and _rotate_half must be made instance-local so converting a second
model doesn't mutate shared globals: remove or stop relying on the top-level
assignment from _resolve_model_components("llama") and instead have modify()
store the resolved classes/func (result of _resolve_model_components) on the
converter/model instance (e.g., self._mlp_cls, self._norm_cls, self._rotary_cls,
self._rotate_half). Update apply_rotary_pos_emb(), the meta-buffer recovery
path, and any other callers that read the globals (including the locations
referenced around lines ~99-106, ~287-289, ~510-514) to read the instance
attributes instead of module globals, and ensure any factory or conversion
helpers receive the instance (or the specific classes) so they use the
per-instance types rather than the module-level variables.
- Around line 832-845: The one-off debug block guarded by self._psg_debug
accesses base_token.item() which fails for batch size B>1; update that debug
block (the code that sets self._psg_debug, selects base_outputs.hidden_states
for self.target_layer_ids, and prints seq_len / dflash_block_size /
self.mask_token_id) to either remove the prints entirely or log a batch-safe
representation of base_token (e.g., replace base_token.item() with
base_token[:,0].cpu().tolist() or base_token.reshape(-1).cpu().tolist()), and
ensure any other printed tensors (sel/th_dbg) are summarized (e.g.,
shapes/norms) to avoid per-sample indexing errors.

In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 156-157: The initializer currently calls
self._ensure_generation_tags() whenever self.answer_only_loss is true, which
swaps templates to text-only variants and breaks multimodal flows (e.g.,
VisionLanguageDataCollator) that expect message['content'] as blocks; modify the
logic in the __init__ (and the similar block around the 220-271 region) to only
perform the template rewrite for text-only collators — e.g., add a guard that
checks a multimodal flag or the collator class/type (or introduce an explicit
is_text_only property) before calling _ensure_generation_tags(), or
alternatively supply multimodal-safe fallback templates and use those when the
collator indicates multimodal input; ensure VisionLanguageDataCollator path does
not trigger the text-only swap.
- Around line 391-418: The collator is dropping samples without an assistant
turn regardless of answer-only mode; update the logic in the block using
messages/conversations (and the call to _sharegpt_to_openai_messages and
print_rank_0) so that the assistant-role existence check and skipping only run
when self.answer_only_loss is True, otherwise accept and batch the sample as-is
(i.e., append messages or converted conversations without the assistant-role
filter); keep references to the same symbols (messages, conversations, batch,
_sharegpt_to_openai_messages, print_rank_0) and ensure the existing dummy-batch
fallback remains unchanged.
- Around line 349-356: The assistant_masks are aligned to the original input_ids
but labels have been shifted with labels[..., :-1] = input_ids[..., 1:], so
shift assistant_mask by one before applying it to labels; inside the
answer_only_loss block (where assistant_mask is read from tokenized_examples)
compute a shifted mask like mask_shifted = assistant_mask[..., 1:] (or align to
labels' shape) and then set labels[..., :-1][mask_shifted == 0] =
IGNORE_TOKEN_ID (ensure tensor type and shape match, e.g., only do this when
assistant_mask is a torch.Tensor and mask_shifted.any()).

In `@tools/launcher/common/dflash/ar_validate.sh`:
- Around line 111-113: The code is calling validator.validate(...) which runs
the offline HFAR path; change it to call the online validation loop by invoking
validator.validate_online(osl=32, input_ids=input_ids, steps=3) (or the correct
parameter names for validate_online) and keep extracting the AR result (e.g.,
"_, ar = validator.validate_online(...)") and appending ar to ars so the script
uses AcceptanceRateValidation.validate_online() instead of
HFARValidation.validate().
- Around line 63-66: The calls to AutoModelForCausalLM.from_pretrained and
AutoTokenizer.from_pretrained currently hardcode trust_remote_code=True; change
them to read a new environment variable (e.g., ALLOW_TRUST_REMOTE_CODE) that
defaults to false and convert it to a boolean (treat "1", "true", "yes"
case-insensitively as true). Pass that boolean into the trust_remote_code
parameter for both AutoModelForCausalLM.from_pretrained and
AutoTokenizer.from_pretrained so remote-code execution is opt-in when
HF_MODEL_CKPT is used.

In `@tools/launcher/common/dflash/online_training.sh`:
- Line 34: The pip install line currently uses an unquoted comparison operator
so the shell interprets ">" as redirection; update the package spec in the
script by quoting or escaping the version constraint (e.g., change the pip
command to use "huggingface-hub>=1.2.1" or huggi ngface-hub\>=1.2.1) so the
minimum version constraint is passed to pip rather than redirecting stdout.
- Around line 181-223: The inline python invoked via the python3 -c block
insecurely interpolates shell variables (e.g., ${DFLASH_BLOCK_SIZE},
${DFLASH_NUM_LAYERS}, ${MASK_ARG}, ${HF_MODEL_CKPT}) and hardcodes
trust_remote_code=True in AutoModelForCausalLM.from_pretrained and
AutoTokenizer.from_pretrained; fix by changing the script to read values from
environment variables or sys.argv inside the Python snippet (use os.environ or
argparse) instead of shell interpolation, validate/convert numeric values
(dflash_block_size, num_hidden_layers) there, and make trust_remote_code
configurable (read from env/default to False) before calling from_pretrained so
no untrusted remote code is loaded by default.

In `@tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml`:
- Around line 50-54: Replace the speculative algorithm flag value from EAGLE3 to
DRAFT_TARGET so VLLM treats the DFlash output as a generic draft model; update
the entry in the YAML where "--speculative_algorithm EAGLE3" appears to
"--speculative_algorithm DRAFT_TARGET" (this ensures the DFlash draft produced
by common/dflash/online_training.sh is loaded instead of mapping to the EAGLE3
backend in examples/specdec_bench/specdec_bench/models/vllm.py).

---

Nitpick comments:
In `@examples/speculative_decoding/doc/dflash_results.md`:
- Around line 5-85: Add reproducibility metadata to dflash_results.md by
appending exact evaluation commands, random seed values, and checkpoint artifact
identifiers used to produce the reported tables (e.g., the Key Metrics, MT-Bench
Per-Category AR, Comparison with z-lab, and Evaluation Method Impact sections).
For each table/experiment (such as the gsm8k and MT-Bench runs and the ModelOpt
306K checkpoint), include the full CLI or python invocation (including flags
like block_size, osl, sequence length, draft layers, anchors per sample), the
seed(s) used, and the storage/registry identifiers or S3/GS/artifact names for
the specific checkpoint(s) (e.g., the 306K checkpoint), plus the environment
(GPU count/nodes) and any non-default preprocessing or eval-mode choices (Fixed
GT vs Online GT). Place this reproducibility block near the top or directly
under "Key Metrics" so readers can immediately reproduce the results.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 57539396-3987-4d1b-9a04-2d7606707bb5

📥 Commits

Reviewing files that changed from the base of the PR and between 4a70040 and f990e5a.

📒 Files selected for processing (24)
  • doc/results/dflash_results.html
  • examples/speculative_decoding/README.md
  • examples/speculative_decoding/doc/dflash_results.md
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • examples/speculative_decoding/train_dflash.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/dflash/__init__.py
  • modelopt/torch/speculative/dflash/conversion.py
  • modelopt/torch/speculative/dflash/default_config.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • modelopt/torch/speculative/mode.py
  • modelopt/torch/speculative/plugins/__init__.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • modelopt/torch/speculative/utils.py
  • modelopt/torch/utils/plugins/transformers_dataset.py
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • tests/gpu/torch/speculative/plugins/test_hf_dflash.py
  • tests/unit/torch/speculative/plugins/test_hf_dflash.py
  • tools/launcher/common/dflash/ar_validate.sh
  • tools/launcher/common/dflash/online_training.sh
  • tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml

Comment on lines +317 to +334
try:
trainer.train(resume_from_checkpoint=checkpoint)
except ValueError as e:
if "parameter group" in str(e):
print_rank_0(
f"Optimizer state mismatch: {e}\n"
f"Resuming with fresh optimizer from {checkpoint}"
)
state_file = os.path.join(checkpoint, "trainer_state.json")
if os.path.isfile(state_file):
state = json.load(open(state_file))
resumed_step = state.get("global_step", 0)
resumed_max_steps = state.get("max_steps", -1)
print_rank_0(f"Resuming from step {resumed_step}/{resumed_max_steps}")
if resumed_max_steps > 0:
training_args.max_steps = resumed_max_steps
trainer.state = trainer.state.load_from_json(state_file)
trainer.train()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /repo && find . -name "main.py" -path "*/speculative_decoding/*" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

cd /repo && cat -n examples/speculative_decoding/main.py | sed -n '310,340p'

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

cd /repo && wc -l examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

pwd && ls -la

Repository: NVIDIA/Model-Optimizer

Length of output: 2199


🏁 Script executed:

find . -name "main.py" -path "*/speculative_decoding/*" -type f 2>/dev/null

Repository: NVIDIA/Model-Optimizer

Length of output: 105


🏁 Script executed:

if [ -f "examples/speculative_decoding/main.py" ]; then wc -l examples/speculative_decoding/main.py; fi

Repository: NVIDIA/Model-Optimizer

Length of output: 107


🏁 Script executed:

if [ -f "examples/speculative_decoding/main.py" ]; then cat -n examples/speculative_decoding/main.py | sed -n '310,340p'; fi

Repository: NVIDIA/Model-Optimizer

Length of output: 1764


🏁 Script executed:

web_search
HuggingFace Transformers Trainer.train() resume_from_checkpoint state initialization behavior

Repository: NVIDIA/Model-Optimizer

Length of output: 295


🏁 Script executed:

rg -A 10 "trainer\.train\(" examples/ --type py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 3879


🏁 Script executed:

rg "trainer\.state\s*=" examples/ --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 496


🏁 Script executed:

grep -r "load_from_json" examples/ --include="*.py" -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 496


🌐 Web query:

HuggingFace Transformers Trainer.train() resume_from_checkpoint parameter state initialization

💡 Result:

In 🤗 Transformers, Trainer.train(resume_from_checkpoint=...) resumes by loading saved state on top of freshly created objects:

  • What the parameter means

    • If resume_from_checkpoint is a path, it resumes from that checkpoint folder.
    • If it is True, it finds and uses the last checkpoint in args.output_dir.
    • When provided, training resumes from the model + optimizer + LR scheduler states found in that checkpoint. [1]
  • What gets (re)initialized vs loaded

    1. Model weights
      • The model is first present in memory (from your model= or model_init), then Trainer loads the checkpoint weights into it (typically from pytorch_model.bin / model.safetensors, depending on setup/version). In the Trainer implementation you can see it loading WEIGHTS_NAME and calling model.load_state_dict(..., strict=False). [2]
    2. Trainer state (self.state)
      • If {checkpoint}/trainer_state.json exists, Trainer loads it into self.state and uses it to compute how much training was already done (e.g., derives epochs_trained from global_step). [2]
    3. Optimizer + LR scheduler state
      • The optimizer and scheduler objects are created as usual (i.e., “fresh”), and then Trainer calls _load_optimizer_and_scheduler(checkpoint).
      • If {checkpoint}/optimizer.pt and {checkpoint}/scheduler.pt exist (and you are not using DeepSpeed), Trainer loads those state dicts into the already-created optimizer/scheduler via load_state_dict. [2]
      • If you are using DeepSpeed, Transformers’ Trainer code path skips this because DeepSpeed handles resume as part of its own init/checkpoint logic. [2]
    4. RNG state
      • Checkpoints also save RNG state (e.g., rng_state.pth / rank-specific variants), so resuming can continue deterministically when those files are present. [2]
  • Common implication

    • “Resume” is not “don’t initialize anything”; it’s “initialize normal training objects, then overwrite their internal state from checkpoint files if they exist.” This is why mismatches can happen if your optimizer parameter groups differ between runs (e.g., changed trainable params, LoRA targets, etc.), causing optimizer.load_state_dict(...) to fail.

Sources: Hugging Face Trainer docs for resume_from_checkpoint [1] and Trainer implementation showing loading of trainer_state.json, optimizer/scheduler state, and RNG save/load logic [2].

[1] (huggingface.co)
[2] (huggingface.co)

Citations:


🏁 Script executed:

rg "def train" examples/speculative_decoding/main.py -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 211


🏁 Script executed:

# Check what imports are in the file
head -30 examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1490


🏁 Script executed:

# Look for any trainer initialization or state handling patterns
rg "Trainer\(" examples/speculative_decoding/main.py -B 5 -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if there's any documentation or comments about state resumption in similar files
rg "resume" examples/llm_qat/main.py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 520


🏁 Script executed:

# Check the actual behavior by looking at what the code tries to accomplish
cat -n examples/speculative_decoding/main.py | sed -n '314,340p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1581


🏁 Script executed:

# Search for any notes or context about why the manual state loading is done
rg "Optimizer state mismatch" . --type py -B 5 -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 1543


🌐 Web query:

HuggingFace Trainer trainer.state initialization without resume_from_checkpoint parameter fresh initialization

💡 Result:

In 🤗 Transformers, a fresh (non-resumed) trainer.state comes from the default TrainerState created when you instantiate Trainer, and it is only replaced if you pass (or propagate) a resume_from_checkpoint and a trainer_state.json is found.

What happens on a fresh run (no resume)

  1. Trainer.__init__ creates a new state

    self.state = TrainerState(
        is_local_process_zero=...,
        is_world_process_zero=...,
        stateful_callbacks=[...],
    )

    (see Trainer.__init__) [1]

  2. Default TrainerState values
    TrainerState defaults include global_step=0, max_steps=0, etc. [2]

  3. At training start, epoch is set to 0
    Inside _inner_training_loop, Trainer sets:

    self.state.epoch = 0

    and only then checks whether it should load a checkpoint state. [3]

When/why it stops being “fresh”

  • In Trainer.train(...), if you pass resume_from_checkpoint=True, it will resolve the last checkpoint in output_dir via get_last_checkpoint(...). [4]
  • In _inner_training_loop, if resume_from_checkpoint points to a checkpoint and that folder contains trainer_state.json, Trainer does:
    self.state = TrainerState.load_from_json(.../trainer_state.json)
    [3]

How to force “fresh initialization” in practice

  • Make sure you are not doing any of these:
    • trainer.train(resume_from_checkpoint=True) (auto-picks last checkpoint) [4]
    • trainer.train(resume_from_checkpoint=".../checkpoint-XXXX") [3]
    • Passing TrainingArguments(resume_from_checkpoint=...) through an example script/CLI (many scripts forward it into trainer.train(...)) [5]

A robust pattern:

training_args = TrainingArguments(..., resume_from_checkpoint=None)
trainer = Trainer(..., args=training_args)
trainer.train(resume_from_checkpoint=None)  # or just trainer.train()

If you still see resume behavior, it’s almost always because some wrapper/script is passing resume_from_checkpoint=True/path under the hood.


Sources:
[1] Trainer initializes self.state = TrainerState(...) in __init__ (Transformers trainer.py). (github.com)
[2] TrainerState default fields (HF docs: Callbacks → TrainerState). (huggingface.co)
[3] _inner_training_loop sets epoch=0 and loads TrainerState from trainer_state.json only when resuming. (github.com)
[4] Trainer.train: if resume_from_checkpoint is True, it resolves get_last_checkpoint(output_dir). (github.com)
[5] TrainingArguments includes a resume_from_checkpoint field (can be propagated by scripts). (huggingface.co)

Citations:


🏁 Script executed:

# Check if there's any setup or initialization of trainer.state before the resume logic
rg "trainer\.state" examples/speculative_decoding/main.py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 419


🏁 Script executed:

# Look for the trainer initialization to understand the baseline state
rg "Trainer\(" examples/speculative_decoding/main.py -B 10 | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 48


The fallback doesn't properly resume training—it restarts the dataloader from step 0.

Loading trainer_state.json into trainer.state before trainer.train() without resume_from_checkpoint is insufficient. In Hugging Face Trainer, the dataloader position and training loop resumption depend on passing the resume_from_checkpoint parameter to trainer.train(). Without it, the dataloader begins at step 0 despite the manually-loaded state values. The printed "Resuming from step X" message is misleading—training actually replays from the start, defeating the purpose of the fallback recovery path (lines 315–316).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 317 - 334, The fallback
currently loads trainer_state.json into trainer.state but calls trainer.train()
without resume_from_checkpoint, so the dataloader/loop restarts at step 0;
update the fallback to load the JSON into trainer.state (using
trainer.state.load_from_json(state_file)) and then call
trainer.train(resume_from_checkpoint=checkpoint) so Hugging Face Trainer
receives the checkpoint and correctly resumes the dataloader and training loop;
ensure you still print the resumed step/max_steps using resumed_step and
resumed_max_steps as before.

model = load_vlm_or_llm(
args.model_path, torch_dtype="auto", trust_remote_code=args.trust_remote_code
)
model = load_vlm_or_llm(args.model_path, torch_dtype="auto")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Check exporter CLI and model load call:"
rg -n -C2 'add_argument\("--trust_remote_code"|load_vlm_or_llm\(' examples/speculative_decoding/scripts/export_hf_checkpoint.py

echo
echo "Check loader API supports trust_remote_code and default value:"
rg -n -C2 'def load_vlm_or_llm\(|trust_remote_code' modelopt/torch/speculative/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1457


🏁 Script executed:

cat -n examples/speculative_decoding/scripts/export_hf_checkpoint.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2038


Expose caller-controlled trust_remote_code parameter to allow exporting models requiring custom code.

The exporter script does not expose control for the trust_remote_code parameter, preventing users from loading models that require custom remote code during export. The underlying load_vlm_or_llm() function already accepts and properly propagates this parameter with a safe default of False; add CLI exposure to let users opt in when needed.

Suggested patch
 def parse_args():
     parser = argparse.ArgumentParser(
         description="Export a HF checkpoint (with ModelOpt state) for deployment."
     )
     parser.add_argument("--model_path", type=str, default="Path of the trained checkpoint.")
     parser.add_argument(
         "--export_path", type=str, default="Destination directory for exported files."
     )
+    parser.add_argument(
+        "--trust_remote_code",
+        action="store_true",
+        help="Allow loading custom remote code from model repos (default: False).",
+    )
     return parser.parse_args()
 
 
 mto.enable_huggingface_checkpointing()
 
 args = parse_args()
-model = load_vlm_or_llm(args.model_path, torch_dtype="auto")
+model = load_vlm_or_llm(
+    args.model_path,
+    torch_dtype="auto",
+    trust_remote_code=args.trust_remote_code,
+)

This follows the guideline to "let the caller decide via a parameter; default to False" for trust_remote_code.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model = load_vlm_or_llm(args.model_path, torch_dtype="auto")
model = load_vlm_or_llm(
args.model_path,
torch_dtype="auto",
trust_remote_code=args.trust_remote_code,
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/scripts/export_hf_checkpoint.py` at line 41,
The script calls load_vlm_or_llm(args.model_path, torch_dtype="auto") without
exposing trust_remote_code; add a CLI flag/argument (e.g.,
args.trust_remote_code defaulting to False) and pass it through to
load_vlm_or_llm as load_vlm_or_llm(args.model_path, torch_dtype="auto",
trust_remote_code=args.trust_remote_code) so callers can opt into remote code
when needed; update the argument parser to document the flag and set the default
to False.

Comment on lines +150 to +153
model = AutoModelForCausalLM.from_pretrained(
args.model, torch_dtype=torch.bfloat16, device_map={"": device}, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

fd train_dflash.py

Repository: NVIDIA/Model-Optimizer

Length of output: 111


🏁 Script executed:

cat -n examples/speculative_decoding/train_dflash.py | sed -n '140,165p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1291


🏁 Script executed:

# Check for --model argument definition
grep -n "model" examples/speculative_decoding/train_dflash.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1203


🏁 Script executed:

# Check if there are any comments or configurations around the trust_remote_code lines
cat -n examples/speculative_decoding/train_dflash.py | sed -n '145,158p'

Repository: NVIDIA/Model-Optimizer

Length of output: 724


Make trust_remote_code opt-in via CLI flag.

Lines 150–153 hardcode trust_remote_code=True for both model and tokenizer loads. Since --model is caller-controlled, this forces execution of arbitrary Python from any untrusted checkpoint during startup. Add a CLI flag (e.g., --trust-remote-code) that defaults to False and thread it through both from_pretrained() calls.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/train_dflash.py` around lines 150 - 153, Add a
new CLI flag (e.g., --trust-remote-code) that defaults to False (use
action='store_true') and expose it as args.trust_remote_code; then remove the
hardcoded True and pass args.trust_remote_code into both
AutoModelForCausalLM.from_pretrained(...) and
AutoTokenizer.from_pretrained(...). Update any argument parsing logic where args
is created so the new flag is available to the model/tokenizer loading calls.

Comment on lines +63 to +66
model = AutoModelForCausalLM.from_pretrained(
'${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Find the file
fd ar_validate.sh

Repository: NVIDIA/Model-Optimizer

Length of output: 109


🏁 Script executed:

# Read the entire file to see the context
cat -n tools/launcher/common/dflash/ar_validate.sh | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 4475


🏁 Script executed:

# Read the rest of the file
cat -n tools/launcher/common/dflash/ar_validate.sh | tail -50

Repository: NVIDIA/Model-Optimizer

Length of output: 2292


Make remote-code execution opt-in; do not hardcode trust_remote_code=True.

Lines 63-66 hardcode trust_remote_code=True for both model and tokenizer loading. Since HF_MODEL_CKPT is environment-controlled, this allows any untrusted model checkpoint to execute arbitrary Python code during loading. Add an environment variable to control this flag and default it to False.

🔒 Proposed fix
 import torch
+import os
 from datasets import load_dataset
 from transformers import AutoModelForCausalLM, AutoTokenizer
 from modelopt.torch.speculative.plugins.transformers import HFARValidation
 import modelopt.torch.opt as mto
 import modelopt.torch.speculative as mtsp
 
 mto.enable_huggingface_checkpointing()
 
+trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "0") == "1"
 model = AutoModelForCausalLM.from_pretrained(
-    '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True
+    '${HF_MODEL_CKPT}',
+    torch_dtype=torch.bfloat16,
+    device_map={'': 0},
+    trust_remote_code=trust_remote_code,
 )
-tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True)
+tokenizer = AutoTokenizer.from_pretrained(
+    '${HF_MODEL_CKPT}', trust_remote_code=trust_remote_code
+)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/ar_validate.sh` around lines 63 - 66, The calls
to AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained
currently hardcode trust_remote_code=True; change them to read a new environment
variable (e.g., ALLOW_TRUST_REMOTE_CODE) that defaults to false and convert it
to a boolean (treat "1", "true", "yes" case-insensitively as true). Pass that
boolean into the trust_remote_code parameter for both
AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained so
remote-code execution is opt-in when HF_MODEL_CKPT is used.

Comment on lines +111 to +113
try:
_, ar = validator.validate(osl=32, input_ids=input_ids, steps=3)
ars.append(ar)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Call the online AR path here.

This still invokes HFARValidation.validate(), so the script computes AR against fixed ground truth instead of the new context-dependent verification loop in AcceptanceRateValidation.validate_online(). That means the reported DFlash metric does not match the new online-validation path added in this PR.

♻️ Proposed fix
-        _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3)
+        _, ar = validator.validate_online(osl=32, input_ids=input_ids, steps=3)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
_, ar = validator.validate(osl=32, input_ids=input_ids, steps=3)
ars.append(ar)
try:
_, ar = validator.validate_online(osl=32, input_ids=input_ids, steps=3)
ars.append(ar)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/ar_validate.sh` around lines 111 - 113, The code
is calling validator.validate(...) which runs the offline HFAR path; change it
to call the online validation loop by invoking validator.validate_online(osl=32,
input_ids=input_ids, steps=3) (or the correct parameter names for
validate_online) and keep extracting the AR result (e.g., "_, ar =
validator.validate_online(...)") and appending ar to ars so the script uses
AcceptanceRateValidation.validate_online() instead of HFARValidation.validate().

Comment on lines +50 to +54
- --engine VLLM
- --tp_size 4
- --ep_size 1
- --speculative_algorithm EAGLE3
- --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use DRAFT_TARGET for the DFlash benchmark.

examples/specdec_bench/specdec_bench/models/vllm.py maps EAGLE3 to the eagle3 backend, and there is no DFlash handler on that path. This task will benchmark the wrong speculative-decoding implementation for the draft produced by common/dflash/online_training.sh. Switch this to DRAFT_TARGET so VLLM loads it as a generic draft model instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml` around lines 50
- 54, Replace the speculative algorithm flag value from EAGLE3 to DRAFT_TARGET
so VLLM treats the DFlash output as a generic draft model; update the entry in
the YAML where "--speculative_algorithm EAGLE3" appears to
"--speculative_algorithm DRAFT_TARGET" (this ensures the DFlash draft produced
by common/dflash/online_training.sh is loaded instead of mapping to the EAGLE3
backend in examples/specdec_bench/specdec_bench/models/vllm.py).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (13)
tools/launcher/common/dflash/ar_validate.sh (2)

111-113: ⚠️ Potential issue | 🟠 Major

Use the online AR path here.

HFARValidation.validate() still measures against fixed ground truth, so this script reports a different metric than the new context-dependent DFlash validation path. Call validate_online(...) in this loop instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/ar_validate.sh` around lines 111 - 113, Replace
the offline validation call to validator.validate(...) with the
online/context-dependent path by invoking validator.validate_online(...) in the
loop; keep the same arguments (e.g., osl=32, input_ids=input_ids, steps=3) and
continue to append the returned ar to ars so the script measures DFlash’s
context-dependent AR instead of comparing to fixed ground truth.

63-66: ⚠️ Potential issue | 🔴 Critical

Make remote-code execution opt-in.

HF_MODEL_CKPT is environment-controlled, but both loads hardcode trust_remote_code=True. That lets an arbitrary checkpoint execute Python during validation. Thread this through an env flag that defaults to False. As per coding guidelines, "Never hardcode trust_remote_code=True; remote-code execution is an RCE vector."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/ar_validate.sh` around lines 63 - 66, The code
currently passes trust_remote_code=True into
AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained for
HF_MODEL_CKPT; change this to be controlled by a new environment flag (e.g.,
TRUST_REMOTE_CODE or ENABLE_TRUST_REMOTE_CODE) that defaults to False, parse it
as a boolean, and pass that variable into the trust_remote_code parameter of
both AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained so
remote-code execution is opt-in.
examples/speculative_decoding/train_dflash.py (2)

150-153: ⚠️ Potential issue | 🔴 Critical

Make trust_remote_code opt-in via CLI.

--model is caller-controlled, but both from_pretrained() calls hardcode trust_remote_code=True, which executes arbitrary checkpoint code during startup. Add a flag like --trust-remote-code defaulting to False and thread it through both loads. As per coding guidelines, "Never hardcode trust_remote_code=True; remote-code execution is an RCE vector."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/train_dflash.py` around lines 150 - 153, Add a
new CLI boolean flag (e.g., args.trust_remote_code defaulting to False) and pass
it to both AutoModelForCausalLM.from_pretrained and
AutoTokenizer.from_pretrained instead of hardcoding trust_remote_code=True;
update the argument parser where args.model is defined to include
"--trust-remote-code" (store_true) and then change the two calls
(AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained) to use
trust_remote_code=args.trust_remote_code so remote-code execution remains
opt-in.

292-293: ⚠️ Potential issue | 🟠 Major

Don't hard-require an internal MT-Bench mirror in a standalone script.

The post-train AR check now depends on /hf-local/HuggingFaceH4/mt_bench_prompts, so this entrypoint fails outside the internal environment. Accept the dataset source via CLI/env and default to HuggingFaceH4/mt_bench_prompts.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/train_dflash.py` around lines 292 - 293, The
script currently hardcodes the MT-Bench dataset path when creating ds with
load_dataset; make it configurable via a CLI flag and env var with a sensible
default: add an argparse option (e.g. --mtbench-dataset) that falls back to
os.environ.get("MT_BENCH_DATASET") and then to the default string
"HuggingFaceH4/mt_bench_prompts"; replace the literal in the load_dataset call
(the line that sets ds = load_dataset(... )["train"]) to use that resolved
dataset identifier so HFARValidation(raw_model, tokenizer) continues to run
against the user-provided or default dataset.
modelopt/torch/utils/plugins/transformers_dataset.py (3)

391-418: ⚠️ Potential issue | 🟠 Major

Only skip assistant-free chats when answer_only_loss=True.

This branch now drops prompt-only/system-user chats even in the normal full-loss path. If an entire batch is filtered, the dummy user/assistant sample becomes synthetic training data instead of a no-op. Keep the assistant-turn check behind self.answer_only_loss.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 391 - 418,
The current filtering drops samples without an assistant turn regardless of loss
mode; only skip samples when self.answer_only_loss is True. Modify the branch
that checks for assistant turns (the blocks handling messages and conversations
and calling _sharegpt_to_openai_messages) to perform the "no assistant turn ->
print warning and continue" only when self.answer_only_loss is True; otherwise
append the messages/conversations as-is. Ensure references: messages,
conversations, converted (from _sharegpt_to_openai_messages), batch, and
self.answer_only_loss are used to gate the skip logic so dummy batch creation
behavior only occurs when answer_only_loss is enabled.

156-157: ⚠️ Potential issue | 🟠 Major

Still unresolved: answer-only template rewriting breaks multimodal collators.

VisionLanguageDataCollator reaches this initializer too, but the fallback templates here explicitly drop VLM content handling and treat message["content"] as plain text. With answer_only_loss=True, multimodal batches still get rewritten to a text-only template before processor.apply_chat_template(...) runs, so list-of-block content will be misformatted or fail.

Also applies to: 177-215, 441-448

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 156 - 157,
The answer-only template rewrite in the constructor (triggered by
answer_only_loss calling _ensure_generation_tags) incorrectly converts
multimodal messages into text-only content, breaking VisionLanguageDataCollator
and processor.apply_chat_template; fix by detecting multimodal inputs (e.g.,
presence of non-text keys or list-of-blocks/image/block structures in
message["content"] used by VisionLanguageDataCollator) and skip the answer-only
template rewrite for those cases, i.e., in _ensure_generation_tags (or the place
where answer_only_loss is handled) add a guard that returns early when message
content is multimodal so processor.apply_chat_template receives the original
multimodal structure. Ensure the check references the same shapes used by
VisionLanguageDataCollator (e.g., 'blocks' or image fields) so multimodal
batches are preserved.

350-356: ⚠️ Potential issue | 🟠 Major

Shift assistant_masks into label space before masking.

labels[..., :-1] already contains next-token targets, but assistant_masks is aligned to the original token positions. Applying it unshifted drops the first assistant target token and keeps one token past each assistant span.

Suggested fix
             if self.answer_only_loss:
                 if "assistant_masks" in tokenized_examples:
                     assistant_mask = tokenized_examples["assistant_masks"]
                     if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any():
-                        labels[assistant_mask == 0] = IGNORE_TOKEN_ID
+                        shifted_assistant_mask = torch.zeros_like(assistant_mask)
+                        shifted_assistant_mask[..., :-1] = assistant_mask[..., 1:]
+                        labels[shifted_assistant_mask == 0] = IGNORE_TOKEN_ID
                     else:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 350 - 356,
The assistant_mask is aligned to original input positions but labels are
next-token targets (labels[..., :-1] = input_ids[..., 1:]), so you must shift
assistant_masks into label-space before masking; compute a shifted mask (e.g.,
shifted = assistant_mask[..., 1:]) and apply it to labels[..., :-1] (or
equivalently mask labels where shifted == 0) instead of using the unshifted
assistant_mask; keep the existing checks for torch.Tensor and .any() and set
masked positions to IGNORE_TOKEN_ID when answer_only_loss is true.
modelopt/torch/export/plugins/hf_spec_export.py (1)

272-316: ⚠️ Potential issue | 🟠 Major

Keep config.json dtype aligned with the exported tensors.

export(dtype=...) casts model.safetensors, but _export_config() still hardcodes "torch_dtype": "bfloat16". Exporting fp16/fp32 will advertise the wrong dtype to downstream loaders.

Suggested fix
-    def _export_config(self):
+    def _export_config(self, dtype: torch.dtype | None = None):
         """Build config.json matching z-lab DFlash format."""
@@
-            "torch_dtype": "bfloat16",
+            "torch_dtype": (
+                str(dtype).replace("torch.", "")
+                if dtype is not None
+                else str(
+                    getattr(
+                        draft_config,
+                        "torch_dtype",
+                        getattr(base_config, "torch_dtype", torch.bfloat16),
+                    )
+                ).replace("torch.", "")
+            ),
@@
-        drafter_config = self._export_config()
+        drafter_config = self._export_config(dtype=dtype)

Also applies to: 328-343

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 272 - 316, The
config currently hardcodes "torch_dtype": "bfloat16" in _export_config(),
causing a mismatch when export(dtype=...) is used; update _export_config to
determine the dtype dynamically (prefer the explicit export dtype passed to
export, e.g., self.export_dtype or self.dtype if present) and set "torch_dtype"
to that dtype's string (falling back to the current bfloat16 value if no export
dtype is available); apply the same change to the other config block around
lines 328-343 so the advertised dtype matches the exported safetensors.
examples/speculative_decoding/main.py (1)

311-334: ⚠️ Potential issue | 🟠 Major

The optimizer-mismatch fallback still doesn't actually resume.

Loading trainer.state and then calling trainer.train() restarts the dataloader at step 0. This path will replay already-seen batches while logging Resuming from step ..., so it is not a safe recovery path for optimizer state mismatches.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 311 - 334, The fallback
path sets trainer.state from trainer_state.json but then calls trainer.train()
with no resume flag, which restarts from step 0 and replays batches; change the
fallback to call trainer.train(resume_from_checkpoint=checkpoint) so the Trainer
resumes at the saved step, and also clear any existing optimizer (e.g., set
trainer.optimizer = None) before training to ensure a fresh optimizer when you
intentionally fall back from an optimizer-state mismatch; keep the existing use
of trainer.state.load_from_json(state_file) and the checkpoint/state_file
variables.
tools/launcher/common/dflash/online_training.sh (2)

30-35: ⚠️ Potential issue | 🟡 Minor

Quote shell variables and version specifiers.

Line 31: ${SCRIPT_DIR} should be double-quoted to prevent globbing/word splitting.

Line 34: The unquoted >= is parsed as shell redirection, causing the version constraint to be lost.

🔧 Proposed fix
-source ${SCRIPT_DIR}/../service_utils.sh
+source "${SCRIPT_DIR}/../service_utils.sh"

 pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt
-pip install huggingface-hub>=1.2.1
+pip install "huggingface-hub>=1.2.1"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/online_training.sh` around lines 30 - 35, The
script uses unquoted shell variables and an unquoted version specifier which can
cause globbing/word-splitting and shell redirection: quote ${SCRIPT_DIR} when
sourcing service_utils.sh (refer to SCRIPT_DIR and service_utils.sh) and quote
the pip requirement string so the shell doesn't treat >= as a redirection (refer
to the pip install command for huggingface-hub), and also quote PATH expansions
when exporting (refer to the export PATH line) to avoid word-splitting.

181-223: ⚠️ Potential issue | 🔴 Critical

Hardcoded trust_remote_code=True and shell variable interpolation pose security risks.

Lines 192 and 194 hardcode trust_remote_code=True for model and tokenizer loading with no override capability. Per coding guidelines, this should be configurable and default to False.

Additionally, shell variables (${HF_MODEL_CKPT}, ${DFLASH_BLOCK_SIZE}, ${DFLASH_NUM_LAYERS}, ${MASK_ARG}) are interpolated directly into the Python heredoc. Maliciously crafted values could inject arbitrary code.

Pass values via environment variables and read them with os.environ.get() inside Python:

import os
model_path = os.environ.get("HF_MODEL_CKPT")
trust_remote = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true"
# ...
model = AutoModelForCausalLM.from_pretrained(
    model_path, torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=trust_remote
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/online_training.sh` around lines 181 - 223, The
heredoc currently hardcodes trust_remote_code=True and interpolates shell
variables directly, which is unsafe; change the script to pass HF_MODEL_CKPT,
DFLASH_BLOCK_SIZE, DFLASH_NUM_LAYERS, MASK_ARG, AR_CKPT and a new
TRUST_REMOTE_CODE via environment variables and read them inside Python with
os.environ.get() (casting block/num to int and TRUST_REMOTE_CODE to a boolean
defaulting to False) before calling AutoModelForCausalLM.from_pretrained and
AutoTokenizer.from_pretrained; remove all ${...} interpolations from the Python
snippet, use the env-derived variables when building the dflash config and when
loading checkpoints (model.load_state_dict and
model.dflash_module.load_state_dict), and ensure a safe default
trust_remote_code=False unless the env explicitly sets it to "true".
modelopt/torch/speculative/plugins/hf_dflash.py (2)

83-84: ⚠️ Potential issue | 🟠 Major

Module-level globals create cross-model interference.

_MLP_CLS, _NORM_CLS, _ROTARY_CLS, and _rotate_half are module-scope variables that modify() (lines 510-514) reassigns. If two DFlash models with different base architectures are instantiated in the same process, the second modify() call will overwrite the globals, potentially corrupting the first model's behavior at runtime (e.g., in apply_rotary_pos_emb, DFlashModule.__init__).

Store these resolved components on the model instance (self._mlp_cls, etc.) and thread them through to DFlashModule and DFlashAttention.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 83 - 84, The
module-level globals _MLP_CLS, _NORM_CLS, _ROTARY_CLS, and _rotate_half cause
cross-model interference; change the code to resolve these per-model and store
them on the model instance (e.g., self._mlp_cls, self._norm_cls,
self._rotary_cls, self._rotate_half) inside modify()/convert() instead of
reassigning the module globals. Update DFlashModule.__init__ to accept or read
instance attributes (self._mlp_cls, self._norm_cls, etc.) and pass them into
DFlashAttention and apply_rotary_pos_emb so those functions/classes no longer
reference module-level names; ensure all call sites (modify/convert,
DFlashModule creation, DFlashAttention usage, and apply_rotary_pos_emb) are
updated to thread the instance-specific components through.

832-845: ⚠️ Potential issue | 🟠 Major

base_token.item() fails for batch size > 1.

Line 841 calls .item() on base_token which has shape [B, 1]. For batched inputs with B > 1, this raises a RuntimeError. Either remove the debug block or use a batch-safe representation:

-            print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}")
+            print(f"[psg] base_token: {base_token[:, 0].tolist()}, mask_token_id: {self.mask_token_id}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 832 - 845, The
debug block uses base_token.item() which fails for batch size >1; update the
debug printing in hf_dflash.py (the block that sets self._psg_debug and prints
using base_token and mask_token_id) to use a batch-safe representation instead
of .item() — e.g., convert the tensor to CPU, detach it and render as a list
(base_token.detach().cpu().view(-1).tolist() or
base_token.detach().cpu()[:,0].tolist()) or print only the first batch element
(base_token[0].item()) so it won't raise for B>1; keep other debug prints
(th_dbg, seq_len, dflash_block_size, target_layer_ids) unchanged.
🧹 Nitpick comments (3)
tools/launcher/common/dflash/online_training.sh (2)

149-152: Quote path variables.

OUTPUT_DIR and EXPORT_DIR should be double-quoted to handle paths containing spaces.

🔧 Proposed fix
 python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \
-    --model_path ${OUTPUT_DIR} \
-    --export_path ${EXPORT_DIR} \
+    --model_path "${OUTPUT_DIR}" \
+    --export_path "${EXPORT_DIR}" \
     || echo "WARNING: Export failed, continuing with AR validation"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/online_training.sh` around lines 149 - 152, The
shell command invoking export_hf_checkpoint uses unquoted path variables which
will break on spaces; update the invocation that references OUTPUT_DIR and
EXPORT_DIR (the python3 ... --model_path ${OUTPUT_DIR} --export_path
${EXPORT_DIR} || ...) to wrap both variables in double quotes (e.g. --model_path
"${OUTPUT_DIR}" --export_path "${EXPORT_DIR}") so paths with spaces are handled
correctly and the fallback echo behavior remains unchanged.

124-128: Quote variables where appropriate.

CONFIG_FILE and NUM_NODES should be quoted to guard against paths with spaces or other edge cases. OVERRIDES is intentionally unquoted for word splitting, but consider using an array for safer argument passing.

🛠️ Minimal fix
 bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \
-    --config ${CONFIG_FILE} \
-    --num_nodes ${NUM_NODES:-1} \
+    --config "${CONFIG_FILE}" \
+    --num_nodes "${NUM_NODES:-1}" \
     --head_node_ip ${HEAD_NODE_IP:-} \
     ${OVERRIDES}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/online_training.sh` around lines 124 - 128, The
command invocation to launch_train.sh should quote CONFIG_FILE and NUM_NODES to
protect against spaces and edge cases: update the call to use "--config
\"${CONFIG_FILE}\"" and "--num_nodes \"${NUM_NODES:-1}\""; keep OVERRIDES
unquoted for word-splitting but preferably refactor OVERRIDES into an array
(e.g., OVERRIDES_ARGS) and expand it safely (e.g., "${OVERRIDES_ARGS[@]}") when
invoking launch_train.sh so arguments are passed reliably; adjust the invocation
around launch_train.sh and the variables CONFIG_FILE, NUM_NODES, and OVERRIDES
accordingly.
modelopt/torch/speculative/plugins/hf_dflash.py (1)

499-499: Consider using logging instead of print statements.

Lines 499 and 548 emit unconditional print() calls. In production, these add noise to output. Use Python's logging module or gate behind a debug flag.

🔧 Example using logging
+import logging
+
+_logger = logging.getLogger(__name__)
+
 # In modify():
-        print(f"DFlash mask_token_id: {self.mask_token_id}")
+        _logger.info(f"DFlash mask_token_id: {self.mask_token_id}")
 # ...
-        print(f"DFlash: using {original_cls.__name__}.forward as base forward")
+        _logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")

Also applies to: 548-548

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 499, Replace the
unconditional print calls that output the mask token id with Python logging:
create/get a module/class logger (eg. logging.getLogger(__name__) or
self.logger) and replace the prints (the occurrences that print DFlash
mask_token_id) with logger.debug or logger.info as appropriate; ensure the log
message includes the same context (e.g., "DFlash mask_token_id: %s") and, if
desired, gate emission behind a debug flag or log level so these messages don't
appear in production.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tools/launcher/common/dflash/ar_eval_mtbench.sh`:
- Around line 147-163: The script currently hardcodes trust_remote_code=True in
calls to AutoTokenizer.from_pretrained and AutoModelForCausalLM.from_pretrained
(for MODEL), allowing execution of untrusted HF repo code; add a CLI boolean
flag (e.g., --trust_remote_code defaulting to False) parsed at startup and pass
that flag's value into both AutoTokenizer.from_pretrained(...) and
AutoModelForCausalLM.from_pretrained(...), ensuring MODEL load calls use
trust_remote_code=trust_remote_code_flag so remote code execution is opt-in.
- Around line 104-123: The Python snippet currently injects shell variables
(MODEL, LAST_CKPT, MASK_TOKEN_ID, ONLINE, etc.) directly into the python -c
string and hardcodes trust_remote_code=True; change it to read those values from
os.environ or sys.argv inside the Python block (e.g., os.environ['MODEL'],
os.environ['LAST_CKPT'], os.environ.get('MASK_TOKEN_ID')) instead of
interpolating into the source, and remove the hardcoded trust_remote_code=True
from AutoTokenizer.from_pretrained and AutoModelForCausalLM.from_pretrained so
the flag is controlled by an explicit parameter or environment variable
(defaulting to False) that callers can opt into.

In `@tools/launcher/common/dflash/ar_validate.sh`:
- Around line 53-127: The script currently injects shell variables directly into
the embedded Python code (HF_MODEL_CKPT, DFLASH_BLOCK_SIZE, DFLASH_NUM_LAYERS,
MASK_ARG, DFLASH_CKPT, NUM_SAMPLES), which risks code injection and quoting
errors; change the launcher to pass these values via environment variables or
explicit CLI args and read them inside the Python block (e.g., os.environ or
argparse) and remove all ${...} interpolation from the python3 -c string, then
parse/convert types (ints, dict parts) inside Python before using them (used
when calling mtsp.convert and loading checkpoints and dataset sampling). Also
make trust_remote_code configurable instead of hardcoding True in
AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained (add a
boolean env/arg like TRUST_REMOTE_CODE defaulting to False and pass it into both
calls), so loading remote code requires explicit opt-in.

---

Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 311-334: The fallback path sets trainer.state from
trainer_state.json but then calls trainer.train() with no resume flag, which
restarts from step 0 and replays batches; change the fallback to call
trainer.train(resume_from_checkpoint=checkpoint) so the Trainer resumes at the
saved step, and also clear any existing optimizer (e.g., set trainer.optimizer =
None) before training to ensure a fresh optimizer when you intentionally fall
back from an optimizer-state mismatch; keep the existing use of
trainer.state.load_from_json(state_file) and the checkpoint/state_file
variables.

In `@examples/speculative_decoding/train_dflash.py`:
- Around line 150-153: Add a new CLI boolean flag (e.g., args.trust_remote_code
defaulting to False) and pass it to both AutoModelForCausalLM.from_pretrained
and AutoTokenizer.from_pretrained instead of hardcoding trust_remote_code=True;
update the argument parser where args.model is defined to include
"--trust-remote-code" (store_true) and then change the two calls
(AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained) to use
trust_remote_code=args.trust_remote_code so remote-code execution remains
opt-in.
- Around line 292-293: The script currently hardcodes the MT-Bench dataset path
when creating ds with load_dataset; make it configurable via a CLI flag and env
var with a sensible default: add an argparse option (e.g. --mtbench-dataset)
that falls back to os.environ.get("MT_BENCH_DATASET") and then to the default
string "HuggingFaceH4/mt_bench_prompts"; replace the literal in the load_dataset
call (the line that sets ds = load_dataset(... )["train"]) to use that resolved
dataset identifier so HFARValidation(raw_model, tokenizer) continues to run
against the user-provided or default dataset.

In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 272-316: The config currently hardcodes "torch_dtype": "bfloat16"
in _export_config(), causing a mismatch when export(dtype=...) is used; update
_export_config to determine the dtype dynamically (prefer the explicit export
dtype passed to export, e.g., self.export_dtype or self.dtype if present) and
set "torch_dtype" to that dtype's string (falling back to the current bfloat16
value if no export dtype is available); apply the same change to the other
config block around lines 328-343 so the advertised dtype matches the exported
safetensors.

In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 83-84: The module-level globals _MLP_CLS, _NORM_CLS, _ROTARY_CLS,
and _rotate_half cause cross-model interference; change the code to resolve
these per-model and store them on the model instance (e.g., self._mlp_cls,
self._norm_cls, self._rotary_cls, self._rotate_half) inside modify()/convert()
instead of reassigning the module globals. Update DFlashModule.__init__ to
accept or read instance attributes (self._mlp_cls, self._norm_cls, etc.) and
pass them into DFlashAttention and apply_rotary_pos_emb so those
functions/classes no longer reference module-level names; ensure all call sites
(modify/convert, DFlashModule creation, DFlashAttention usage, and
apply_rotary_pos_emb) are updated to thread the instance-specific components
through.
- Around line 832-845: The debug block uses base_token.item() which fails for
batch size >1; update the debug printing in hf_dflash.py (the block that sets
self._psg_debug and prints using base_token and mask_token_id) to use a
batch-safe representation instead of .item() — e.g., convert the tensor to CPU,
detach it and render as a list (base_token.detach().cpu().view(-1).tolist() or
base_token.detach().cpu()[:,0].tolist()) or print only the first batch element
(base_token[0].item()) so it won't raise for B>1; keep other debug prints
(th_dbg, seq_len, dflash_block_size, target_layer_ids) unchanged.

In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 391-418: The current filtering drops samples without an assistant
turn regardless of loss mode; only skip samples when self.answer_only_loss is
True. Modify the branch that checks for assistant turns (the blocks handling
messages and conversations and calling _sharegpt_to_openai_messages) to perform
the "no assistant turn -> print warning and continue" only when
self.answer_only_loss is True; otherwise append the messages/conversations
as-is. Ensure references: messages, conversations, converted (from
_sharegpt_to_openai_messages), batch, and self.answer_only_loss are used to gate
the skip logic so dummy batch creation behavior only occurs when
answer_only_loss is enabled.
- Around line 156-157: The answer-only template rewrite in the constructor
(triggered by answer_only_loss calling _ensure_generation_tags) incorrectly
converts multimodal messages into text-only content, breaking
VisionLanguageDataCollator and processor.apply_chat_template; fix by detecting
multimodal inputs (e.g., presence of non-text keys or list-of-blocks/image/block
structures in message["content"] used by VisionLanguageDataCollator) and skip
the answer-only template rewrite for those cases, i.e., in
_ensure_generation_tags (or the place where answer_only_loss is handled) add a
guard that returns early when message content is multimodal so
processor.apply_chat_template receives the original multimodal structure. Ensure
the check references the same shapes used by VisionLanguageDataCollator (e.g.,
'blocks' or image fields) so multimodal batches are preserved.
- Around line 350-356: The assistant_mask is aligned to original input positions
but labels are next-token targets (labels[..., :-1] = input_ids[..., 1:]), so
you must shift assistant_masks into label-space before masking; compute a
shifted mask (e.g., shifted = assistant_mask[..., 1:]) and apply it to
labels[..., :-1] (or equivalently mask labels where shifted == 0) instead of
using the unshifted assistant_mask; keep the existing checks for torch.Tensor
and .any() and set masked positions to IGNORE_TOKEN_ID when answer_only_loss is
true.

In `@tools/launcher/common/dflash/ar_validate.sh`:
- Around line 111-113: Replace the offline validation call to
validator.validate(...) with the online/context-dependent path by invoking
validator.validate_online(...) in the loop; keep the same arguments (e.g.,
osl=32, input_ids=input_ids, steps=3) and continue to append the returned ar to
ars so the script measures DFlash’s context-dependent AR instead of comparing to
fixed ground truth.
- Around line 63-66: The code currently passes trust_remote_code=True into
AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained for
HF_MODEL_CKPT; change this to be controlled by a new environment flag (e.g.,
TRUST_REMOTE_CODE or ENABLE_TRUST_REMOTE_CODE) that defaults to False, parse it
as a boolean, and pass that variable into the trust_remote_code parameter of
both AutoModelForCausalLM.from_pretrained and AutoTokenizer.from_pretrained so
remote-code execution is opt-in.

In `@tools/launcher/common/dflash/online_training.sh`:
- Around line 30-35: The script uses unquoted shell variables and an unquoted
version specifier which can cause globbing/word-splitting and shell redirection:
quote ${SCRIPT_DIR} when sourcing service_utils.sh (refer to SCRIPT_DIR and
service_utils.sh) and quote the pip requirement string so the shell doesn't
treat >= as a redirection (refer to the pip install command for
huggingface-hub), and also quote PATH expansions when exporting (refer to the
export PATH line) to avoid word-splitting.
- Around line 181-223: The heredoc currently hardcodes trust_remote_code=True
and interpolates shell variables directly, which is unsafe; change the script to
pass HF_MODEL_CKPT, DFLASH_BLOCK_SIZE, DFLASH_NUM_LAYERS, MASK_ARG, AR_CKPT and
a new TRUST_REMOTE_CODE via environment variables and read them inside Python
with os.environ.get() (casting block/num to int and TRUST_REMOTE_CODE to a
boolean defaulting to False) before calling AutoModelForCausalLM.from_pretrained
and AutoTokenizer.from_pretrained; remove all ${...} interpolations from the
Python snippet, use the env-derived variables when building the dflash config
and when loading checkpoints (model.load_state_dict and
model.dflash_module.load_state_dict), and ensure a safe default
trust_remote_code=False unless the env explicitly sets it to "true".

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Line 499: Replace the unconditional print calls that output the mask token id
with Python logging: create/get a module/class logger (eg.
logging.getLogger(__name__) or self.logger) and replace the prints (the
occurrences that print DFlash mask_token_id) with logger.debug or logger.info as
appropriate; ensure the log message includes the same context (e.g., "DFlash
mask_token_id: %s") and, if desired, gate emission behind a debug flag or log
level so these messages don't appear in production.

In `@tools/launcher/common/dflash/online_training.sh`:
- Around line 149-152: The shell command invoking export_hf_checkpoint uses
unquoted path variables which will break on spaces; update the invocation that
references OUTPUT_DIR and EXPORT_DIR (the python3 ... --model_path ${OUTPUT_DIR}
--export_path ${EXPORT_DIR} || ...) to wrap both variables in double quotes
(e.g. --model_path "${OUTPUT_DIR}" --export_path "${EXPORT_DIR}") so paths with
spaces are handled correctly and the fallback echo behavior remains unchanged.
- Around line 124-128: The command invocation to launch_train.sh should quote
CONFIG_FILE and NUM_NODES to protect against spaces and edge cases: update the
call to use "--config \"${CONFIG_FILE}\"" and "--num_nodes \"${NUM_NODES:-1}\"";
keep OVERRIDES unquoted for word-splitting but preferably refactor OVERRIDES
into an array (e.g., OVERRIDES_ARGS) and expand it safely (e.g.,
"${OVERRIDES_ARGS[@]}") when invoking launch_train.sh so arguments are passed
reliably; adjust the invocation around launch_train.sh and the variables
CONFIG_FILE, NUM_NODES, and OVERRIDES accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6e6d829b-957a-4ee3-a64b-949827a0b751

📥 Commits

Reviewing files that changed from the base of the PR and between f990e5a and e45cc37.

📒 Files selected for processing (25)
  • doc/results/dflash_results.html
  • examples/speculative_decoding/README.md
  • examples/speculative_decoding/doc/dflash_results.md
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • examples/speculative_decoding/train_dflash.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/dflash/__init__.py
  • modelopt/torch/speculative/dflash/conversion.py
  • modelopt/torch/speculative/dflash/default_config.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • modelopt/torch/speculative/mode.py
  • modelopt/torch/speculative/plugins/__init__.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • modelopt/torch/speculative/utils.py
  • modelopt/torch/utils/plugins/transformers_dataset.py
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • tests/gpu/torch/speculative/plugins/test_hf_dflash.py
  • tests/unit/torch/speculative/plugins/test_hf_dflash.py
  • tools/launcher/common/dflash/ar_eval_mtbench.sh
  • tools/launcher/common/dflash/ar_validate.sh
  • tools/launcher/common/dflash/online_training.sh
  • tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
✅ Files skipped from review due to trivial changes (5)
  • modelopt/torch/speculative/dflash/default_config.py
  • examples/speculative_decoding/README.md
  • tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • examples/speculative_decoding/doc/dflash_results.md
🚧 Files skipped from review as they are similar to previous changes (9)
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • modelopt/torch/speculative/plugins/init.py
  • modelopt/torch/speculative/dflash/init.py
  • modelopt/torch/speculative/mode.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/utils.py
  • modelopt/torch/speculative/dflash/conversion.py
  • tests/unit/torch/speculative/plugins/test_hf_dflash.py
  • examples/speculative_decoding/eagle_utils.py

DFlash (Block Diffusion for Flash Speculative Decoding) predicts an entire
block of tokens in a single forward pass using masked parallel prediction
with KV injection from the target model's hidden states.

Key features:
- Feature fusion (multi-layer hidden states -> FC + RMSNorm)
- KV injection (fused features as K/V in every draft layer with QK-norm)
- Random anchor sampling with bidirectional intra-block attention
- Logit distillation with exponential loss decay (gamma weighting)
- Multi-node DDP training with checkpoint resume
- Export to z-lab compatible HF format
- Online validation (context-dependent ground truth)

Training recipe: modelopt_recipes/general/speculative_decoding/dflash.yaml
Results: examples/speculative_decoding/doc/dflash_results.md

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ChenhanYu ChenhanYu force-pushed the chenhany/dflash-v2 branch from e45cc37 to 5f8d004 Compare April 8, 2026 21:14
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)

287-301: ⚠️ Potential issue | 🔴 Critical

medusa mode still reaches the trainer with data_module unset.

data_module is only initialized inside the ("eagle3", "dflash") branch, but TrainingArguments.mode and the conversion block still accept "medusa". A medusa run will hit EagleTrainerWithAccLog(..., **data_module) with data_module undefined.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 287 - 301, The code only
sets data_module inside the ("eagle3", "dflash") branch so a run with
training_args.mode == "medusa" leaves data_module undefined and crashes when
passed into EagleTrainerWithAccLog; fix by initializing data_module before the
conditional (e.g. data_module = {}) and either add an explicit branch to create
a medusa-specific module (call the appropriate builder if one exists) or raise a
clear error for unsupported modes, making sure the symbols training_args.mode,
make_eagle_supervised_data_module, and EagleTrainerWithAccLog are updated
accordingly.
♻️ Duplicate comments (7)
tools/launcher/common/dflash/ar_eval_mtbench.sh (2)

147-163: ⚠️ Potential issue | 🔴 Critical

Make remote code execution opt-in for MT-Bench eval.

Both AutoTokenizer.from_pretrained() and AutoModelForCausalLM.from_pretrained() hardcode trust_remote_code=True. That executes arbitrary repository code from HF_MODEL_CKPT on the launcher node. Thread this through a flag or env var and default it to False.

As per coding guidelines, "Do not hardcode trust_remote_code=True when loading Hugging Face Transformers models. Let the caller decide via a parameter; default to False."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/ar_eval_mtbench.sh` around lines 147 - 163, The
code currently hardcodes trust_remote_code=True when calling
AutoTokenizer.from_pretrained and AutoModelForCausalLM.from_pretrained (with
MODEL), which enables remote code execution; add a configurable flag (e.g., a
function/CLI param or env var like TRUST_REMOTE_CODE defaulting to False) and
pass that flag into both AutoTokenizer.from_pretrained(...) and
AutoModelForCausalLM.from_pretrained(...). Ensure the new flag is read early
(before tokenizer/model load) and used in the model load call that also includes
ATTN_IMPL/device_map so behavior is consistent and safe by default.

104-123: ⚠️ Potential issue | 🔴 Critical

Don’t splice shell variables directly into the python -c source.

MODEL, LAST_CKPT, MASK_TOKEN_ID, and ONLINE are interpolated straight into Python literals. A checkpoint path containing a quote breaks the script, and a user-controlled value can turn into arbitrary Python. Pass these values via env vars or argv and read them inside the Python block instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/ar_eval_mtbench.sh` around lines 104 - 123, The
current python -c block injects shell variables directly into Python literals
(see MODEL, CKPT_PATH/ LAST_CKPT, MASK_TOKEN_ID_STR, ONLINE) which can break on
quotes or allow code injection; instead pass those values via environment
variables or command-line args and read them inside the Python snippet (use
os.environ.get(...) or sys.argv parsing), convert MASK_TOKEN_ID and numeric
flags (BLOCK_SIZE, NUM_LAYERS, OSL, STEPS) to ints and ONLINE to a boolean
safely, and replace the direct interpolations in the python -c source with
references to the env/argv variables to eliminate quoting/injection issues when
loading the checkpoint or interpreting flags.
modelopt/torch/utils/plugins/transformers_dataset.py (3)

351-356: ⚠️ Potential issue | 🟠 Major

Shift assistant_masks into label space before masking.

Line 351 pre-shifts labels for next-token prediction, but Line 356 applies the unshifted tokenizer mask. That drops the first assistant token from loss and keeps one token after the assistant span.

🛠️ Suggested change
             if self.answer_only_loss:
                 if "assistant_masks" in tokenized_examples:
                     assistant_mask = tokenized_examples["assistant_masks"]
                     if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any():
-                        labels[assistant_mask == 0] = IGNORE_TOKEN_ID
+                        shifted_assistant_mask = torch.zeros_like(assistant_mask)
+                        shifted_assistant_mask[..., :-1] = assistant_mask[..., 1:]
+                        labels[shifted_assistant_mask == 0] = IGNORE_TOKEN_ID
                     else:
                         # All assistant content truncated or no assistant in batch — mask all
                         labels[:] = IGNORE_TOKEN_ID
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 351 - 356,
The code shifts labels for next-token prediction (labels[..., :-1] =
input_ids[..., 1:]) but then applies the unshifted assistant mask, causing
misalignment; update the masking to shift assistant_masks into label space
before applying IGNORE_TOKEN_ID (e.g., use assistant_mask[..., :-1] or
equivalent) so that assistant token positions align with labels when setting
labels[assistant_mask == 0] = IGNORE_TOKEN_ID; adjust the block around labels,
self.answer_only_loss, and tokenized_examples["assistant_masks"] accordingly to
use the shifted mask.

220-271: ⚠️ Potential issue | 🟠 Major

These fallback templates are still text-only.

All three variants concatenate message["content"] as a string, but VisionLanguageDataCollator turns multimodal content into block lists before templating. Enabling answer_only_loss on that path will misformat or fail batches. Gate this rewrite to text-only collators or provide multimodal-safe templates.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 220 - 271,
The fallback _GENERATION_TEMPLATES are text-only but are applied to multimodal
inputs by VisionLanguageDataCollator (especially when answer_only_loss is used),
causing misformatted batches; update the code that selects or applies
_GENERATION_TEMPLATES so it only uses these string-concatenating templates for
text-only messages (e.g., detect collator type or check message['content'] is a
str/does not contain block lists) and either (a) provide multimodal-safe
templates/serialization for non-text message['content'] or (b) raise/explicitly
gate with a clear error when VisionLanguageDataCollator or answer_only_loss
would feed non-str content into _GENERATION_TEMPLATES (reference
_GENERATION_TEMPLATES, VisionLanguageDataCollator, and answer_only_loss to
locate where to implement the guard).

391-409: ⚠️ Potential issue | 🟠 Major

Only drop assistant-free chats in answer-only mode.

This now skips prompt-only/system-user samples even when answer_only_loss=False, which changes the generic collator behavior and can collapse a whole batch into the dummy sample. The assistant-turn filter should only run in answer-only mode.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 391 - 409,
The current logic in the collator (the block handling example.get("messages")
and example.get("conversations") and calling _sharegpt_to_openai_messages)
unconditionally drops samples with no assistant turn; change it so the "no
assistant turn" checks (the any(m.get("role") == "assistant") guards and the
print_rank_0 warnings) only run when answer_only_loss is True. Concretely,
update the branches around messages/conversations in the collator to check
answer_only_loss before skipping or warning (leave normal batching behavior
intact when answer_only_loss is False), referencing the variables/messages,
conversations, _sharegpt_to_openai_messages, and print_rank_0 to locate and
modify the conditions.
examples/speculative_decoding/main.py (1)

317-334: ⚠️ Potential issue | 🟠 Major

The optimizer-mismatch fallback still replays training data from step 0.

Loading trainer_state.json into trainer.state is not enough by itself. Because the retry drops resume_from_checkpoint, this branch logs a resumed step but restarts the input pipeline from the beginning. Keep the checkpoint on the second call and bypass only optimizer/scheduler restoration.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 317 - 334, The retry
branch currently calls trainer.train() without resume_from_checkpoint so the
data pipeline restarts; instead call
trainer.train(resume_from_checkpoint=checkpoint) and ensure only
optimizer/scheduler state is bypassed by clearing or reinitializing those
objects after loading trainer.state: after trainer.state =
trainer.state.load_from_json(state_file) explicitly set trainer.optimizer = None
and trainer.lr_scheduler = None (or reinitialize them as appropriate) so
checkpoint is used for data/resume but optimizer/scheduler are not restored.
modelopt/torch/export/plugins/hf_spec_export.py (1)

272-318: ⚠️ Potential issue | 🟠 Major

Keep config.json dtype aligned with the exported tensors.

export(dtype=...) casts model.safetensors, but _export_config() still writes "torch_dtype": "bfloat16". fp16/fp32 exports will therefore advertise the wrong dtype to downstream loaders.

🛠️ Suggested change
-    def _export_config(self):
+    def _export_config(self, dtype: torch.dtype | None = None):
         """Build config.json matching z-lab DFlash format."""
         model = self.model
         base_config = (
             getattr(model.config, "text_config", None)
             or getattr(model.config, "llm_config", None)
@@
             "attention_dropout": getattr(draft_config, "attention_dropout", 0.0),
             "rope_theta": getattr(base_config, "rope_theta", 1000000.0),
             "rope_scaling": getattr(base_config, "rope_scaling", None),
             "tie_word_embeddings": False,
-            "torch_dtype": "bfloat16",
+            "torch_dtype": (
+                str(dtype).replace("torch.", "")
+                if dtype is not None
+                else str(
+                    getattr(
+                        draft_config,
+                        "torch_dtype",
+                        getattr(base_config, "torch_dtype", torch.bfloat16),
+                    )
+                ).replace("torch.", "")
+            ),
             "num_target_layers": getattr(base_config, "num_hidden_layers", 36),
         }
@@
-        drafter_config = self._export_config()
+        drafter_config = self._export_config(dtype=dtype)
         with open(f"{export_dir}/config.json", "w") as f:
             json.dump(drafter_config, f, indent=2)

Also applies to: 329-343

🧹 Nitpick comments (8)
examples/speculative_decoding/train_dflash.py (3)

264-270: Guard against missing train_acc attribute.

If the model output lacks train_acc, accessing output.train_acc[0][0] before the hasattr check completes would fail. The current code handles this correctly with hasattr, but consider using getattr for cleaner extraction.

-            acc = output.train_acc[0][0] if hasattr(output, "train_acc") else 0.0
+            acc = getattr(output, "train_acc", [[0.0]])[0][0]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/train_dflash.py` around lines 264 - 270,
Replace the hasattr pattern with a safe getattr extraction so acc is read
atomically: use getattr(output, "train_acc", None) to fetch train_acc, handle
None by defaulting acc to 0.0, and keep using scheduler.get_last_lr()[0] and
print_rank0 for logging; update the block around global_step / args.log_interval
where acc is computed to reference output.train_acc via getattr to avoid any
race or attribute-access issues.

102-108: Silent exception swallowing hides data pipeline errors.

The bare except Exception: discards the error details. Logging the exception would help diagnose tokenization or parsing failures during debugging.

Proposed fix
         try:
             input_ids, loss_mask = parser.parse(convs, max_length=max_length)
             processed["input_ids"].append(input_ids)
             processed["loss_mask"].append(loss_mask)
-        except Exception:
+        except Exception as e:
+            if is_rank0():
+                print(f"Skipping sample: {e}")
             skipped += 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/train_dflash.py` around lines 102 - 108, The
try/except in the loop around parser.parse(convs, max_length=max_length) is
swallowing errors silently; change the bare except to catch Exception as e and
log the exception (e.g., using logger.exception or logger.error(...,
exc_info=True)) before incrementing skipped so tokenization/parsing failures are
recorded for debugging while preserving the existing behavior that appends to
processed["input_ids"]/["loss_mask"] only on success; update the block around
parser.parse and the skipped increment accordingly.

207-212: Consider implications of find_unused_parameters=True on performance.

This setting adds overhead by tracking parameter usage each iteration. It's necessary here since only dflash_module parameters are trained, but worth documenting for future maintainers.

     # Wrap with DDP
+    # find_unused_parameters=True needed because only dflash_module is trained
     model = torch.nn.parallel.DistributedDataParallel(
         model,
         device_ids=[local_rank],
         find_unused_parameters=True,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/train_dflash.py` around lines 207 - 212, The
DDP wrapper uses find_unused_parameters=True which incurs runtime overhead;
update the DistributedDataParallel instantiation
(torch.nn.parallel.DistributedDataParallel for the model variable) to include a
clear inline comment explaining that find_unused_parameters=True is required
because only dflash_module parameters are being trained (so many parameters
remain unused), and add a TODO/docs note suggesting to set it to False when all
model parameters are trained or to conditionally set this flag based on whether
only dflash_module is being optimized; keep the DDP call otherwise unchanged.
tools/launcher/common/dflash/online_training.sh (2)

124-128: Quote variables to prevent word splitting and globbing.

Multiple unquoted variables could cause issues with paths containing spaces or special characters.

Proposed fix
 bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \
-    --config ${CONFIG_FILE} \
-    --num_nodes ${NUM_NODES:-1} \
-    --head_node_ip ${HEAD_NODE_IP:-} \
-    ${OVERRIDES}
+    --config "${CONFIG_FILE}" \
+    --num_nodes "${NUM_NODES:-1}" \
+    --head_node_ip "${HEAD_NODE_IP:-}" \
+    ${OVERRIDES}  # OVERRIDES intentionally unquoted for word splitting
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/online_training.sh` around lines 124 - 128, The
invocation of launch_train.sh in online_training.sh uses unquoted shell
variables which can cause word-splitting and globbing; update the call to quote
variables like "${CONFIG_FILE}", "${NUM_NODES:-1}", "${HEAD_NODE_IP:-}" and
"${OVERRIDES}" (or handle OVERRIDES as an array if it may contain multiple args)
so the command in the launch_train.sh call is robust to spaces and special
characters.

31-31: Quote variable to prevent word splitting.

-source ${SCRIPT_DIR}/../service_utils.sh
+source "${SCRIPT_DIR}/../service_utils.sh"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/online_training.sh` at line 31, The source
command uses an unquoted SCRIPT_DIR which can cause word-splitting for paths
with spaces; update the invocation that sources service_utils.sh to quote the
SCRIPT_DIR expansion (i.e. use the quoted form of the existing source command)
so the path is treated as a single token when calling source
"${SCRIPT_DIR}/../service_utils.sh".
modelopt/torch/speculative/plugins/hf_dflash.py (2)

362-410: Mask token ID detection relies on fragile heuristics.

The auto-detection uses hardcoded offsets (26, 25, 24) and magic numbers (128002 for Llama3, vocab_size thresholds). These assumptions may break with new model versions.

Consider:

  1. Adding a warning when falling back to heuristics
  2. Documenting the known model-specific token IDs
  3. Encouraging explicit mask_token_id configuration
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 362 - 410, The
_auto_detect_mask_token_id function relies on fragile hardcoded heuristics
(offsets [26,25,24], magic 128002, vocab_size thresholds); update it to log a
warning via the module logger when any heuristic/fallback path is used (e.g.,
when returning a candidate from offsets, the vocab_size heuristics, falling back
to 128002, pad/eos or final fallback), add a short docstring comment inside
_auto_detect_mask_token_id enumerating the known model-specific IDs (Qwen mask
region, Llama3 reserved_special_token_0) and clarify they are heuristics, and
update any public-facing docs or function docstring to recommend explicitly
supplying mask_token_id in config (mention symbol base_config.mask_token_id) so
callers can avoid autodetection.

499-499: Replace unconditional print statements with proper logging.

Direct print() calls bypass the logging framework, making it harder to control output verbosity in production.

Proposed fix
+import logging
+
+logger = logging.getLogger(__name__)
+
 # In modify():
-        print(f"DFlash mask_token_id: {self.mask_token_id}")
+        logger.info(f"DFlash mask_token_id: {self.mask_token_id}")
 
 # Later:
-        print(f"DFlash: using {original_cls.__name__}.forward as base forward")
+        logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")

Also applies to: 548-548

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 499, Replace the
unconditional print statements that output the mask token id (e.g.
print(f"DFlash mask_token_id: {self.mask_token_id}")) with calls to the logging
framework; obtain a logger (module-level logging.getLogger(__name__) or reuse an
existing self.logger if the class provides one) and emit the message at an
appropriate level (debug/info). Ensure the logging import is added if missing
and update both occurrences (the one referencing self.mask_token_id and the
other at the noted second location) to use logger.debug(...) or logger.info(...)
instead of print so output respects configured log levels.
tools/launcher/common/dflash/ar_validate.sh (1)

30-31: Quote variable to prevent word splitting.

The unquoted ${SCRIPT_DIR} could cause issues with paths containing spaces.

-source ${SCRIPT_DIR}/../service_utils.sh
+source "${SCRIPT_DIR}/../service_utils.sh"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/ar_validate.sh` around lines 30 - 31, The source
invocation uses an unquoted variable which can break on paths with spaces;
update the source command to quote SCRIPT_DIR (use
"${SCRIPT_DIR}/../service_utils.sh") so the shell treats the path as a single
token and keep the trap 'error_handler $0 $LINENO' ERR unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 156-157: The bug is that _ensure_generation_tags() mutates
tokenizer.chat_template in-place, causing later apply_chat_template() calls
(used by trainer/validation) to see the mutated template; change the
implementation so it does not rewrite the shared tokenizer template in place —
instead operate on a shallow copy of the chat template (or create a new template
object/string) and assign that copy to the collator/processing_class or return
it from _ensure_generation_tags() without modifying tokenizer.chat_template;
ensure code paths that use the generated tags (answer_only_loss branch,
_ensure_generation_tags, and apply_chat_template) consume the copied template
rather than the original tokenizer.chat_template so reuse of the tokenizer
elsewhere remains unchanged.

In `@tools/launcher/common/dflash/ar_validate.sh`:
- Line 102: Replace the hardcoded internal dataset path in the load_dataset call
so it reads the dataset source from an environment variable (e.g.,
DATASET_SOURCE) with a fallback to the public "HuggingFaceH4/mt_bench_prompts";
update the line that calls load_dataset (the expression assigning ds) to use
process.env.DATASET_SOURCE (or equivalent shell/env expansion) and default to
"HuggingFaceH4/mt_bench_prompts" when the env var is not set, ensuring
portability across environments.

In `@tools/launcher/common/dflash/online_training.sh`:
- Line 228: Replace the hardcoded dataset path in the load_dataset call with an
environment-variable-driven value: read an env var (e.g., MT_BENCH_DATASET) with
a public fallback (for example "HuggingFaceH4/mt_bench_prompts") and pass that
variable into the load_dataset invocation (the line that calls
load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')). Ensure the script
documents the env var and uses the fallback when the env var is unset.

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 287-301: The code only sets data_module inside the ("eagle3",
"dflash") branch so a run with training_args.mode == "medusa" leaves data_module
undefined and crashes when passed into EagleTrainerWithAccLog; fix by
initializing data_module before the conditional (e.g. data_module = {}) and
either add an explicit branch to create a medusa-specific module (call the
appropriate builder if one exists) or raise a clear error for unsupported modes,
making sure the symbols training_args.mode, make_eagle_supervised_data_module,
and EagleTrainerWithAccLog are updated accordingly.

---

Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 317-334: The retry branch currently calls trainer.train() without
resume_from_checkpoint so the data pipeline restarts; instead call
trainer.train(resume_from_checkpoint=checkpoint) and ensure only
optimizer/scheduler state is bypassed by clearing or reinitializing those
objects after loading trainer.state: after trainer.state =
trainer.state.load_from_json(state_file) explicitly set trainer.optimizer = None
and trainer.lr_scheduler = None (or reinitialize them as appropriate) so
checkpoint is used for data/resume but optimizer/scheduler are not restored.

In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 351-356: The code shifts labels for next-token prediction
(labels[..., :-1] = input_ids[..., 1:]) but then applies the unshifted assistant
mask, causing misalignment; update the masking to shift assistant_masks into
label space before applying IGNORE_TOKEN_ID (e.g., use assistant_mask[..., :-1]
or equivalent) so that assistant token positions align with labels when setting
labels[assistant_mask == 0] = IGNORE_TOKEN_ID; adjust the block around labels,
self.answer_only_loss, and tokenized_examples["assistant_masks"] accordingly to
use the shifted mask.
- Around line 220-271: The fallback _GENERATION_TEMPLATES are text-only but are
applied to multimodal inputs by VisionLanguageDataCollator (especially when
answer_only_loss is used), causing misformatted batches; update the code that
selects or applies _GENERATION_TEMPLATES so it only uses these
string-concatenating templates for text-only messages (e.g., detect collator
type or check message['content'] is a str/does not contain block lists) and
either (a) provide multimodal-safe templates/serialization for non-text
message['content'] or (b) raise/explicitly gate with a clear error when
VisionLanguageDataCollator or answer_only_loss would feed non-str content into
_GENERATION_TEMPLATES (reference _GENERATION_TEMPLATES,
VisionLanguageDataCollator, and answer_only_loss to locate where to implement
the guard).
- Around line 391-409: The current logic in the collator (the block handling
example.get("messages") and example.get("conversations") and calling
_sharegpt_to_openai_messages) unconditionally drops samples with no assistant
turn; change it so the "no assistant turn" checks (the any(m.get("role") ==
"assistant") guards and the print_rank_0 warnings) only run when
answer_only_loss is True. Concretely, update the branches around
messages/conversations in the collator to check answer_only_loss before skipping
or warning (leave normal batching behavior intact when answer_only_loss is
False), referencing the variables/messages, conversations,
_sharegpt_to_openai_messages, and print_rank_0 to locate and modify the
conditions.

In `@tools/launcher/common/dflash/ar_eval_mtbench.sh`:
- Around line 147-163: The code currently hardcodes trust_remote_code=True when
calling AutoTokenizer.from_pretrained and AutoModelForCausalLM.from_pretrained
(with MODEL), which enables remote code execution; add a configurable flag
(e.g., a function/CLI param or env var like TRUST_REMOTE_CODE defaulting to
False) and pass that flag into both AutoTokenizer.from_pretrained(...) and
AutoModelForCausalLM.from_pretrained(...). Ensure the new flag is read early
(before tokenizer/model load) and used in the model load call that also includes
ATTN_IMPL/device_map so behavior is consistent and safe by default.
- Around line 104-123: The current python -c block injects shell variables
directly into Python literals (see MODEL, CKPT_PATH/ LAST_CKPT,
MASK_TOKEN_ID_STR, ONLINE) which can break on quotes or allow code injection;
instead pass those values via environment variables or command-line args and
read them inside the Python snippet (use os.environ.get(...) or sys.argv
parsing), convert MASK_TOKEN_ID and numeric flags (BLOCK_SIZE, NUM_LAYERS, OSL,
STEPS) to ints and ONLINE to a boolean safely, and replace the direct
interpolations in the python -c source with references to the env/argv variables
to eliminate quoting/injection issues when loading the checkpoint or
interpreting flags.

---

Nitpick comments:
In `@examples/speculative_decoding/train_dflash.py`:
- Around line 264-270: Replace the hasattr pattern with a safe getattr
extraction so acc is read atomically: use getattr(output, "train_acc", None) to
fetch train_acc, handle None by defaulting acc to 0.0, and keep using
scheduler.get_last_lr()[0] and print_rank0 for logging; update the block around
global_step / args.log_interval where acc is computed to reference
output.train_acc via getattr to avoid any race or attribute-access issues.
- Around line 102-108: The try/except in the loop around parser.parse(convs,
max_length=max_length) is swallowing errors silently; change the bare except to
catch Exception as e and log the exception (e.g., using logger.exception or
logger.error(..., exc_info=True)) before incrementing skipped so
tokenization/parsing failures are recorded for debugging while preserving the
existing behavior that appends to processed["input_ids"]/["loss_mask"] only on
success; update the block around parser.parse and the skipped increment
accordingly.
- Around line 207-212: The DDP wrapper uses find_unused_parameters=True which
incurs runtime overhead; update the DistributedDataParallel instantiation
(torch.nn.parallel.DistributedDataParallel for the model variable) to include a
clear inline comment explaining that find_unused_parameters=True is required
because only dflash_module parameters are being trained (so many parameters
remain unused), and add a TODO/docs note suggesting to set it to False when all
model parameters are trained or to conditionally set this flag based on whether
only dflash_module is being optimized; keep the DDP call otherwise unchanged.

In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 362-410: The _auto_detect_mask_token_id function relies on fragile
hardcoded heuristics (offsets [26,25,24], magic 128002, vocab_size thresholds);
update it to log a warning via the module logger when any heuristic/fallback
path is used (e.g., when returning a candidate from offsets, the vocab_size
heuristics, falling back to 128002, pad/eos or final fallback), add a short
docstring comment inside _auto_detect_mask_token_id enumerating the known
model-specific IDs (Qwen mask region, Llama3 reserved_special_token_0) and
clarify they are heuristics, and update any public-facing docs or function
docstring to recommend explicitly supplying mask_token_id in config (mention
symbol base_config.mask_token_id) so callers can avoid autodetection.
- Line 499: Replace the unconditional print statements that output the mask
token id (e.g. print(f"DFlash mask_token_id: {self.mask_token_id}")) with calls
to the logging framework; obtain a logger (module-level
logging.getLogger(__name__) or reuse an existing self.logger if the class
provides one) and emit the message at an appropriate level (debug/info). Ensure
the logging import is added if missing and update both occurrences (the one
referencing self.mask_token_id and the other at the noted second location) to
use logger.debug(...) or logger.info(...) instead of print so output respects
configured log levels.

In `@tools/launcher/common/dflash/ar_validate.sh`:
- Around line 30-31: The source invocation uses an unquoted variable which can
break on paths with spaces; update the source command to quote SCRIPT_DIR (use
"${SCRIPT_DIR}/../service_utils.sh") so the shell treats the path as a single
token and keep the trap 'error_handler $0 $LINENO' ERR unchanged.

In `@tools/launcher/common/dflash/online_training.sh`:
- Around line 124-128: The invocation of launch_train.sh in online_training.sh
uses unquoted shell variables which can cause word-splitting and globbing;
update the call to quote variables like "${CONFIG_FILE}", "${NUM_NODES:-1}",
"${HEAD_NODE_IP:-}" and "${OVERRIDES}" (or handle OVERRIDES as an array if it
may contain multiple args) so the command in the launch_train.sh call is robust
to spaces and special characters.
- Line 31: The source command uses an unquoted SCRIPT_DIR which can cause
word-splitting for paths with spaces; update the invocation that sources
service_utils.sh to quote the SCRIPT_DIR expansion (i.e. use the quoted form of
the existing source command) so the path is treated as a single token when
calling source "${SCRIPT_DIR}/../service_utils.sh".
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d4913aed-3f2b-413c-9f8a-98c6809e510e

📥 Commits

Reviewing files that changed from the base of the PR and between e45cc37 and 5f8d004.

📒 Files selected for processing (25)
  • doc/results/dflash_results.html
  • examples/speculative_decoding/README.md
  • examples/speculative_decoding/doc/dflash_results.md
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • examples/speculative_decoding/train_dflash.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/dflash/__init__.py
  • modelopt/torch/speculative/dflash/conversion.py
  • modelopt/torch/speculative/dflash/default_config.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • modelopt/torch/speculative/mode.py
  • modelopt/torch/speculative/plugins/__init__.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • modelopt/torch/speculative/utils.py
  • modelopt/torch/utils/plugins/transformers_dataset.py
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • tests/gpu/torch/speculative/plugins/test_hf_dflash.py
  • tests/unit/torch/speculative/plugins/test_hf_dflash.py
  • tools/launcher/common/dflash/ar_eval_mtbench.sh
  • tools/launcher/common/dflash/ar_validate.sh
  • tools/launcher/common/dflash/online_training.sh
  • tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
✅ Files skipped from review due to trivial changes (5)
  • examples/speculative_decoding/README.md
  • modelopt/torch/speculative/dflash/init.py
  • tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • examples/speculative_decoding/doc/dflash_results.md
🚧 Files skipped from review as they are similar to previous changes (7)
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • modelopt/torch/speculative/dflash/default_config.py
  • modelopt/torch/speculative/plugins/init.py
  • modelopt/torch/speculative/mode.py
  • modelopt/torch/speculative/utils.py
  • modelopt/torch/speculative/dflash/conversion.py
  • examples/speculative_decoding/eagle_utils.py

Comment on lines +156 to +157
if self.answer_only_loss:
self._ensure_generation_tags()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t rewrite the shared tokenizer template in place.

_ensure_generation_tags() overwrites tokenizer.chat_template on the same tokenizer instance that is later reused as the trainer/validation processing_class. In DFlash mode that changes every later apply_chat_template() call, not just collation, so evaluation and generation can silently diverge from the model’s original prompt format.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 156 - 157,
The bug is that _ensure_generation_tags() mutates tokenizer.chat_template
in-place, causing later apply_chat_template() calls (used by trainer/validation)
to see the mutated template; change the implementation so it does not rewrite
the shared tokenizer template in place — instead operate on a shallow copy of
the chat template (or create a new template object/string) and assign that copy
to the collator/processing_class or return it from _ensure_generation_tags()
without modifying tokenizer.chat_template; ensure code paths that use the
generated tags (answer_only_loss branch, _ensure_generation_tags, and
apply_chat_template) consume the copied template rather than the original
tokenizer.chat_template so reuse of the tokenizer elsewhere remains unchanged.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
examples/speculative_decoding/main.py (1)

323-324: ⚠️ Potential issue | 🟠 Major

The fallback resume still doesn't properly resume the dataloader position.

Loading trainer_state.json into trainer.state and then calling trainer.train() without resume_from_checkpoint causes the dataloader to restart from step 0. The HuggingFace Trainer only resumes dataloader position when resume_from_checkpoint is passed to trainer.train().

If the intent is a "fresh optimizer" resume (model weights + state, skip optimizer), this approach needs a different strategy—either patching the checkpoint to remove optimizer state before calling trainer.train(resume_from_checkpoint=checkpoint), or accepting that dataloader replay from step 0 is the intended behavior and documenting it clearly.

🛠️ Possible approaches

Option A: If dataloader replay is acceptable, add a comment explaining this tradeoff:

                     trainer.state = trainer.state.load_from_json(state_file)
-                trainer.train()
+                # Note: dataloader restarts from step 0; only trainer.state is restored
+                trainer.train()

Option B: For true resume, remove/rename optimizer files before retrying:

+                # Remove optimizer state to allow resume without optimizer mismatch
+                optimizer_file = os.path.join(checkpoint, "optimizer.pt")
+                if os.path.isfile(optimizer_file):
+                    os.rename(optimizer_file, optimizer_file + ".bak")
                 trainer.state = trainer.state.load_from_json(state_file)
-                trainer.train()
+                trainer.train(resume_from_checkpoint=checkpoint)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 323 - 324, The current
fallback replaces trainer.state from trainer_state.json and then calls
trainer.train(), which does not restore the dataloader position; change the
logic where trainer.state = trainer.state.load_from_json(state_file) followed by
trainer.train() to one of two fixes: (A) call
trainer.train(resume_from_checkpoint=checkpoint) so HF Trainer resumes
dataloader/step position, or (B) implement a "fresh-optimizer" resume path that
patches the checkpoint before calling
trainer.train(resume_from_checkpoint=checkpoint) by removing/renaming
optimizer-related files (optimizer.pt/scheduler state) so weights and
trainer.state are used but optimizer state is skipped; also add a short comment
explaining the chosen behavior.
🧹 Nitpick comments (3)
examples/speculative_decoding/main.py (1)

170-171: Consider moving import json to module-level.

The json import is placed inside the function. While this works, Python convention prefers module-level imports for standard library modules. This is a minor style preference.

♻️ Suggested change

Move import json to the top of the file with other imports (around line 32-47).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 170 - 171, The local
"import json" inside the function should be moved to the module-level with the
other imports at the top of examples/speculative_decoding/main.py: remove the
in-function "import json" and add "import json" alongside the other imports near
lines 32–47 so the standard-library import follows Python conventions and avoids
repeated imports; update any references to json (e.g., in the function where it
was imported) to use the top-level import.
modelopt/torch/speculative/plugins/hf_dflash.py (2)

475-478: Infer the placement device without assuming .layers[-1].

_find_base_model_parts() probes several backbone layouts, but this placement path only works when the resolved base model exposes .layers. Pulling the device from next(self._base_model.parameters()) or the embeddings module would match the broader probing logic and avoid conversion failures on other supported layouts.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 475 - 478, The
placement of the DFlashModule assumes the base model has .layers; change the
device resolution in the DFlashModule initialization so it does not rely on
self._base_model.layers[-1]. Instead, determine the target device from the base
model's parameters or embeddings—e.g., use
next(self._base_model.parameters()).device (or the resolved embeddings module if
present) when calling self.dflash_module.to(self._base_model.dtype).to(...);
update references in the DFlashModule creation sequence (DFlashModule,
self.dflash_module, self.dflash_config, and _base_model) accordingly so it
matches the probing logic in _find_base_model_parts().

48-58: Add type hints to the new plugin entry points.

This module adds several public helpers and runtime entry points without annotations, which leaves mypy blind on a config- and tensor-heavy surface. Please type the arguments and return values for helpers like build_target_layer_ids() and the main HFDFlashModel methods before this lands.

As per coding guidelines, "Ensure type hints are properly annotated for static type checking with mypy".

Also applies to: 401-402, 557-570, 761-762

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 48 - 58, Add
static type hints for the public helpers and runtime entry points so mypy can
check them: annotate build_target_layer_ids(num_target_layers: int,
num_draft_layers: int) -> list[int] (or List[int]) and annotate
apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin:
torch.Tensor) -> tuple[torch.Tensor, torch.Tensor] (or Tuple[torch.Tensor,
torch.Tensor]); likewise add type annotations to the HFDFlashModel public
methods referenced (the constructor and all methods around the 557-570 and
761-762 regions), using torch.Tensor for tensor params, int/float/bool for
scalars, and Optional[...] or List[...] where appropriate, and import typing
names (List, Optional, Tuple) as needed to satisfy mypy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/speculative_decoding/main.py`:
- Line 317: The line that reads state = json.load(open(state_file)) leaks a file
handle; change it to open the file using a context manager so the handle is
closed automatically (e.g., use with open(state_file) as f: then json.load(f)) —
update the code around the state and state_file usage (the assignment to state)
to use a with-block to ensure proper resource cleanup.

In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 48-55: The build_target_layer_ids function can produce negative or
out-of-range indices for tiny backbones; change it to special-case small models
and/or clamp every returned id into the valid range [0, num_target_layers - 1].
Specifically, in build_target_layer_ids ensure when num_target_layers < 4 you
return safe indices (e.g., center or 0..n-1) and after computing the list, map
each id to max(0, min(id, num_target_layers - 1)) so downstream logic that uses
lid + offset (the decoder embedding lookup) never receives a negative or
>=num_target_layers index.
- Around line 599-605: The forward call to the teacher/base model
(super().forward) is invoked while self.training may be True, so dropout remains
active; wrap the base-model forward in a context that sets the teacher to eval
mode (e.g., call model.eval() on the teacher/base instance) before calling
super().forward to produce base_outputs/target_hidden, and restore the original
training mode afterwards (use a try/finally or a small context manager) to
ensure deterministic hidden states during distillation while not changing the
overall module training flag.

---

Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 323-324: The current fallback replaces trainer.state from
trainer_state.json and then calls trainer.train(), which does not restore the
dataloader position; change the logic where trainer.state =
trainer.state.load_from_json(state_file) followed by trainer.train() to one of
two fixes: (A) call trainer.train(resume_from_checkpoint=checkpoint) so HF
Trainer resumes dataloader/step position, or (B) implement a "fresh-optimizer"
resume path that patches the checkpoint before calling
trainer.train(resume_from_checkpoint=checkpoint) by removing/renaming
optimizer-related files (optimizer.pt/scheduler state) so weights and
trainer.state are used but optimizer state is skipped; also add a short comment
explaining the chosen behavior.

---

Nitpick comments:
In `@examples/speculative_decoding/main.py`:
- Around line 170-171: The local "import json" inside the function should be
moved to the module-level with the other imports at the top of
examples/speculative_decoding/main.py: remove the in-function "import json" and
add "import json" alongside the other imports near lines 32–47 so the
standard-library import follows Python conventions and avoids repeated imports;
update any references to json (e.g., in the function where it was imported) to
use the top-level import.

In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 475-478: The placement of the DFlashModule assumes the base model
has .layers; change the device resolution in the DFlashModule initialization so
it does not rely on self._base_model.layers[-1]. Instead, determine the target
device from the base model's parameters or embeddings—e.g., use
next(self._base_model.parameters()).device (or the resolved embeddings module if
present) when calling self.dflash_module.to(self._base_model.dtype).to(...);
update references in the DFlashModule creation sequence (DFlashModule,
self.dflash_module, self.dflash_config, and _base_model) accordingly so it
matches the probing logic in _find_base_model_parts().
- Around line 48-58: Add static type hints for the public helpers and runtime
entry points so mypy can check them: annotate
build_target_layer_ids(num_target_layers: int, num_draft_layers: int) ->
list[int] (or List[int]) and annotate apply_rotary_pos_emb(q: torch.Tensor, k:
torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor,
torch.Tensor] (or Tuple[torch.Tensor, torch.Tensor]); likewise add type
annotations to the HFDFlashModel public methods referenced (the constructor and
all methods around the 557-570 and 761-762 regions), using torch.Tensor for
tensor params, int/float/bool for scalars, and Optional[...] or List[...] where
appropriate, and import typing names (List, Optional, Tuple) as needed to
satisfy mypy.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fd57b686-ff1c-4721-afd4-c52f8a65f250

📥 Commits

Reviewing files that changed from the base of the PR and between 5f8d004 and 3061868.

📒 Files selected for processing (2)
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/plugins/hf_dflash.py

Comment on lines +48 to +55
def build_target_layer_ids(num_target_layers, num_draft_layers):
"""Select layers uniformly from the target model for feature extraction."""
if num_draft_layers == 1:
return [num_target_layers // 2]
start = 1
end = num_target_layers - 3
span = end - start
return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Clamp target-layer selection for tiny backbones.

For num_target_layers < 4, Line 53 makes end negative. A 2-layer target with 2 draft layers returns [1, -1], and the later lid + offset at Line 608 then pulls the embedding state for the -1 slot instead of a decoder layer. Please special-case small models or clamp the result so every id stays in [0, num_target_layers - 1].

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 48 - 55, The
build_target_layer_ids function can produce negative or out-of-range indices for
tiny backbones; change it to special-case small models and/or clamp every
returned id into the valid range [0, num_target_layers - 1]. Specifically, in
build_target_layer_ids ensure when num_target_layers < 4 you return safe indices
(e.g., center or 0..n-1) and after computing the list, map each id to max(0,
min(id, num_target_layers - 1)) so downstream logic that uses lid + offset (the
decoder embedding lookup) never receives a negative or >=num_target_layers
index.

Comment on lines +599 to +605
# 1. Run base model → hidden states
with torch.no_grad():
base_outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '580,620p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1922


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '1,80p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3962


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '500,580p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3866


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '400,500p'

Repository: NVIDIA/Model-Optimizer

Length of output: 5296


🏁 Script executed:

rg "\.eval\(\)" modelopt/torch/speculative/plugins/hf_dflash.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg "\.train\(\)" modelopt/torch/speculative/plugins/hf_dflash.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '300,400p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4738


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '200,300p'

Repository: NVIDIA/Model-Optimizer

Length of output: 5251


🏁 Script executed:

rg -A 20 "class DFlashModel" modelopt/torch/speculative/dflash/dflash_model.py | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 877


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '595,650p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3089


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '1,100p' | grep -E "(\.eval|\.train|training)"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

rg "self.training" modelopt/torch/speculative/plugins/hf_dflash.py

Repository: NVIDIA/Model-Optimizer

Length of output: 169


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_dflash.py | sed -n '468,482p'

Repository: NVIDIA/Model-Optimizer

Length of output: 652


🏁 Script executed:

rg -B 5 -A 15 "torch.no_grad" modelopt/torch/speculative/plugins/hf_dflash.py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 1961


Switch the teacher base model to eval mode for the forward pass.

torch.no_grad() disables gradients but does not disable dropout. Since the model is in training mode (self.training==True at this point), stochastic layers will be active, causing target_hidden to jitter across identical batches. This destabilizes distillation training. Before this forward pass, set the base model to eval mode using a context manager, then restore training mode afterward.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 599 - 605, The
forward call to the teacher/base model (super().forward) is invoked while
self.training may be True, so dropout remains active; wrap the base-model
forward in a context that sets the teacher to eval mode (e.g., call model.eval()
on the teacher/base instance) before calling super().forward to produce
base_outputs/target_hidden, and restore the original training mode afterwards
(use a try/finally or a small context manager) to ensure deterministic hidden
states during distillation while not changing the overall module training flag.

ChenhanYu and others added 2 commits April 8, 2026 17:00
- Use Qwen3 components directly (no dynamic _resolve_model_components)
- Add sliding window attention support (config.layer_types)
- Move rotary meta buffer fix to DFlashModule._apply() with detailed docs
- Remove DFlash-specific resume code from main.py (standard resume works)
- Remove unused train_dflash.py and ar_validate.sh
- Simplify online_training.sh: direct accelerate launch, no arg parsing
- YAML uses OmegaConf overrides directly (matching eagle3 pattern)
- Update README to point to launcher example
- Add extension docs for MoE and MLA support

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Revert on_step_end AR validation to upstream (DFlash deadlocks with DDP)
- Revert checkpoint resume to upstream (load from checkpoint directly)
- Keep: answer_only_loss pass-through, accuracy console/tensorboard logging
- Document sliding window support in README and recipe YAML

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (4)
tools/launcher/common/dflash/online_training.sh (1)

42-42: ⚠️ Potential issue | 🟠 Major

Quote the package constraint to avoid shell redirection.

Line 42 is parsed by Bash as redirection (>), so pip may not receive the >=1.2.1 constraint.

Minimal fix
-pip install huggingface-hub>=1.2.1
+pip install "huggingface-hub>=1.2.1"
#!/bin/bash
set -euo pipefail
target="tools/launcher/common/dflash/online_training.sh"

echo "Inspect current command:"
nl -ba "$target" | sed -n '40,44p'

echo "Reproduce Bash parsing safely in temp dir:"
tmpdir="$(mktemp -d)"
(
  cd "$tmpdir"
  mkdir -p bin
  cat > bin/pip <<'EOF'
#!/bin/bash
printf 'pip args:\n'
for a in "$@"; do printf '  [%s]\n' "$a"; done
EOF
  chmod +x bin/pip
  PATH="$tmpdir/bin:$PATH" bash -lc 'pip install huggingface-hub>=1.2.1' || true
  echo "Created files (redirection artifact expected):"
  ls -la
)
rm -rf "$tmpdir"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/online_training.sh` at line 42, The pip install
line in online_training.sh currently uses an unquoted version specifier (`pip
install huggingface-hub>=1.2.1`) which Bash will parse as a redirection; update
that command in tools/launcher/common/dflash/online_training.sh to quote the
constraint (e.g. use pip install 'huggingface-hub>=1.2.1' or pip install
"huggingface-hub>=1.2.1") so the version operator is passed to pip rather than
treated as shell redirection.
modelopt/torch/speculative/plugins/hf_dflash.py (3)

843-856: ⚠️ Potential issue | 🟡 Minor

Debug block fails for batch size > 1.

base_token.item() on line 852 raises an error when base_token has shape [B, 1] with B > 1. This breaks the first batched call to pseudo_speculative_generate().

Proposed fix
         if not hasattr(self, "_psg_debug"):
             self._psg_debug = True
             sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids]
             th_dbg = torch.cat(sel, dim=-1)
             n_layers = len(base_outputs.hidden_states)
             th_norm = th_dbg.norm().item()
             print(
                 f"[psg] hidden layers: {n_layers}, target_hidden: {th_dbg.shape}, norm: {th_norm:.2f}"
             )
-            print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}")
+            print(f"[psg] base_token: {base_token.squeeze().tolist()}, mask_token_id: {self.mask_token_id}")
             seq_len = input_ids.shape[1]
             blk = self.dflash_block_size
             print(f"[psg] pos: ctx=[0..{seq_len - 1}], blk=[{seq_len}..{seq_len + blk - 1}]")

Or consider removing the debug prints entirely and using logging.debug() if needed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 843 - 856, The
debug block guarded by self._psg_debug calls base_token.item(), which fails for
batch sizes >1 in pseudo_speculative_generate; change the debug to handle
batched base_token (e.g., log a representative value or use
base_token.tolist()/base_token.flatten() or base_token[0].item()) or remove the
prints and use logging.debug; update the debug block around _psg_debug in
hf_dflash.py (references: _psg_debug, base_token, pseudo_speculative_generate,
target_layer_ids, dflash_block_size) so it no longer calls .item() on a batched
tensor and safely formats the token(s) for any batch size.

74-82: ⚠️ Potential issue | 🟡 Minor

Edge case: negative or out-of-order indices for small models.

For num_target_layers < 4, end becomes negative or zero, producing potentially invalid or reversed layer indices. For example, with num_target_layers=2 and num_draft_layers=2, this returns [1, -1].

Consider adding bounds clamping or a special case for small models:

Proposed fix
 def build_target_layer_ids(num_target_layers, num_draft_layers):
     """Select layers uniformly from the target model for feature extraction."""
+    if num_target_layers < 4:
+        # For tiny models, return evenly spaced indices within valid range
+        return [i * (num_target_layers - 1) // max(num_draft_layers - 1, 1) 
+                for i in range(min(num_draft_layers, num_target_layers))]
     if num_draft_layers == 1:
         return [num_target_layers // 2]
     start = 1
     end = num_target_layers - 3
     span = end - start
-    return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)]
+    ids = [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)]
+    return [max(0, min(i, num_target_layers - 1)) for i in ids]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 74 - 82, The
function build_target_layer_ids can produce negative or out-of-range indices for
small models; update it to handle small/edge cases by adding guards: if
num_target_layers <= 0 or num_draft_layers <= 0 return an empty list; if
num_target_layers < 4 treat target layers as 0..num_target_layers-1 and clip
selection accordingly (e.g., spread indices across that valid range); compute
start = max(0, 1) and end = max(0, num_target_layers-1) (or special-case when
num_draft_layers == 1 to return the middle valid index), generate indices with
the existing uniform formula but clamp each resulting index to the inclusive
range [0, num_target_layers-1] before returning to ensure no negative or
out-of-order values from build_target_layer_ids.

649-656: ⚠️ Potential issue | 🟠 Major

Teacher forward should use eval mode to disable dropout.

The base model forward runs under torch.no_grad(), which disables gradient computation but not dropout. When self.training=True, any dropout layers in the base model remain active, causing target_hidden to vary stochastically across identical inputs. This can destabilize distillation training.

Proposed fix
         # 1. Run base model → hidden states
+        # Temporarily switch to eval to disable dropout in teacher
+        was_training = self.training
         with torch.no_grad():
+            self.eval()
             base_outputs = super().forward(
                 input_ids=input_ids,
                 attention_mask=attention_mask,
                 output_hidden_states=True,
             )
+            if was_training:
+                self.train()

Note: Alternatively, only set eval mode on the base model submodules rather than self to avoid affecting dflash_module's training state.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 649 - 656, The
base model forward currently runs under torch.no_grad() but not eval(), so
dropout can remain active; wrap the call to super().forward(...) in evaluation
mode (call eval() on the base model submodule) and restore its original training
state after the call so only the base model is in eval and not the whole dflash
module—i.e., around the super().forward(...) invocation (which produces
base_outputs/target_hidden from input_ids and attention_mask with
output_hidden_states=True) set the base model to eval, call
super().forward(...), then revert the base model to its prior training flag to
avoid affecting self.training.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/hf_dflash.py (1)

516-516: Consider using logging instead of print statements.

Debug information like mask_token_id and original_cls (lines 516, 559) should use the logging module for better control over verbosity levels in production.

Proposed fix
+import logging
+
+logger = logging.getLogger(__name__)
+
 # In modify():
-        print(f"DFlash mask_token_id: {self.mask_token_id}")
+        logger.info(f"DFlash mask_token_id: {self.mask_token_id}")
         ...
-        print(f"DFlash: using {original_cls.__name__}.forward as base forward")
+        logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 516, Replace ad-hoc
print statements that output debug info (e.g., the print of self.mask_token_id
and the one referencing self.original_cls) with the Python logging module: add
import logging and a module-level logger = logging.getLogger(__name__), then
change the prints to logger.debug (or logger.info if more appropriate) inside
the class/method where they occur (references: self.mask_token_id and
self.original_cls) so verbosity can be controlled in production.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tools/launcher/common/dflash/online_training.sh`:
- Around line 77-87: When NUM_NODES != "1" ensure HEAD_NODE_IP is validated
before building MULTI_NODE_ARGS: check that HEAD_NODE_IP is non-empty and a
plausible IP/hostname (non-empty string) after the auto-detection logic, and if
it is empty print a clear error message (including value of NUM_NODES and hint
about missing detection/SLURM variables) to stderr and exit with non-zero status
to fail fast; update the multi-node branch that constructs MULTI_NODE_ARGS
(referencing MULTI_NODE_ARGS, NUM_NODES, GPU_PER_NODE, SLURM_PROCID, and
HEAD_NODE_IP) to perform this validation immediately before using HEAD_NODE_IP.
- Around line 41-47: Move the failure handling and enable fail-fast before any
package installs: add "set -e" (or "set -o errexit") and install the trap
invocation for error_handler (the line "trap 'error_handler $0 $LINENO' ERR")
before the two pip install lines (the commands that install requirements and
huggingface-hub). This ensures any failure in the pip install commands triggers
error_handler immediately and the script exits instead of continuing into
training with partial dependencies.

---

Duplicate comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 843-856: The debug block guarded by self._psg_debug calls
base_token.item(), which fails for batch sizes >1 in
pseudo_speculative_generate; change the debug to handle batched base_token
(e.g., log a representative value or use
base_token.tolist()/base_token.flatten() or base_token[0].item()) or remove the
prints and use logging.debug; update the debug block around _psg_debug in
hf_dflash.py (references: _psg_debug, base_token, pseudo_speculative_generate,
target_layer_ids, dflash_block_size) so it no longer calls .item() on a batched
tensor and safely formats the token(s) for any batch size.
- Around line 74-82: The function build_target_layer_ids can produce negative or
out-of-range indices for small models; update it to handle small/edge cases by
adding guards: if num_target_layers <= 0 or num_draft_layers <= 0 return an
empty list; if num_target_layers < 4 treat target layers as
0..num_target_layers-1 and clip selection accordingly (e.g., spread indices
across that valid range); compute start = max(0, 1) and end = max(0,
num_target_layers-1) (or special-case when num_draft_layers == 1 to return the
middle valid index), generate indices with the existing uniform formula but
clamp each resulting index to the inclusive range [0, num_target_layers-1]
before returning to ensure no negative or out-of-order values from
build_target_layer_ids.
- Around line 649-656: The base model forward currently runs under
torch.no_grad() but not eval(), so dropout can remain active; wrap the call to
super().forward(...) in evaluation mode (call eval() on the base model
submodule) and restore its original training state after the call so only the
base model is in eval and not the whole dflash module—i.e., around the
super().forward(...) invocation (which produces base_outputs/target_hidden from
input_ids and attention_mask with output_hidden_states=True) set the base model
to eval, call super().forward(...), then revert the base model to its prior
training flag to avoid affecting self.training.

In `@tools/launcher/common/dflash/online_training.sh`:
- Line 42: The pip install line in online_training.sh currently uses an unquoted
version specifier (`pip install huggingface-hub>=1.2.1`) which Bash will parse
as a redirection; update that command in
tools/launcher/common/dflash/online_training.sh to quote the constraint (e.g.
use pip install 'huggingface-hub>=1.2.1' or pip install
"huggingface-hub>=1.2.1") so the version operator is passed to pip rather than
treated as shell redirection.

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Line 516: Replace ad-hoc print statements that output debug info (e.g., the
print of self.mask_token_id and the one referencing self.original_cls) with the
Python logging module: add import logging and a module-level logger =
logging.getLogger(__name__), then change the prints to logger.debug (or
logger.info if more appropriate) inside the class/method where they occur
(references: self.mask_token_id and self.original_cls) so verbosity can be
controlled in production.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9d2ccb5b-e089-4948-8b21-8672dcbead83

📥 Commits

Reviewing files that changed from the base of the PR and between 3061868 and c694c53.

📒 Files selected for processing (5)
  • examples/speculative_decoding/README.md
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • tools/launcher/common/dflash/online_training.sh
  • tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
✅ Files skipped from review due to trivial changes (2)
  • examples/speculative_decoding/README.md
  • tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/speculative_decoding/main.py

Comment on lines +77 to +87
if [[ "$NUM_NODES" != "1" ]]; then
GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \
--num_machines $NUM_NODES \
--machine_rank $SLURM_PROCID \
--rdzv_backend c10d \
--main_process_ip $HEAD_NODE_IP \
--main_process_port 29500"
else
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail
target="tools/launcher/common/dflash/online_training.sh"
rg -n 'HEAD_NODE_IP|main_process_ip|NUM_NODES' "$target" -n -C3

Repository: NVIDIA/Model-Optimizer

Length of output: 2312


Add validation for HEAD_NODE_IP in multi-node mode to fail fast.

The auto-detection logic (lines 49–72) attempts to resolve HEAD_NODE_IP but does not validate success. If all detection methods fail (e.g., in containerized environments or when SLURM variables are absent), the variable remains empty. The multi-node branch at line 77 then passes --main_process_ip $HEAD_NODE_IP to accelerate with an empty value, causing unclear failures downstream.

Suggested guard
 if [[ "$NUM_NODES" != "1" ]]; then
+    if [[ -z "${HEAD_NODE_IP:-}" ]]; then
+        echo "[ERROR] HEAD_NODE_IP is empty for NUM_NODES=$NUM_NODES" >&2
+        exit 1
+    fi
     GPU_PER_NODE=${GPU_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
     TOTAL_GPU=$((NUM_NODES * GPU_PER_NODE))
     echo "Total GPUs: $TOTAL_GPU (NUM_NODES: $NUM_NODES, GPU_PER_NODE: $GPU_PER_NODE)"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tools/launcher/common/dflash/online_training.sh` around lines 77 - 87, When
NUM_NODES != "1" ensure HEAD_NODE_IP is validated before building
MULTI_NODE_ARGS: check that HEAD_NODE_IP is non-empty and a plausible
IP/hostname (non-empty string) after the auto-detection logic, and if it is
empty print a clear error message (including value of NUM_NODES and hint about
missing detection/SLURM variables) to stderr and exit with non-zero status to
fail fast; update the multi-node branch that constructs MULTI_NODE_ARGS
(referencing MULTI_NODE_ARGS, NUM_NODES, GPU_PER_NODE, SLURM_PROCID, and
HEAD_NODE_IP) to perform this validation immediately before using HEAD_NODE_IP.

ChenhanYu and others added 2 commits April 8, 2026 19:46
- Consolidate dflash_results.md into comprehensive dflash.md
- Simplify ar_validate.py: online GT as default, per-category support
- Simplify ar_eval_mtbench.sh: calls ar_validate.py instead of inline Python
- Error on unsupported mask_token_id instead of falling back to pad/eos
- Add sliding window, FP8/NVFP4, offline training, MLA docs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- export.sh: standalone checkpoint export to z-lab format
- ptq_and_export.sh: FP8/NVFP4 quantization via hf_ptq.py
- Fix rope_theta export (prefer draft_config over base_config)
- Document vLLM integration gap, FP8/NVFP4 flow in dflash.md

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant