Skip to content

Refine layerwise non-mutating calibration#1592

Closed
realAsma wants to merge 4 commits into
fridah/layerwise-configfrom
asma/layerwise_skip_in_meta
Closed

Refine layerwise non-mutating calibration#1592
realAsma wants to merge 4 commits into
fridah/layerwise-configfrom
asma/layerwise_skip_in_meta

Conversation

@realAsma

@realAsma realAsma commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Stacked on #1571.

Summary:

  • Add calib_mutates_weights gating for non-mutating layerwise calibration.
  • Skip layer weight checkpoint/writeback for quantizer-state-only calibration.
  • Keep FSDP2/Accelerate writeback conditional and improve layerwise progress reporting.
  • Make _zeros_from_meta allocate skip placeholders on torch.device("meta") so skip-mode preserves structure without real-device tensor allocation.

Testing:

  • Pre-commit hooks passed during commit.

realAsma added 2 commits June 1, 2026 21:41
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 1, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 1, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 2bc49397-9cd8-4c10-adee-e640efc2f68c

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch asma/layerwise_skip_in_meta

Comment @coderabbitai help to get the list of available commands and usage tips.

Comment on lines +507 to +511
if not writeback:
with _fsdp2_unshard_context(fsdp_module):
yield
return

@realAsma realAsma Jun 1, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sugunav14 here is an easy perf improvement for layerwise FSDP2

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@realAsma Claude claims a correctness issue on this branch, it makes sense to me, please check from your side:

Bug: FSDP2 writeback=False path never all-gathers weights (non-mutating layerwise calibration computes on sharded shards)

Where: _fsdp2_unshard_context (core_utils.py:480) reached via persistent_materialization(layer, writeback=False).

Root cause — a collision between two layers of the same context stack.

persistent_materialization enters its context managers in this order:

with (
    _disable_fsdp_unshard_reshard(layer),                              # ① enters FIRST
    enable_weight_access_and_writeback(layer, layer, writeback=False), # ② enters SECOND
    temporarily_remove_accelerate_hook(layer),
):
  • monkeypatches the class method FSDPParamGroup.unshard to a no-op (to stop FSDP from re-sharding weights between
    calibration batches). This patch is global and is now active.

  • with writeback=False routes into _fsdp2_unshard_context, whose gather step is:

    if was_sharded:
        fsdp_module.unshard()   # core_utils.py:480

    But FSDPModule.unshard() delegates the actual all-gather to fsdp_param_group.unshard() — which is exactly the method
    just patched to a no-op. So no all-gather happens; the params stay sharded DTensors, and the calibration/capture forward
    runs on partial shards → wrong _amax (or a shape/dtype error).

The layer's own FSDP pre-forward hook can't save it either: that hook also calls FSDPParamGroup.unshard, which is still the
no-op.

Why writeback=True (GPTQ/AWQ/SmoothQuant) is unaffected: that branch of fsdp2_weight_access_and_writeback_context gathers
via param.redistribute(... Replicate ...) — a DTensor collective that doesn't touch the patched unshard(). So only the
writeback=False (max/mse/local_hessian) path, i.e. the non-mutating optimization this PR adds, is broken on FSDP2.

_fsdp2_unshard_context is internally self-contradictory under this caller: line 480 calls unshard() (needs it to work) while
line 482 calls _disable_fsdp_unshard_reshard (disables it). It assumes it runs before anyone disables unshard, but its only
production caller disables it first.

Impact / trigger: FSDP2 layerwise calibration with calib_mutates_weights=False. The default (True) takes the redistribute
path and works, so the bug only fires when a user opts into the non-mutating feature.

Test that should catch this (likely unrun — it's a multi-GPU GPU test on a draft PR):
tests/gpu/torch/quantization/test_fsdp2.py:264

with persistent_materialization(layer, writeback=False):
    assert not isinstance(layer[0].weight, DTensor)   # fails: still a DTensor
    layer(inputs)

Suggested fix (don't double-disable):

  1. Have _fsdp2_unshard_context call the unpatched unshard (_disable_fsdp_unshard_reshard already captures orig_unshard;
    expose it), so the gather works while per-forward reshard stays suppressed; or
  2. Skip persistent_materialization's outer _disable_fsdp_unshard_reshard on the FSDP writeback=False path, since
    _fsdp2_unshard_context already suppresses reshard internally after a real unshard.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is possible. can we move _disable_fsdp_unshard_reshard(layer) to _disable_fsdp_reshard(layer) ??

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Bot comment.

Fixed in 9bfec17479.

persistent_materialization() now enters enable_weight_access_and_writeback(..., writeback=writeback) before _disable_fsdp_unshard_reshard(...), so the writeback=False FSDP2 path performs the real one-time unshard before per-forward unshard/reshard is suppressed.

I also added regression coverage for the actual public flow: mtq.quantize(...) with layerwise calib_mutates_weights=False on FSDP2 now asserts that the layer is materialized inside persistent_materialization (params are no longer DTensors), and still checks the quantized FSDP2 output against the non-FSDP reference.

Validated with:

pytest_pwd tests/unit/torch/quantization/test_layerwise_calibrate.py -q
pytest_pwd CUDA_VISIBLE_DEVICES=2,3 tests/gpu/torch/quantization/test_fsdp2.py -k "layerwise_calibrate_fsdp2 or persistent_materialization" -q

Signed-off-by: realAsma <akuriparambi@nvidia.com>
if self.layerwise.save_quantizers_only and not self._supports_save_quantizers_only:
def _validate_non_mutating_layerwise_supported(self):
"""Enforce the ``calib_mutates_weights=False`` whitelist."""
if not self.layerwise.calib_mutates_weights and not self._supports_save_quantizers_only:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can rename _supports_save_quantizers_only to names like _calib_is_amax_only to align with new vocabulary, otherwise a reader has to learn that _supports_save_quantizers_only=True means "amax-only algorithm, so calib_mutates_weights=False is allowed"

if ckpt:
ckpt.save(layer_idx, model, transformer_layers, next_inputs)
if ckpt:
ckpt.save(layer_idx, model, transformer_layers, next_inputs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain why we move ckpt.save inside the persistent_materialization context?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should materialize a layer once and do all the operations (get inputs, calibrate, save); outside of this context; the layer might be sharded or moved to disk in which case we need to materialize again. The point is to avoid redundant materialization.

with (
_disable_fsdp_unshard_reshard(layer),
enable_weight_access_and_writeback(layer, layer, writeback=writeback),
temporarily_remove_accelerate_hook(layer),

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this during an experiment which saved the zeros_from_meta in meta device instead of actual device.

However zeros_from_meta in meta device was breaking a test and I did not check that in - but I think we should modify that test instead.

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Fridah-nv added a commit that referenced this pull request Jun 5, 2026
The skip-layer-weight-checkpoint optimization (save_quantizers_only) is
moved out of this foundation PR; it lands complete in the stacked PR #1592
(as calib_mutates_weights). _CheckpointState now always saves the full layer
weights blob. save_every and the rest of the nested LayerwiseConfig are kept.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv

Copy link
Copy Markdown
Contributor

Move the changes in current PR to #1640 as the base branch changed

@Fridah-nv Fridah-nv closed this Jun 5, 2026
Fridah-nv added a commit that referenced this pull request Jun 5, 2026
The skip-layer-weight-checkpoint optimization (save_quantizers_only) is
moved out of this foundation PR; it lands complete in the stacked PR #1592
(as calib_mutates_weights). _CheckpointState now always saves the full layer
weights blob. save_every and the rest of the nested LayerwiseConfig are kept.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Fridah-nv added a commit that referenced this pull request Jun 5, 2026
save_every); this PR adds the two optimizations on top: calib_mutates_weights
(skip weight checkpoint + writeback for amax-only algorithms) and meta-device
skip-layer placeholders.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants