Simplify KDTrainer and enhance ModelOptHFTrainer#1191
Simplify KDTrainer and enhance ModelOptHFTrainer#1191realAsma wants to merge 2 commits intoasma/new-qat-1from
Conversation
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. 🗂️ Base branches to auto review (3)
Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## asma/new-qat-1 #1191 +/- ##
==================================================
- Coverage 76.71% 75.24% -1.48%
==================================================
Files 352 352
Lines 40473 40699 +226
==================================================
- Hits 31050 30624 -426
- Misses 9423 10075 +652
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
e349a90 to
b762870
Compare
97759a4 to
bfc343c
Compare
Replaces launch.sh with YAML-driven configs, adds ModelOptArgParser with --config support, and moves dataset processing params from blend YAML to DataArguments for full CLI overrideability. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
- Remove gradient_checkpointing from FSDP2 configs (conflicts with fsdp_activation_checkpointing) - Increase default model_max_length from 4096 to 8192 - Add attn_implementation passthrough to quantize.py - Fix quantize.py recipe.ptq_cfg -> recipe.quantize - KDTrainer accepts teacher_model as explicit nn.Module kwarg instead of via DistillArguments - Refactor test_llm_qat.py to use --config YAML style with FAST_DATA_ARGS/FAST_TRAIN_ARGS - Switch QAD test from Llama to Qwen3 for consistency Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
cc45203 to
9dd1732
Compare
Summary
This PR simplifies the HuggingFace knowledge distillation trainer and enhances the base
ModelOptHFTrainerwith Liger fused loss, per-parameter learning rates, and training utilities.Model-agnostic Liger kernel fused loss
Adds custom Liger kernel integration in
ModelOptHFTrainerthat extends HuggingFace's built-in support in three ways:lm_head, unlike HF's Liger which only supports a fixed set of model architectures.KDTrainerextends fused loss to knowledge distillation viaLigerFusedLinearJSDfor fused lm_head + Jensen-Shannon divergence.ModelOptHFTrainer enhancements
ModelOptTrainerArgumentswith--trainable_params,--frozen_params,--lr_config,--save_dtype, and--manual_gcflagslr_config)_prepare_modeland_update_config_json_dtypepromoted to base classKDTrainer simplification
Removes
mtd.convert()and theDistillationModelin-place class-swap for the HF path. The teacher model now lives directly on the trainer and is forwarded explicitly insidecompute_kd_loss_func. This eliminates:mtd.convert()in-place class swap and DynamicModule wrappinghide_teacher_model/hide_loss_modulescontext managers for checkpointingsave_modelandQADTrainer._quantize_modeloverridesOnly logit-level distillation is supported for the HF path. The core
DistillationModel/mtd.convert()API remains for Megatron and advanced intermediate-layer distillation use cases.Test plan
pytest tests/unit/torch/distill/(29 passed)pytest tests/unit/torch/opt/plugins/test_hf_patching.py(2 passed)pytest tests/unit/torch/opt/plugins/test_lr_config.pypytest tests/examples/llm_qat/(QAT, QAD, LoRA QAT, QLoRA)pytest tests/examples/llm_distill/🤖 Generated with Claude Code