Refine layerwise non-mutating calibration#1592
Conversation
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
| if not writeback: | ||
| with _fsdp2_unshard_context(fsdp_module): | ||
| yield | ||
| return | ||
|
|
There was a problem hiding this comment.
@sugunav14 here is an easy perf improvement for layerwise FSDP2
There was a problem hiding this comment.
@realAsma Claude claims a correctness issue on this branch, it makes sense to me, please check from your side:
Bug: FSDP2 writeback=False path never all-gathers weights (non-mutating layerwise calibration computes on sharded shards)
Where: _fsdp2_unshard_context (core_utils.py:480) reached via persistent_materialization(layer, writeback=False).
Root cause — a collision between two layers of the same context stack.
persistent_materialization enters its context managers in this order:
with (
_disable_fsdp_unshard_reshard(layer), # ① enters FIRST
enable_weight_access_and_writeback(layer, layer, writeback=False), # ② enters SECOND
temporarily_remove_accelerate_hook(layer),
):-
① monkeypatches the class method
FSDPParamGroup.unshardto a no-op (to stop FSDP from re-sharding weights between
calibration batches). This patch is global and is now active. -
② with
writeback=Falseroutes into_fsdp2_unshard_context, whose gather step is:if was_sharded: fsdp_module.unshard() # core_utils.py:480
But
FSDPModule.unshard()delegates the actual all-gather tofsdp_param_group.unshard()— which is exactly the method ①
just patched to a no-op. So no all-gather happens; the params stay shardedDTensors, and the calibration/capture forward
runs on partial shards → wrong_amax(or a shape/dtype error).
The layer's own FSDP pre-forward hook can't save it either: that hook also calls FSDPParamGroup.unshard, which is still the
no-op.
Why writeback=True (GPTQ/AWQ/SmoothQuant) is unaffected: that branch of fsdp2_weight_access_and_writeback_context gathers
via param.redistribute(... Replicate ...) — a DTensor collective that doesn't touch the patched unshard(). So only the
writeback=False (max/mse/local_hessian) path, i.e. the non-mutating optimization this PR adds, is broken on FSDP2.
_fsdp2_unshard_context is internally self-contradictory under this caller: line 480 calls unshard() (needs it to work) while
line 482 calls _disable_fsdp_unshard_reshard (disables it). It assumes it runs before anyone disables unshard, but its only
production caller disables it first.
Impact / trigger: FSDP2 layerwise calibration with calib_mutates_weights=False. The default (True) takes the redistribute
path and works, so the bug only fires when a user opts into the non-mutating feature.
Test that should catch this (likely unrun — it's a multi-GPU GPU test on a draft PR):
tests/gpu/torch/quantization/test_fsdp2.py:264
with persistent_materialization(layer, writeback=False):
assert not isinstance(layer[0].weight, DTensor) # fails: still a DTensor
layer(inputs)Suggested fix (don't double-disable):
- Have
_fsdp2_unshard_contextcall the unpatched unshard (_disable_fsdp_unshard_reshardalready capturesorig_unshard;
expose it), so the gather works while per-forward reshard stays suppressed; or - Skip
persistent_materialization's outer_disable_fsdp_unshard_reshardon the FSDPwriteback=Falsepath, since
_fsdp2_unshard_contextalready suppresses reshard internally after a real unshard.
There was a problem hiding this comment.
This is possible. can we move _disable_fsdp_unshard_reshard(layer) to _disable_fsdp_reshard(layer) ??
There was a problem hiding this comment.
🤖 Bot comment.
Fixed in 9bfec17479.
persistent_materialization() now enters enable_weight_access_and_writeback(..., writeback=writeback) before _disable_fsdp_unshard_reshard(...), so the writeback=False FSDP2 path performs the real one-time unshard before per-forward unshard/reshard is suppressed.
I also added regression coverage for the actual public flow: mtq.quantize(...) with layerwise calib_mutates_weights=False on FSDP2 now asserts that the layer is materialized inside persistent_materialization (params are no longer DTensors), and still checks the quantized FSDP2 output against the non-FSDP reference.
Validated with:
pytest_pwd tests/unit/torch/quantization/test_layerwise_calibrate.py -q
pytest_pwd CUDA_VISIBLE_DEVICES=2,3 tests/gpu/torch/quantization/test_fsdp2.py -k "layerwise_calibrate_fsdp2 or persistent_materialization" -qSigned-off-by: realAsma <akuriparambi@nvidia.com>
| if self.layerwise.save_quantizers_only and not self._supports_save_quantizers_only: | ||
| def _validate_non_mutating_layerwise_supported(self): | ||
| """Enforce the ``calib_mutates_weights=False`` whitelist.""" | ||
| if not self.layerwise.calib_mutates_weights and not self._supports_save_quantizers_only: |
There was a problem hiding this comment.
nit: we can rename _supports_save_quantizers_only to names like _calib_is_amax_only to align with new vocabulary, otherwise a reader has to learn that _supports_save_quantizers_only=True means "amax-only algorithm, so calib_mutates_weights=False is allowed"
| if ckpt: | ||
| ckpt.save(layer_idx, model, transformer_layers, next_inputs) | ||
| if ckpt: | ||
| ckpt.save(layer_idx, model, transformer_layers, next_inputs) |
There was a problem hiding this comment.
could you explain why we move ckpt.save inside the persistent_materialization context?
There was a problem hiding this comment.
We should materialize a layer once and do all the operations (get inputs, calibrate, save); outside of this context; the layer might be sharded or moved to disk in which case we need to materialize again. The point is to avoid redundant materialization.
| with ( | ||
| _disable_fsdp_unshard_reshard(layer), | ||
| enable_weight_access_and_writeback(layer, layer, writeback=writeback), | ||
| temporarily_remove_accelerate_hook(layer), |
There was a problem hiding this comment.
I added this during an experiment which saved the zeros_from_meta in meta device instead of actual device.
However zeros_from_meta in meta device was breaking a test and I did not check that in - but I think we should modify that test instead.
Signed-off-by: realAsma <akuriparambi@nvidia.com>
The skip-layer-weight-checkpoint optimization (save_quantizers_only) is moved out of this foundation PR; it lands complete in the stacked PR #1592 (as calib_mutates_weights). _CheckpointState now always saves the full layer weights blob. save_every and the rest of the nested LayerwiseConfig are kept. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
|
Move the changes in current PR to #1640 as the base branch changed |
The skip-layer-weight-checkpoint optimization (save_quantizers_only) is moved out of this foundation PR; it lands complete in the stacked PR #1592 (as calib_mutates_weights). _CheckpointState now always saves the full layer weights blob. save_every and the rest of the nested LayerwiseConfig are kept. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
save_every); this PR adds the two optimizations on top: calib_mutates_weights (skip weight checkpoint + writeback for amax-only algorithms) and meta-device skip-layer placeholders. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
Stacked on #1571.
Summary:
_zeros_from_metaallocate skip placeholders ontorch.device("meta")so skip-mode preserves structure without real-device tensor allocation.Testing: