|
1 | 1 | import copy |
2 | 2 | import os |
| 3 | +import pathlib |
3 | 4 | import pickle |
4 | 5 | import platform |
| 6 | +import subprocess |
5 | 7 | import sys |
6 | 8 | from tempfile import TemporaryDirectory |
7 | 9 |
|
@@ -431,3 +433,96 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st |
431 | 433 | grad_compiled = x.grad.clone() |
432 | 434 |
|
433 | 435 | torch.testing.assert_close(grad_compiled, grad_ref) |
| 436 | + |
| 437 | + |
| 438 | +@pytest.mark.parametrize("device", get_available_devices()) |
| 439 | +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) |
| 440 | +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) |
| 441 | +def test_params4bit_quant_state_attr_access(device, quant_type, compress_statistics): |
| 442 | + """Test that Params4bit proxies QuantState attributes for FSDP state_dict traversal (#1405). |
| 443 | +
|
| 444 | + PyTorch's FSDP state_dict machinery traverses FQN paths like |
| 445 | + 'model.layers.0.weight.absmax' using getattr(). This test verifies |
| 446 | + that Params4bit and QuantState expose the attributes that appear as |
| 447 | + state_dict keys so that _get_fqns() traversal succeeds. |
| 448 | + """ |
| 449 | + if device == "hpu" and not is_supported_on_hpu(quant_type): |
| 450 | + pytest.skip("This configuration is not supported on HPU.") |
| 451 | + |
| 452 | + layer = bnb.nn.Linear4bit( |
| 453 | + 64, |
| 454 | + 64, |
| 455 | + bias=False, |
| 456 | + compress_statistics=compress_statistics, |
| 457 | + quant_type=quant_type, |
| 458 | + ) |
| 459 | + layer = layer.to(device) |
| 460 | + w = layer.weight |
| 461 | + |
| 462 | + assert w.quant_state is not None, "quant_state should be set after quantization" |
| 463 | + |
| 464 | + # Direct QuantState attributes proxied through Params4bit |
| 465 | + assert torch.equal(w.absmax, w.quant_state.absmax) |
| 466 | + assert torch.equal(w.code, w.quant_state.code) |
| 467 | + |
| 468 | + # "quant_map" is how as_dict() serializes "code" — FSDP uses this key name |
| 469 | + assert torch.equal(w.quant_map, w.quant_state.code) |
| 470 | + |
| 471 | + # QuantState packed key: as_dict(packed=True) produces "quant_state.bitsandbytes__<type>" |
| 472 | + # FSDP resolves this as getattr(quant_state_obj, "bitsandbytes__<type>") |
| 473 | + packed_attr = f"bitsandbytes__{quant_type}" |
| 474 | + assert hasattr(w.quant_state, packed_attr) |
| 475 | + packed_val = getattr(w.quant_state, packed_attr) |
| 476 | + assert isinstance(packed_val, torch.Tensor) |
| 477 | + |
| 478 | + # Simulate the full FSDP _get_fqns traversal for all state_dict keys |
| 479 | + state_dict_keys = list(w.quant_state.as_dict(packed=True).keys()) |
| 480 | + for key in state_dict_keys: |
| 481 | + # Each key is relative to "weight.", e.g. "absmax" or "quant_state.bitsandbytes__nf4" |
| 482 | + parts = key.split(".") |
| 483 | + obj = w |
| 484 | + for part in parts: |
| 485 | + obj = getattr(obj, part) |
| 486 | + assert obj is not None |
| 487 | + |
| 488 | + # hasattr should return True for proxied attrs, False for unknown ones |
| 489 | + assert hasattr(w, "absmax") |
| 490 | + assert hasattr(w, "code") |
| 491 | + assert hasattr(w, "quant_map") |
| 492 | + assert not hasattr(w, "nonexistent_attribute") |
| 493 | + |
| 494 | + # Unknown attributes must still raise AttributeError |
| 495 | + with pytest.raises(AttributeError, match="nonexistent_attribute"): |
| 496 | + _ = w.nonexistent_attribute |
| 497 | + |
| 498 | + # Verify that normal Params4bit attributes are unaffected by __getattr__ |
| 499 | + assert isinstance(w.quant_state, bnb.functional.QuantState) |
| 500 | + assert isinstance(w.bnb_quantized, bool) |
| 501 | + assert w.bnb_quantized is True |
| 502 | + |
| 503 | + |
| 504 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA") |
| 505 | +@pytest.mark.skipif( |
| 506 | + not torch.distributed.is_nccl_available(), |
| 507 | + reason="FSDP test requires NCCL backend", |
| 508 | +) |
| 509 | +def test_fsdp_state_dict_save_4bit(): |
| 510 | + """Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405). |
| 511 | +
|
| 512 | + Launches a single-GPU FSDP process via torchrun to exercise the real |
| 513 | + _get_fqns() code path that previously crashed with: |
| 514 | + AttributeError: 'Params4bit' object has no attribute 'absmax' |
| 515 | + """ |
| 516 | + script = pathlib.Path(__file__).with_name("fsdp_state_dict_save.py") |
| 517 | + result = subprocess.run( |
| 518 | + ["torchrun", "--nproc_per_node=1", str(script)], |
| 519 | + capture_output=True, |
| 520 | + text=True, |
| 521 | + timeout=120, |
| 522 | + ) |
| 523 | + if result.returncode != 0: |
| 524 | + pytest.fail( |
| 525 | + f"FSDP state_dict test failed (exit {result.returncode}):\n" |
| 526 | + f"stdout: {result.stdout}\n" |
| 527 | + f"stderr: {result.stderr}" |
| 528 | + ) |
0 commit comments