@@ -111,6 +111,8 @@ def register_module_override(self, module, param_name, config):
111111
112112
113113class Optimizer8bit (torch .optim .Optimizer ):
114+ _FSDP_WRAPPED_QUANT_STATE_KEY = "__bnb_optimizer_quant_state__"
115+
114116 def __init__ (self , params , defaults , optim_bits = 32 , is_paged = False ):
115117 """
116118 Base 8-bit optimizer class.
@@ -152,6 +154,34 @@ def fill_qmap(self):
152154 self .name2qmap ["dynamic" ] = F .create_dynamic_map (signed = True )
153155 self .name2qmap ["udynamic" ] = F .create_dynamic_map (signed = False )
154156
157+ def state_dict (self ):
158+ """Return optimizer state, wrapping quantization tensors for FSDP compatibility.
159+
160+ FSDP's full_optim_state_dict gathers all tensor states across ranks.
161+ Quantization states (state1, state2, absmax, etc.) have different shapes
162+ than model parameters, causing gather operations to fail. By wrapping
163+ these tensors in a nested dict, FSDP skips them during gathering.
164+ """
165+ state_dict = super ().state_dict ()
166+
167+ # Deep copy the state to avoid modifying the original optimizer state
168+ # PyTorch's state_dict() only does a shallow copy
169+ state_dict ["state" ] = {
170+ k : {kk : vv for kk , vv in v .items ()} if isinstance (v , dict ) else v for k , v in state_dict ["state" ].items ()
171+ }
172+
173+ # Wrap quantization-specific tensors in a nested dict to hide from FSDP
174+ for param_state in state_dict ["state" ].values ():
175+ if isinstance (param_state , dict ):
176+ quant_state = {}
177+ keys_to_wrap = [k for k in param_state if k in self .non_castable_tensor_keys ]
178+ for key in keys_to_wrap :
179+ quant_state [key ] = param_state .pop (key )
180+ if quant_state :
181+ param_state [self ._FSDP_WRAPPED_QUANT_STATE_KEY ] = quant_state
182+
183+ return state_dict
184+
155185 def __setstate__ (self , state ):
156186 super ().__setstate__ (state )
157187
@@ -166,6 +196,13 @@ def load_state_dict(self, state_dict, move_to_device=True):
166196 """
167197 # deepcopy, to be consistent with module API
168198 state_dict = deepcopy (state_dict )
199+
200+ # Unwrap quantization states that were wrapped for FSDP compatibility
201+ for param_state in state_dict ["state" ].values ():
202+ if isinstance (param_state , dict ) and self ._FSDP_WRAPPED_QUANT_STATE_KEY in param_state :
203+ quant_state = param_state .pop (self ._FSDP_WRAPPED_QUANT_STATE_KEY )
204+ param_state .update (quant_state )
205+
169206 # Validate the state_dict
170207 groups = self .param_groups
171208 saved_groups = state_dict ["param_groups" ]
0 commit comments