Skip to content

Add Qwen3VL#895

Open
hychiang-git wants to merge 14 commits intoNVIDIA:mainfrom
eigen-ai-labs:main
Open

Add Qwen3VL#895
hychiang-git wants to merge 14 commits intoNVIDIA:mainfrom
eigen-ai-labs:main

Conversation

@hychiang-git
Copy link

@hychiang-git hychiang-git commented Feb 16, 2026

What does this PR do?

new feature:

Overview: Add Qwen3-VL (Vision-Language) model support to the Megatron Core export/import
plugin, enabling HuggingFace-to-mcore weight conversion for PTQ/QAT/QAD workflows

Details

Qwen3-VL has a different weight structure from Qwen3 text-only models:

  • Language model weights are under model.language_model. prefix (not model.)
  • Visual encoder weights are under model.visual. prefix
  • The lm_head is at root level, not nested under language_model

This PR adds:

  • mcore_qwen3vl.py: Import/export weight mapping rules between HuggingFace
    Qwen3VLForConditionalGeneration and Megatron Core, handling the language_model prefix for
    all decoder layers, QKV merging/slicing, gated MLP merging/slicing, Q/K layer norms.
  • mcore_common.py: Registers Qwen3VLForConditionalGeneration in
    all_mcore_hf_export_mapping and all_mcore_hf_import_mapping.

Usage

  • Import Qwen3-VL from HuggingFace to MCore, and export the MCore model
#!/usr/bin/env python3
"""Minimal example: Load Qwen3-VL with visual encoder from HF + language model via mcore mapping.

This script demonstrates the two-step loading process for Qwen3-VL:
  1. Visual encoder: loaded from HuggingFace directly in Qwen3VLModel.__init__
  2. Language model: imported via import_mcore_gpt_from_hf (uses mcore_qwen3vl.py mapping)

Usage (single GPU):
    python load_qwen3vl_example.py

Usage (multi-GPU with TP=2):
    torchrun --nproc_per_node=2 load_qwen3vl_example.py \
        --tensor-model-parallel-size 2

Requirements:
    pip install torch "transformers>=4.45,<5" flash-attn nvidia-modelopt safetensors
"""

import os
import sys

import torch

# Add Megatron-LM to path (adjust as needed)
MEGATRON_PATH = os.environ.get(
    "MEGATRON_PATH",
    os.path.join(os.path.dirname(__file__), "Megatron-LM-public"),
)
if os.path.isdir(MEGATRON_PATH):
    sys.path.insert(0, MEGATRON_PATH)

# For single-GPU runs (no torchrun), set distributed env vars
if "MASTER_ADDR" not in os.environ:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29501"
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"

# Qwen3-VL-8B architecture args — injected into sys.argv so Megatron's
# argparse picks them up correctly (args_defaults gets overridden).
QWEN3VL_8B_ARGS = [
    "--num-layers", "36",
    "--hidden-size", "4096",
    "--ffn-hidden-size", "12288",
    "--num-attention-heads", "32",
    "--group-query-attention",
    "--num-query-groups", "8",
    "--seq-length", "4096",
    "--max-position-embeddings", "32768",
    "--norm-epsilon", "1e-6",
    "--swiglu",
    "--bf16",
    "--untie-embeddings-and-output-weights",
    "--position-embedding-type", "rope",
    "--rotary-base", "1000000",
    "--normalization", "RMSNorm",
    "--qk-layernorm",
    "--disable-bias-linear",
    "--img-h", "384",
    "--img-w", "384",
    "--micro-batch-size", "1",
    "--tokenizer-type", "HuggingFaceTokenizer",
    "--tokenizer-model", "Qwen/Qwen3-VL-8B-Instruct",
    "--no-load-rng",
    "--no-load-optim",
    "--no-gradient-accumulation-fusion",
    "--padded-vocab-size", "151936",
]

# Only inject defaults for args not already on the command line
for arg in QWEN3VL_8B_ARGS:
    if arg.startswith("--") and arg not in sys.argv:
        idx = QWEN3VL_8B_ARGS.index(arg)
        # Check if next element is a value (not a flag)
        if idx + 1 < len(QWEN3VL_8B_ARGS) and not QWEN3VL_8B_ARGS[idx + 1].startswith("--"):
            sys.argv.extend([arg, QWEN3VL_8B_ARGS[idx + 1]])
        else:
            sys.argv.append(arg)


def main():
    # ── 1. Initialize Megatron distributed environment ──────────────────
    from megatron.core.enums import ModelType
    from megatron.training import get_args, get_model, initialize_megatron

    def extra_args(parser):
        group = parser.add_argument_group("Qwen3-VL example")
        group.add_argument(
            "--hf-model-name",
            type=str,
            default="Qwen/Qwen3-VL-8B-Instruct",
            help="HuggingFace model name or local path",
        )
        return parser

    initialize_megatron(
        extra_args_provider=extra_args,
        args_defaults={"no_load_rng": True, "no_load_optim": True},
    )
    args = get_args()

    # ── 2. Build Qwen3VLModel (visual encoder auto-loaded from HF) ─────
    #
    # Qwen3VLModel.__init__ does:
    #   self.visual = Qwen3VLVisionEncoder(hf_model_name=...)
    #       → calls AutoModelForVision2Seq.from_pretrained(hf_model_name)
    #       → extracts hf_model.visual, deletes the rest
    #       → visual encoder weights are loaded ✓
    #
    #   self.language_model = GPTModel(config, layer_spec, ...)
    #       → creates empty MCore GPTModel with random weights
    #       → language model weights NOT loaded yet ✗
    from copy import deepcopy

    from megatron.core import parallel_state
    from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
    from megatron.core.models.multimodal.qwen3_vl_model import Qwen3VLModel
    from megatron.training.arguments import core_transformer_config_from_args
    from megatron.training.utils import print_rank_0, unwrap_model

    language_config = core_transformer_config_from_args(args)
    # Local spec uses TorchNorm which doesn't support persist_layer_norm
    language_config.persist_layer_norm = False
    # Use local spec (not TE spec) for HF weight import compatibility.
    # TE spec fuses layernorm+linear into TELayerNormColumnParallelLinear
    # which expects different state_dict keys than plain HF weights.
    language_layer_spec = get_gpt_layer_local_spec(
        num_experts=args.num_experts,
        moe_grouped_gemm=args.moe_grouped_gemm,
        qk_layernorm=args.qk_layernorm,
    )

    vision_config = deepcopy(language_config)
    vision_config.context_parallel_size = 1

    def model_provider(
        pre_process=True, post_process=True, parallel_output=True,
        config=None, pg_collection=None, vp_stage=None,
    ):
        model = Qwen3VLModel(
            language_transformer_config=language_config,
            language_transformer_layer_spec=language_layer_spec,
            language_vocab_size=args.padded_vocab_size,
            language_max_sequence_length=args.max_position_embeddings,
            vision_transformer_config=vision_config,
            hf_model_name=args.hf_model_name,
            parallel_output=parallel_output,
            language_position_embedding_type=args.position_embedding_type,
            language_rotary_percent=args.rotary_percent,
            language_rotary_base=args.rotary_base,
            pre_process=pre_process,
            post_process=post_process,
            add_encoder=parallel_state.is_pipeline_first_stage(),
            add_decoder=True,
            img_h=args.img_h,
            img_w=args.img_w,
        )
        return model

    model = get_model(model_provider, ModelType.encoder_or_decoder, wrap_with_ddp=False)
    unwrapped = unwrap_model(model)[0]

    print_rank_0(f"Visual encoder : {type(unwrapped.visual)}")
    print_rank_0(f"Language model : {type(unwrapped.language_model)}")

    # At this point:
    #   unwrapped.visual         → loaded from HF ✓
    #   unwrapped.language_model → random weights ✗

    # ── 3. Import language model weights via mcore mapping ──────────────
    #
    # import_mcore_gpt_from_hf internally:
    #   1. Reads config.json → finds "Qwen3VLForConditionalGeneration"
    #   2. Looks up qwen3vl_causal_lm_import from mcore_qwen3vl.py
    #   3. Maps HF weight names to mcore names:
    #      "model.language_model.layers.{L}.self_attn.q/k/v_proj" → "linear_qkv" (QKVMerging)
    #      "model.language_model.layers.{L}.mlp.gate/up_proj"     → "linear_fc1" (GatedMLPMerging)
    #      "model.language_model.embed_tokens"                     → "word_embeddings"
    #      "model.language_model.norm"                              → "final_layernorm"
    #      "lm_head"                                                → "output_layer"
    from modelopt.torch.export import import_mcore_gpt_from_hf

    workspace_dir = os.environ.get("WORKSPACE_DIR", "/tmp/mcore_workspace")
    import_dtype = torch.bfloat16 if args.bf16 else torch.float16

    print_rank_0(f"Importing language model weights from {args.hf_model_name} ...")
    import_mcore_gpt_from_hf(
        model=unwrapped.language_model,  # only the GPTModel, NOT the full VLM
        pretrained_model_path=args.hf_model_name,
        workspace_dir=workspace_dir,
        dtype=import_dtype,
    )
    print_rank_0("Language model weights imported ✓")

    # Now both components are loaded:
    #   unwrapped.visual         → from HF directly ✓
    #   unwrapped.language_model → from HF via mcore mapping ✓

    # ── 4. Verify: print parameter stats ────────────────────────────────
    if unwrapped.visual is not None:
        visual_params = sum(p.numel() for p in unwrapped.visual.parameters())
        print_rank_0(f"Visual encoder params : {visual_params:,}")
    lm_params = sum(p.numel() for p in unwrapped.language_model.parameters())
    print_rank_0(f"Language model params : {lm_params:,}")
    total = sum(p.numel() for p in unwrapped.parameters())
    print_rank_0(f"Total params          : {total:,}")

    # ── 5. Export back to HF format ────────────────────────────────────
    #
    # Two-step process:
    #   Step A: Export language model weights via mcore mapping (reverse of step 3)
    #   Step B: Copy visual encoder weights from original HF checkpoint
    import shutil
    from glob import glob

    from modelopt.torch.export import export_mcore_gpt_to_hf
    from safetensors import safe_open
    from safetensors.torch import save_file

    export_dir = os.environ.get("EXPORT_DIR", "/tmp/qwen3vl-exported")
    os.makedirs(export_dir, exist_ok=True)

    # Step A: Export language model (mcore → HF weight names)
    # Internally uses qwen3vl_causal_lm_export from mcore_qwen3vl.py:
    #   "linear_qkv" → q_proj/k_proj/v_proj (QKVSlicing)
    #   "linear_fc1" → gate_proj/up_proj     (GatedMLPSlicing)
    #   etc.
    print_rank_0(f"Exporting language model to {export_dir} ...")
    export_mcore_gpt_to_hf(
        model=unwrapped.language_model,  # only the GPTModel
        pretrained_model_name_or_path=args.hf_model_name,
        export_dir=export_dir,
        dtype=torch.bfloat16,
    )
    print_rank_0("Language model exported ✓")

    # Step B: Copy visual encoder from original HF checkpoint
    # (Only rank 0 does file I/O)
    if torch.distributed.get_rank() == 0:
        # Resolve HF model to local cache path
        hf_local_path = args.hf_model_name
        if not os.path.isdir(hf_local_path):
            from huggingface_hub import snapshot_download
            hf_local_path = snapshot_download(
                repo_id=args.hf_model_name, local_files_only=True,
            )

        # Extract visual weights from original HF safetensors
        visual_state_dict = {}
        for sf_file in glob(os.path.join(hf_local_path, "*.safetensors")):
            with safe_open(sf_file, framework="pt", device="cpu") as f:
                for key in f.keys():
                    if key.startswith("visual") or key.startswith("model.visual"):
                        visual_state_dict[key] = f.get_tensor(key)

        if visual_state_dict:
            print_rank_0(f"Found {len(visual_state_dict)} visual tensors")

            # Load exported language model weights
            all_weights = {}
            for sf_file in glob(os.path.join(export_dir, "model*.safetensors")):
                with safe_open(sf_file, framework="pt", device="cpu") as f:
                    for key in f.keys():
                        all_weights[key] = f.get_tensor(key)

            # Merge visual + language model weights
            all_weights.update(visual_state_dict)
            merged_file = os.path.join(export_dir, "model.safetensors")
            save_file(all_weights, merged_file)
            print_rank_0(f"Merged checkpoint saved to {merged_file}")

            # Clean up old shard files
            for sf_file in glob(os.path.join(export_dir, "model*.safetensors")):
                if sf_file != merged_file:
                    os.remove(sf_file)
            for json_file in glob(os.path.join(export_dir, "model-*.json")):
                os.remove(json_file)

        # Copy VLM-specific config files
        for fname in ["preprocessor_config.json", "processor_config.json",
                       "chat_template.json"]:
            src = os.path.join(hf_local_path, fname)
            dst = os.path.join(export_dir, fname)
            if os.path.exists(src) and not os.path.exists(dst):
                shutil.copy(src, dst)
                print_rank_0(f"Copied {fname}")

        # Print export size
        total_size = sum(
            os.path.getsize(os.path.join(dp, f))
            for dp, _, fns in os.walk(export_dir) for f in fns
        )
        if total_size >= 1024**3:
            size_str = f"{total_size / 1024**3:.2f} GB"
        else:
            size_str = f"{total_size / 1024**2:.2f} MB"
        print_rank_0(f"Export size: {size_str} ({export_dir})")

    print_rank_0("Done! Full VLM exported to HF format.")


if __name__ == "__main__":
    main()

Testing

  • Verified round-trip import/export with Qwen3-VL-8B-Instruct with the example usage above
  • Unit tests in tests/unit/torch/export/test_mcore_qwen3vl.py
    covering:
    • Registration in global export/import mappings
    • Import mapping: dense keys, model.language_model.
      prefix, lm_head. at root, QKVMerging, GatedMLPMerging, REPLICATE
      for layernorms, TP sharding configs
    • Export mapping: QKVSlicing, GatedMLPSlicing, no
      parallel_config
    • Import/export symmetry: same mcore keys, matching HF
      prefixes
    • Qwen3-VL vs Qwen3 difference: same keys, VL adds
      language_model. prefix, lm_head unchanged

Before your PR is "Ready for review"

  • Is this change backward compatible?: Yes, additive only
  • Did you write any new necessary tests?: Yes, tests/unit/torch/export/test_mcore_qwen3vl.py
  • Did you add or update any necessary documentation? Yes, see docs/source/deployment/3_unified_hf.rst
  • Did you update Changelog? Yes, see CHANGELOG.rst

Additional Information

Companion Megatron-LM PR adds Qwen3VLModel, Qwen3VLDataset, and pretrain_qwenvl.py. Please see this PR NVIDIA/Megatron-LM#3444

Summary by CodeRabbit

  • New Features

    • Added support for Qwen3-VL vision-language models with TensorRT-LLM deployment framework.
    • Supports FP8 and NVFP4 quantization formats for optimized inference.
    • Compatible with both dense and Mixture of Experts (MoE) model variants.
  • Tests

    • Added comprehensive test coverage for Qwen3-VL integration and export/import workflows.

@hychiang-git hychiang-git requested a review from a team as a code owner February 16, 2026 19:52
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 16, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 16, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds support for Qwen3-VL vision-language models to the Megatron Core export/import framework by introducing model mappings to convert between Hugging Face and Megatron Core structures for quantization workflows, alongside documentation updates and comprehensive unit tests.

Changes

Cohort / File(s) Summary
Documentation & Changelog
CHANGELOG.rst, docs/source/deployment/3_unified_hf.rst
Added Qwen 3-VL model entries to changelog and TensorRT-LLM framework support list.
Plugin Registration
modelopt/torch/export/plugins/mcore_common.py
Imported new Qwen3VL export/import functions and registered Qwen3VLForConditionalGeneration in both public mapping dictionaries.
Qwen3-VL Mapping Implementation
modelopt/torch/export/plugins/mcore_qwen3vl.py
New module defining component-level mappings between Hugging Face and Megatron Core representations, supporting dense and MoE variants with QKV/MLP merging and slicing strategies.
Unit Tests
tests/unit/torch/export/test_mcore_qwen3vl.py
Comprehensive test suite validating Qwen3VLForConditionalGeneration registration, mapping presence/correctness, prefix handling, component merging/slicing, and import/export symmetry.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 3

❌ Failed checks (2 warnings, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (22 files):

⚔️ .github/workflows/example_tests.yml (content)
⚔️ .github/workflows/gpu_tests.yml (content)
⚔️ CHANGELOG.rst (content)
⚔️ docs/source/deployment/3_unified_hf.rst (content)
⚔️ examples/llm_ptq/example_utils.py (content)
⚔️ examples/llm_ptq/hf_ptq.py (content)
⚔️ examples/llm_ptq/vlm_utils.py (content)
⚔️ examples/megatron_bridge/distill.py (content)
⚔️ examples/megatron_bridge/prune_minitron.py (content)
⚔️ modelopt/torch/export/model_utils.py (content)
⚔️ modelopt/torch/export/plugins/mcore_common.py (content)
⚔️ modelopt/torch/export/plugins/mcore_custom.py (content)
⚔️ modelopt/torch/export/plugins/mcore_nemotron.py (content)
⚔️ modelopt/torch/export/plugins/megatron_importer.py (content)
⚔️ modelopt/torch/export/plugins/vllm_fakequant_megatron.py (content)
⚔️ modelopt/torch/export/unified_export_hf.py (content)
⚔️ modelopt/torch/export/unified_export_megatron.py (content)
⚔️ modelopt/torch/opt/searcher.py (content)
⚔️ modelopt/torch/utils/plugins/mbridge.py (content)
⚔️ pyproject.toml (content)
⚔️ setup.py (content)
⚔️ tox.ini (content)

These conflicts must be resolved before merging into main.
Resolve conflicts locally and push changes to this branch.
Title check ❓ Inconclusive The pull request title 'Add Qwen3VL' is overly vague and lacks specificity about what aspect of Qwen3VL support is being added, making it difficult for developers scanning history to understand the scope. Consider a more descriptive title such as 'Add Qwen3VL Megatron Core export/import plugin' or 'Add Qwen3-VL vision-language model support' to clarify the specific change being made.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

willg-nv and others added 13 commits February 16, 2026 21:52
## What does this PR do?

This PR implements RegionSearch class. RegionSearch could help partition
big ONNX model into small region. QDQ autouning will be performed on the
regions.

**Overview:** ?

## Usage
<!-- You can potentially add a usage example below. -->

```python
# Add a code snippet demonstrating how to use this
```

## Testing
<!-- Mention how have you tested your change if applicable. -->

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes
- **Did you write any new necessary tests?**: Yes
- **Did you add or update any necessary documentation?**: No, document
updates is in Part 4.
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
CHANGELOG will be updated when all changes are ready.

## Additional Information
<!-- E.g. related issue. -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

**Refactor**
* Improved ONNX quantization backend with new optimization framework and
extensive test coverage to enhance internal graph processing
capabilities.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Will Guo <willg@nvidia.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
## What does this PR do?

This PR implement RegionInspect tool. This tool could be used to
visualize the regions parititioned by RegionSearch classes. This tool
could be used to analyze if the partitioned regions match the fusion
patterns.

**Overview:** ?

## Usage
<!-- You can potentially add a usage example below. -->

```python
# Add a code snippet demonstrating how to use this
```

## Testing
<!-- Mention how have you tested your change if applicable. -->

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes
- **Did you write any new necessary tests?**: Yes
- **Did you add or update any necessary documentation?**: No, document
update is in Part 4.
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
No, CHANGELOG will be updated when all changes are ready.

## Additional Information
<!-- E.g. related issue. -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## New Features
* Added a region inspection tool for ONNX models. Analyzes model
structure and generates detailed reports including region statistics,
hierarchical relationships, node coverage metrics, and size distribution
analysis. Available through a command-line interface with configurable
parameters.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Will Guo <willg@nvidia.com>
Co-authored-by: Ajinkya Rasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
## What does this PR do?

**Type of change:** ? <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. --> Bug fix

**Overview:** ?

1. Fixing megatron ignore module has additional `.` in the suffix
2. Change megatron export to safe per layer as a safetensor (avoid ghost
safetensors)

## Usage
<!-- You can potentially add a usage example below. -->

```python
# Add a code snippet demonstrating how to use this
```

## Testing
<!-- Mention how have you tested your change if applicable. -->

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Export workflow now supports additional model components (EAGLE/Medusa
modules)
* Per-layer model state organization for improved checkpoint management

* **Bug Fixes**
* More robust Hugging Face configuration, tokenizer, and image processor
preservation
  * Enhanced multimodal component extraction and loading

* **Refactor**
* Optimized model export process with improved per-layer safetensors
handling

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
## What does this PR do?

**Type of change:** New model support <!-- Use one of the following: Bug
fix, new feature, new example, new tests, documentation. -->

**Overview:** Add PTQ support for
https://huggingface.co/nvidia/NVIDIA-Nemotron-Parse-v1.1

## Usage
<!-- You can potentially add a usage example below. -->

```python
python3 hf_ptq.py --pyt_ckpt_path /home/omniml_data_3/models/NVIDIA-Nemotron-Parse-v1.1 --qformat fp8 --export_path /home/omniml_data_3/zhiyuc/checkpoints/NVIDIA-Nemotron-Parse-v1.1-FP8 --trust_remote_code --kv_cache_qformat none --attn_implementation eager
```
By default, image-text data will be used in calibration for VLMs.

## Testing
<!-- Mention how have you tested your change if applicable. -->

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes <!--- If No, explain why.
-->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Not yet <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added support for Nemotron-Parse multimodal models, including proper
device mapping, processor loading, and generation handling.

* **Improvements**
* Enhanced quantization robustness with safer handling of quantization
attributes and fallback logic.
* Improved model loading with better device placement and encoder buffer
management for vision-language models.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
## What does this PR do?

**Type of change:** Bug fix <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->

## Testing
<!-- Mention how have you tested your change if applicable. -->

Nemotron Nano v2 pruned can be saved

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed Hugging Face model loading to properly respect the
`trust_remote_code` parameter during model instantiation.

* **Improvements**
* Enhanced distributed training logging with rank-0 aware warning and
logging mechanisms for cleaner, non-redundant output in multi-GPU and
multi-node scenarios.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
## What does this PR do?

[Short term]: Megatron based tests take a long time often resulting in
CICD timeout. Splitting megatron tests into a dedicated CICD job for
faster overall CI/CD run
[Mid/Long term]: Run all megatron gpu tests using `torchrun` instead of
`pytest` so all dist processes are already created and all individual
tests no longer need to setup and destroy their processes which adds a
lot of overhead per test

## Testing
<!-- Mention how have you tested your change if applicable. -->

- [x] 1-GPU CI/CD passing (on this PR)
- [x] 2-GPU CI/CD passing (on nightly run - manually triggered):
https://github.com/NVIDIA/Model-Optimizer/actions/runs/22000517688

---------

Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
NVIDIA#884)

## What does this PR do?

**Type of change:** new feature

**Overview:** Enable full TE spec support for NemotronH (Mamba hybrid)
models during HF-to-Megatron weight import via
`import_mcore_gpt_from_hf`.

Previously, importing HF weights into a Megatron model built with the
full TE spec (`TELayerNormColumnParallelLinear`, `TEGroupedMLP`, etc.)
failed for NemotronH models due to two issues:

1. **Grouped expert prefix bug**: The `experts.linear_fc1/fc2` import
rules had a hard-coded `mtp.layers.{}` prefix, which was only correct
for MTP layers. When regular decoder MoE layers use `TEGroupedMLP` (via
the full TE spec), the importer generated incorrect HF keys (e.g.,
`mtp.layers.27.mixer.experts.0.up_proj.weight` instead of
`backbone.layers.27.mixer.experts.0.up_proj.weight`).

2. **Fused layer norm loading**: In the full TE spec, layer norms are
fused into `TELayerNormColumnParallelLinear` modules as
`layer_norm_weight`. The importer's `_name_remapping` would crash trying
to load `layer_norm_weight` from a non-existent HF path (e.g.,
`backbone.layers.X.mixer.in_proj.layer_norm_weight`), when the actual HF
norm weight lives at `backbone.layers.X.norm.weight`.

### Changes

**`mcore_nemotron.py`**:
- Fixed grouped expert prefix from `mtp.layers.{}` to
`backbone.layers.{}`. The `_grouped_mlp_merging` function already
handles `backbone` → `mtp` replacement when `is_mtp=True`, so both
decoder and MTP layers work correctly.
- Added `mapping={"layer_norm_weight": None}` to `in_proj` and
`linear_fc1` rules to skip `layer_norm_weight` during `_name_remapping`
(loaded separately via `fused_norm`).
- Added `fused_norm` rule
(`NameRemapping("backbone.layers.{}.norm.weight")`) to load HF norm
weights into fused TE modules.

**`megatron_importer.py`**:
- Added `source_key is None` check in `_name_remapping` to skip keys
mapped to `None` in the mapping dict (keeps existing value instead of
crashing on missing HF key).
- Added fused norm loading in `_import_mamba_layer`: after loading
`in_proj`, loads `layer_norm_weight` from HF via `fused_norm` rule when
`layer.norm` is `IdentityOp`.
- Added fused norm loading in `_import_transformer_layer`: loads
`layer_norm_weight` into `linear_qkv` (when `input_layernorm` is
`IdentityOp`) and into `linear_fc1` (when `pre_mlp_layernorm` is
`IdentityOp`).

## Usage

The full TE spec is enabled via the `--full-te-spec` flag on the
Megatron-LM side (separate PR). On the ModelOpt side, no user-facing
changes are needed -- the import rules automatically handle both local
spec and full TE spec models.

```bash
# Convert HF checkpoint to Megatron with full TE spec (megatron-lm side)
unset MLM_MODEL_CKPT && export MLM_MODEL_SAVE=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16_mlm && export HF_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
export PP=2
export MLM_EXTRA_ARGS="--full-te-spec"
bash convert.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

# Quantize the converted checkpoint (megatron-lm side)
export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16_mlm
export MLM_MODEL_SAVE=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm
export HF_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
export PP=2 && export TP=4 && export EP=4 && export ETP=1
export MLM_EXTRA_ARGS="--full-te-spec"
bash quantize.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 FP8_DEFAULT_CFG

# Generate
export PP=2 && export TP=4 && export EP=4 && export ETP=1
export MLM_EXTRA_ARGS="--full-te-spec"
export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm && ./generate.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

# MMLU
export PP=2 && export TP=4 && export EP=4 && export ETP=1
export MLM_EXTRA_ARGS="--full-te-spec"
export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm && export MLM_EXTRA_ARGS="--fraction 0.05 --disable-tqdm" && ./mmlu.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
```

## Testing

- Tested end-to-end: HF → Megatron conversion → FP8 quantization →
inference (generate) → MMLU evaluation with
Nemotron-3-Nano-30B-A3B-BF16.
- Verified the resulting model structure matches Megatron-Bridge's TE
spec output (TELayerNormColumnParallelLinear, TEGroupedMLP, IdentityOp
norms, etc.).
- Verified quantized model produces coherent text generation outputs.
- Verified backward compatibility: all changes are no-ops for existing
local-spec pipelines (guarded by `IdentityOp` checks, `hasattr` checks,
and `"fused_norm" in self.rules` checks).

## Before your PR is "*Ready for review*"

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes -- all changes are
guarded by conditions that only activate for full TE spec models. Local
spec models follow the exact same code paths as before.
- **Did you write any new necessary tests?**: No
- **Did you add or update any necessary documentation?**: No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
No

## Additional Information

Companion megatron-lm changes (separate PR):
- `megatron/core/post_training/modelopt/mamba/model_specs.py`: Added
`use_full_te_spec` parameter to return canonical `mamba_stack_spec` from
`mamba_layer_specs.py`.
- `megatron/post_training/model_builder.py`: Passes
`use_full_te_spec=args.full_te_spec` to `get_mamba_stack_modelopt_spec`.
- `megatron/post_training/arguments.py`: Added `--full-te-spec` CLI
flag.
- `examples/post_training/modelopt/convert_model.py`: Skip
`moe_grouped_gemm=False` override when `--full-te-spec` is set.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added support for loading fused normalization weights during model
import.

* **Bug Fixes**
* Improved weight mapping logic to correctly skip redundant layer norm
weights in specialized model architectures.

* **Refactor**
* Reorganized expert model parallel configuration paths for better
compatibility with mixed parallel processing settings.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: James Shen <yueshen@nvidia.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
## What does this PR do?

**Type of change:** new example <!-- Use one of the following: Bug fix,
new feature, new example, new tests, documentation. -->

**Overview:**
Adding LTX-2 distillation trainer.

## Usage
<!-- You can potentially add a usage example below. -->

```bash
accelerate launch \
    --config_file configs/accelerate/fsdp.yaml \
    --num_processes 8 \
    distillation_trainer.py --config configs/distillation_example.yaml
```

See readme for more details.

## Testing
Run training with single/multiple nodes.

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes <!--- If No, explain why.
-->
- **Did you write any new necessary tests?**: NA
- **Did you add or update any necessary documentation?**: Yes
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes <!--- Only for new features, API changes, critical bug fixes or bw
breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## New Features
* Added distillation training support for LTX-2 models with quantization
integration.
* Introduced comprehensive documentation and example configurations for
distillation workflows.
* Includes multi-GPU and multi-node training setup with distributed
training support and customizable configuration templates.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Meng Xin <mxin@nvidia.com>
Signed-off-by: Hung-Yueh <hungyueh.chiang@gmail.com>
Signed-off-by: hychiang <kenny5312012@gmail.com>
Signed-off-by: hychiang <kenny5312012@gmail.com>
Copy link

@modelopt-bot modelopt-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

Overall this is a well-structured PR adding Qwen3-VL support. The mapping logic correctly handles the different weight structure (model.language_model. prefix) compared to Qwen3. The tests are comprehensive.

However, I have a few questions and suggestions that warrant discussion before approval.

Copy link

@modelopt-bot modelopt-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

Overall this is a well-structured PR adding Qwen3-VL support. The mapping logic correctly handles the different weight structure (model.language_model. prefix) compared to Qwen3. Tests are comprehensive.

Requested Changes:

  1. Merge conflicts - 22 files have conflicts that need resolution (as flagged by CodeRabbit)
  2. Missing newline - mcore_qwen3vl.py is missing a trailing newline

Questions/Clarifications:
3. Copyright year mismatch (2023-2025 vs 2024 in test file)
4. Router mapping includes MoE but no shared_expert mappings - is this intentional?
5. Consider adding docstrings to the mapping dictionaries

Please address the merge conflicts first, then the minor formatting issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants