fix: megatron export correctness for TP>1 GQA, single-file MTP, and Hub remote code#1209
Conversation
…ub remote code - _qkv_slicing: weight tensor is TP-sharded (shape[0] = per-rank heads * head_size), but all reshape/slice operations used global qkv_total_dim. Derive per_rank_qkv_dim and num_query_groups_local from the actual tensor shape so reshapes and arange slices are correct at any TP degree. - _get_mtp_state_dict: hf_hub_download raised EntryNotFoundError for small models that ship a single model.safetensors instead of a sharded index. Fall back to downloading model.safetensors directly and scan its keys with safe_open when the index file is absent, for both Hub IDs and local dirs. - copy_remote_code: raised ValueError for Hub model IDs because it only accepted local directory paths. Use list_repo_files + hf_hub_download to fetch top-level .py files (custom modeling code) from the Hub when the path is not a local directory. Add tests for all three fixes in tests/gpu_megatron/torch/export/. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
📝 WalkthroughWalkthroughThe changes extend Changes
Sequence DiagramssequenceDiagram
participant Client
participant Local["Local Filesystem"]
participant HFHub["HF Hub API"]
participant SaveDir["Save Directory"]
Client->>Client: check if pretrained_model_path.is_dir()
alt Local Directory
Client->>Local: glob top-level *.py files
Local-->>Client: file list
Client->>Local: read each .py file
Local-->>Client: file contents
else Hub Model ID
Client->>HFHub: list_repo_files(repo_id)
HFHub-->>Client: all files in repo
Client->>Client: filter top-level *.py files
loop For each top-level .py file
Client->>HFHub: hf_hub_download(repo_id, file_path)
HFHub-->>Client: downloaded file
end
end
Client->>SaveDir: copy all collected .py files
SaveDir-->>Client: copy complete
sequenceDiagram
participant Exporter["GPTModelExporter"]
participant Local["Disk"]
participant HFHub["HF Hub API"]
participant SafeTensors["safetensors"]
Exporter->>Exporter: check _hf_pretrained_model_name set?
alt Not Set
Exporter-->>Exporter: early return empty dict
else Set
Exporter->>Local: check model.safetensors exists?
alt Single File
Exporter->>SafeTensors: safe_open(model.safetensors)
SafeTensors-->>Exporter: file handle
Exporter->>SafeTensors: iterate keys, get_tensor(key)
SafeTensors-->>Exporter: mtp.* tensors
else Index File
Exporter->>HFHub: download model.safetensors.index.json
HFHub-->>Exporter: index metadata
Exporter->>Exporter: enumerate mtp.* entries from index
loop For each mtp.* tensor
Exporter->>HFHub: hf_hub_download(shard_file)
HFHub-->>Exporter: shard with mtp.* tensors
Exporter->>SafeTensors: safe_open(shard), get_tensor(key)
SafeTensors-->>Exporter: tensor data
end
else Neither Exists
Exporter-->>Exporter: return empty dict
end
Exporter->>Exporter: append "mtp*" to exclude_modules (if mtp_exists)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsTimed out fetching pipeline failures after 30000ms Comment |
|
There was a problem hiding this comment.
🧹 Nitpick comments (3)
tests/gpu_megatron/torch/export/test_unified_export_megatron.py (1)
296-302: Unusual but acceptable test pattern usingobject.__new__.Using
object.__new__(GPTModelExporter)bypasses the constructor to create a minimal exporter for testing_get_mtp_state_dictin isolation. This is a valid testing technique but is fragile—if_get_mtp_state_dictstarts depending on additional attributes initialized in__init__, tests will break without clear indication why.Consider adding a brief comment explaining why this pattern is used, e.g.:
# Bypass __init__ to test _get_mtp_state_dict in isolation without requiring # a real GPTModel and HuggingFace config loading.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu_megatron/torch/export/test_unified_export_megatron.py` around lines 296 - 302, The test helper _make_exporter_for_mtp currently uses object.__new__(GPTModelExporter) to bypass GPTModelExporter.__init__ for isolating _get_mtp_state_dict; add a short comment above that line explaining this intent (e.g., "Bypass __init__ to test _get_mtp_state_dict in isolation without creating a real GPTModel or loading HF config") so future maintainers understand why object.__new__ is used and the fragility of this pattern; keep the rest of the helper (setting _hf_pretrained_model_name, _state_dict, exclude_modules) unchanged.modelopt/torch/export/plugins/hf_checkpoint_utils.py (1)
53-59: Consider adding error handling for Hub API failures.When
list_repo_filesorhf_hub_downloadfail (e.g., network issues, invalid repo ID, authentication required), the exception will propagate up uncaught. While this may be acceptable, consider whether you want to:
- Catch specific exceptions and provide clearer error messages, or
- Document in the docstring that exceptions from the Hub API may be raised.
This is a minor consideration since the caller can handle exceptions, but explicit error handling would improve robustness.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/hf_checkpoint_utils.py` around lines 53 - 59, The code calls list_repo_files and hf_hub_download without handling Hub/API failures; wrap the hub calls (the loop using list_repo_files(repo_id) and hf_hub_download(repo_id=repo_id, filename=filename)) in a try/except that catches relevant hub/network exceptions (e.g., requests.exceptions.RequestException or the hub client’s specific exception class) and either re-raises a more informative exception that includes repo_id and filename or logs a clear error and continues/skips problematic files; alternatively add a docstring note to the function referencing that list_repo_files and hf_hub_download may raise hub/network exceptions so callers can handle them.modelopt/torch/export/unified_export_megatron.py (1)
550-569: Potential silent failure on unexpected Hub errors.When downloading from the Hub,
EntryNotFoundErroris caught and handled as a fallback. However, other Hub errors (network issues, authentication, rate limiting) will propagate up. This is likely the intended behavior, but consider whether you want to catch and wrap these errors with more context about what was being attempted.Additionally, if both
model.safetensors.index.jsonandmodel.safetensorsare missing (line 568-569), the function returns an empty dict silently. Consider logging a debug message to aid troubleshooting.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/unified_export_megatron.py` around lines 550 - 569, The code currently only handles EntryNotFoundError when calling hf_hub_download for model.safetensors.index.json and model.safetensors; update the try/except around hf_hub_download calls (referencing hf_hub_download, EntryNotFoundError, safetensors_index_file, single_safetensors_file) to also catch and re-raise other hub-related exceptions with additional context (e.g., wrap HubError/Exception with a message specifying the repo and filename attempted) so upstream callers get actionable info, and add a debug/log statement right before returning mtp_state_dict when both files are missing to record that neither safetensors index nor single file was found.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelopt/torch/export/plugins/hf_checkpoint_utils.py`:
- Around line 53-59: The code calls list_repo_files and hf_hub_download without
handling Hub/API failures; wrap the hub calls (the loop using
list_repo_files(repo_id) and hf_hub_download(repo_id=repo_id,
filename=filename)) in a try/except that catches relevant hub/network exceptions
(e.g., requests.exceptions.RequestException or the hub client’s specific
exception class) and either re-raises a more informative exception that includes
repo_id and filename or logs a clear error and continues/skips problematic
files; alternatively add a docstring note to the function referencing that
list_repo_files and hf_hub_download may raise hub/network exceptions so callers
can handle them.
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 550-569: The code currently only handles EntryNotFoundError when
calling hf_hub_download for model.safetensors.index.json and model.safetensors;
update the try/except around hf_hub_download calls (referencing hf_hub_download,
EntryNotFoundError, safetensors_index_file, single_safetensors_file) to also
catch and re-raise other hub-related exceptions with additional context (e.g.,
wrap HubError/Exception with a message specifying the repo and filename
attempted) so upstream callers get actionable info, and add a debug/log
statement right before returning mtp_state_dict when both files are missing to
record that neither safetensors index nor single file was found.
In `@tests/gpu_megatron/torch/export/test_unified_export_megatron.py`:
- Around line 296-302: The test helper _make_exporter_for_mtp currently uses
object.__new__(GPTModelExporter) to bypass GPTModelExporter.__init__ for
isolating _get_mtp_state_dict; add a short comment above that line explaining
this intent (e.g., "Bypass __init__ to test _get_mtp_state_dict in isolation
without creating a real GPTModel or loading HF config") so future maintainers
understand why object.__new__ is used and the fragility of this pattern; keep
the rest of the helper (setting _hf_pretrained_model_name, _state_dict,
exclude_modules) unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 65457d00-85e9-4fe6-8e0e-19079d1cfc53
📒 Files selected for processing (4)
modelopt/torch/export/plugins/hf_checkpoint_utils.pymodelopt/torch/export/unified_export_megatron.pytests/gpu_megatron/torch/export/test_hf_checkpoint_utils.pytests/gpu_megatron/torch/export/test_unified_export_megatron.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1209 +/- ##
==========================================
+ Coverage 71.65% 77.02% +5.37%
==========================================
Files 353 353
Lines 40355 40382 +27
==========================================
+ Hits 28915 31105 +2190
+ Misses 11440 9277 -2163
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
ChenhanYu
left a comment
There was a problem hiding this comment.
PR Review: fix: megatron export correctness for TP>1 GQA, single-file MTP, and Hub remote code
Three well-scoped correctness fixes with thorough test coverage. LGTM.
Minor suggestions (non-blocking)
1. Assert num_query_groups_local divides evenly (unified_export_megatron.py:1023)
num_query_groups_local = num_query_groups * per_rank_qkv_dim // qkv_total_dimRelies on integer division being exact. Consider adding:
assert num_query_groups * per_rank_qkv_dim % qkv_total_dim == 0, \
f"num_query_groups={num_query_groups} not evenly divisible across TP"Catches misconfigured TP/GQA combos early instead of silently producing wrong slices.
2. copy_remote_code error behavior changed (hf_checkpoint_utils.py:48-52)
Invalid paths now raise RepositoryNotFoundError instead of the previous explicit ValueError. Fine as fail-fast behavior, but worth noting for callers that may have been catching ValueError.
3. Nested try/except in _get_mtp_state_dict (unified_export_megatron.py:549-568)
Double-nested EntryNotFoundError handling works but is slightly hard to follow. A flat helper pattern would improve readability. Minor nit.
4. Copyright year 2024 on new test_hf_checkpoint_utils.py — should be 2026.
Test coverage is excellent — TP=2 GQA with FP8, single/sharded/no-MTP safetensors, and Hub ID mocking.
This is an AI-assisted review — human sign-off required before merging.
…ub remote code (#1209) ### What does this PR do? Type of change: Bug fix Three correctness fixes for the Megatron Core GPT export pipeline: **1. `_qkv_slicing`: reshape failure with TP>1 on GQA models** When tensor parallelism is enabled, the `linear_qkv` weight tensor arriving in `_qkv_slicing` is already TP-sharded, so `weight.shape[0]` equals `per_rank_qkv_dim * head_size`, not `qkv_total_dim * head_size`. All five reshape/`arange` operations were using the global `qkv_total_dim`, causing a runtime shape mismatch for any GQA model with TP > 1. The fix derives `per_rank_qkv_dim` and `num_query_groups_local` from the actual tensor shape, making the logic correct for any TP degree (a no-op for TP=1). **2. `_get_mtp_state_dict`: `EntryNotFoundError` for non-sharded models** `hf_hub_download("model.safetensors.index.json")` raises `EntryNotFoundError` for small models that ship a single `model.safetensors` rather than a sharded index. The function now catches this and falls back to downloading/reading `model.safetensors` directly, scanning its keys with `safe_open`. The same two-path logic applies to local directories. **3. `copy_remote_code`: `ValueError` for Hub model IDs** `copy_remote_code` only accepted local directory paths and raised `ValueError` for HuggingFace Hub model IDs (e.g. `"meta-llama/Llama-3.2-1B"`). The function now falls back to `list_repo_files` + `hf_hub_download` to fetch and copy top-level `.py` files (custom modeling code) when the path is not a local directory. ### Usage ```python # TP>1 GQA export now works (previously raised RuntimeError on reshape) export_mcore_gpt_to_hf(gqa_model, "meta-llama/Llama-3.2-1B", export_dir="./out", dtype=torch.bfloat16) # Models with a single model.safetensors now have their MTP weights exported export_mcore_gpt_to_hf(model, "./small_model_dir", export_dir="./out", dtype=torch.bfloat16) # Hub model IDs no longer raise ValueError in copy_remote_code export_mcore_gpt_to_hf(model, "org/custom-model-with-remote-code", export_dir="./out", dtype=torch.bfloat16) ``` ### Testing New tests added in `tests/gpu_megatron/torch/export/`: - `test_unified_export_megatron.py::test_qkv_slicing_gqa_tp2` — FP8-quantized GQA model export with TP=2 (`num_query_groups=2 < num_attention_heads=8`), exercises both the weight reshape and per-channel weight-scale reshape paths. - `test_unified_export_megatron.py::test_mtp_state_dict_single_safetensors` — unit test verifying MTP weights are collected from a single `model.safetensors` file. - `test_unified_export_megatron.py::test_mtp_state_dict_index_file` — unit test verifying MTP weights are collected from a sharded checkpoint. - `test_unified_export_megatron.py::test_mtp_state_dict_no_mtp_keys` — edge case: no MTP keys → empty dict, no side effects. - `test_hf_checkpoint_utils.py` — four tests covering `copy_remote_code` for local directories and Hub model IDs (with and without `.py` files). ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A ### Additional Information Fixes reported against Megatron export when running quantization with TP>1, small non-sharded HF models, and HuggingFace Hub model IDs passed to `export_mcore_gpt_to_hf`. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Export functionality now supports downloading code directly from Hugging Face Hub model repositories in addition to local directories. * **Bug Fixes** * Improved safetensors loading with better error handling for missing model entries and support for both single and sharded weight files. * Enhanced tensor slicing behavior for multi-GPU distributed export scenarios. * **Tests** * Added comprehensive test coverage for Hugging Face integration and export functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
What does this PR do?
Type of change: Bug fix
Three correctness fixes for the Megatron Core GPT export pipeline:
1.
_qkv_slicing: reshape failure with TP>1 on GQA modelsWhen tensor parallelism is enabled, the
linear_qkvweight tensor arriving in_qkv_slicingis already TP-sharded, soweight.shape[0]equalsper_rank_qkv_dim * head_size, notqkv_total_dim * head_size. All five reshape/arangeoperations were using the globalqkv_total_dim, causing a runtime shape mismatch for any GQA model with TP > 1. The fix derivesper_rank_qkv_dimandnum_query_groups_localfrom the actual tensor shape, making the logic correct for any TP degree (a no-op for TP=1).2.
_get_mtp_state_dict:EntryNotFoundErrorfor non-sharded modelshf_hub_download("model.safetensors.index.json")raisesEntryNotFoundErrorfor small models that ship a singlemodel.safetensorsrather than a sharded index. The function now catches this and falls back to downloading/readingmodel.safetensorsdirectly, scanning its keys withsafe_open. The same two-path logic applies to local directories.3.
copy_remote_code:ValueErrorfor Hub model IDscopy_remote_codeonly accepted local directory paths and raisedValueErrorfor HuggingFace Hub model IDs (e.g."meta-llama/Llama-3.2-1B"). The function now falls back tolist_repo_files+hf_hub_downloadto fetch and copy top-level.pyfiles (custom modeling code) when the path is not a local directory.Usage
Testing
New tests added in
tests/gpu_megatron/torch/export/:test_unified_export_megatron.py::test_qkv_slicing_gqa_tp2— FP8-quantized GQA model export with TP=2 (num_query_groups=2 < num_attention_heads=8), exercises both the weight reshape and per-channel weight-scale reshape paths.test_unified_export_megatron.py::test_mtp_state_dict_single_safetensors— unit test verifying MTP weights are collected from a singlemodel.safetensorsfile.test_unified_export_megatron.py::test_mtp_state_dict_index_file— unit test verifying MTP weights are collected from a sharded checkpoint.test_unified_export_megatron.py::test_mtp_state_dict_no_mtp_keys— edge case: no MTP keys → empty dict, no side effects.test_hf_checkpoint_utils.py— four tests coveringcopy_remote_codefor local directories and Hub model IDs (with and without.pyfiles).Before your PR is "Ready for review"
CONTRIBUTING.md: N/AAdditional Information
Fixes reported against Megatron export when running quantization with TP>1, small non-sharded HF models, and HuggingFace Hub model IDs passed to
export_mcore_gpt_to_hf.Summary by CodeRabbit
New Features
Bug Fixes
Tests