@@ -258,42 +258,85 @@ def __setstate__(self, state):
258258 self .bnb_quantized = state ["bnb_quantized" ]
259259 self .module = state ["module" ]
260260
261- # Map from state_dict key names (as produced by QuantState.as_dict) to
262- # the actual QuantState attribute/access path. FSDP's _get_fqns() resolves
263- # dotted FQN keys via getattr, so "weight.quant_map" becomes
264- # getattr(weight, "quant_map") — we must map that to quant_state.code.
265- _QUANT_STATE_ATTR_MAP = {
266- # Direct QuantState attributes
267- "absmax" : lambda qs : qs .absmax ,
268- "code" : lambda qs : qs .code ,
269- "blocksize" : lambda qs : qs .blocksize ,
270- "dtype" : lambda qs : qs .dtype ,
271- "shape" : lambda qs : qs .shape ,
272- "offset" : lambda qs : qs .offset ,
273- "state2" : lambda qs : qs .state2 ,
274- # as_dict serializes code → "quant_map"
275- "quant_map" : lambda qs : qs .code ,
276- "quant_type" : lambda qs : qs .quant_type ,
277- # as_dict serializes nested state2 attributes under "nested_*" keys
278- "nested_absmax" : lambda qs : qs .state2 .absmax ,
279- "nested_blocksize" : lambda qs : qs .state2 .blocksize ,
280- "nested_quant_map" : lambda qs : qs .state2 .code ,
281- "nested_dtype" : lambda qs : qs .state2 .dtype ,
282- "nested_offset" : lambda qs : qs .offset ,
283- }
284-
285- def __getattr__ (self , name ):
286- # Proxy known QuantState attributes so that PyTorch's FSDP state_dict
287- # machinery (which traverses FQN paths via getattr) can find them.
288- accessor = self ._QUANT_STATE_ATTR_MAP .get (name )
289- if accessor is not None :
290- quant_state = self .__dict__ .get ("quant_state" )
291- if quant_state is not None :
292- try :
293- return accessor (quant_state )
294- except AttributeError :
295- pass
296- raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute '{ name } '" )
261+ # Properties that proxy QuantState attributes for FSDP state_dict traversal.
262+ # FSDP's _get_fqns() resolves dotted FQN keys via getattr, e.g. "weight.absmax"
263+ # becomes getattr(weight, "absmax"). Using @property instead of __getattr__
264+ # avoids torch.compile graph breaks (see #1904), since Dynamo can trace
265+ # descriptor protocol access but not __getattr__ on Tensor subclasses.
266+ #
267+ # Note: attributes that collide with Params4bit instance attrs (blocksize,
268+ # quant_type) or Tensor attrs (dtype, shape) are intentionally omitted —
269+ # they are packed into the bitsandbytes__* blob and not traversed by FSDP.
270+
271+ @property
272+ def absmax (self ):
273+ qs = self .__dict__ .get ("quant_state" )
274+ if qs is not None :
275+ return qs .absmax
276+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'absmax'" )
277+
278+ @property
279+ def code (self ):
280+ qs = self .__dict__ .get ("quant_state" )
281+ if qs is not None :
282+ return qs .code
283+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'code'" )
284+
285+ @property
286+ def quant_map (self ):
287+ qs = self .__dict__ .get ("quant_state" )
288+ if qs is not None :
289+ return qs .code
290+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'quant_map'" )
291+
292+ @property
293+ def offset (self ):
294+ qs = self .__dict__ .get ("quant_state" )
295+ if qs is not None :
296+ return qs .offset
297+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'offset'" )
298+
299+ @property
300+ def state2 (self ):
301+ qs = self .__dict__ .get ("quant_state" )
302+ if qs is not None :
303+ return qs .state2
304+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'state2'" )
305+
306+ @property
307+ def nested_absmax (self ):
308+ qs = self .__dict__ .get ("quant_state" )
309+ if qs is not None and qs .state2 is not None :
310+ return qs .state2 .absmax
311+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'nested_absmax'" )
312+
313+ @property
314+ def nested_blocksize (self ):
315+ qs = self .__dict__ .get ("quant_state" )
316+ if qs is not None and qs .state2 is not None :
317+ return qs .state2 .blocksize
318+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'nested_blocksize'" )
319+
320+ @property
321+ def nested_quant_map (self ):
322+ qs = self .__dict__ .get ("quant_state" )
323+ if qs is not None and qs .state2 is not None :
324+ return qs .state2 .code
325+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'nested_quant_map'" )
326+
327+ @property
328+ def nested_dtype (self ):
329+ qs = self .__dict__ .get ("quant_state" )
330+ if qs is not None and qs .state2 is not None :
331+ return qs .state2 .dtype
332+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'nested_dtype'" )
333+
334+ @property
335+ def nested_offset (self ):
336+ qs = self .__dict__ .get ("quant_state" )
337+ if qs is not None :
338+ return qs .offset
339+ raise AttributeError (f"'{ type (self ).__name__ } ' object has no attribute 'nested_offset'" )
297340
298341 def __deepcopy__ (self , memo ):
299342 new_instance = type (self ).__new__ (type (self ))
0 commit comments