Skip to content

Commit 7dad173

Browse files
committed
error early in auto_cpu_offload
1 parent 8f80dda commit 7dad173

1 file changed

Lines changed: 19 additions & 7 deletions

File tree

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ def __call__(self, hooks, model_id, model, execution_device):
160160
if len(hooks) == 0:
161161
return []
162162

163-
current_module_size = model.get_memory_footprint()
163+
try:
164+
current_module_size = model.get_memory_footprint()
165+
except AttributeError:
166+
raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.")
164167

165168
device_type = execution_device.type
166169
device_module = getattr(torch, device_type, torch.cuda)
@@ -703,19 +706,28 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
703706
if not is_accelerate_available():
704707
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
705708

706-
# TODO: add a warning if mem_get_info isn't available on `device`.
709+
if device is None:
710+
device = get_device()
711+
if not isinstance(device, torch.device):
712+
device = torch.device(device)
713+
714+
device_type = device.type
715+
device_module = getattr(torch, device_type, torch.cuda)
716+
if not hasattr(device_module, "mem_get_info"):
717+
raise NotImplementedError(
718+
f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}."
719+
)
720+
721+
if device.index is None:
722+
device = torch.device(f"{device.type}:{0}")
707723

708724
for name, component in self.components.items():
709725
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
710726
remove_hook_from_module(component, recurse=True)
711727

712728
self.disable_auto_cpu_offload()
713729
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
714-
if device is None:
715-
device = get_device()
716-
device = torch.device(device)
717-
if device.index is None:
718-
device = torch.device(f"{device.type}:{0}")
730+
719731
all_hooks = []
720732
for name, component in self.components.items():
721733
if isinstance(component, torch.nn.Module):

0 commit comments

Comments
 (0)