-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[core] fix group offloading when using torchao #13276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
019a9de
8797398
1a959dc
9b9e2e1
d2666a9
6125a4f
7006773
a8cef07
8671923
7eaeb99
59c1b25
0650979
f60afe5
baddc28
cb7402e
011b294
344adce
027e8cc
5a73207
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| import safetensors.torch | ||
| import torch | ||
|
|
||
| from ..utils import get_logger, is_accelerate_available | ||
| from ..utils import get_logger, is_accelerate_available, is_torchao_available | ||
| from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS | ||
| from .hooks import HookRegistry, ModelHook | ||
|
|
||
|
|
@@ -35,6 +35,54 @@ | |
| logger = get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| def _is_torchao_tensor(tensor: torch.Tensor) -> bool: | ||
| if not is_torchao_available(): | ||
| return False | ||
| from torchao.utils import TorchAOBaseTensor | ||
|
|
||
| return isinstance(tensor, TorchAOBaseTensor) | ||
|
|
||
|
|
||
| def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]: | ||
| """Get names of all internal tensor data attributes from a TorchAO tensor.""" | ||
| cls = type(tensor) | ||
| names = list(getattr(cls, "tensor_data_names", [])) | ||
| for attr_name in getattr(cls, "optional_tensor_data_names", []): | ||
| if getattr(tensor, attr_name, None) is not None: | ||
| names.append(attr_name) | ||
| return names | ||
|
|
||
|
|
||
| def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: | ||
| """Move a TorchAO parameter to the device of `source` via `swap_tensors`. | ||
|
|
||
| `param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces | ||
| the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the | ||
| original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so | ||
| that any dict keyed by `id(param)` remains valid. | ||
|
|
||
| Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion. | ||
| """ | ||
| torch.utils.swap_tensors(param, source) | ||
|
|
||
|
|
||
| def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None: | ||
| """Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`. | ||
|
|
||
| Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not** | ||
| modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in | ||
| `cpu_param_dict`). | ||
| """ | ||
| for attr_name in _get_torchao_inner_tensor_names(source): | ||
| setattr(param, attr_name, getattr(source, attr_name)) | ||
|
|
||
|
|
||
| def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None: | ||
| """Record stream for all internal tensors of a TorchAO parameter.""" | ||
| for attr_name in _get_torchao_inner_tensor_names(param): | ||
| getattr(param, attr_name).record_stream(stream) | ||
|
|
||
|
|
||
| # fmt: off | ||
| _GROUP_OFFLOADING = "group_offloading" | ||
| _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" | ||
|
|
@@ -124,24 +172,29 @@ def __init__( | |
| else torch.cuda | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _to_cpu(tensor, low_cpu_mem_usage): | ||
| # For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes | ||
| # (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly. | ||
| t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() | ||
| return t if low_cpu_mem_usage else t.pin_memory() | ||
|
|
||
| def _init_cpu_param_dict(self): | ||
| cpu_param_dict = {} | ||
| if self.stream is None: | ||
| return cpu_param_dict | ||
|
|
||
| for module in self.modules: | ||
| for param in module.parameters(): | ||
| cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() | ||
| cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage) | ||
| for buffer in module.buffers(): | ||
| cpu_param_dict[buffer] = ( | ||
| buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() | ||
| ) | ||
| cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage) | ||
|
|
||
| for param in self.parameters: | ||
| cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() | ||
| cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage) | ||
|
|
||
| for buffer in self.buffers: | ||
| cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() | ||
| cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage) | ||
|
|
||
| return cpu_param_dict | ||
|
|
||
|
|
@@ -157,9 +210,16 @@ def _pinned_memory_tensors(self): | |
| pinned_dict = None | ||
|
|
||
| def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): | ||
| tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this mean the if you have a minimal repro, we might be able to fix I think
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel the proper way to do this is: parameter.data is not a recommended API. and
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But this does not seem like a fix? Your snippet mentions Also, how do I best implement it in the context of the error and the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this is the proper way to move device for a tensor subclass instance I think. please ignore comments, that was copied from your original example. this runs on my side.
basically we should not be using we have to go through all linear modules in the model, and use swap_tensor to change device: |
||
| moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) | ||
| if _is_torchao_tensor(tensor): | ||
| _swap_torchao_tensor(tensor, moved) | ||
| else: | ||
| tensor.data = moved | ||
| if self.record_stream: | ||
| tensor.data.record_stream(default_stream) | ||
| if _is_torchao_tensor(tensor): | ||
| _record_stream_torchao_tensor(tensor, default_stream) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could potentially implement the tensor.record_stream(default_stream) directly. also wondering if this would work if you just do this for nn.Parameter as well (parameter.record_stream(default_stream) instead of (parameter.data.record_stream(default_stream))?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggestion looks great. But I guess that will take some work on your end to ship. Maybe we can add a comment about it here and revisit when you land it?
Wouldn't mind refactoring it from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good to check in another PR |
||
| else: | ||
| tensor.data.record_stream(default_stream) | ||
|
|
||
| def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None): | ||
| for group_module in self.modules: | ||
|
|
@@ -178,7 +238,19 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) | |
| source = pinned_memory[buffer] if pinned_memory else buffer.data | ||
| self._transfer_tensor_to_device(buffer, source, default_stream) | ||
|
|
||
| def _check_disk_offload_torchao(self): | ||
| all_tensors = list(self.tensor_to_key.keys()) | ||
| has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) | ||
| if has_torchao: | ||
| raise ValueError( | ||
| "Disk offloading is not supported for TorchAO quantized tensors because safetensors " | ||
| "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " | ||
| "setting `offload_to_disk_path`." | ||
| ) | ||
|
|
||
| def _onload_from_disk(self): | ||
| self._check_disk_offload_torchao() | ||
|
|
||
| if self.stream is not None: | ||
| # Wait for previous Host->Device transfer to complete | ||
| self.stream.synchronize() | ||
|
|
@@ -221,6 +293,8 @@ def _onload_from_memory(self): | |
| self._process_tensors_from_modules(None) | ||
|
|
||
| def _offload_to_disk(self): | ||
| self._check_disk_offload_torchao() | ||
|
|
||
| # TODO: we can potentially optimize this code path by checking if the _all_ the desired | ||
| # safetensor files exist on the disk and if so, skip this step entirely, reducing IO | ||
| # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not | ||
|
|
@@ -245,18 +319,35 @@ def _offload_to_memory(self): | |
|
|
||
| for group_module in self.modules: | ||
| for param in group_module.parameters(): | ||
| param.data = self.cpu_param_dict[param] | ||
| if _is_torchao_tensor(param): | ||
| _restore_torchao_tensor(param, self.cpu_param_dict[param]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similarly for this one I'm wondering if it would make sense to implement some copy op in torchao tensor subclasses, also
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed that would be great! |
||
| else: | ||
| param.data = self.cpu_param_dict[param] | ||
| for param in self.parameters: | ||
| param.data = self.cpu_param_dict[param] | ||
| if _is_torchao_tensor(param): | ||
| _restore_torchao_tensor(param, self.cpu_param_dict[param]) | ||
| else: | ||
| param.data = self.cpu_param_dict[param] | ||
| for buffer in self.buffers: | ||
| buffer.data = self.cpu_param_dict[buffer] | ||
| if _is_torchao_tensor(buffer): | ||
| _restore_torchao_tensor(buffer, self.cpu_param_dict[buffer]) | ||
| else: | ||
| buffer.data = self.cpu_param_dict[buffer] | ||
| else: | ||
| for group_module in self.modules: | ||
| group_module.to(self.offload_device, non_blocking=False) | ||
| for param in self.parameters: | ||
| param.data = param.data.to(self.offload_device, non_blocking=False) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember hearing from Brian and Alban before that param.data is a private API and we should not rely on it, I think it also does not work with tensor subclasses |
||
| if _is_torchao_tensor(param): | ||
| moved = param.to(self.offload_device, non_blocking=False) | ||
| _swap_torchao_tensor(param, moved) | ||
| else: | ||
| param.data = param.data.to(self.offload_device, non_blocking=False) | ||
| for buffer in self.buffers: | ||
| buffer.data = buffer.data.to(self.offload_device, non_blocking=False) | ||
| if _is_torchao_tensor(buffer): | ||
| moved = buffer.to(self.offload_device, non_blocking=False) | ||
| _swap_torchao_tensor(buffer, moved) | ||
| else: | ||
| buffer.data = buffer.data.to(self.offload_device, non_blocking=False) | ||
|
|
||
| @torch.compiler.disable() | ||
| def onload_(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while most of the torchao tensor subclasses are developed on top of
TorchAOBaseTensor. it's not a requirement to use it. practically this should work for most of the use case but it's not 100% guaranteedI feel ideally / long term, we can refactor all uses of parameter.data to just operate on parameter itself (if it works)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love that but sadly that's not the case currently as we cannot always control implementation details from external dependencies.