Skip to content

Commit ab6d634

Browse files
committed
style
1 parent da4242d commit ab6d634

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121

2222
import torch
2323

24+
from ..hooks import ModelHook
2425
from ..utils import (
2526
is_accelerate_available,
2627
logging,
2728
)
2829

29-
from ..hooks import ModelHook
30-
3130

3231
if is_accelerate_available():
3332
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
@@ -69,6 +68,7 @@ class CustomOffloadHook(ModelHook):
6968
The device on which the model should be executed. Will default to the MPS device if it's available, then
7069
GPU 0 if there is a GPU, and finally to the CPU.
7170
"""
71+
7272
no_grad = False
7373

7474
def __init__(
@@ -541,10 +541,9 @@ def matches_pattern(component_id, pattern, exact_match=False):
541541
raise ValueError(f"Invalid type for names: {type(names)}")
542542

543543
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
544-
545544
if not is_accelerate_available():
546545
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
547-
546+
548547
for name, component in self.components.items():
549548
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
550549
remove_hook_from_module(component, recurse=True)

0 commit comments

Comments
 (0)