Skip to content

Commit 05a1d3a

Browse files
committed
update
2 parents ac194af + 05e7a85 commit 05a1d3a

11 files changed

Lines changed: 419 additions & 237 deletions

File tree

src/diffusers/hooks/group_offloading.py

Lines changed: 140 additions & 153 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/lora_base.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from huggingface_hub import model_info
2626
from huggingface_hub.constants import HF_HUB_OFFLINE
2727

28+
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
2829
from ..models.modeling_utils import ModelMixin, load_state_dict
2930
from ..utils import (
3031
USE_PEFT_BACKEND,
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
391392
adapter_name = get_adapter_name(text_encoder)
392393

393394
# <Unsafe code
394-
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
395+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
396+
_pipeline
397+
)
395398
# inject LoRA layers and load the state dict
396399
# in transformers we automatically check whether the adapter name is already in use or not
397400
text_encoder.load_adapter(
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
410413
_pipeline.enable_model_cpu_offload()
411414
elif is_sequential_cpu_offload:
412415
_pipeline.enable_sequential_cpu_offload()
416+
elif is_group_offload:
417+
for component in _pipeline.components.values():
418+
if isinstance(component, torch.nn.Module):
419+
_maybe_remove_and_reapply_group_offloading(component)
413420
# Unsafe code />
414421

415422
if prefix is not None and not state_dict:
@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
433440
434441
Returns:
435442
tuple:
436-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
443+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
437444
"""
438445
is_model_cpu_offload = False
439446
is_sequential_cpu_offload = False
447+
is_group_offload = False
440448

441449
if _pipeline is not None and _pipeline.hf_device_map is None:
442450
for _, component in _pipeline.components.items():
443-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
444-
if not is_model_cpu_offload:
445-
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
446-
if not is_sequential_cpu_offload:
447-
is_sequential_cpu_offload = (
448-
isinstance(component._hf_hook, AlignDevicesHook)
449-
or hasattr(component._hf_hook, "hooks")
450-
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
451-
)
451+
if not isinstance(component, nn.Module):
452+
continue
453+
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
454+
if not hasattr(component, "_hf_hook"):
455+
continue
456+
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
457+
is_sequential_cpu_offload = is_sequential_cpu_offload or (
458+
isinstance(component._hf_hook, AlignDevicesHook)
459+
or hasattr(component._hf_hook, "hooks")
460+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
461+
)
452462

453-
logger.info(
454-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
455-
)
456-
if is_sequential_cpu_offload or is_model_cpu_offload:
457-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
463+
if is_sequential_cpu_offload or is_model_cpu_offload:
464+
logger.info(
465+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
466+
)
467+
for _, component in _pipeline.components.items():
468+
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
469+
continue
470+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
458471

459-
return (is_model_cpu_offload, is_sequential_cpu_offload)
472+
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
460473

461474

462475
class LoraBaseMixin:

src/diffusers/loaders/peft.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import safetensors
2323
import torch
2424

25+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
2526
from ..utils import (
2627
MIN_PEFT_VERSION,
2728
USE_PEFT_BACKEND,
@@ -263,7 +264,9 @@ def load_lora_adapter(
263264

264265
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
265266
# otherwise loading LoRA weights will lead to an error.
266-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
267+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
268+
_pipeline
269+
)
267270
peft_kwargs = {}
268271
if is_peft_version(">=", "0.13.1"):
269272
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -354,6 +357,10 @@ def map_state_dict_for_hotswap(sd):
354357
_pipeline.enable_model_cpu_offload()
355358
elif is_sequential_cpu_offload:
356359
_pipeline.enable_sequential_cpu_offload()
360+
elif is_group_offload:
361+
for component in _pipeline.components.values():
362+
if isinstance(component, torch.nn.Module):
363+
_maybe_remove_and_reapply_group_offloading(component)
357364
# Unsafe code />
358365

359366
if prefix is not None and not state_dict:
@@ -693,6 +700,10 @@ def unload_lora(self):
693700
recurse_remove_peft_layers(self)
694701
if hasattr(self, "peft_config"):
695702
del self.peft_config
703+
if hasattr(self, "_hf_peft_config_loaded"):
704+
self._hf_peft_config_loaded = None
705+
706+
_maybe_remove_and_reapply_group_offloading(self)
696707

697708
def disable_lora(self):
698709
"""

src/diffusers/loaders/unet.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.nn.functional as F
2323
from huggingface_hub.utils import validate_hf_hub_args
2424

25+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
2526
from ..models.embeddings import (
2627
ImageProjection,
2728
IPAdapterFaceIDImageProjection,
@@ -203,6 +204,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
203204
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
204205
is_model_cpu_offload = False
205206
is_sequential_cpu_offload = False
207+
is_group_offload = False
206208

207209
if is_lora:
208210
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -211,7 +213,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
211213
if is_custom_diffusion:
212214
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
213215
elif is_lora:
214-
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
216+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
215217
state_dict=state_dict,
216218
unet_identifier_key=self.unet_name,
217219
network_alphas=network_alphas,
@@ -230,7 +232,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
230232

231233
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
232234
if is_custom_diffusion and _pipeline is not None:
233-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
235+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
236+
_pipeline=_pipeline
237+
)
234238

235239
# only custom diffusion needs to set attn processors
236240
self.set_attn_processor(attn_processors)
@@ -241,6 +245,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
241245
_pipeline.enable_model_cpu_offload()
242246
elif is_sequential_cpu_offload:
243247
_pipeline.enable_sequential_cpu_offload()
248+
elif is_group_offload:
249+
for component in _pipeline.components.values():
250+
if isinstance(component, torch.nn.Module):
251+
_maybe_remove_and_reapply_group_offloading(component)
244252
# Unsafe code />
245253

246254
def _process_custom_diffusion(self, state_dict):
@@ -307,6 +315,7 @@ def _process_lora(
307315

308316
is_model_cpu_offload = False
309317
is_sequential_cpu_offload = False
318+
is_group_offload = False
310319
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
311320

312321
if len(state_dict_to_be_used) > 0:
@@ -356,7 +365,9 @@ def _process_lora(
356365

357366
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
358367
# otherwise loading LoRA weights will lead to an error
359-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
368+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
369+
_pipeline
370+
)
360371
peft_kwargs = {}
361372
if is_peft_version(">=", "0.13.1"):
362373
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -389,7 +400,7 @@ def _process_lora(
389400
if warn_msg:
390401
logger.warning(warn_msg)
391402

392-
return is_model_cpu_offload, is_sequential_cpu_offload
403+
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
393404

394405
@classmethod
395406
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading

tests/lora/test_lora_layers_cogvideox.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest
1717

1818
import torch
19+
from parameterized import parameterized
1920
from transformers import AutoTokenizer, T5EncoderModel
2021

2122
from diffusers import (
@@ -28,6 +29,7 @@
2829
from diffusers.utils.testing_utils import (
2930
floats_tensor,
3031
require_peft_backend,
32+
require_torch_accelerator,
3133
)
3234

3335

@@ -127,6 +129,13 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self):
127129
def test_lora_scale_kwargs_match_fusion(self):
128130
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
129131

132+
@parameterized.expand([("block_level", True), ("leaf_level", False)])
133+
@require_torch_accelerator
134+
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
135+
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
136+
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
137+
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
138+
130139
@unittest.skip("Not supported in CogVideoX.")
131140
def test_simple_inference_with_text_denoiser_block_scale(self):
132141
pass

tests/lora/test_lora_layers_cogview4.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,17 @@
1818

1919
import numpy as np
2020
import torch
21+
from parameterized import parameterized
2122
from transformers import AutoTokenizer, GlmModel
2223

2324
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
24-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
25+
from diffusers.utils.testing_utils import (
26+
floats_tensor,
27+
require_peft_backend,
28+
require_torch_accelerator,
29+
skip_mps,
30+
torch_device,
31+
)
2532

2633

2734
sys.path.append(".")
@@ -141,6 +148,13 @@ def test_simple_inference_save_pretrained(self):
141148
"Loading from saved checkpoints should give same results.",
142149
)
143150

151+
@parameterized.expand([("block_level", True), ("leaf_level", False)])
152+
@require_torch_accelerator
153+
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
154+
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
155+
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
156+
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
157+
144158
@unittest.skip("Not supported in CogView4.")
145159
def test_simple_inference_with_text_denoiser_block_scale(self):
146160
pass

0 commit comments

Comments
 (0)