|
49 | 49 | _convert_non_diffusers_anima_lora_to_diffusers, |
50 | 50 | _convert_non_diffusers_flux2_lora_to_diffusers, |
51 | 51 | _convert_non_diffusers_hidream_lora_to_diffusers, |
| 52 | + _convert_non_diffusers_ideogram4_lora_to_diffusers, |
52 | 53 | _convert_non_diffusers_lora_to_diffusers, |
53 | 54 | _convert_non_diffusers_ltx2_lora_to_diffusers, |
54 | 55 | _convert_non_diffusers_ltxv_lora_to_diffusers, |
@@ -6018,6 +6019,213 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): |
6018 | 6019 | super().unfuse_lora(components=components, **kwargs) |
6019 | 6020 |
|
6020 | 6021 |
|
| 6022 | +class Ideogram4LoraLoaderMixin(LoraBaseMixin): |
| 6023 | + r""" |
| 6024 | + Load LoRA layers into [`Ideogram4Transformer2DModel`]. Specific to [`Ideogram4Pipeline`]. |
| 6025 | + """ |
| 6026 | + |
| 6027 | + _lora_loadable_modules = ["transformer"] |
| 6028 | + transformer_name = TRANSFORMER_NAME |
| 6029 | + |
| 6030 | + @classmethod |
| 6031 | + @validate_hf_hub_args |
| 6032 | + def lora_state_dict( |
| 6033 | + cls, |
| 6034 | + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], |
| 6035 | + **kwargs, |
| 6036 | + ): |
| 6037 | + r""" |
| 6038 | + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. |
| 6039 | + """ |
| 6040 | + # Load the main state dict first which has the LoRA layers for either of |
| 6041 | + # transformer and text encoder or both. |
| 6042 | + cache_dir = kwargs.pop("cache_dir", None) |
| 6043 | + force_download = kwargs.pop("force_download", False) |
| 6044 | + proxies = kwargs.pop("proxies", None) |
| 6045 | + local_files_only = kwargs.pop("local_files_only", None) |
| 6046 | + token = kwargs.pop("token", None) |
| 6047 | + revision = kwargs.pop("revision", None) |
| 6048 | + subfolder = kwargs.pop("subfolder", None) |
| 6049 | + weight_name = kwargs.pop("weight_name", None) |
| 6050 | + use_safetensors = kwargs.pop("use_safetensors", None) |
| 6051 | + return_lora_metadata = kwargs.pop("return_lora_metadata", False) |
| 6052 | + |
| 6053 | + allow_pickle = False |
| 6054 | + if use_safetensors is None: |
| 6055 | + use_safetensors = True |
| 6056 | + allow_pickle = True |
| 6057 | + |
| 6058 | + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} |
| 6059 | + |
| 6060 | + state_dict, metadata = _fetch_state_dict( |
| 6061 | + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, |
| 6062 | + weight_name=weight_name, |
| 6063 | + use_safetensors=use_safetensors, |
| 6064 | + local_files_only=local_files_only, |
| 6065 | + cache_dir=cache_dir, |
| 6066 | + force_download=force_download, |
| 6067 | + proxies=proxies, |
| 6068 | + token=token, |
| 6069 | + revision=revision, |
| 6070 | + subfolder=subfolder, |
| 6071 | + user_agent=user_agent, |
| 6072 | + allow_pickle=allow_pickle, |
| 6073 | + ) |
| 6074 | + |
| 6075 | + is_dora_scale_present = any("dora_scale" in k for k in state_dict) |
| 6076 | + if is_dora_scale_present: |
| 6077 | + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." |
| 6078 | + logger.warning(warn_msg) |
| 6079 | + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} |
| 6080 | + |
| 6081 | + # ai-toolkit (ostris) saves Ideogram4 LoRAs under a `diffusion_model.` prefix with a fused |
| 6082 | + # `attention.qkv` projection; convert those to the diffusers layout before loading. |
| 6083 | + is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) or any( |
| 6084 | + ".attention.qkv." in k for k in state_dict |
| 6085 | + ) |
| 6086 | + if is_non_diffusers_format: |
| 6087 | + state_dict = _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict) |
| 6088 | + |
| 6089 | + out = (state_dict, metadata) if return_lora_metadata else state_dict |
| 6090 | + return out |
| 6091 | + |
| 6092 | + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights |
| 6093 | + def load_lora_weights( |
| 6094 | + self, |
| 6095 | + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], |
| 6096 | + adapter_name: str | None = None, |
| 6097 | + hotswap: bool = False, |
| 6098 | + **kwargs, |
| 6099 | + ): |
| 6100 | + """ |
| 6101 | + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. |
| 6102 | + """ |
| 6103 | + if not USE_PEFT_BACKEND: |
| 6104 | + raise ValueError("PEFT backend is required for this method.") |
| 6105 | + |
| 6106 | + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) |
| 6107 | + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): |
| 6108 | + raise ValueError( |
| 6109 | + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." |
| 6110 | + ) |
| 6111 | + |
| 6112 | + # if a dict is passed, copy it instead of modifying it inplace |
| 6113 | + if isinstance(pretrained_model_name_or_path_or_dict, dict): |
| 6114 | + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() |
| 6115 | + |
| 6116 | + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. |
| 6117 | + kwargs["return_lora_metadata"] = True |
| 6118 | + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) |
| 6119 | + |
| 6120 | + is_correct_format = all("lora" in key for key in state_dict.keys()) |
| 6121 | + if not is_correct_format: |
| 6122 | + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") |
| 6123 | + |
| 6124 | + self.load_lora_into_transformer( |
| 6125 | + state_dict, |
| 6126 | + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, |
| 6127 | + adapter_name=adapter_name, |
| 6128 | + metadata=metadata, |
| 6129 | + _pipeline=self, |
| 6130 | + low_cpu_mem_usage=low_cpu_mem_usage, |
| 6131 | + hotswap=hotswap, |
| 6132 | + ) |
| 6133 | + |
| 6134 | + @classmethod |
| 6135 | + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel |
| 6136 | + def load_lora_into_transformer( |
| 6137 | + cls, |
| 6138 | + state_dict, |
| 6139 | + transformer, |
| 6140 | + adapter_name=None, |
| 6141 | + _pipeline=None, |
| 6142 | + low_cpu_mem_usage=False, |
| 6143 | + hotswap: bool = False, |
| 6144 | + metadata=None, |
| 6145 | + ): |
| 6146 | + """ |
| 6147 | + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. |
| 6148 | + """ |
| 6149 | + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): |
| 6150 | + raise ValueError( |
| 6151 | + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." |
| 6152 | + ) |
| 6153 | + |
| 6154 | + # Load the layers corresponding to transformer. |
| 6155 | + logger.info(f"Loading {cls.transformer_name}.") |
| 6156 | + transformer.load_lora_adapter( |
| 6157 | + state_dict, |
| 6158 | + network_alphas=None, |
| 6159 | + adapter_name=adapter_name, |
| 6160 | + metadata=metadata, |
| 6161 | + _pipeline=_pipeline, |
| 6162 | + low_cpu_mem_usage=low_cpu_mem_usage, |
| 6163 | + hotswap=hotswap, |
| 6164 | + ) |
| 6165 | + |
| 6166 | + @classmethod |
| 6167 | + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights |
| 6168 | + def save_lora_weights( |
| 6169 | + cls, |
| 6170 | + save_directory: str | os.PathLike, |
| 6171 | + transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None, |
| 6172 | + is_main_process: bool = True, |
| 6173 | + weight_name: str = None, |
| 6174 | + save_function: Callable = None, |
| 6175 | + safe_serialization: bool = True, |
| 6176 | + transformer_lora_adapter_metadata: dict | None = None, |
| 6177 | + ): |
| 6178 | + r""" |
| 6179 | + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. |
| 6180 | + """ |
| 6181 | + lora_layers = {} |
| 6182 | + lora_metadata = {} |
| 6183 | + |
| 6184 | + if transformer_lora_layers: |
| 6185 | + lora_layers[cls.transformer_name] = transformer_lora_layers |
| 6186 | + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata |
| 6187 | + |
| 6188 | + if not lora_layers: |
| 6189 | + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") |
| 6190 | + |
| 6191 | + cls._save_lora_weights( |
| 6192 | + save_directory=save_directory, |
| 6193 | + lora_layers=lora_layers, |
| 6194 | + lora_metadata=lora_metadata, |
| 6195 | + is_main_process=is_main_process, |
| 6196 | + weight_name=weight_name, |
| 6197 | + save_function=save_function, |
| 6198 | + safe_serialization=safe_serialization, |
| 6199 | + ) |
| 6200 | + |
| 6201 | + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora |
| 6202 | + def fuse_lora( |
| 6203 | + self, |
| 6204 | + components: list[str] = ["transformer"], |
| 6205 | + lora_scale: float = 1.0, |
| 6206 | + safe_fusing: bool = False, |
| 6207 | + adapter_names: list[str] | None = None, |
| 6208 | + **kwargs, |
| 6209 | + ): |
| 6210 | + r""" |
| 6211 | + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. |
| 6212 | + """ |
| 6213 | + super().fuse_lora( |
| 6214 | + components=components, |
| 6215 | + lora_scale=lora_scale, |
| 6216 | + safe_fusing=safe_fusing, |
| 6217 | + adapter_names=adapter_names, |
| 6218 | + **kwargs, |
| 6219 | + ) |
| 6220 | + |
| 6221 | + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora |
| 6222 | + def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): |
| 6223 | + r""" |
| 6224 | + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. |
| 6225 | + """ |
| 6226 | + super().unfuse_lora(components=components, **kwargs) |
| 6227 | + |
| 6228 | + |
6021 | 6229 | class ErnieImageLoraLoaderMixin(LoraBaseMixin): |
6022 | 6230 | r""" |
6023 | 6231 | Load LoRA layers into [`ErnieImageTransformer2DModel`]. Specific to [`ErnieImagePipeline`]. |
|
0 commit comments