@@ -3637,7 +3637,8 @@ def lora_state_dict(
36373637 r"""
36383638 See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
36393639 """
3640- # Load the main state dict first which has the LoRA layers
3640+ # Load the main state dict first which has the LoRA layers for either of
3641+ # transformer and text encoder or both.
36413642 cache_dir = kwargs .pop ("cache_dir" , None )
36423643 force_download = kwargs .pop ("force_download" , False )
36433644 proxies = kwargs .pop ("proxies" , None )
@@ -3695,7 +3696,7 @@ def load_lora_weights(
36953696 raise ValueError ("PEFT backend is required for this method." )
36963697
36973698 low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
3698- if low_cpu_mem_usage and not is_peft_version (">= " , "0.13.1 " ):
3699+ if low_cpu_mem_usage and is_peft_version ("< " , "0.13.0 " ):
36993700 raise ValueError (
37003701 "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
37013702 )
@@ -3712,7 +3713,6 @@ def load_lora_weights(
37123713 if not is_correct_format :
37133714 raise ValueError ("Invalid LoRA checkpoint." )
37143715
3715- # Load LoRA into transformer
37163716 self .load_lora_into_transformer (
37173717 state_dict ,
37183718 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
@@ -3738,7 +3738,7 @@ def load_lora_into_transformer(
37383738 """
37393739 See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
37403740 """
3741- if low_cpu_mem_usage and not is_peft_version (">= " , "0.13.1 " ):
3741+ if low_cpu_mem_usage and is_peft_version ("< " , "0.13.0 " ):
37423742 raise ValueError (
37433743 "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
37443744 )
@@ -3765,7 +3765,7 @@ def save_lora_weights(
37653765 weight_name : str = None ,
37663766 save_function : Callable = None ,
37673767 safe_serialization : bool = True ,
3768- transformer_lora_adapter_metadata = None ,
3768+ transformer_lora_adapter_metadata : Optional [ dict ] = None ,
37693769 ):
37703770 r"""
37713771 See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
@@ -3778,7 +3778,7 @@ def save_lora_weights(
37783778 lora_metadata [cls .transformer_name ] = transformer_lora_adapter_metadata
37793779
37803780 if not lora_layers :
3781- raise ValueError ("You must pass at least one of `transformer_lora_layers`" )
3781+ raise ValueError ("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`. " )
37823782
37833783 cls ._save_lora_weights (
37843784 save_directory = save_directory ,
0 commit comments