Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions modelopt/torch/export/plugins/hf_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download, list_repo_files
from safetensors.torch import safe_open
from tqdm import tqdm

Expand All @@ -35,24 +36,27 @@ def copy_remote_code(
we need to copy them to the export directory for seamless integration with inference
frameworks.

If ``pretrained_model_path`` is a local directory, Python files are copied directly.
If it is a HuggingFace Hub model ID, Python files are downloaded from the Hub first.

Args:
pretrained_model_path: Path to the pretrained model.
pretrained_model_path: Local path to the pretrained model or HuggingFace Hub model ID.
save_directory: Path to the save directory.

Raises:
ValueError: If the pretrained model path is not a directory.
"""
hf_checkpoint_path = Path(pretrained_model_path)
save_dir = Path(save_directory)

if not hf_checkpoint_path.is_dir():
raise ValueError(
f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory."
)

for py_file in hf_checkpoint_path.glob("*.py"):
if py_file.is_file():
shutil.copy(py_file, save_dir / py_file.name)
if hf_checkpoint_path.is_dir():
for py_file in hf_checkpoint_path.glob("*.py"):
if py_file.is_file():
shutil.copy(py_file, save_dir / py_file.name)
else:
# Hub model ID: download any top-level .py files (custom modeling code)
repo_id = str(pretrained_model_path)
for filename in list_repo_files(repo_id):
if "/" not in filename and filename.endswith(".py"):
local_path = hf_hub_download(repo_id=repo_id, filename=filename)
shutil.copy(local_path, save_dir / filename)


def load_multimodal_components(
Expand Down
81 changes: 58 additions & 23 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import torch
import torch.distributed
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
from safetensors import safe_open
from safetensors.torch import save_file

from modelopt import __version__
Expand Down Expand Up @@ -534,29 +536,57 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
# TODO Implement MTP export for quantized MTP
# Hacky version for now: copy MTP weights from pretrained model
mtp_state_dict = {}
if self._hf_pretrained_model_name:
if os.path.isdir(self._hf_pretrained_model_name):
safetensors_index_file = (
Path(self._hf_pretrained_model_name) / "model.safetensors.index.json"
)
else:
safetensors_index_file = hf_hub_download(
repo_id=self._hf_pretrained_model_name, filename="model.safetensors.index.json"
if not self._hf_pretrained_model_name:
return mtp_state_dict

mtp_exists = False

if os.path.isdir(self._hf_pretrained_model_name):
safetensors_index_file = (
Path(self._hf_pretrained_model_name) / "model.safetensors.index.json"
)
single_safetensors_file = Path(self._hf_pretrained_model_name) / "model.safetensors"
else:
try:
safetensors_index_file = Path(
hf_hub_download(
repo_id=self._hf_pretrained_model_name,
filename="model.safetensors.index.json",
)
)
single_safetensors_file = None
except EntryNotFoundError:
# Model uses a single unsharded safetensors file — check it for MTP weights.
safetensors_index_file = None
try:
single_safetensors_file = Path(
hf_hub_download(
repo_id=self._hf_pretrained_model_name,
filename="model.safetensors",
)
)
except EntryNotFoundError:
return mtp_state_dict

if safetensors_index_file is not None and safetensors_index_file.exists():
print(f"Exporting MTP: using safetensors_index_file: {safetensors_index_file}")
mtp_exists = False
if safetensors_index_file and os.path.exists(safetensors_index_file):
with open(safetensors_index_file) as f:
safetensors_index = json.load(f)
model_dir = Path(safetensors_index_file).parent
for key in safetensors_index["weight_map"]:
with open(safetensors_index_file) as f:
safetensors_index = json.load(f)
model_dir = safetensors_index_file.parent
for key in safetensors_index["weight_map"]:
if key.startswith("mtp.") and key not in self._state_dict:
mtp_state_dict[key] = get_safetensor(model_dir, key)
mtp_exists = True
elif single_safetensors_file is not None and single_safetensors_file.exists():
print(f"Exporting MTP: using single safetensors file: {single_safetensors_file}")
with safe_open(str(single_safetensors_file), framework="pt", device="cpu") as f:
for key in f.keys(): # noqa: SIM118
if key.startswith("mtp.") and key not in self._state_dict:
mtp_state_dict[key] = get_safetensor(model_dir, key)
mtp_state_dict[key] = f.get_tensor(key)
mtp_exists = True

if mtp_exists:
self.exclude_modules.append("mtp*")
if mtp_exists:
self.exclude_modules.append("mtp*")
return mtp_state_dict

def _get_mamba_layer_state_dict(self, layer, layer_id):
Expand Down Expand Up @@ -987,17 +1017,22 @@ def _qkv_slicing(
)
hidden_size = 2 * hidden_size

weight = weight.reshape([qkv_total_dim, head_size, hidden_size])
# When TP > 1 the weight tensor is already sharded: shape[0] = per_rank_qkv_dim, not
# qkv_total_dim. Derive the per-rank dimensions from the actual tensor shape so that
# all subsequent reshape/slice operations are correct regardless of TP degree.
per_rank_qkv_dim = weight.shape[0] // head_size
num_query_groups_local = num_query_groups * per_rank_qkv_dim // qkv_total_dim
weight = weight.reshape([per_rank_qkv_dim, head_size, hidden_size])
weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)

q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
for i in range(num_query_groups_local)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
k_slice = torch.arange(heads_per_group, per_rank_qkv_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, per_rank_qkv_dim, (heads_per_group + 2))
## Example of slices
## 7b: num_query_groups = head_num = 32,
## q_slice = [0, 3, 6, 9 , ... 90, 93]
Expand All @@ -1022,7 +1057,7 @@ def _qkv_slicing(
weight_scale_dtype = weight_scale.dtype
weight_scale_hidden_size = weight_scale.shape[-1]
weight_scale = weight_scale.to(dtype=float).reshape(
[qkv_total_dim, head_size, weight_scale_hidden_size]
[per_rank_qkv_dim, head_size, weight_scale_hidden_size]
)
proj_weight_scales = [
weight_scale[s]
Expand Down Expand Up @@ -1063,7 +1098,7 @@ def _qkv_slicing(
if key == "bias":
# Slice bias similar to weight
bias = val.detach().clone()
bias = bias.reshape([qkv_total_dim, head_size])
bias = bias.reshape([per_rank_qkv_dim, head_size])
proj_biases = [bias[s].reshape(-1) for s in slices]
proj_bias_keys = [q_proj_prefix + key, k_proj_prefix + key, v_proj_prefix + key]
for bias_tensor, bias_key in zip(proj_biases, proj_bias_keys):
Expand Down
109 changes: 109 additions & 0 deletions tests/gpu_megatron/torch/export/test_hf_checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for modelopt/torch/export/plugins/hf_checkpoint_utils.py"""

from unittest.mock import patch

from modelopt.torch.export.plugins.hf_checkpoint_utils import copy_remote_code


def test_copy_remote_code_local_dir(tmp_path):
"""copy_remote_code copies top-level .py files from a local directory."""
src_dir = tmp_path / "src"
src_dir.mkdir()
(src_dir / "modeling_custom.py").write_text("# custom model")
(src_dir / "configuration_custom.py").write_text("# custom config")
(src_dir / "not_python.txt").write_text("not python")
(src_dir / "subdir").mkdir()
(src_dir / "subdir" / "nested.py").write_text("# nested — should not be copied")

dst_dir = tmp_path / "dst"
dst_dir.mkdir()

copy_remote_code(src_dir, dst_dir)

assert (dst_dir / "modeling_custom.py").read_text() == "# custom model"
assert (dst_dir / "configuration_custom.py").read_text() == "# custom config"
assert not (dst_dir / "not_python.txt").exists(), "non-.py files should not be copied"
assert not (dst_dir / "nested.py").exists(), "nested .py files should not be copied"


def test_copy_remote_code_local_dir_no_py_files(tmp_path):
"""copy_remote_code is a no-op when the local directory has no .py files."""
src_dir = tmp_path / "src"
src_dir.mkdir()
(src_dir / "config.json").write_text("{}")

dst_dir = tmp_path / "dst"
dst_dir.mkdir()

copy_remote_code(src_dir, dst_dir) # should not raise

assert list(dst_dir.iterdir()) == [], "no files should be copied"


def test_copy_remote_code_hub_id(tmp_path):
"""copy_remote_code downloads and copies top-level .py files from a Hub model ID."""
dst_dir = tmp_path / "dst"
dst_dir.mkdir()

# Create a fake cached file that hf_hub_download would return
cached_py = tmp_path / "cached_modeling_custom.py"
cached_py.write_text("# custom hub model")

repo_files = [
"modeling_custom.py", # top-level .py — should be downloaded
"config.json", # non-.py — skip
"model.safetensors", # non-.py — skip
"subdir/nested.py", # subdirectory .py — skip (contains "/")
]

with (
patch(
"modelopt.torch.export.plugins.hf_checkpoint_utils.list_repo_files",
return_value=repo_files,
) as mock_list,
patch(
"modelopt.torch.export.plugins.hf_checkpoint_utils.hf_hub_download",
return_value=str(cached_py),
) as mock_download,
):
copy_remote_code("meta-llama/Llama-3.2-1B", dst_dir)

mock_list.assert_called_once_with("meta-llama/Llama-3.2-1B")
# Only the top-level .py should have been downloaded
mock_download.assert_called_once_with(
repo_id="meta-llama/Llama-3.2-1B", filename="modeling_custom.py"
)
assert (dst_dir / "modeling_custom.py").read_text() == "# custom hub model"


def test_copy_remote_code_hub_id_no_py_files(tmp_path):
"""copy_remote_code is a no-op when the Hub repo has no top-level .py files."""
dst_dir = tmp_path / "dst"
dst_dir.mkdir()

with (
patch(
"modelopt.torch.export.plugins.hf_checkpoint_utils.list_repo_files",
return_value=["config.json", "model.safetensors"],
),
patch("modelopt.torch.export.plugins.hf_checkpoint_utils.hf_hub_download") as mock_download,
):
copy_remote_code("meta-llama/Llama-3.2-1B", dst_dir)

mock_download.assert_not_called()
assert list(dst_dir.iterdir()) == []
Loading
Loading