2525from huggingface_hub import model_info
2626from huggingface_hub .constants import HF_HUB_OFFLINE
2727
28+ from ..hooks .group_offloading import _is_group_offload_enabled , _maybe_remove_and_reapply_group_offloading
2829from ..models .modeling_utils import ModelMixin , load_state_dict
2930from ..utils import (
3031 USE_PEFT_BACKEND ,
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
391392 adapter_name = get_adapter_name (text_encoder )
392393
393394 # <Unsafe code
394- is_model_cpu_offload , is_sequential_cpu_offload = _func_optionally_disable_offloading (_pipeline )
395+ is_model_cpu_offload , is_sequential_cpu_offload , is_group_offload = _func_optionally_disable_offloading (
396+ _pipeline
397+ )
395398 # inject LoRA layers and load the state dict
396399 # in transformers we automatically check whether the adapter name is already in use or not
397400 text_encoder .load_adapter (
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
410413 _pipeline .enable_model_cpu_offload ()
411414 elif is_sequential_cpu_offload :
412415 _pipeline .enable_sequential_cpu_offload ()
416+ elif is_group_offload :
417+ for component in _pipeline .components .values ():
418+ if isinstance (component , torch .nn .Module ):
419+ _maybe_remove_and_reapply_group_offloading (component )
413420 # Unsafe code />
414421
415422 if prefix is not None and not state_dict :
@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
433440
434441 Returns:
435442 tuple:
436- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
443+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
437444 """
438445 is_model_cpu_offload = False
439446 is_sequential_cpu_offload = False
447+ is_group_offload = False
440448
441449 if _pipeline is not None and _pipeline .hf_device_map is None :
442450 for _ , component in _pipeline .components .items ():
443- if isinstance (component , nn .Module ) and hasattr (component , "_hf_hook" ):
444- if not is_model_cpu_offload :
445- is_model_cpu_offload = isinstance (component ._hf_hook , CpuOffload )
446- if not is_sequential_cpu_offload :
447- is_sequential_cpu_offload = (
448- isinstance (component ._hf_hook , AlignDevicesHook )
449- or hasattr (component ._hf_hook , "hooks" )
450- and isinstance (component ._hf_hook .hooks [0 ], AlignDevicesHook )
451- )
451+ if not isinstance (component , nn .Module ):
452+ continue
453+ is_group_offload = is_group_offload or _is_group_offload_enabled (component )
454+ if not hasattr (component , "_hf_hook" ):
455+ continue
456+ is_model_cpu_offload = is_model_cpu_offload or isinstance (component ._hf_hook , CpuOffload )
457+ is_sequential_cpu_offload = is_sequential_cpu_offload or (
458+ isinstance (component ._hf_hook , AlignDevicesHook )
459+ or hasattr (component ._hf_hook , "hooks" )
460+ and isinstance (component ._hf_hook .hooks [0 ], AlignDevicesHook )
461+ )
452462
453- logger .info (
454- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
455- )
456- if is_sequential_cpu_offload or is_model_cpu_offload :
457- remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
463+ if is_sequential_cpu_offload or is_model_cpu_offload :
464+ logger .info (
465+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
466+ )
467+ for _ , component in _pipeline .components .items ():
468+ if not isinstance (component , nn .Module ) or not hasattr (component , "_hf_hook" ):
469+ continue
470+ remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
458471
459- return (is_model_cpu_offload , is_sequential_cpu_offload )
472+ return (is_model_cpu_offload , is_sequential_cpu_offload , is_group_offload )
460473
461474
462475class LoraBaseMixin :
0 commit comments