@@ -552,9 +552,26 @@ def half(self) -> torch.Tensor:
552552 # pylint: disable=missing-function-docstring
553553 return self .dequantize (dtype = torch .float16 )
554554
555- def cpu (self , memory_format = torch .preserve_format ) -> torch .Tensor :
555+ def cpu (self , memory_format = torch .preserve_format ) -> QuantizedTensor :
556+ """Move tensor to CPU while preserving the QuantizedTensor type.
557+
558+ Routes through ``aten._to_copy.default`` so the subclass-preserving
559+ handler in ``__torch_dispatch__`` runs (rather than dequantizing).
560+
561+ """
556562 # pylint: disable=missing-function-docstring
557- return self .dequantize ().cpu (memory_format = memory_format )
563+ return self .to (device = torch .device ("cpu" ), memory_format = memory_format )
564+
565+ def untyped_storage (self ) -> torch .UntypedStorage :
566+ """Return an empty UntypedStorage on the tensor's device.
567+
568+ ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
569+ backing storage of its own; the actual bytes live in the inner
570+ buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
571+ an implementation detail of the quantization scheme. Need to define
572+ this method to avoid DCP staging errors with FSDP2.
573+ """
574+ return torch .UntypedStorage (0 , device = self .device )
558575
559576 def expand_as (self , other : torch .Tensor ) -> torch .Tensor :
560577 # pylint: disable=missing-function-docstring
@@ -608,6 +625,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
608625 dst .copy_ (src )
609626 return None
610627
628+ # _to_copy op (used by .to(device=...), .cpu(), DCP staging).
629+ # Preserve the QuantizedTensor subclass and move all internal
630+ # buffers (data, scales, etc.) to the requested device.
631+ if func == torch .ops .aten ._to_copy .default :
632+ tensor = args [0 ]
633+ kw = dict (kwargs ) if kwargs else {}
634+ dtype = kw .get ("dtype" , None )
635+ if dtype is None or dtype == tensor .dtype :
636+ target_device = kw .get ("device" , tensor .device ) or tensor .device
637+ target_device = torch .device (target_device )
638+ pin_memory = bool (kw .get ("pin_memory" , False ))
639+ non_blocking = bool (kw .get ("non_blocking" , False ))
640+ new_metadata = {"device" : target_device }
641+ # Update tensor storage metadata
642+ for key , value in tensor .get_metadata ().items ():
643+ if isinstance (value , torch .Tensor ):
644+ value = value .to (device = target_device , non_blocking = non_blocking )
645+ if pin_memory and target_device .type == "cpu" :
646+ value = value .pin_memory ()
647+ new_metadata [key ] = value
648+ # Update torch Tensor metadata
649+ new_metadata .update (
650+ {
651+ "dtype" : tensor .dtype ,
652+ "shape" : tensor .shape ,
653+ "requires_grad" : tensor .requires_grad ,
654+ }
655+ )
656+ return type (tensor )(** new_metadata )
657+
611658 # View op
612659 if func == torch .ops .aten .view .default :
613660 raise NotImplementedError ("{cls.__name__} class does not support tensor views" )
@@ -748,14 +795,19 @@ def make_like(
748795 """Create new quantized tensor
749796
750797 By default, new tensor has the same attributes and underlying
751- data. This function is intended to create view of tensors.
752-
798+ data. This function is intended to create a view of ``tensor``,
753799 """
754800 shape = shape if shape is not None else tensor .shape
755801 dtype = dtype if dtype is not None else tensor .dtype
756802 kwargs = tensor .get_metadata ()
757803 kwargs ["fake_dtype" ] = dtype
758- return cls (shape = shape , dtype = dtype , requires_grad = requires_grad , ** kwargs )
804+ return cls (
805+ shape = shape ,
806+ dtype = dtype ,
807+ requires_grad = requires_grad ,
808+ device = tensor .device ,
809+ ** kwargs ,
810+ )
759811
760812 def to_dtype (self , dtype : torch .dtype ) -> QuantizedTensor :
761813 """Create `QuantizedTensor` with given nominal dtype
0 commit comments