Skip to content

feat(deepseek): add --cast_mxfp4_to_nvfp4 to deepseek_v4 quantize step#1653

Merged
kevalmorabia97 merged 1 commit into
mainfrom
chenjiel/dsv4-cast-mxfp4-to-nvfp4
Jun 10, 2026
Merged

feat(deepseek): add --cast_mxfp4_to_nvfp4 to deepseek_v4 quantize step#1653
kevalmorabia97 merged 1 commit into
mainfrom
chenjiel/dsv4-cast-mxfp4-to-nvfp4

Conversation

@cjluo-nv

@cjluo-nv cjluo-nv commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

What does this PR do?

Type of change: new feature

Brings the GPT-OSS lossless MXFP4 → NVFP4 cast (#1372) to DeepSeek V4's routed-expert export by adding a --cast_mxfp4_to_nvfp4 flag to examples/deepseek/deepseek_v4/quantize_to_nvfp4.py.

To avoid duplicating the closed-form math, the shared numerics — mxfp4_to_nvfp4_global_amax, mxfp4_to_nvfp4_per_block_amax, and the E2M1/E4M3/E8M0 constants — are hoisted out of the GPT-OSS example cast into the library at modelopt/torch/quantization/utils/numeric_utils.py. Both the GPT-OSS cast (examples/llm_ptq/cast_mxfp4_to_nvfp4.py) and the new DeepSeek path now import them from there.

DeepSeek V4's routed experts ship as MXFP4 (E2M1 nibbles + a power-of-two E8M0 scale per 32-element block). By default the export dequantizes them to BF16 and re-quantizes to NVFP4 using the calibrated per-tensor weight amax, which re-derives per-block scales from the data and is therefore lossy. With the flag, the cast pins scale_2 = 2^(k_max-8) and each per-block E4M3 scale to 2^(k_j-m) straight from the source E8M0 scales, so per_block_scale * scale_2 = 2^k_j and the NVFP4 nibbles equal the source MXFP4 nibbles bit-for-bit (for every block whose k_j lands in E4M3's representable window; rare out-of-range blocks clamp). The one V4-specific addition is that w1/w3 share a single scale_2 for the fused GEMM1, so k_max is taken over both projections. The flag only affects routed-expert weights — activation input_scale still comes from --amax_path calibration.

Usage

python deepseek_v4/quantize_to_nvfp4.py \
    --amax_path ${AMAX} \
    --source_ckpt ${DS_V4} \
    --output_ckpt ${HF_NVFP4_PATH} \
    --cast_mxfp4_to_nvfp4

Testing

  • The hoisted numerics get unit tests in tests/unit/torch/quantization/test_numeric_utils.py (10 cases: per-tensor global_amax, per-block amax incl. out-of-range, magnitude-table cache) — 10/10 pass. The example test tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py keeps the cast-specific cases (quantizer naming, build_amax_map, apply_to_model).
  • Validated on real DeepSeek-V4-Flash expert tensors (incl. the on-disk float8_e8m0fnu scale dtype): 23.5M blocks, 100% lossless, 0 error.
  • Generated a full NVFP4 checkpoint for DeepSeek-V4-Flash (43 layers, 256 routed experts) end-to-end: [cast] lossless MXFP4->NVFP4 blocks: 8,657,043,456/8,657,043,456 (100.0000%). Output weights match an independently-produced reference cast byte-for-byte (weight_scale, weight_scale_2, packed nibbles modulo the harmless sign-of-zero).

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅ (new opt-in flag; default export behavior unchanged; hoist re-exports through the existing example module)
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ N/A (no new deps; shared numerics moved into the library rather than duplicated)
  • Did you write any new necessary tests?: ✅ (library numerics covered by tests/unit/torch/quantization/test_numeric_utils.py; end-to-end validated on a real DeepSeek-V4 checkpoint)
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: ❌ (will run /claude review)

Additional Information

Mirrors and reuses #1372 (GPT-OSS MXFP4 → NVFP4 cast); the closed-form numerics are now shared via modelopt.torch.quantization.utils.numeric_utils.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added --cast_mxfp4_to_nvfp4 flag to perform a closed-form, mostly lossless MXFP4→NVFP4 conversion for routed-expert weights with aggregated lossless/block statistics.
  • Documentation

    • Updated DeepSeek V4 export instructions and README to document the new flag and clarify calibration behavior for activation scales.
  • Chores

    • Exposed shared numeric quantization utilities for MXFP4→NVFP4 casting.
  • Tests

    • Added and updated tests to validate the new numeric helpers and conversion behavior.

@cjluo-nv cjluo-nv requested a review from a team as a code owner June 8, 2026 22:24
@cjluo-nv cjluo-nv requested a review from sugunav14 June 8, 2026 22:24
@coderabbitai

coderabbitai Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a --cast_mxfp4_to_nvfp4 CLI flag to DeepSeek V4 quantization, implements closed-form tensor-only numeric helpers for MXFP4→NVFP4 scale selection, re-exports those helpers, updates example scripts/tests to use them, and wires a lossless cast path into shard conversion with aggregated logging.

Changes

MXFP4→NVFP4 Lossless Cast

Layer / File(s) Summary
Feature documentation
CHANGELOG.rst, examples/deepseek/README.md, examples/deepseek/deepseek_v4/quantize_to_nvfp4.py
Adds changelog and README entries and updates the module docstring to document --cast_mxfp4_to_nvfp4, scoped to routed-expert weights and noting input_scale remains from --amax_path calibration.
Numeric utilities and re-exports
modelopt/torch/quantization/utils/numeric_utils.py, modelopt/torch/quantization/utils/__init__.py, examples/llm_ptq/cast_mxfp4_to_nvfp4.py, tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py, tests/unit/torch/quantization/test_numeric_utils.py
New numeric_utils module providing E2M1/E4M3/E8M0 constants, a cached E2M1 magnitude table, and mxfp4_to_nvfp4_global_amax / mxfp4_to_nvfp4_per_block_amax; __init__ re-exports these symbols; example script and tests updated to use them.
DeepSeek lossless cast implementation & CLI integration
examples/deepseek/deepseek_v4/quantize_to_nvfp4.py
Implements closed-form lossless MXFP4→NVFP4 routed-expert weight cast (derive k_max, pin NVFP4 scale_2, per-block amax/fallback), extends convert_shard(..., cast: bool=False), conditionally selects cast vs calibrated quantization, wires CLI flag, and logs aggregated lossless-block statistics.

Sequence Diagram

sequenceDiagram
  participant User
  participant DeepSeekScript
  participant numeric_utils
  participant NVFP4Quantizer
  participant HFCheckpoint
  User->>DeepSeekScript: run quantize_to_nvfp4.py --cast_mxfp4_to_nvfp4
  DeepSeekScript->>NVFP4Quantizer: convert_shard(cast=True)
  NVFP4Quantizer->>numeric_utils: mxfp4_to_nvfp4_global_amax / mxfp4_to_nvfp4_per_block_amax
  numeric_utils-->>NVFP4Quantizer: amax, k_max, diagnostics
  NVFP4Quantizer-->>DeepSeekScript: quantized NVFP4 weights + lossless stats
  DeepSeekScript->>HFCheckpoint: write `--output_ckpt` NVFP4 checkpoint
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • NVIDIA/Model-Optimizer#1341: Modifies DeepSeek V4 quantize_to_nvfp4.py export/quantization flow for routed-expert MXFP4→NVFP4 weights.

Suggested labels

cherry-pick-0.45.0

Suggested reviewers

  • meenchen
  • realAsma
  • sugunav14
🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately and specifically describes the main change: adding a --cast_mxfp4_to_nvfp4 flag to the deepseek_v4 quantize step, which aligns directly with the primary objective of the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 91.67% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. torch.load uses weights_only=True; no numpy.load, eval/exec, trust_remote_code, or nosec comments present; no new non-permissive dependencies added.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenjiel/dsv4-cast-mxfp4-to-nvfp4

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov

codecov Bot commented Jun 8, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.07%. Comparing base (01415c2) to head (ada91f0).
⚠️ Report is 11 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1653      +/-   ##
==========================================
- Coverage   77.51%   76.07%   -1.44%     
==========================================
  Files         489      509      +20     
  Lines       54498    57850    +3352     
==========================================
+ Hits        42242    44008    +1766     
- Misses      12256    13842    +1586     
Flag Coverage Δ
examples 42.81% <52.00%> (-0.10%) ⬇️
gpu 58.37% <24.00%> (-0.67%) ⬇️
regression 14.87% <0.00%> (+0.06%) ⬆️
unit 54.46% <100.00%> (+0.39%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@cjluo-nv cjluo-nv force-pushed the chenjiel/dsv4-cast-mxfp4-to-nvfp4 branch from 1d16130 to 460509f Compare June 8, 2026 22:36
@cjluo-nv cjluo-nv force-pushed the chenjiel/dsv4-cast-mxfp4-to-nvfp4 branch from 460509f to 656101f Compare June 9, 2026 05:28
@cjluo-nv cjluo-nv requested review from a team as code owners June 9, 2026 05:28
@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-06-10 05:24 UTC

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py (1)

278-285: 💤 Low value

Consider moving the import to module level.

The numeric_utils import inside the test function could be moved to the top of the file with other imports. Per coding guidelines, imports belong at module top so import errors surface at collection time. However, since this is a test file and only this single test uses numeric_utils, the current approach is acceptable.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py` around lines 278 - 285,
Move the local import of numeric_utils out of
test_e2m1_magnitude_table_cached_per_device and into the module-level imports so
import errors surface during test collection; specifically, import
modelopt.torch.quantization.utils.numeric_utils at the top of the file and then
in the test simply call numeric_utils._e2m1_magnitude_table(torch.device("cpu"))
and assert as before (references: test_e2m1_magnitude_table_cached_per_device,
numeric_utils, _e2m1_magnitude_table, _E2M1_MAGNITUDE).

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py`:
- Around line 278-285: Move the local import of numeric_utils out of
test_e2m1_magnitude_table_cached_per_device and into the module-level imports so
import errors surface during test collection; specifically, import
modelopt.torch.quantization.utils.numeric_utils at the top of the file and then
in the test simply call numeric_utils._e2m1_magnitude_table(torch.device("cpu"))
and assert as before (references: test_e2m1_magnitude_table_cached_per_device,
numeric_utils, _e2m1_magnitude_table, _E2M1_MAGNITUDE).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: ee495359-5a81-4f65-bf13-107e56db31b6

📥 Commits

Reviewing files that changed from the base of the PR and between 460509f and 656101f.

📒 Files selected for processing (7)
  • CHANGELOG.rst
  • examples/deepseek/README.md
  • examples/deepseek/deepseek_v4/quantize_to_nvfp4.py
  • examples/llm_ptq/cast_mxfp4_to_nvfp4.py
  • modelopt/torch/quantization/utils/__init__.py
  • modelopt/torch/quantization/utils/numeric_utils.py
  • tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py
✅ Files skipped from review due to trivial changes (2)
  • examples/deepseek/README.md
  • CHANGELOG.rst

@cjluo-nv cjluo-nv force-pushed the chenjiel/dsv4-cast-mxfp4-to-nvfp4 branch 2 times, most recently from dd1444d to 2a2ac77 Compare June 9, 2026 05:40

@meenchen meenchen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Solid follow-up to #1372: the closed-form MXFP4→NVFP4 numerics are hoisted into modelopt/torch/quantization/utils/numeric_utils.py and reused by both example cast paths instead of duplicated, and the V4 twist (w1/w3 sharing one scale_2 for the fused GEMM1) is the right minimal addition. Design protocol satisfied — the PR body explicitly justifies hoisting rather than copy-pasting, and numeric_utils.py is a thin pure-tensor utility, not a new framework.

Math reads correctly: per_block_scale = 2^(k_j-m) × scale_2 = 2^m2^k_j exactly, BF16 dequant of E2M1 magnitudes is lossless, and the lossless-accounting mask (k >= k_max - 17) | (e8m0 == 0) correctly handles all-zero blocks (whose k_j = -127 would otherwise read as out-of-range). NVFP4QTensor.quantize accepts a pre-supplied per-block scale tensor and that path was already exercised by the GPT-OSS cast.

A few things that warrant a human pass before sign-off:

  • No unit tests for the new V4-specific helpers. _quantize_weight_nvfp4_lossless, _build_w13_kmax_overrides, and _kmax_from_mxfp4_scale are exercised only by the on-the-real-V4-checkpoint validation reported in the PR body; the moved tests in tests/unit/torch/quantization/test_numeric_utils.py cover only the hoisted library helpers. A cheap unit test for the V4 path (synthetic MXFP4 weight + scale → roundtrip lossless count, plus a w1/w3 shared-k_max case) would lock in the V4-specific contract.
  • README quietly promotes --cast_mxfp4_to_nvfp4 to the standard example. examples/deepseek/README.md now appends --cast_mxfp4_to_nvfp4 to the main quantize_to_nvfp4.py invocation, not as a "lossless variant" callout — that contradicts the PR body's "default export behavior unchanged" framing. Worth either dropping the flag from the top-level usage block (keeping the dedicated subsection that explains the trade-off) or saying explicitly that the flag is now recommended for V4 sources.
  • Stat naming nit: stats[f"cast_oor_layers_{block_kind}"] is incremented per expert-projection tensor (w1/w2/w3), not per layer — cast_oor_tensors_* would match experts_* / weight_synth_* siblings.
  • Minor: _build_w13_kmax_overrides reads each *.scale once and the main loop reads it again via f.get_tensor(scale_key). Cheap (E8M0 is tiny relative to packed nibbles), but if you're already touching this you could reuse the cached read.

Nothing blocking; the cast itself looks right. Just want a human reviewer to land the unit-test ask and the README usage call.

@cjluo-nv

cjluo-nv commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator Author

Bot review — DM the bot to share feedback.

Solid follow-up to #1372: the closed-form MXFP4→NVFP4 numerics are hoisted into modelopt/torch/quantization/utils/numeric_utils.py and reused by both example cast paths instead of duplicated, and the V4 twist (w1/w3 sharing one scale_2 for the fused GEMM1) is the right minimal addition. Design protocol satisfied — the PR body explicitly justifies hoisting rather than copy-pasting, and numeric_utils.py is a thin pure-tensor utility, not a new framework.

Math reads correctly: per_block_scale = 2^(k_j-m) × scale_2 = 2^m2^k_j exactly, BF16 dequant of E2M1 magnitudes is lossless, and the lossless-accounting mask (k >= k_max - 17) | (e8m0 == 0) correctly handles all-zero blocks (whose k_j = -127 would otherwise read as out-of-range). NVFP4QTensor.quantize accepts a pre-supplied per-block scale tensor and that path was already exercised by the GPT-OSS cast.

A few things that warrant a human pass before sign-off:

  • No unit tests for the new V4-specific helpers. _quantize_weight_nvfp4_lossless, _build_w13_kmax_overrides, and _kmax_from_mxfp4_scale are exercised only by the on-the-real-V4-checkpoint validation reported in the PR body; the moved tests in tests/unit/torch/quantization/test_numeric_utils.py cover only the hoisted library helpers. A cheap unit test for the V4 path (synthetic MXFP4 weight + scale → roundtrip lossless count, plus a w1/w3 shared-k_max case) would lock in the V4-specific contract.
  • README quietly promotes --cast_mxfp4_to_nvfp4 to the standard example. examples/deepseek/README.md now appends --cast_mxfp4_to_nvfp4 to the main quantize_to_nvfp4.py invocation, not as a "lossless variant" callout — that contradicts the PR body's "default export behavior unchanged" framing. Worth either dropping the flag from the top-level usage block (keeping the dedicated subsection that explains the trade-off) or saying explicitly that the flag is now recommended for V4 sources.
  • Stat naming nit: stats[f"cast_oor_layers_{block_kind}"] is incremented per expert-projection tensor (w1/w2/w3), not per layer — cast_oor_tensors_* would match experts_* / weight_synth_* siblings.
  • Minor: _build_w13_kmax_overrides reads each *.scale once and the main loop reads it again via f.get_tensor(scale_key). Cheap (E8M0 is tiny relative to packed nibbles), but if you're already touching this you could reuse the cached read.

Nothing blocking; the cast itself looks right. Just want a human reviewer to land the unit-test ask and the README usage call.

it's ok to skip unittests impl for example code.

@cjluo-nv cjluo-nv force-pushed the chenjiel/dsv4-cast-mxfp4-to-nvfp4 branch from 2a2ac77 to 4fe63d0 Compare June 9, 2026 06:21

@mxinO mxinO left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks.

return overrides


def _quantize_weight_nvfp4_lossless(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to make this a general util in modeopt, I bet the mxfp4-> nvfp4 casting will be a often used util later.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the normal PTQ path, we already have the feature integrated. See https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/llm_ptq/cast_mxfp4_to_nvfp4.py

@meenchen meenchen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review: prior critical/important comments are addressed.

  • 💬 Operator: "We don't need to add unittest for this example" — drops the V4-specific-helpers unit-test ask from the prior nudge.
  • README usage call: top-level quantize_to_nvfp4.py invocation no longer appends --cast_mxfp4_to_nvfp4; the flag is now in a dedicated "Lossless MXFP4 → NVFP4 weight cast" subsection with the trade-off explained, matching the "default unchanged" framing.
  • Stat naming: cast_oor_layers_* renamed to cast_oor_tensors_* to match the per-tensor increment and the experts_* / weight_synth_* siblings.
  • Hoist + math (closed-form scale_2 = 2^(k_max-8), per_block_scale = 2^(k_j-m), lossless mask (k >= k_max - 17) | (e8m0 == 0), w1/w3 shared k_max) is consistent with PR #1372 and previously reviewed.
  • numeric_utils.py carries the standard NVIDIA Apache-2.0 header; no other licensing changes.

Complex PR: spans 7 directories (≥ 5). Looping in a human for approval.

Comment thread modelopt/torch/quantization/utils/__init__.py

@meenchen meenchen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

@kevalmorabia97 kevalmorabia97 added the cherry-pick-0.45.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc label Jun 9, 2026
Add a closed-form, bit-exact MXFP4 -> NVFP4 routed-expert weight cast to examples/deepseek/deepseek_v4/quantize_to_nvfp4.py via a --cast_mxfp4_to_nvfp4 flag. Pins scale_2 = 2^(k_max-8) and each per-block E4M3 scale to 2^(k_j-m) from the source E8M0 scales, so the NVFP4 nibbles equal the source MXFP4 nibbles bit-for-bit for every in-range block. w1/w3 share one scale_2 for the fused GEMM1; activation input_scale still comes from --amax_path calibration.

Hoist the shared closed-form numerics (mxfp4_to_nvfp4_global_amax, mxfp4_to_nvfp4_per_block_amax, and the E2M1/E4M3/E8M0 constants) out of the GPT-OSS example cast (examples/llm_ptq/cast_mxfp4_to_nvfp4.py, PR #1372) into modelopt.torch.quantization.utils.numeric_utils, so both the GPT-OSS and DeepSeek-V4 cast paths import them from the library. Their unit tests move to tests/unit/torch/quantization/test_numeric_utils.py; the example test keeps the cast-specific cases (quantizer naming, build_amax_map, apply_to_model).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv force-pushed the chenjiel/dsv4-cast-mxfp4-to-nvfp4 branch from 4fe63d0 to ada91f0 Compare June 9, 2026 17:09
@cjluo-nv cjluo-nv enabled auto-merge (squash) June 9, 2026 17:15

@Edwardf0t1 Edwardf0t1 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems there are overlapped contents in readme and docstrings, could we simplify?

@kevalmorabia97 kevalmorabia97 disabled auto-merge June 10, 2026 04:12
@kevalmorabia97 kevalmorabia97 merged commit bde162a into main Jun 10, 2026
55 of 56 checks passed
@kevalmorabia97 kevalmorabia97 deleted the chenjiel/dsv4-cast-mxfp4-to-nvfp4 branch June 10, 2026 05:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-0.45.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants