Skip to content

Commit 31610c9

Browse files
authored
bug: fix 8bitoptimsupport with fsdp (#1840)
1 parent 0bd4782 commit 31610c9

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

bitsandbytes/optim/optimizer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def register_module_override(self, module, param_name, config):
111111

112112

113113
class Optimizer8bit(torch.optim.Optimizer):
114+
_FSDP_WRAPPED_QUANT_STATE_KEY = "__bnb_optimizer_quant_state__"
115+
114116
def __init__(self, params, defaults, optim_bits=32, is_paged=False):
115117
"""
116118
Base 8-bit optimizer class.
@@ -152,6 +154,34 @@ def fill_qmap(self):
152154
self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
153155
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
154156

157+
def state_dict(self):
158+
"""Return optimizer state, wrapping quantization tensors for FSDP compatibility.
159+
160+
FSDP's full_optim_state_dict gathers all tensor states across ranks.
161+
Quantization states (state1, state2, absmax, etc.) have different shapes
162+
than model parameters, causing gather operations to fail. By wrapping
163+
these tensors in a nested dict, FSDP skips them during gathering.
164+
"""
165+
state_dict = super().state_dict()
166+
167+
# Deep copy the state to avoid modifying the original optimizer state
168+
# PyTorch's state_dict() only does a shallow copy
169+
state_dict["state"] = {
170+
k: {kk: vv for kk, vv in v.items()} if isinstance(v, dict) else v for k, v in state_dict["state"].items()
171+
}
172+
173+
# Wrap quantization-specific tensors in a nested dict to hide from FSDP
174+
for param_state in state_dict["state"].values():
175+
if isinstance(param_state, dict):
176+
quant_state = {}
177+
keys_to_wrap = [k for k in param_state if k in self.non_castable_tensor_keys]
178+
for key in keys_to_wrap:
179+
quant_state[key] = param_state.pop(key)
180+
if quant_state:
181+
param_state[self._FSDP_WRAPPED_QUANT_STATE_KEY] = quant_state
182+
183+
return state_dict
184+
155185
def __setstate__(self, state):
156186
super().__setstate__(state)
157187

@@ -166,6 +196,13 @@ def load_state_dict(self, state_dict, move_to_device=True):
166196
"""
167197
# deepcopy, to be consistent with module API
168198
state_dict = deepcopy(state_dict)
199+
200+
# Unwrap quantization states that were wrapped for FSDP compatibility
201+
for param_state in state_dict["state"].values():
202+
if isinstance(param_state, dict) and self._FSDP_WRAPPED_QUANT_STATE_KEY in param_state:
203+
quant_state = param_state.pop(self._FSDP_WRAPPED_QUANT_STATE_KEY)
204+
param_state.update(quant_state)
205+
169206
# Validate the state_dict
170207
groups = self.param_groups
171208
saved_groups = state_dict["param_groups"]

0 commit comments

Comments
 (0)