@@ -222,6 +222,21 @@ def __new__(
222222 if data is None :
223223 data = torch .empty (0 )
224224
225+ # Handle FakeTensor creation during dynamo tracing
226+ if torch ._dynamo .is_compiling () and not isinstance (data , cls ):
227+ if isinstance (data , torch ._subclasses .FakeTensor ):
228+ param = data .as_subclass (cls )
229+ param .requires_grad = requires_grad
230+ param .quant_state = quant_state
231+ param .blocksize = blocksize
232+ param .compress_statistics = compress_statistics
233+ param .quant_type = quant_type
234+ param .quant_storage = quant_storage
235+ param .module = module
236+ param .bnb_quantized = bnb_quantized
237+ return param
238+
239+ # Standard initialization for real tensors
225240 self = torch .Tensor ._make_subclass (cls , data , requires_grad )
226241 self .blocksize = blocksize
227242 self .compress_statistics = compress_statistics
@@ -324,26 +339,23 @@ def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: ...
324339 def to (self : T , tensor : Tensor , non_blocking : bool = ...) -> T : ...
325340
326341 def to (self , * args , ** kwargs ):
327- device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
342+ device , dtype , non_blocking , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
328343
329344 if device is not None and device .type == "cuda" and not self .bnb_quantized :
330345 return self ._quantize (device )
331346 else :
332- if self .quant_state is not None :
333- self .quant_state .to (device )
334-
335- new_param = Params4bit (
347+ return Params4bit (
336348 super ().to (device = device , dtype = dtype , non_blocking = non_blocking ),
337349 requires_grad = self .requires_grad ,
338- quant_state = self .quant_state ,
350+ quant_state = self .quant_state . to ( device ) if self . quant_state else None ,
339351 blocksize = self .blocksize ,
340352 compress_statistics = self .compress_statistics ,
341353 quant_type = self .quant_type ,
342354 quant_storage = self .quant_storage ,
355+ module = self .module ,
356+ bnb_quantized = self .bnb_quantized ,
343357 )
344358
345- return new_param
346-
347359 def __tensor_flatten__ (self ):
348360 """Return data tensor and non-tensor context"""
349361 ctx = {
@@ -361,6 +373,20 @@ def __tensor_flatten__(self):
361373 def __tensor_unflatten__ (inner_tensors , ctx , outer_size , outer_stride ):
362374 """Reconstruct Params4bit from components"""
363375 data = inner_tensors ["data" ]
376+
377+ # Special handling for FakeTensor reconstruction
378+ if isinstance (data , torch ._subclasses .FakeTensor ):
379+ param = data .as_subclass (Params4bit )
380+ param .blocksize = ctx ["blocksize" ]
381+ param .compress_statistics = ctx ["compress_statistics" ]
382+ param .quant_type = ctx ["quant_type" ]
383+ param .quant_state = ctx ["quant_state" ]
384+ param .quant_storage = ctx ["quant_storage" ]
385+ param .module = ctx ["module" ]
386+ param .bnb_quantized = ctx ["bnb_quantized" ]
387+ return param
388+
389+ # Standard reconstruction for real tensors
364390 return Params4bit (
365391 data ,
366392 requires_grad = data .requires_grad ,
@@ -373,6 +399,21 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
373399 bnb_quantized = ctx ["bnb_quantized" ],
374400 )
375401
402+ @classmethod
403+ def __torch_function__ (cls , func , types , args = (), kwargs = None ):
404+ # Type preservation through ops
405+ result = super ().__torch_function__ (func , types , args , kwargs or {})
406+ if isinstance (result , torch .Tensor ) and not isinstance (result , cls ):
407+ return result .as_subclass (cls )
408+ return result
409+
410+ @classmethod
411+ def __torch_dispatch__ (cls , func , types , args = (), kwargs = None ):
412+ # Delegate to FakeTensor implementation when needed
413+ if any (isinstance (x , torch ._subclasses .FakeTensor ) for x in args ):
414+ return torch ._C .DispatchKey .Fake (func (* args , ** (kwargs or {})))
415+ return super ().__torch_dispatch__ (func , types , args , kwargs )
416+
376417 def detach (self ):
377418 """Create new instance preserving quantization state"""
378419 return type (self )(
@@ -460,21 +501,36 @@ def __init__(
460501 bias (`bool`, defaults to `True`):
461502 Whether the linear class uses the bias term as well.
462503 """
463- super ().__init__ (input_features , output_features , bias , device )
504+ # Bypass nn.Linear's parameter initialization
505+ super (nn .Linear , self ).__init__ ()
506+ self .in_features = input_features
507+ self .out_features = output_features
508+
509+ # Manually register parameters
464510 self .weight = Params4bit (
465- self . weight . data ,
511+ torch . empty (( output_features , input_features ), dtype = quant_storage ) ,
466512 requires_grad = False ,
467513 compress_statistics = compress_statistics ,
468514 quant_type = quant_type ,
469515 quant_storage = quant_storage ,
470516 module = self ,
471517 )
472- # self.persistent_buffers = [] # TODO consider as way to save quant state
518+
519+ if bias :
520+ self .bias = nn .Parameter (torch .empty (output_features ))
521+ else :
522+ self .register_parameter ("bias" , None )
523+
524+ self .reset_parameters ()
473525 self .compute_dtype = compute_dtype
474526 self .compute_type_is_set = False
475527 self .quant_state = None
476528 self .quant_storage = quant_storage
477529
530+ def reset_parameters (self ):
531+ # Disable standard initialization
532+ pass
533+
478534 def set_compute_type (self , x ):
479535 if x .dtype in [torch .float32 , torch .bfloat16 ]:
480536 # the input is in a dtype that is safe to compute in, we switch
0 commit comments