diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 891ac28455af..49509cbf04b9 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -22,7 +22,7 @@ import safetensors.torch import torch -from ..utils import get_logger, is_accelerate_available +from ..utils import get_logger, is_accelerate_available, is_torchao_available from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -35,6 +35,54 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +def _is_torchao_tensor(tensor: torch.Tensor) -> bool: + if not is_torchao_available(): + return False + from torchao.utils import TorchAOBaseTensor + + return isinstance(tensor, TorchAOBaseTensor) + + +def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: + """Get names of all internal tensor data attributes from a TorchAO tensor.""" + cls = type(tensor) + names = list(getattr(cls, "tensor_data_names", [])) + for attr_name in getattr(cls, "optional_tensor_data_names", []): + if getattr(tensor, attr_name, None) is not None: + names.append(attr_name) + return names + + +def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: + """Move a TorchAO parameter to the device of `source` via `swap_tensors`. + + `param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces + the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the + original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so + that any dict keyed by `id(param)` remains valid. + + Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion. + """ + torch.utils.swap_tensors(param, source) + + +def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: + """Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`. + + Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not** + modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in + `cpu_param_dict`). + """ + for attr_name in _get_torchao_inner_tensor_names(source): + setattr(param, attr_name, getattr(source, attr_name)) + + +def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: + """Record stream for all internal tensors of a TorchAO parameter.""" + for attr_name in _get_torchao_inner_tensor_names(param): + getattr(param, attr_name).record_stream(stream) + + # fmt: off _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" @@ -124,6 +172,13 @@ def __init__( else torch.cuda ) + @staticmethod + def _to_cpu(tensor, low_cpu_mem_usage): + # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes + # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly. + t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() + return t if low_cpu_mem_usage else t.pin_memory() + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -131,17 +186,15 @@ def _init_cpu_param_dict(self): for module in self.modules: for param in module.parameters(): - cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage) for buffer in module.buffers(): - cpu_param_dict[buffer] = ( - buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() - ) + cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage) for param in self.parameters: - cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage) for buffer in self.buffers: - cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() + cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage) return cpu_param_dict @@ -157,9 +210,16 @@ def _pinned_memory_tensors(self): pinned_dict = None def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): - tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) + if _is_torchao_tensor(tensor): + _swap_torchao_tensor(tensor, moved) + else: + tensor.data = moved if self.record_stream: - tensor.data.record_stream(default_stream) + if _is_torchao_tensor(tensor): + _record_stream_torchao_tensor(tensor, default_stream) + else: + tensor.data.record_stream(default_stream) def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None): for group_module in self.modules: @@ -178,7 +238,19 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, default_stream) + def _check_disk_offload_torchao(self): + all_tensors = list(self.tensor_to_key.keys()) + has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) + if has_torchao: + raise ValueError( + "Disk offloading is not supported for TorchAO quantized tensors because safetensors " + "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " + "setting `offload_to_disk_path`." + ) + def _onload_from_disk(self): + self._check_disk_offload_torchao() + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -221,6 +293,8 @@ def _onload_from_memory(self): self._process_tensors_from_modules(None) def _offload_to_disk(self): + self._check_disk_offload_torchao() + # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not @@ -245,18 +319,35 @@ def _offload_to_memory(self): for group_module in self.modules: for param in group_module.parameters(): - param.data = self.cpu_param_dict[param] + if _is_torchao_tensor(param): + _restore_torchao_tensor(param, self.cpu_param_dict[param]) + else: + param.data = self.cpu_param_dict[param] for param in self.parameters: - param.data = self.cpu_param_dict[param] + if _is_torchao_tensor(param): + _restore_torchao_tensor(param, self.cpu_param_dict[param]) + else: + param.data = self.cpu_param_dict[param] for buffer in self.buffers: - buffer.data = self.cpu_param_dict[buffer] + if _is_torchao_tensor(buffer): + _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) + else: + buffer.data = self.cpu_param_dict[buffer] else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=False) for param in self.parameters: - param.data = param.data.to(self.offload_device, non_blocking=False) + if _is_torchao_tensor(param): + moved = param.to(self.offload_device, non_blocking=False) + _swap_torchao_tensor(param, moved) + else: + param.data = param.data.to(self.offload_device, non_blocking=False) for buffer in self.buffers: - buffer.data = buffer.data.to(self.offload_device, non_blocking=False) + if _is_torchao_tensor(buffer): + moved = buffer.to(self.offload_device, non_blocking=False) + _swap_torchao_tensor(buffer, moved) + else: + buffer.data = buffer.data.to(self.offload_device, non_blocking=False) @torch.compiler.disable() def onload_(self):