Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds NVFP4 export support and post-processing: safetensors are saved plainly, then optionally merged with base checkpoints, NVFP4 weights padded, NVFP4 scales swizzled, and per-layer quant metadata injected; new flags propagate through export call chain and quantizer restore context always runs on exit. Changes
Sequence DiagramsequenceDiagram
participant User as User/Export Script
participant Export as export_hf_checkpoint
participant DiffusersExport as _export_diffusers_checkpoint
participant SaveComp as _save_component_state_dict_safetensors
participant PostProc as _postprocess_safetensors
participant SafeTensors as Files (.safetensors)
participant Utils as diffusers_utils
User->>Export: call export_hf_checkpoint(with NVFP4 flags)
Export->>DiffusersExport: forward flags (enable_layerwise_quant_metadata, enable_swizzle_layout, padding_strategy)
DiffusersExport->>SaveComp: save component(s) (plain safetensors)
SaveComp->>SafeTensors: write component.safetensors
DiffusersExport->>PostProc: invoke _postprocess_safetensors(export_dir, flags)
PostProc->>SafeTensors: load component.safetensors
SafeTensors-->>PostProc: state_dict
alt padding_strategy provided
PostProc->>Utils: pad_nvfp4_weights(state_dict, strategy)
Utils-->>PostProc: padded state_dict
end
alt enable_swizzle_layout true
PostProc->>Utils: swizzle_nvfp4_scales(state_dict)
Utils-->>PostProc: swizzled scales
end
alt enable_layerwise_quant_metadata true
PostProc->>Utils: build_layerwise_quant_metadata(state_dict, config)
Utils-->>PostProc: per-layer quant JSON
end
PostProc->>SafeTensors: re-save component.safetensors with updated __metadata__
DiffusersExport-->>Export: export complete
Export-->>User: checkpoint saved
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
modelopt/torch/export/diffusers_utils.py (2)
873-891: Minor:padded_countis computed but never used.The variable is incremented but not logged or returned, making it dead code.
♻️ Proposed fix: either remove or log it
nvfp4_layers = _find_nvfp4_layers(state_dict) - padded_count = 0 for layer in sorted(nvfp4_layers): w_key = f"{layer}.weight" @@ -887,7 +886,6 @@ def pad_nvfp4_weights( if pad_r > 0 or pad_c_w > 0: state_dict[w_key] = torch.nn.functional.pad(weight, (0, pad_c_w, 0, pad_r)) state_dict[s_key] = torch.nn.functional.pad(scale, (0, pad_c_s, 0, pad_r)) - padded_count += 1 return state_dict🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/diffusers_utils.py` around lines 873 - 891, The variable padded_count is incremented inside the loop over nvfp4_layers but never used; either remove padded_count entirely or surface it (e.g., log or return it) so the work is observable. Update the code around the loop that pads state_dict entries (references: padded_count, nvfp4_layers, w_key/s_key construction, padding_strategy and the torch.nn.functional.pad calls) to either delete padded_count and its increments, or increment it and then emit a debug/info log via the existing logger or return it from the enclosing function so callers can see how many tensors were padded.
895-939: Consider adding a debug-time warning when scale dimensions suggest missing padding.The docstring correctly notes that padding should happen before swizzling. While
_to_blockedinternally pads to block boundaries (128 rows, 4 cols), this differs from the NVFP4 alignment requirement (16). Ifpad_nvfp4_weightsis skipped butswizzle_nvfp4_scalesis called, the internal padding may mask dimension misalignment issues.A debug-level warning (or at least a comment explaining the intentional fallback) could help during development.
Also, there's an extra blank line before
return state_dicton line 939.modelopt/torch/export/unified_export_hf.py (1)
151-218: Implementation looks correct; consider usingsafe_openfor metadata reading.The manual header parsing with
struct.unpack(lines 194-197) is correct per the safetensors specification, but usingsafetensors.torch.safe_openwould be more robust and consistent with how_merge_ltx2reads base metadata indiffusers_utils.py.♻️ Suggested alternative for metadata reading
- with open(sf_path, "rb") as f: - header_size = struct.unpack("<Q", f.read(8))[0] - header = json.loads(f.read(header_size)) - metadata = header.get("__metadata__", None) or {} + from safetensors import safe_open + with safe_open(str(sf_path), framework="pt", device="cpu") as f: + metadata = f.metadata() or {}This would also allow removing the
import structstatement on line 185.The transformation order (merge → pad → swizzle → metadata injection) is correct per the PR requirements.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/unified_export_hf.py` around lines 151 - 218, _postprocess_safetensors currently parses safetensors headers manually using struct.unpack and json.loads; replace that manual header parsing with safetensors.torch.safe_open to read the file metadata (mirroring how _merge_ltx2 in diffusers_utils.py reads base metadata), remove the now-unused import struct, and keep the rest of the flow (merge/pad/swizzle/metadata injection) intact; locate the manual read around the header_size/header/json.loads block inside _postprocess_safetensors and swap it for a safe_open-based metadata fetch so save_file receives consistent metadata.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/diffusers/README.md`:
- Line 136: Fix the minor typo in the README sentence "To additionally apply
NVFP4 scale swizzle and padding , add:" by removing the extra space before the
comma so it reads "To additionally apply NVFP4 scale swizzle and padding, add:";
update the line in examples/diffusers/README.md where that exact sentence
occurs.
---
Nitpick comments:
In `@modelopt/torch/export/diffusers_utils.py`:
- Around line 873-891: The variable padded_count is incremented inside the loop
over nvfp4_layers but never used; either remove padded_count entirely or surface
it (e.g., log or return it) so the work is observable. Update the code around
the loop that pads state_dict entries (references: padded_count, nvfp4_layers,
w_key/s_key construction, padding_strategy and the torch.nn.functional.pad
calls) to either delete padded_count and its increments, or increment it and
then emit a debug/info log via the existing logger or return it from the
enclosing function so callers can see how many tensors were padded.
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 151-218: _postprocess_safetensors currently parses safetensors
headers manually using struct.unpack and json.loads; replace that manual header
parsing with safetensors.torch.safe_open to read the file metadata (mirroring
how _merge_ltx2 in diffusers_utils.py reads base metadata), remove the
now-unused import struct, and keep the rest of the flow
(merge/pad/swizzle/metadata injection) intact; locate the manual read around the
header_size/header/json.loads block inside _postprocess_safetensors and swap it
for a safe_open-based metadata fetch so save_file receives consistent metadata.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 082f8625-c0e8-4792-b981-ce792105371a
📒 Files selected for processing (6)
examples/diffusers/README.mdexamples/diffusers/quantization/pipeline_manager.pyexamples/diffusers/quantization/quantize.pymodelopt/torch/export/diffusers_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/quantization/conversion.py
| --extra-param merged_base_safetensor_path=./ltx-2-19b-dev-fp8.safetensors | ||
| ``` | ||
|
|
||
| To additionally apply NVFP4 scale swizzle and padding , add: |
There was a problem hiding this comment.
Minor typo: extra space before comma.
📝 Proposed fix
-To additionally apply NVFP4 scale swizzle and padding , add:
+To additionally apply NVFP4 scale swizzle and padding, add:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| To additionally apply NVFP4 scale swizzle and padding , add: | |
| To additionally apply NVFP4 scale swizzle and padding, add: |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/README.md` at line 136, Fix the minor typo in the README
sentence "To additionally apply NVFP4 scale swizzle and padding , add:" by
removing the extra space before the comma so it reads "To additionally apply
NVFP4 scale swizzle and padding, add:"; update the line in
examples/diffusers/README.md where that exact sentence occurs.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1195 +/- ##
==========================================
+ Coverage 76.03% 76.11% +0.07%
==========================================
Files 350 350
Lines 40469 40546 +77
==========================================
+ Hits 30772 30860 +88
+ Misses 9697 9686 -11
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
0170e83 to
a2d3d21
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/diffusers/quantization/quantize.py`:
- Around line 323-330: Validate and normalize extra_params before inserting into
kwargs: for boolean flags "enable_swizzle_layout" and
"enable_layerwise_quant_metadata" (from model_config.extra_params) only accept
explicit true tokens ("true","1","yes") and explicit false tokens
("false","0","no")—if the token is present but not one of these, raise a
ValueError (or argparse-style error) instead of silently converting to False;
for "padding_strategy" validate the value against the allowed set used by the
exporter (e.g., "same","valid", etc.) and raise an error on unknown values;
update the code that reads model_config.extra_params and assigns to kwargs to
perform these checks and normalized conversions before returning/using kwargs.
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 187-223: The loop currently post-processes each file in
safetensor_files independently which corrupts sharded checkpoints; detect
sharded outputs by checking for an accompanying "*.safetensors.index.json" (e.g.
inspect export_dir for files matching that pattern) before per-file processing
and either raise/return an explicit error (fail-fast) or implement shard-aware
post-processing: load the index, apply transformations to the full model tensor
set (or reconstruct merged state dict), regenerate shards and a fresh index
before calling save_file; update logic around load_file, the sd cloning step,
merge_diffusion_checkpoint, pad_nvfp4_weights, swizzle_nvfp4_scales,
build_layerwise_quant_metadata, and save_file to operate on the unified state
dict (or bail out) when an index is present and when
merged_base_safetensor_path/model_type/hf_quant_config could change global
layout.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d5db7523-559e-4d32-92db-e4d0f2b5da71
📒 Files selected for processing (6)
examples/diffusers/README.mdexamples/diffusers/quantization/pipeline_manager.pyexamples/diffusers/quantization/quantize.pymodelopt/torch/export/diffusers_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/quantization/conversion.py
✅ Files skipped from review due to trivial changes (2)
- examples/diffusers/README.md
- examples/diffusers/quantization/pipeline_manager.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/conversion.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unit/torch/export/test_nvfp4_utils.py`:
- Around line 81-90: The test always triggers the pad_r > 0 branch and misses
the scenario where rows are already 16-aligned but weight_scale requires column
padding; update or add a pytest case for pad_nvfp4_weights in
test_row_col_padding that uses an aligned row count (e.g., call
_make_nvfp4_state_dict with rows=16) and a columns value that is not 16-aligned
(e.g., cols=20) so the function must only pad weight_scale columns — verify
resulting "layer0.weight" and "layer0.weight_scale" shapes have both dims
%16==0; keep the test name or add a new test that explicitly covers the
row-aligned-but-column-padding case.
- Around line 140-171: Update the two tests to assert the safetensors file
header metadata and swizzle effects instead of relying on load_file() which
strips metadata: in test_metadata_injection call _postprocess_safetensors and
then open the saved safetensors file with the safetensors API that returns
metadata (e.g., safe_open/get_metadata) and assert the header contains the
expected layerwise-quant entries from hf_quant_config and
enable_layerwise_quant_metadata; in test_padding_and_swizzle, after calling
_postprocess_safetensors assert a swizzle-specific effect (e.g., a metadata flag
like "swizzled" or that the on-disk tensor bytes/order differ from the original
tensor before swizzle) in addition to shape and dtype checks so the test fails
if swizzling was skipped. Ensure you reference the existing helpers save_file,
load_file, and the function _postprocess_safetensors and assert keys like
"weight_scale" and the safetensors header fields directly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: df2949d4-7b53-462f-aa24-3b05cb554c67
📒 Files selected for processing (3)
examples/diffusers/quantization/quantize.pymodelopt/torch/export/unified_export_hf.pytests/unit/torch/export/test_nvfp4_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/export/unified_export_hf.py
There was a problem hiding this comment.
♻️ Duplicate comments (2)
tests/unit/torch/export/test_nvfp4_utils.py (2)
147-162:⚠️ Potential issue | 🟠 MajorAssert safetensors header metadata, not only tensor payload.
At Line 160,
load_file()only validates tensor bytes. This test still passes if metadata injection regresses, so it doesn’t verify the feature under test.Suggested patch
import json import pytest import torch +from safetensors import safe_open from safetensors.torch import load_file, save_file @@ def test_metadata_injection(self, tmp_path): from modelopt.torch.export.unified_export_hf import _postprocess_safetensors @@ reloaded = load_file(str(tmp_path / "model.safetensors")) assert torch.allclose(reloaded["weight"], sd["weight"]) + with safe_open(str(tmp_path / "model.safetensors"), framework="pt", device="cpu") as f: + metadata = f.metadata() + assert json.loads(metadata["quantization_config"]) == hf_quant_config + assert json.loads(metadata["_quantization_metadata"]) == { + "format_version": "1.0", + "layers": {}, + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/export/test_nvfp4_utils.py` around lines 147 - 162, The test test_metadata_injection currently only reloads tensor payload via load_file and misses verifying header metadata injected by _postprocess_safetensors; update the test to open the saved safetensors file and assert the header metadata (e.g., using safetensors safe_open or equivalent) contains the expected quantization keys/values derived from hf_quant_config and enable_layerwise_quant_metadata after calling _postprocess_safetensors (keep using save_file, _postprocess_safetensors, and load_file for payload assertions but add explicit checks against the safetensors metadata for the quantization entries).
163-179:⚠️ Potential issue | 🟠 MajorAdd a swizzle-specific assertion to prevent false positives.
Lines 176-178 validate row padding and dtype, but this can still pass when swizzling is skipped (input
weight_scaleis alreadyfloat8_e4m3fn). Add a swizzle-observable assertion (e.g., expected swizzled shape).Suggested patch
reloaded = load_file(str(tmp_path / "model.safetensors")) assert reloaded["layer0.weight"].shape[0] == 32 assert reloaded["layer0.weight_scale"].dtype == torch.float8_e4m3fn + assert reloaded["layer0.weight_scale"].shape == (128, 64 // 16)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/export/test_nvfp4_utils.py` around lines 163 - 179, The test currently only checks padding and dtype and can pass if swizzling was skipped; to ensure swizzling actually happened, load the original safetensors before calling _postprocess_safetensors (using load_file) and after processing assert that the tensor for "layer0.weight" (or another swizzle-target key) is different from the original (e.g., not torch.equal), so the test verifies a swizzle-observable change in addition to the existing shape and dtype checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@tests/unit/torch/export/test_nvfp4_utils.py`:
- Around line 147-162: The test test_metadata_injection currently only reloads
tensor payload via load_file and misses verifying header metadata injected by
_postprocess_safetensors; update the test to open the saved safetensors file and
assert the header metadata (e.g., using safetensors safe_open or equivalent)
contains the expected quantization keys/values derived from hf_quant_config and
enable_layerwise_quant_metadata after calling _postprocess_safetensors (keep
using save_file, _postprocess_safetensors, and load_file for payload assertions
but add explicit checks against the safetensors metadata for the quantization
entries).
- Around line 163-179: The test currently only checks padding and dtype and can
pass if swizzling was skipped; to ensure swizzling actually happened, load the
original safetensors before calling _postprocess_safetensors (using load_file)
and after processing assert that the tensor for "layer0.weight" (or another
swizzle-target key) is different from the original (e.g., not torch.equal), so
the test verifies a swizzle-observable change in addition to the existing shape
and dtype checks.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0da7f869-0e67-435b-ad9b-684169258f0e
📒 Files selected for processing (1)
tests/unit/torch/export/test_nvfp4_utils.py
2807ebd to
f46429f
Compare
f46429f to
13eb1e9
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: Adds NVFP4 post-processing pipeline for exported diffusion model checkpoints — padding, cuBLAS block-scale swizzling, and independent metadata injection — refactored out of the save step into a separate _postprocess_safetensors pass. Also fixes a context-manager restoration bug in set_quantizer_by_cfg_context.
Issues Found:
-
[Correctness] Metadata merge order is reversed from prior behavior. Old code:
base_metadata.update(export_metadata)(export keys like_export_formatwin). New code in_postprocess_safetensors:metadata.update(base_metadata)(base keys win). Currently benign because key sets likely don't overlap, but fragile if base checkpoints ever contain_export_formator_class_name.- File:
modelopt/torch/export/unified_export_hf.py, new lines ~207-210
- File:
-
[Correctness]
quantization_configis written twice when merge is active: once insidemerge_diffusion_checkpoint(line 830 of diffusers_utils.py) and again explicitly in_postprocess_safetensors(line ~220). The second write is identical, so not a bug, but the redundancy means the merge function's metadata write is now dead code for merged paths.- File:
modelopt/torch/export/unified_export_hf.py, lines ~218-220 anddiffusers_utils.pyline 830
- File:
-
[Correctness]
_postprocess_safetensorsis now invoked for all quantized components, including those saved viasave_pretrained. Previously these components got no merge/swizzle/padding/metadata in the safetensors file. This is likely intentional for extending NVFP4 support, but it's a behavioral change worth confirming — especially for components that aren't the main transformer (e.g., a quantized VAE could have its scales swizzled too).- File:
modelopt/torch/export/unified_export_hf.py, lines ~1047-1055
- File:
-
[Readability]
import structis inside the_postprocess_safetensorsfunction body. It's a stdlib module — move it to the module-level imports for consistency with the rest of the file.- File:
modelopt/torch/export/unified_export_hf.py, line ~186
- File:
-
[Readability] Metadata is read via manual binary header parsing (
struct.unpack+json.loads). Thesafetensorslibrary already providessafe_open(...).metadata()for this purpose (used elsewhere in this codebase, e.g.,_merge_ltx2). Usingsafe_openwould be simpler and less fragile.- File:
modelopt/torch/export/unified_export_hf.py, lines ~194-197
- File:
-
[Tests] The
try/finallyfix inconversion.py(the most safety-critical change in this PR) has zero test coverage. Codecov confirms 12 missing lines. A test that raises inside theset_quantizer_by_cfg_contextbody and verifies quantizer state is still restored would validate this fix. -
[Tests] No test covers the
save_pretrained→_postprocess_safetensorspath (issue #3 above). The existing tests only exercise the custom-save path via standalone safetensors files. -
[Correctness]
_to_blockedinternally pads scale rows to multiples of 128 and changes the output tensor shape (e.g., input(16, 4)→ output(128, 4)). This means after swizzle,weight.shape[0] != weight_scale.shape[0]. The test explicitly asserts this, and cuBLAS likely requires it, but this shape mismatch should be documented in theswizzle_nvfp4_scalesdocstring so future maintainers don't "fix" it.
Suggestions:
- Consider passing
hf_quant_config=Nonetomerge_diffusion_checkpointin_postprocess_safetensorsto avoid the redundantquantization_configwrite, since it's set explicitly afterward anyway. _find_nvfp4_layersis called separately bypad_nvfp4_weightsandswizzle_nvfp4_scales. When both are invoked in_postprocess_safetensors, detection runs twice over the same state dict. A minor optimization would be to compute it once and pass it in, but not blocking.- The copyright header in
test_nvfp4_utils.pysays 2024 — should be 2026.
Overall Assessment: The core functionality (padding, swizzle, metadata) is well-structured and correctly extracted. The conversion.py try/finally fix is a genuine bug fix. However, the metadata precedence reversal and the behavioral change for save_pretrained components need explicit confirmation, and the try/finally fix needs a test.
| @@ -28,7 +28,9 @@ | |||
|
|
|||
| import torch | |||
| import torch.nn as nn | |||
There was a problem hiding this comment.
load_file is added here but safe_open (already available from safetensors) would be better for reading metadata at line ~196, avoiding the manual struct.unpack header parsing.
| import struct | ||
|
|
||
| safetensor_files = sorted(export_dir.glob("*.safetensors")) | ||
| if not safetensor_files: |
There was a problem hiding this comment.
Move import struct to the top of the file with other stdlib imports. Inline imports are typically reserved for optional/heavy dependencies, not stdlib modules.
| "Export with a larger max_shard_size or disable merge/metadata options." | ||
| ) | ||
|
|
||
| for sf_path in safetensor_files: |
There was a problem hiding this comment.
Consider replacing this manual header parsing with safetensors.safe_open:
with safe_open(str(sf_path), framework="pt") as f:
metadata = dict(f.metadata() or {})This is the pattern used in _merge_ltx2 and is less fragile than binary header parsing.
| sd = {k: v.clone() for k, v in sd.items()} | ||
|
|
||
| if merged_base_safetensor_path is not None and model_type is not None: | ||
| sd, base_metadata = merge_diffusion_checkpoint( |
There was a problem hiding this comment.
metadata.update(base_metadata) reverses the old merge precedence. Previously export keys (_export_format, _class_name) always won. Now base checkpoint keys would win if they overlap. Consider using: base_metadata.update(metadata); metadata = base_metadata to preserve the old semantics.
| @@ -797,23 +830,119 @@ def merge_diffusion_checkpoint( | |||
| if hf_quant_config is not None: | |||
There was a problem hiding this comment.
This quantization_config write is now redundant — _postprocess_safetensors always writes it again when hf_quant_config is not None. Consider removing this line (or passing hf_quant_config=None from the caller) to avoid confusion about where metadata is authored.
| dtype=input_matrix.dtype, | ||
| ) | ||
| padded[:rows, :cols] = input_matrix | ||
| blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) |
There was a problem hiding this comment.
The docstring should note that the output shape may differ from the input shape (rows padded to multiples of 128 for cuBLAS alignment). This is a non-obvious permanent transformation that could surprise maintainers.
| @@ -0,0 +1,206 @@ | |||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
Copyright year says 2024 — should be 2026 to match the current year.
| _save_component_state_dict_safetensors(component, component_export_dir) | ||
|
|
||
| # Step 7: Post-process — merge, metadata, padding, swizzle | ||
| _postprocess_safetensors( |
There was a problem hiding this comment.
This _postprocess_safetensors call now runs for save_pretrained components too (the if hasattr(component, "save_pretrained") branch above). In the old code, those components got no merge/swizzle/metadata. Please confirm this behavioral change is intentional and add a test for it.
There was a problem hiding this comment.
For diffusers based model like flux2 the save_pretrained branch will hit and the swizzling, metadata will be required for those models as well.
…and support for different padding strategy Signed-off-by: ynankani <ynankani@nvidia.com> Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com> Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com> Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com> Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: ynankani <ynankani@nvidia.com> Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: ynankani-nv <ynankani@nvidia.com> Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
Signed-off-by: YASH Nankani <ynankani@2u1g-x570-0073.ipp2a1.colossus.nvidia.com>
13eb1e9 to
2b729ba
Compare
Signed-off-by: YASH Nankani <ynankani@dl325g11-1979.ipp2a2.colossus.nvidia.com>
What does this PR do?
Type of change: ? new feature
Adds post-processing support for exported diffusion model checkpoints to enable NVFP4 block scale swizzling and configurable padding strategies. This allows exported quantized checkpoints to be directly consumed by inference runtimes (e.g., ComfyUI with comfy_kitchen) that require cuBLAS 2-D block-scaling-factors layout.
Changes:
Usage
Testing
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation