Skip to content

Commit 3467efa

Browse files
linoytsabanclaudesayakpauldg845
authored
Add Ideogram4LoraLoaderMixin (LoRA loading for Ideogram4) (#13921)
* add Ideogram4LoraLoaderMixin * Support loading non-diffusers Ideogram4 LoRA checkpoints (#13919) support loading non-diffusers Ideogram4 LoRAs * add Ideogram4 LoRA loader tests * support call-time LoRA scaling via attention_kwargs in Ideogram4 * fix and un-skip Ideogram4 LoRA loader tests * document attention_kwargs in Ideogram4 forward and pipeline * style Ideogram4 attention_kwargs docstrings * fix Ideogram4 LoRA loader CI test failures - pipeline: run the text encoder on its parameters' current device, then move features to the execution device, so encode_prompt works under enable_model_cpu_offload. The pipeline calls the text encoder's submodules directly to tap intermediate layers, which bypasses accelerate's onload hook, so the weights stay on CPU while inputs are on the execution device. Fixes test_lora_loading_model_cpu_offload. - tests: override test_lora_fuse_nan to corrupt a weight under Ideogram4's `layers` tower (the base test probes transformer_blocks/blocks/etc.). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * address review nits on Ideogram4 LoRA loader - pipeline: clarify the te_device comment (per review) — explain the CpuOffload hook attaches to forward, why submodule calls bypass it, and that te_device is the offload device under enable_model_cpu_offload. - tests: drop the unused `import sys` and `sys.path.append(".")`. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 5e7540f commit 3467efa

7 files changed

Lines changed: 600 additions & 5 deletions

File tree

docs/source/en/api/loaders/lora.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
144144
## KandinskyLoraLoaderMixin
145145
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin
146146

147+
## Ideogram4LoraLoaderMixin
148+
149+
[[autodoc]] loaders.lora_pipeline.Ideogram4LoraLoaderMixin
150+
147151
## LoraBaseMixin
148152

149153
[[autodoc]] loaders.lora_base.LoraBaseMixin

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def text_encoder_attn_modules(text_encoder):
8686
"QwenImageLoraLoaderMixin",
8787
"ZImageLoraLoaderMixin",
8888
"Flux2LoraLoaderMixin",
89+
"Ideogram4LoraLoaderMixin",
8990
"ErnieImageLoraLoaderMixin",
9091
"CosmosLoraLoaderMixin",
9192
]
@@ -128,6 +129,7 @@ def text_encoder_attn_modules(text_encoder):
128129
HeliosLoraLoaderMixin,
129130
HiDreamImageLoraLoaderMixin,
130131
HunyuanVideoLoraLoaderMixin,
132+
Ideogram4LoraLoaderMixin,
131133
KandinskyLoraLoaderMixin,
132134
LoraLoaderMixin,
133135
LTX2LoraLoaderMixin,

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2883,3 +2883,88 @@ def get_alpha_scales(down_weight, alpha_key):
28832883

28842884
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
28852885
return converted_state_dict
2886+
2887+
2888+
def _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict):
2889+
"""
2890+
Convert non-diffusers Ideogram4 LoRA state dict to diffusers format.
2891+
2892+
Handles:
2893+
- `diffusion_model.` / `conditional_transformer.` prefix removal
2894+
- `lora_down`/`lora_up` (kohya) -> `lora_A`/`lora_B`, with `.alpha` folded into the weights
2895+
- fused `attention.qkv` -> split `to_q`/`to_k`/`to_v`; `attention.o` -> `to_out.0`
2896+
- `feed_forward.w1`/`w2`/`w3` and `adaln_modulation` map one-to-one
2897+
"""
2898+
for prefix in ("diffusion_model.", "conditional_transformer."):
2899+
if any(k.startswith(prefix) for k in state_dict):
2900+
state_dict = {k.removeprefix(prefix): v for k, v in state_dict.items()}
2901+
break
2902+
2903+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
2904+
down_suffix = ".lora_down.weight" if is_kohya else ".lora_A.weight"
2905+
up_suffix = ".lora_up.weight" if is_kohya else ".lora_B.weight"
2906+
2907+
def get_alpha_scales(down_weight, alpha_key):
2908+
rank = down_weight.shape[0]
2909+
alpha_tensor = state_dict.pop(alpha_key, None)
2910+
if alpha_tensor is None:
2911+
return 1.0, 1.0
2912+
# LoRA is scaled by `alpha / rank` in the forward pass; split the factor between down and up.
2913+
scale = alpha_tensor.item() / rank
2914+
scale_down, scale_up = scale, 1.0
2915+
while scale_down * 2 < scale_up:
2916+
scale_down *= 2
2917+
scale_up /= 2
2918+
return scale_down, scale_up
2919+
2920+
def pull(base):
2921+
"""Pop the scaled (lora_A, lora_B) pair for a module path, or return None if absent."""
2922+
down_key = base + down_suffix
2923+
if down_key not in state_dict:
2924+
return None
2925+
down = state_dict.pop(down_key)
2926+
up = state_dict.pop(base + up_suffix)
2927+
scale_down, scale_up = get_alpha_scales(down, base + ".alpha")
2928+
return down * scale_down, up * scale_up
2929+
2930+
num_layers = 0
2931+
for k in state_dict:
2932+
match = re.match(r"layers\.(\d+)\.", k)
2933+
if match:
2934+
num_layers = max(num_layers, int(match.group(1)) + 1)
2935+
2936+
converted_state_dict = {}
2937+
for i in range(num_layers):
2938+
layer_prefix = f"layers.{i}"
2939+
2940+
# Fused qkv -> split to_q / to_k / to_v (shared down/lora_A, chunk up/lora_B in thirds).
2941+
qkv = pull(f"{layer_prefix}.attention.qkv")
2942+
if qkv is not None:
2943+
down, up = qkv
2944+
up_q, up_k, up_v = torch.chunk(up, 3, dim=0)
2945+
for proj, up_proj in (("to_q", up_q), ("to_k", up_k), ("to_v", up_v)):
2946+
converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_A.weight"] = down.clone()
2947+
converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_B.weight"] = up_proj.contiguous()
2948+
2949+
# attention.o -> attention.to_out.0
2950+
out = pull(f"{layer_prefix}.attention.o")
2951+
if out is not None:
2952+
down, up = out
2953+
converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_A.weight"] = down
2954+
converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_B.weight"] = up
2955+
2956+
# feed_forward.{w1,w2,w3} and adaln_modulation map one-to-one.
2957+
for module in ("feed_forward.w1", "feed_forward.w2", "feed_forward.w3", "adaln_modulation"):
2958+
pair = pull(f"{layer_prefix}.{module}")
2959+
if pair is not None:
2960+
down, up = pair
2961+
converted_state_dict[f"{layer_prefix}.{module}.lora_A.weight"] = down
2962+
converted_state_dict[f"{layer_prefix}.{module}.lora_B.weight"] = up
2963+
2964+
if len(state_dict) > 0:
2965+
raise ValueError(
2966+
f"`state_dict` should be empty at this point but has {sorted(state_dict.keys())}. "
2967+
"This may be an unsupported Ideogram4 LoRA layout."
2968+
)
2969+
2970+
return {f"transformer.{k}": v for k, v in converted_state_dict.items()}

src/diffusers/loaders/lora_pipeline.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
_convert_non_diffusers_anima_lora_to_diffusers,
5050
_convert_non_diffusers_flux2_lora_to_diffusers,
5151
_convert_non_diffusers_hidream_lora_to_diffusers,
52+
_convert_non_diffusers_ideogram4_lora_to_diffusers,
5253
_convert_non_diffusers_lora_to_diffusers,
5354
_convert_non_diffusers_ltx2_lora_to_diffusers,
5455
_convert_non_diffusers_ltxv_lora_to_diffusers,
@@ -6018,6 +6019,213 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
60186019
super().unfuse_lora(components=components, **kwargs)
60196020

60206021

6022+
class Ideogram4LoraLoaderMixin(LoraBaseMixin):
6023+
r"""
6024+
Load LoRA layers into [`Ideogram4Transformer2DModel`]. Specific to [`Ideogram4Pipeline`].
6025+
"""
6026+
6027+
_lora_loadable_modules = ["transformer"]
6028+
transformer_name = TRANSFORMER_NAME
6029+
6030+
@classmethod
6031+
@validate_hf_hub_args
6032+
def lora_state_dict(
6033+
cls,
6034+
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
6035+
**kwargs,
6036+
):
6037+
r"""
6038+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
6039+
"""
6040+
# Load the main state dict first which has the LoRA layers for either of
6041+
# transformer and text encoder or both.
6042+
cache_dir = kwargs.pop("cache_dir", None)
6043+
force_download = kwargs.pop("force_download", False)
6044+
proxies = kwargs.pop("proxies", None)
6045+
local_files_only = kwargs.pop("local_files_only", None)
6046+
token = kwargs.pop("token", None)
6047+
revision = kwargs.pop("revision", None)
6048+
subfolder = kwargs.pop("subfolder", None)
6049+
weight_name = kwargs.pop("weight_name", None)
6050+
use_safetensors = kwargs.pop("use_safetensors", None)
6051+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
6052+
6053+
allow_pickle = False
6054+
if use_safetensors is None:
6055+
use_safetensors = True
6056+
allow_pickle = True
6057+
6058+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
6059+
6060+
state_dict, metadata = _fetch_state_dict(
6061+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
6062+
weight_name=weight_name,
6063+
use_safetensors=use_safetensors,
6064+
local_files_only=local_files_only,
6065+
cache_dir=cache_dir,
6066+
force_download=force_download,
6067+
proxies=proxies,
6068+
token=token,
6069+
revision=revision,
6070+
subfolder=subfolder,
6071+
user_agent=user_agent,
6072+
allow_pickle=allow_pickle,
6073+
)
6074+
6075+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
6076+
if is_dora_scale_present:
6077+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
6078+
logger.warning(warn_msg)
6079+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
6080+
6081+
# ai-toolkit (ostris) saves Ideogram4 LoRAs under a `diffusion_model.` prefix with a fused
6082+
# `attention.qkv` projection; convert those to the diffusers layout before loading.
6083+
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) or any(
6084+
".attention.qkv." in k for k in state_dict
6085+
)
6086+
if is_non_diffusers_format:
6087+
state_dict = _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict)
6088+
6089+
out = (state_dict, metadata) if return_lora_metadata else state_dict
6090+
return out
6091+
6092+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
6093+
def load_lora_weights(
6094+
self,
6095+
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
6096+
adapter_name: str | None = None,
6097+
hotswap: bool = False,
6098+
**kwargs,
6099+
):
6100+
"""
6101+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
6102+
"""
6103+
if not USE_PEFT_BACKEND:
6104+
raise ValueError("PEFT backend is required for this method.")
6105+
6106+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
6107+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
6108+
raise ValueError(
6109+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
6110+
)
6111+
6112+
# if a dict is passed, copy it instead of modifying it inplace
6113+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
6114+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
6115+
6116+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
6117+
kwargs["return_lora_metadata"] = True
6118+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
6119+
6120+
is_correct_format = all("lora" in key for key in state_dict.keys())
6121+
if not is_correct_format:
6122+
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
6123+
6124+
self.load_lora_into_transformer(
6125+
state_dict,
6126+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
6127+
adapter_name=adapter_name,
6128+
metadata=metadata,
6129+
_pipeline=self,
6130+
low_cpu_mem_usage=low_cpu_mem_usage,
6131+
hotswap=hotswap,
6132+
)
6133+
6134+
@classmethod
6135+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
6136+
def load_lora_into_transformer(
6137+
cls,
6138+
state_dict,
6139+
transformer,
6140+
adapter_name=None,
6141+
_pipeline=None,
6142+
low_cpu_mem_usage=False,
6143+
hotswap: bool = False,
6144+
metadata=None,
6145+
):
6146+
"""
6147+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
6148+
"""
6149+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
6150+
raise ValueError(
6151+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
6152+
)
6153+
6154+
# Load the layers corresponding to transformer.
6155+
logger.info(f"Loading {cls.transformer_name}.")
6156+
transformer.load_lora_adapter(
6157+
state_dict,
6158+
network_alphas=None,
6159+
adapter_name=adapter_name,
6160+
metadata=metadata,
6161+
_pipeline=_pipeline,
6162+
low_cpu_mem_usage=low_cpu_mem_usage,
6163+
hotswap=hotswap,
6164+
)
6165+
6166+
@classmethod
6167+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
6168+
def save_lora_weights(
6169+
cls,
6170+
save_directory: str | os.PathLike,
6171+
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
6172+
is_main_process: bool = True,
6173+
weight_name: str = None,
6174+
save_function: Callable = None,
6175+
safe_serialization: bool = True,
6176+
transformer_lora_adapter_metadata: dict | None = None,
6177+
):
6178+
r"""
6179+
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
6180+
"""
6181+
lora_layers = {}
6182+
lora_metadata = {}
6183+
6184+
if transformer_lora_layers:
6185+
lora_layers[cls.transformer_name] = transformer_lora_layers
6186+
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
6187+
6188+
if not lora_layers:
6189+
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
6190+
6191+
cls._save_lora_weights(
6192+
save_directory=save_directory,
6193+
lora_layers=lora_layers,
6194+
lora_metadata=lora_metadata,
6195+
is_main_process=is_main_process,
6196+
weight_name=weight_name,
6197+
save_function=save_function,
6198+
safe_serialization=safe_serialization,
6199+
)
6200+
6201+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
6202+
def fuse_lora(
6203+
self,
6204+
components: list[str] = ["transformer"],
6205+
lora_scale: float = 1.0,
6206+
safe_fusing: bool = False,
6207+
adapter_names: list[str] | None = None,
6208+
**kwargs,
6209+
):
6210+
r"""
6211+
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
6212+
"""
6213+
super().fuse_lora(
6214+
components=components,
6215+
lora_scale=lora_scale,
6216+
safe_fusing=safe_fusing,
6217+
adapter_names=adapter_names,
6218+
**kwargs,
6219+
)
6220+
6221+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
6222+
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
6223+
r"""
6224+
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
6225+
"""
6226+
super().unfuse_lora(components=components, **kwargs)
6227+
6228+
60216229
class ErnieImageLoraLoaderMixin(LoraBaseMixin):
60226230
r"""
60236231
Load LoRA layers into [`ErnieImageTransformer2DModel`]. Specific to [`ErnieImagePipeline`].

src/diffusers/models/transformers/transformer_ideogram4.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import logging
24+
from ...utils import apply_lora_scale, logging
2525
from ...utils.torch_utils import maybe_allow_in_graph
2626
from ..attention import AttentionMixin, AttentionModuleMixin
2727
from ..attention_dispatch import dispatch_attention_fn
@@ -365,6 +365,7 @@ def __init__(
365365
adaln_dim=adaln_dim,
366366
)
367367

368+
@apply_lora_scale("attention_kwargs")
368369
def forward(
369370
self,
370371
hidden_states: torch.Tensor,
@@ -373,6 +374,7 @@ def forward(
373374
position_ids: torch.Tensor,
374375
segment_ids: torch.Tensor,
375376
indicator: torch.Tensor,
377+
attention_kwargs: dict | None = None,
376378
return_dict: bool = True,
377379
) -> Transformer2DModelOutput | tuple[torch.Tensor]:
378380
r"""
@@ -391,6 +393,9 @@ def forward(
391393
Per-token sample id within a packed batch. Positions sharing a `segment_id` attend to each other.
392394
indicator (`torch.Tensor` of shape `(batch_size, sequence_length)`):
393395
Per-token role: `LLM_TOKEN_INDICATOR` (text) or `OUTPUT_IMAGE_INDICATOR` (image).
396+
attention_kwargs (`dict`, *optional*):
397+
A kwargs dictionary passed along to the attention processor. A `"scale"` entry scales the LoRA weights
398+
(when the PEFT backend is active).
394399
return_dict (`bool`, *optional*, defaults to `True`):
395400
Whether to return a [`~models.modeling_outputs.Transformer2DModelOutput`] instead of a plain tuple.
396401

0 commit comments

Comments
 (0)