Skip to content

Commit 8202d32

Browse files
Add Flux2 LoKR adapter support via peft and fix alpha handling
Support native LoKR loading for Flux2/Klein models through peft's LoKrConfig + inject_adapter_in_model, matching how LoRA is already handled. Also fix lora_down/lora_up + alpha keys being left unconsumed. Changes: - Fix alpha baking for lora_down/lora_up format in the Flux2 converter - Add LoKR conversion function mapping BFL keys to diffusers paths - Use fuse-first QKV strategy: fuse model projections so LoKR targets fused to_qkv/to_added_qkv directly (avoids lossy Kronecker splitting) - Detect and infer decompose_factor from checkpoint w1 shapes - Add LoKR detection branch in load_lora_adapter with LoKrConfig creation - Add re-fusion helper for loading LoRA onto already-fused models - Temporary debug logging with [LoKR DEBUG] prefix GitHub issue: huggingface#13261
1 parent b9761ce commit 8202d32

4 files changed

Lines changed: 362 additions & 55 deletions

File tree

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23312331
temp_state_dict[new_key] = v
23322332
original_state_dict = temp_state_dict
23332333

2334+
# Bake alpha/rank scaling into lora_A weights (matches Flux1 kohya alpha handling).
2335+
# Without this, .alpha keys remain unconsumed and trigger a validation error.
2336+
alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")]
2337+
for alpha_key in alpha_keys:
2338+
alpha = original_state_dict.pop(alpha_key).item()
2339+
module_path = alpha_key[: -len(".alpha")]
2340+
lora_a_key = f"{module_path}.lora_A.weight"
2341+
if lora_a_key in original_state_dict:
2342+
rank = original_state_dict[lora_a_key].shape[0]
2343+
scale = alpha / rank
2344+
original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale
2345+
23342346
num_double_layers = 0
23352347
num_single_layers = 0
23362348
for key in original_state_dict.keys():
@@ -2443,6 +2455,166 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24432455
return converted_state_dict
24442456

24452457

2458+
def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict):
2459+
"""Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
2460+
2461+
Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by
2462+
`fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's
2463+
QKV projections before injecting the adapter.
2464+
"""
2465+
converted_state_dict = {}
2466+
2467+
prefix = "diffusion_model."
2468+
original_state_dict = {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
2469+
2470+
# Log key patterns for debugging
2471+
suffixes = set()
2472+
for k in original_state_dict:
2473+
for s in ("lokr_w1_a", "lokr_w1_b", "lokr_w2_a", "lokr_w2_b", "lokr_w1", "lokr_w2", "lokr_t2", "alpha"):
2474+
if k.endswith(f".{s}"):
2475+
suffixes.add(s)
2476+
break
2477+
logger.warning(
2478+
f"[LoKR DEBUG] LoKR converter: {len(original_state_dict)} keys, weight types: {sorted(suffixes)}"
2479+
)
2480+
2481+
num_double_layers = 0
2482+
num_single_layers = 0
2483+
for key in original_state_dict:
2484+
if key.startswith("single_blocks."):
2485+
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
2486+
elif key.startswith("double_blocks."):
2487+
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
2488+
2489+
lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2")
2490+
2491+
def _remap_lokr_module(bfl_path, diff_path):
2492+
"""Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path."""
2493+
alpha_key = f"{bfl_path}.alpha"
2494+
alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None
2495+
2496+
for suffix in lokr_suffixes:
2497+
src_key = f"{bfl_path}.{suffix}"
2498+
if src_key not in original_state_dict:
2499+
continue
2500+
2501+
weight = original_state_dict.pop(src_key)
2502+
2503+
# Bake alpha/rank scaling into the first w1 tensor encountered for this module.
2504+
# After baking, peft's config uses alpha=r so its runtime scaling is 1.0.
2505+
if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"):
2506+
w2a_key = f"{bfl_path}.lokr_w2_a"
2507+
w1a_key = f"{bfl_path}.lokr_w1_a"
2508+
if w2a_key in original_state_dict:
2509+
r_eff = original_state_dict[w2a_key].shape[1]
2510+
elif w1a_key in original_state_dict:
2511+
r_eff = original_state_dict[w1a_key].shape[1]
2512+
else:
2513+
r_eff = alpha
2514+
scale = alpha / r_eff
2515+
logger.warning(
2516+
f"[LoKR DEBUG] Alpha bake: {bfl_path} alpha={alpha:.1f} r={r_eff} scale={scale:.4f}"
2517+
)
2518+
weight = weight * scale
2519+
alpha = None # only bake once per module
2520+
2521+
converted_state_dict[f"{diff_path}.{suffix}"] = weight
2522+
2523+
# --- Single blocks ---
2524+
for sl in range(num_single_layers):
2525+
_remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj")
2526+
_remap_lokr_module(f"single_blocks.{sl}.linear2", f"single_transformer_blocks.{sl}.attn.to_out")
2527+
2528+
# --- Double blocks ---
2529+
for dl in range(num_double_layers):
2530+
tb = f"transformer_blocks.{dl}"
2531+
db = f"double_blocks.{dl}"
2532+
2533+
# QKV → fused to_qkv / to_added_qkv (model must be fused before injection)
2534+
_remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv")
2535+
_remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv")
2536+
2537+
# Projections
2538+
_remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0")
2539+
_remap_lokr_module(f"{db}.txt_attn.proj", f"{tb}.attn.to_add_out")
2540+
2541+
# MLPs
2542+
_remap_lokr_module(f"{db}.img_mlp.0", f"{tb}.ff.linear_in")
2543+
_remap_lokr_module(f"{db}.img_mlp.2", f"{tb}.ff.linear_out")
2544+
_remap_lokr_module(f"{db}.txt_mlp.0", f"{tb}.ff_context.linear_in")
2545+
_remap_lokr_module(f"{db}.txt_mlp.2", f"{tb}.ff_context.linear_out")
2546+
2547+
# --- Extra mappings (embedders, modulation, final layer) ---
2548+
extra_mappings = {
2549+
"img_in": "x_embedder",
2550+
"txt_in": "context_embedder",
2551+
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
2552+
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
2553+
"final_layer.linear": "proj_out",
2554+
"final_layer.adaLN_modulation.1": "norm_out.linear",
2555+
"single_stream_modulation.lin": "single_stream_modulation.linear",
2556+
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
2557+
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
2558+
}
2559+
for bfl_key, diff_key in extra_mappings.items():
2560+
_remap_lokr_module(bfl_key, diff_key)
2561+
2562+
if len(original_state_dict) > 0:
2563+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
2564+
2565+
for key in list(converted_state_dict.keys()):
2566+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
2567+
2568+
return converted_state_dict
2569+
2570+
2571+
def _refuse_flux2_lora_state_dict(state_dict):
2572+
"""Re-fuse separate Q/K/V LoRA keys into fused to_qkv/to_added_qkv keys.
2573+
2574+
When the model's QKV projections are fused (e.g., from a prior LoKR load), incoming LoRA keys
2575+
that target separate to_q/to_k/to_v must be re-fused to match. This is the exact inverse of the
2576+
QKV split in ``_convert_non_diffusers_flux2_lora_to_diffusers``.
2577+
"""
2578+
converted = {}
2579+
remaining = dict(state_dict)
2580+
2581+
# Detect double block indices from keys
2582+
num_double_layers = 0
2583+
for key in remaining:
2584+
if ".transformer_blocks." in key:
2585+
parts = key.split(".")
2586+
idx = parts.index("transformer_blocks") + 1
2587+
if idx < len(parts) and parts[idx].isdigit():
2588+
num_double_layers = max(num_double_layers, int(parts[idx]) + 1)
2589+
2590+
# Fuse Q/K/V for image stream and text stream
2591+
qkv_groups = [
2592+
(["to_q", "to_k", "to_v"], "to_qkv"),
2593+
(["add_q_proj", "add_k_proj", "add_v_proj"], "to_added_qkv"),
2594+
]
2595+
2596+
for dl in range(num_double_layers):
2597+
attn_prefix = f"transformer.transformer_blocks.{dl}.attn"
2598+
for separate_keys, fused_name in qkv_groups:
2599+
for lora_key in ("lora_A", "lora_B"):
2600+
src_keys = [f"{attn_prefix}.{sk}.{lora_key}.weight" for sk in separate_keys]
2601+
if not all(k in remaining for k in src_keys):
2602+
continue
2603+
2604+
weights = [remaining.pop(k) for k in src_keys]
2605+
dst_key = f"{attn_prefix}.{fused_name}.{lora_key}.weight"
2606+
if lora_key == "lora_A":
2607+
# lora_A was replicated during split — all three are identical, take the first
2608+
converted[dst_key] = weights[0]
2609+
else:
2610+
# lora_B was chunked along dim=0 — concatenate back
2611+
converted[dst_key] = torch.cat(weights, dim=0)
2612+
2613+
# Pass through all non-QKV keys unchanged
2614+
converted.update(remaining)
2615+
return converted
2616+
2617+
24462618
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
24472619
"""
24482620
Convert non-diffusers ZImage LoRA state dict to diffusers format.

src/diffusers/loaders/lora_pipeline.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_convert_hunyuan_video_lora_to_diffusers,
4646
_convert_kohya_flux_lora_to_diffusers,
4747
_convert_musubi_wan_lora_to_diffusers,
48+
_convert_non_diffusers_flux2_lokr_to_diffusers,
4849
_convert_non_diffusers_flux2_lora_to_diffusers,
4950
_convert_non_diffusers_hidream_lora_to_diffusers,
5051
_convert_non_diffusers_lora_to_diffusers,
@@ -56,6 +57,7 @@
5657
_convert_non_diffusers_z_image_lora_to_diffusers,
5758
_convert_xlabs_flux_lora_to_diffusers,
5859
_maybe_map_sgm_blocks_to_diffusers,
60+
_refuse_flux2_lora_state_dict,
5961
)
6062

6163

@@ -5679,12 +5681,26 @@ def lora_state_dict(
56795681

56805682
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
56815683
if is_ai_toolkit:
5682-
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
5684+
is_lokr = any("lokr_" in k for k in state_dict)
5685+
if is_lokr:
5686+
logger.warning(f"[LoKR DEBUG] Detected LoKR format ({len(state_dict)} keys), converting to diffusers")
5687+
state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict)
5688+
logger.warning(f"[LoKR DEBUG] Converted to {len(state_dict)} diffusers keys")
5689+
if metadata is None:
5690+
metadata = {}
5691+
metadata["is_lokr"] = "true"
5692+
else:
5693+
has_alpha = any(k.endswith(".alpha") for k in state_dict)
5694+
has_down_up = any("lora_down" in k or "lora_up" in k for k in state_dict)
5695+
logger.warning(
5696+
f"[LoKR DEBUG] Detected LoRA format ({len(state_dict)} keys, "
5697+
f"has_alpha={has_alpha}, has_lora_down_up={has_down_up})"
5698+
)
5699+
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
56835700

56845701
out = (state_dict, metadata) if return_lora_metadata else state_dict
56855702
return out
56865703

5687-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
56885704
def load_lora_weights(
56895705
self,
56905706
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
@@ -5712,13 +5728,29 @@ def load_lora_weights(
57125728
kwargs["return_lora_metadata"] = True
57135729
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
57145730

5715-
is_correct_format = all("lora" in key for key in state_dict.keys())
5731+
is_correct_format = all("lora" in key or "lokr" in key for key in state_dict.keys())
57165732
if not is_correct_format:
5717-
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
5733+
raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.")
5734+
5735+
# For LoKR adapters, fuse QKV projections so peft can target the fused modules directly.
5736+
is_lokr = metadata is not None and metadata.get("is_lokr") == "true"
5737+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
5738+
if is_lokr:
5739+
logger.warning("[LoKR DEBUG] Fusing QKV projections for LoKR adapter")
5740+
transformer.fuse_qkv_projections()
5741+
elif (
5742+
hasattr(transformer, "transformer_blocks")
5743+
and len(transformer.transformer_blocks) > 0
5744+
and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False)
5745+
):
5746+
# Model is fused (e.g., from a prior LoKR load) but this is a LoRA adapter
5747+
# targeting separate Q/K/V. Re-fuse the LoRA keys to match the fused modules.
5748+
logger.warning("[LoKR DEBUG] Model is fused, re-fusing LoRA keys to match fused QKV modules")
5749+
state_dict = _refuse_flux2_lora_state_dict(state_dict)
57185750

57195751
self.load_lora_into_transformer(
57205752
state_dict,
5721-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5753+
transformer=transformer,
57225754
adapter_name=adapter_name,
57235755
metadata=metadata,
57245756
_pipeline=self,

src/diffusers/loaders/peft.py

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
set_adapter_layers,
3838
set_weights_and_activate_adapters,
3939
)
40-
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
40+
from ..utils.peft_utils import _create_lokr_config, _create_lora_config, _maybe_warn_for_unhandled_keys
4141
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
4242
from .unet_loader_utils import _maybe_expand_lora_scales
4343

@@ -232,56 +232,73 @@ def load_lora_adapter(
232232
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
233233
)
234234

235-
# check with first key if is not in peft format
236-
first_key = next(iter(state_dict.keys()))
237-
if "lora_A" not in first_key:
238-
state_dict = convert_unet_state_dict_to_peft(state_dict)
239-
240-
# Control LoRA from SAI is different from BFL Control LoRA
241-
# https://huggingface.co/stabilityai/control-lora
242-
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
243-
is_sai_sd_control_lora = "lora_controlnet" in state_dict
244-
if is_sai_sd_control_lora:
245-
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
246-
247-
rank = {}
248-
for key, val in state_dict.items():
249-
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
250-
# Bias layers in LoRA only have a single dimension
251-
if "lora_B" in key and val.ndim > 1:
252-
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
253-
# We may run into some ambiguous configuration values when a model has module
254-
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
255-
# for example) and they have different LoRA ranks.
256-
rank[f"^{key}"] = val.shape[1]
257-
258-
if network_alphas is not None and len(network_alphas) >= 1:
259-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
260-
network_alphas = {
261-
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
262-
}
263-
264-
# adapter_name
265-
if adapter_name is None:
266-
adapter_name = get_adapter_name(self)
267-
268-
# create LoraConfig
269-
lora_config = _create_lora_config(
270-
state_dict,
271-
network_alphas,
272-
metadata,
273-
rank,
274-
model_state_dict=self.state_dict(),
275-
adapter_name=adapter_name,
276-
)
235+
# Detect whether this is a LoKR adapter (Kronecker product, not low-rank)
236+
is_lokr = any("lokr_" in k for k in state_dict)
237+
238+
if is_lokr:
239+
logger.warning(f"[LoKR DEBUG] load_lora_adapter: detected LoKR state dict ({len(state_dict)} keys)")
240+
if adapter_name is None:
241+
adapter_name = get_adapter_name(self)
242+
lora_config = _create_lokr_config(state_dict)
243+
logger.warning(
244+
f"[LoKR DEBUG] LoKrConfig: r={lora_config.r}, alpha={lora_config.alpha}, "
245+
f"decompose_both={lora_config.decompose_both}, "
246+
f"decompose_factor={lora_config.decompose_factor}, "
247+
f"targets={len(lora_config.target_modules)} modules, "
248+
f"rank_pattern={dict(lora_config.rank_pattern) if lora_config.rank_pattern else '{}'}"
249+
)
250+
is_sai_sd_control_lora = False
251+
else:
252+
# check with first key if is not in peft format
253+
first_key = next(iter(state_dict.keys()))
254+
if "lora_A" not in first_key:
255+
state_dict = convert_unet_state_dict_to_peft(state_dict)
256+
257+
# Control LoRA from SAI is different from BFL Control LoRA
258+
# https://huggingface.co/stabilityai/control-lora
259+
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
260+
is_sai_sd_control_lora = "lora_controlnet" in state_dict
261+
if is_sai_sd_control_lora:
262+
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
263+
264+
rank = {}
265+
for key, val in state_dict.items():
266+
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
267+
# Bias layers in LoRA only have a single dimension
268+
if "lora_B" in key and val.ndim > 1:
269+
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
270+
# We may run into some ambiguous configuration values when a model has module
271+
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
272+
# for example) and they have different LoRA ranks.
273+
rank[f"^{key}"] = val.shape[1]
274+
275+
if network_alphas is not None and len(network_alphas) >= 1:
276+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
277+
network_alphas = {
278+
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
279+
}
280+
281+
# adapter_name
282+
if adapter_name is None:
283+
adapter_name = get_adapter_name(self)
284+
285+
# create LoraConfig
286+
lora_config = _create_lora_config(
287+
state_dict,
288+
network_alphas,
289+
metadata,
290+
rank,
291+
model_state_dict=self.state_dict(),
292+
adapter_name=adapter_name,
293+
)
277294

278-
# Adjust LoRA config for Control LoRA
279-
if is_sai_sd_control_lora:
280-
lora_config.lora_alpha = lora_config.r
281-
lora_config.alpha_pattern = lora_config.rank_pattern
282-
lora_config.bias = "all"
283-
lora_config.modules_to_save = lora_config.exclude_modules
284-
lora_config.exclude_modules = None
295+
# Adjust LoRA config for Control LoRA
296+
if is_sai_sd_control_lora:
297+
lora_config.lora_alpha = lora_config.r
298+
lora_config.alpha_pattern = lora_config.rank_pattern
299+
lora_config.bias = "all"
300+
lora_config.modules_to_save = lora_config.exclude_modules
301+
lora_config.exclude_modules = None
285302

286303
# <Unsafe code
287304
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype

0 commit comments

Comments
 (0)