Skip to content

Simplify KDTrainer and enhance ModelOptHFTrainer#1191

Open
realAsma wants to merge 2 commits intoasma/new-qat-1from
asma/new-qat-2
Open

Simplify KDTrainer and enhance ModelOptHFTrainer#1191
realAsma wants to merge 2 commits intoasma/new-qat-1from
asma/new-qat-2

Conversation

@realAsma
Copy link
Copy Markdown
Contributor

@realAsma realAsma commented Apr 7, 2026

Summary

This PR simplifies the HuggingFace knowledge distillation trainer and enhances the base ModelOptHFTrainer with Liger fused loss, per-parameter learning rates, and training utilities.

Model-agnostic Liger kernel fused loss

Adds custom Liger kernel integration in ModelOptHFTrainer that extends HuggingFace's built-in support in three ways:

  1. Model-agnostic: Works with any causal LM that has an lm_head, unlike HF's Liger which only supports a fixed set of model architectures.
  2. DeepSpeed ZeRO-3 support: HF's Liger integration only works with FSDP. ModelOpt adds distributed param gathering for DeepSpeed ZeRO-3 and DDP as well.
  3. KD loss support: KDTrainer extends fused loss to knowledge distillation via LigerFusedLinearJSD for fused lm_head + Jensen-Shannon divergence.

ModelOptHFTrainer enhancements

  • ModelOptTrainerArguments with --trainable_params, --frozen_params, --lr_config, --save_dtype, and --manual_gc flags
  • Per-parameter learning rate support via YAML config (lr_config)
  • _prepare_model and _update_config_json_dtype promoted to base class

KDTrainer simplification

Removes mtd.convert() and the DistillationModel in-place class-swap for the HF path. The teacher model now lives directly on the trainer and is forwarded explicitly inside compute_kd_loss_func. This eliminates:

  • mtd.convert() in-place class swap and DynamicModule wrapping
  • Forward hooks for capturing intermediate outputs
  • hide_teacher_model / hide_loss_modules context managers for checkpointing
  • Deferred initialization branching (FSDP2 vs DDP/DeepSpeed)
  • save_model and QADTrainer._quantize_model overrides

Only logit-level distillation is supported for the HF path. The core DistillationModel/mtd.convert() API remains for Megatron and advanced intermediate-layer distillation use cases.

Test plan

  • pytest tests/unit/torch/distill/ (29 passed)
  • pytest tests/unit/torch/opt/plugins/test_hf_patching.py (2 passed)
  • pytest tests/unit/torch/opt/plugins/test_lr_config.py
  • Pre-commit hooks pass
  • GPU example tests: pytest tests/examples/llm_qat/ (QAT, QAD, LoRA QAT, QLoRA)
  • GPU distill example: pytest tests/examples/llm_distill/

🤖 Generated with Claude Code

@realAsma realAsma requested review from a team as code owners April 7, 2026 21:40
@realAsma realAsma requested review from Edwardf0t1 and removed request for a team April 7, 2026 21:40
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2026

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

🗂️ Base branches to auto review (3)
  • main
  • release/.*
  • feature/.*

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9eccaec1-ad0a-4208-89bb-e864a5c20749

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch asma/new-qat-2

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma requested review from ChenhanYu and shengliangxu and removed request for a team April 7, 2026 21:40
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 7, 2026

Codecov Report

❌ Patch coverage is 20.25723% with 248 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.24%. Comparing base (b3d46c7) to head (99da38e).

Files with missing lines Patch % Lines
modelopt/torch/opt/plugins/transformers.py 20.28% 165 Missing ⚠️
modelopt/torch/distill/plugins/huggingface.py 20.38% 82 Missing ⚠️
...torch/quantization/plugins/transformers_trainer.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@                Coverage Diff                 @@
##           asma/new-qat-1    #1191      +/-   ##
==================================================
- Coverage           76.71%   75.24%   -1.48%     
==================================================
  Files                 352      352              
  Lines               40473    40699     +226     
==================================================
- Hits                31050    30624     -426     
- Misses               9423    10075     +652     
Flag Coverage Δ
examples 41.50% <20.25%> (-2.61%) ⬇️
gpu 56.59% <20.25%> (-0.21%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@realAsma realAsma force-pushed the asma/new-qat-2 branch 2 times, most recently from e349a90 to b762870 Compare April 7, 2026 22:36
@realAsma realAsma force-pushed the asma/new-qat-1 branch 2 times, most recently from 97759a4 to bfc343c Compare April 8, 2026 16:03
Replaces launch.sh with YAML-driven configs, adds ModelOptArgParser
with --config support, and moves dataset processing params from blend
YAML to DataArguments for full CLI overrideability.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
- Remove gradient_checkpointing from FSDP2 configs (conflicts with fsdp_activation_checkpointing)
- Increase default model_max_length from 4096 to 8192
- Add attn_implementation passthrough to quantize.py
- Fix quantize.py recipe.ptq_cfg -> recipe.quantize
- KDTrainer accepts teacher_model as explicit nn.Module kwarg instead of via DistillArguments
- Refactor test_llm_qat.py to use --config YAML style with FAST_DATA_ARGS/FAST_TRAIN_ARGS
- Switch QAD test from Llama to Qwen3 for consistency

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma realAsma force-pushed the asma/new-qat-1 branch 3 times, most recently from cc45203 to 9dd1732 Compare April 9, 2026 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant