Skip to content

Commit 22cc015

Browse files
TimDettmersclaude
andcommitted
Add FSDP integration test for 4-bit state_dict save (#1405)
Add a subprocess-based pytest test that launches a single-GPU FSDP process via torchrun to verify get_model_state_dict with cpu_offload=True works for QLoRA-style models with frozen Params4bit base weights. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 857d562 commit 22cc015

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

tests/fsdp_state_dict_save.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""FSDP state_dict save integration test for 4-bit quantized models (#1405).
2+
3+
This script must be launched via torchrun (not directly):
4+
torchrun --nproc_per_node=1 tests/fsdp_state_dict_save.py
5+
6+
It wraps a QLoRA-style model (frozen 4-bit base + trainable adapter) in FSDP
7+
and calls get_model_state_dict with cpu_offload=True, which exercises the
8+
_get_fqns() getattr traversal that previously crashed with:
9+
AttributeError: 'Params4bit' object has no attribute 'absmax'
10+
"""
11+
12+
import sys
13+
14+
import torch
15+
import torch.distributed as dist
16+
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
17+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
18+
import torch.nn as nn
19+
20+
import bitsandbytes as bnb
21+
22+
23+
class SimpleQLoRAModel(nn.Module):
24+
"""Minimal model with a frozen 4-bit base layer and a trainable adapter."""
25+
26+
def __init__(self, quant_type="nf4"):
27+
super().__init__()
28+
self.base = bnb.nn.Linear4bit(64, 64, bias=False, quant_type=quant_type)
29+
self.adapter = nn.Linear(64, 64, bias=False)
30+
31+
def forward(self, x):
32+
return self.base(x) + self.adapter(x)
33+
34+
35+
def main():
36+
dist.init_process_group(backend="nccl")
37+
rank = dist.get_rank()
38+
torch.cuda.set_device(rank)
39+
40+
errors = []
41+
42+
for quant_type in ("nf4", "fp4"):
43+
model = SimpleQLoRAModel(quant_type=quant_type)
44+
model = model.to("cuda")
45+
46+
# Freeze quantized base weights (as in real QLoRA)
47+
for p in model.base.parameters():
48+
p.requires_grad = False
49+
50+
# Tell FSDP to ignore the frozen quantized params (can't flatten int dtypes)
51+
ignored = list(model.base.parameters())
52+
fsdp_model = FSDP(model, device_id=rank, ignored_states=ignored, use_orig_params=True)
53+
54+
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
55+
try:
56+
state_dict = get_model_state_dict(fsdp_model, options=options)
57+
58+
# Verify expected keys are present
59+
expected_substrings = ["base.weight", "absmax", "quant_map", "adapter.weight"]
60+
for substr in expected_substrings:
61+
if not any(substr in k for k in state_dict.keys()):
62+
errors.append(f"{quant_type}: missing key containing '{substr}' in {list(state_dict.keys())}")
63+
64+
print(f"{quant_type}: SUCCESS ({len(state_dict)} keys)", flush=True)
65+
except Exception as e:
66+
errors.append(f"{quant_type}: {type(e).__name__}: {e}")
67+
print(f"{quant_type}: FAILED: {e}", flush=True)
68+
69+
dist.destroy_process_group()
70+
71+
if errors:
72+
print("\nFAILURES:\n" + "\n".join(errors), file=sys.stderr, flush=True)
73+
sys.exit(1)
74+
else:
75+
print("\nAll FSDP state_dict tests passed.", flush=True)
76+
sys.exit(0)
77+
78+
79+
if __name__ == "__main__":
80+
main()

tests/test_linear4bit.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import copy
22
import os
3+
import pathlib
34
import pickle
45
import platform
6+
import subprocess
57
import sys
68
from tempfile import TemporaryDirectory
79

@@ -497,3 +499,30 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist
497499
assert isinstance(w.quant_state, bnb.functional.QuantState)
498500
assert isinstance(w.bnb_quantized, bool)
499501
assert w.bnb_quantized is True
502+
503+
504+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA")
505+
@pytest.mark.skipif(
506+
not torch.distributed.is_nccl_available(),
507+
reason="FSDP test requires NCCL backend",
508+
)
509+
def test_fsdp_state_dict_save_4bit():
510+
"""Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405).
511+
512+
Launches a single-GPU FSDP process via torchrun to exercise the real
513+
_get_fqns() code path that previously crashed with:
514+
AttributeError: 'Params4bit' object has no attribute 'absmax'
515+
"""
516+
script = pathlib.Path(__file__).with_name("fsdp_state_dict_save.py")
517+
result = subprocess.run(
518+
["torchrun", "--nproc_per_node=1", str(script)],
519+
capture_output=True,
520+
text=True,
521+
timeout=120,
522+
)
523+
if result.returncode != 0:
524+
pytest.fail(
525+
f"FSDP state_dict test failed (exit {result.returncode}):\n"
526+
f"stdout: {result.stdout}\n"
527+
f"stderr: {result.stderr}"
528+
)

0 commit comments

Comments
 (0)