Skip to content

Commit c41b7b5

Browse files
Add Flux2 LoKR adapter support with dual conversion paths
- Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV) - Generic lossy path: optional SVD conversion via peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints - Re-fuse LoRA keys when model QKV is fused from prior LoKR load
1 parent da6718f commit c41b7b5

File tree

4 files changed

+361
-58
lines changed

4 files changed

+361
-58
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23312331
temp_state_dict[new_key] = v
23322332
original_state_dict = temp_state_dict
23332333

2334+
# Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed.
2335+
# Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1.
2336+
alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")]
2337+
for alpha_key in alpha_keys:
2338+
alpha = original_state_dict.pop(alpha_key).item()
2339+
module_path = alpha_key[: -len(".alpha")]
2340+
lora_a_key = f"{module_path}.lora_A.weight"
2341+
if lora_a_key in original_state_dict:
2342+
rank = original_state_dict[lora_a_key].shape[0]
2343+
scale = alpha / rank
2344+
original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale
2345+
23342346
num_double_layers = 0
23352347
num_single_layers = 0
23362348
for key in original_state_dict.keys():
@@ -2443,6 +2455,152 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24432455
return converted_state_dict
24442456

24452457

2458+
def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict):
2459+
"""Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
2460+
2461+
Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by
2462+
`fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's
2463+
QKV projections before injecting the adapter.
2464+
"""
2465+
converted_state_dict = {}
2466+
2467+
prefix = "diffusion_model."
2468+
original_state_dict = {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
2469+
2470+
num_double_layers = 0
2471+
num_single_layers = 0
2472+
for key in original_state_dict:
2473+
if key.startswith("single_blocks."):
2474+
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
2475+
elif key.startswith("double_blocks."):
2476+
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
2477+
2478+
lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2")
2479+
2480+
def _remap_lokr_module(bfl_path, diff_path):
2481+
"""Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path."""
2482+
alpha_key = f"{bfl_path}.alpha"
2483+
alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None
2484+
2485+
for suffix in lokr_suffixes:
2486+
src_key = f"{bfl_path}.{suffix}"
2487+
if src_key not in original_state_dict:
2488+
continue
2489+
2490+
weight = original_state_dict.pop(src_key)
2491+
2492+
# Bake alpha/rank scaling into the first w1 tensor encountered for this module.
2493+
# After baking, peft's config uses alpha=r so its runtime scaling is 1.0.
2494+
if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"):
2495+
w2a_key = f"{bfl_path}.lokr_w2_a"
2496+
w1a_key = f"{bfl_path}.lokr_w1_a"
2497+
if w2a_key in original_state_dict:
2498+
r_eff = original_state_dict[w2a_key].shape[1]
2499+
elif w1a_key in original_state_dict:
2500+
r_eff = original_state_dict[w1a_key].shape[1]
2501+
else:
2502+
r_eff = alpha
2503+
scale = alpha / r_eff
2504+
weight = weight * scale
2505+
alpha = None # only bake once per module
2506+
2507+
converted_state_dict[f"{diff_path}.{suffix}"] = weight
2508+
2509+
# --- Single blocks ---
2510+
for sl in range(num_single_layers):
2511+
_remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj")
2512+
_remap_lokr_module(f"single_blocks.{sl}.linear2", f"single_transformer_blocks.{sl}.attn.to_out")
2513+
2514+
# --- Double blocks ---
2515+
for dl in range(num_double_layers):
2516+
tb = f"transformer_blocks.{dl}"
2517+
db = f"double_blocks.{dl}"
2518+
2519+
# QKV → fused to_qkv / to_added_qkv (model must be fused before injection)
2520+
_remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv")
2521+
_remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv")
2522+
2523+
# Projections
2524+
_remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0")
2525+
_remap_lokr_module(f"{db}.txt_attn.proj", f"{tb}.attn.to_add_out")
2526+
2527+
# MLPs
2528+
_remap_lokr_module(f"{db}.img_mlp.0", f"{tb}.ff.linear_in")
2529+
_remap_lokr_module(f"{db}.img_mlp.2", f"{tb}.ff.linear_out")
2530+
_remap_lokr_module(f"{db}.txt_mlp.0", f"{tb}.ff_context.linear_in")
2531+
_remap_lokr_module(f"{db}.txt_mlp.2", f"{tb}.ff_context.linear_out")
2532+
2533+
# --- Extra mappings (embedders, modulation, final layer) ---
2534+
extra_mappings = {
2535+
"img_in": "x_embedder",
2536+
"txt_in": "context_embedder",
2537+
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
2538+
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
2539+
"final_layer.linear": "proj_out",
2540+
"final_layer.adaLN_modulation.1": "norm_out.linear",
2541+
"single_stream_modulation.lin": "single_stream_modulation.linear",
2542+
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
2543+
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
2544+
}
2545+
for bfl_key, diff_key in extra_mappings.items():
2546+
_remap_lokr_module(bfl_key, diff_key)
2547+
2548+
if len(original_state_dict) > 0:
2549+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
2550+
2551+
for key in list(converted_state_dict.keys()):
2552+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
2553+
2554+
return converted_state_dict
2555+
2556+
2557+
def _refuse_flux2_lora_state_dict(state_dict):
2558+
"""Re-fuse separate Q/K/V LoRA keys into fused to_qkv/to_added_qkv keys.
2559+
2560+
When the model's QKV projections are fused, incoming LoRA keys targeting separate
2561+
to_q/to_k/to_v must be re-fused to match. Inverse of the QKV split performed by
2562+
``_convert_non_diffusers_flux2_lora_to_diffusers``.
2563+
"""
2564+
converted = {}
2565+
remaining = dict(state_dict)
2566+
2567+
# Detect double block indices from keys
2568+
num_double_layers = 0
2569+
for key in remaining:
2570+
if ".transformer_blocks." in key:
2571+
parts = key.split(".")
2572+
idx = parts.index("transformer_blocks") + 1
2573+
if idx < len(parts) and parts[idx].isdigit():
2574+
num_double_layers = max(num_double_layers, int(parts[idx]) + 1)
2575+
2576+
# Fuse Q/K/V for image stream and text stream
2577+
qkv_groups = [
2578+
(["to_q", "to_k", "to_v"], "to_qkv"),
2579+
(["add_q_proj", "add_k_proj", "add_v_proj"], "to_added_qkv"),
2580+
]
2581+
2582+
for dl in range(num_double_layers):
2583+
attn_prefix = f"transformer.transformer_blocks.{dl}.attn"
2584+
for separate_keys, fused_name in qkv_groups:
2585+
for lora_key in ("lora_A", "lora_B"):
2586+
src_keys = [f"{attn_prefix}.{sk}.{lora_key}.weight" for sk in separate_keys]
2587+
if not all(k in remaining for k in src_keys):
2588+
continue
2589+
2590+
weights = [remaining.pop(k) for k in src_keys]
2591+
dst_key = f"{attn_prefix}.{fused_name}.{lora_key}.weight"
2592+
if lora_key == "lora_A":
2593+
# lora_A was replicated during split - all three are identical, take the first
2594+
converted[dst_key] = weights[0]
2595+
else:
2596+
# lora_B was chunked along dim=0 - concatenate back
2597+
converted[dst_key] = torch.cat(weights, dim=0)
2598+
2599+
# Pass through all non-QKV keys unchanged
2600+
converted.update(remaining)
2601+
return converted
2602+
2603+
24462604
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
24472605
"""
24482606
Convert non-diffusers ZImage LoRA state dict to diffusers format.
@@ -2600,14 +2758,14 @@ def get_alpha_scales(down_weight, alpha_key):
26002758

26012759
base = k[: -len(lora_dot_down_key)]
26022760

2603-
# Skip combined "qkv" projection individual to.q/k/v keys are also present.
2761+
# Skip combined "qkv" projection - individual to.q/k/v keys are also present.
26042762
if base.endswith(".qkv"):
26052763
state_dict.pop(k)
26062764
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
26072765
state_dict.pop(base + ".alpha", None)
26082766
continue
26092767

2610-
# Skip bare "out.lora.*" "to_out.0.lora.*" covers the same projection.
2768+
# Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection.
26112769
if re.search(r"\.out$", base) and ".to_out" not in base:
26122770
state_dict.pop(k)
26132771
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_convert_hunyuan_video_lora_to_diffusers,
4646
_convert_kohya_flux_lora_to_diffusers,
4747
_convert_musubi_wan_lora_to_diffusers,
48+
_convert_non_diffusers_flux2_lokr_to_diffusers,
4849
_convert_non_diffusers_flux2_lora_to_diffusers,
4950
_convert_non_diffusers_hidream_lora_to_diffusers,
5051
_convert_non_diffusers_lora_to_diffusers,
@@ -56,6 +57,7 @@
5657
_convert_non_diffusers_z_image_lora_to_diffusers,
5758
_convert_xlabs_flux_lora_to_diffusers,
5859
_maybe_map_sgm_blocks_to_diffusers,
60+
_refuse_flux2_lora_state_dict,
5961
)
6062

6163

@@ -5679,12 +5681,18 @@ def lora_state_dict(
56795681

56805682
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
56815683
if is_ai_toolkit:
5682-
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
5684+
is_lokr = any("lokr_" in k for k in state_dict)
5685+
if is_lokr:
5686+
state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict)
5687+
if metadata is None:
5688+
metadata = {}
5689+
metadata["is_lokr"] = "true"
5690+
else:
5691+
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
56835692

56845693
out = (state_dict, metadata) if return_lora_metadata else state_dict
56855694
return out
56865695

5687-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
56885696
def load_lora_weights(
56895697
self,
56905698
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
@@ -5712,13 +5720,26 @@ def load_lora_weights(
57125720
kwargs["return_lora_metadata"] = True
57135721
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
57145722

5715-
is_correct_format = all("lora" in key for key in state_dict.keys())
5723+
is_correct_format = all("lora" in key or "lokr" in key for key in state_dict.keys())
57165724
if not is_correct_format:
5717-
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
5725+
raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.")
5726+
5727+
# For LoKR adapters, fuse QKV projections so peft can target the fused modules directly.
5728+
is_lokr = metadata is not None and metadata.get("is_lokr") == "true"
5729+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
5730+
if is_lokr:
5731+
transformer.fuse_qkv_projections()
5732+
elif (
5733+
hasattr(transformer, "transformer_blocks")
5734+
and len(transformer.transformer_blocks) > 0
5735+
and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False)
5736+
):
5737+
# Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match.
5738+
state_dict = _refuse_flux2_lora_state_dict(state_dict)
57185739

57195740
self.load_lora_into_transformer(
57205741
state_dict,
5721-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5742+
transformer=transformer,
57225743
adapter_name=adapter_name,
57235744
metadata=metadata,
57245745
_pipeline=self,

src/diffusers/loaders/peft.py

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@
3838
set_adapter_layers,
3939
set_weights_and_activate_adapters,
4040
)
41-
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
41+
from ..utils.peft_utils import _create_lokr_config, _create_lora_config, _maybe_warn_for_unhandled_keys
4242
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
4343
from .unet_loader_utils import _maybe_expand_lora_scales
4444

4545

4646
logger = logging.get_logger(__name__)
4747

4848
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
49-
lambda: (lambda model_cls, weights: weights),
49+
lambda: lambda model_cls, weights: weights,
5050
{
5151
"UNet2DConditionModel": _maybe_expand_lora_scales,
5252
"UNetMotionModel": _maybe_expand_lora_scales,
@@ -213,56 +213,65 @@ def load_lora_adapter(
213213
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
214214
)
215215

216-
# check with first key if is not in peft format
217-
first_key = next(iter(state_dict.keys()))
218-
if "lora_A" not in first_key:
219-
state_dict = convert_unet_state_dict_to_peft(state_dict)
220-
221-
# Control LoRA from SAI is different from BFL Control LoRA
222-
# https://huggingface.co/stabilityai/control-lora
223-
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
224-
is_sai_sd_control_lora = "lora_controlnet" in state_dict
225-
if is_sai_sd_control_lora:
226-
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
227-
228-
rank = {}
229-
for key, val in state_dict.items():
230-
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
231-
# Bias layers in LoRA only have a single dimension
232-
if "lora_B" in key and val.ndim > 1:
233-
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
234-
# We may run into some ambiguous configuration values when a model has module
235-
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
236-
# for example) and they have different LoRA ranks.
237-
rank[f"^{key}"] = val.shape[1]
238-
239-
if network_alphas is not None and len(network_alphas) >= 1:
240-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
241-
network_alphas = {
242-
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
243-
}
244-
245-
# adapter_name
246-
if adapter_name is None:
247-
adapter_name = get_adapter_name(self)
248-
249-
# create LoraConfig
250-
lora_config = _create_lora_config(
251-
state_dict,
252-
network_alphas,
253-
metadata,
254-
rank,
255-
model_state_dict=self.state_dict(),
256-
adapter_name=adapter_name,
257-
)
216+
# Detect whether this is a LoKR adapter (Kronecker product, not low-rank)
217+
is_lokr = any("lokr_" in k for k in state_dict)
218+
219+
if is_lokr:
220+
if adapter_name is None:
221+
adapter_name = get_adapter_name(self)
222+
lora_config = _create_lokr_config(state_dict)
223+
is_sai_sd_control_lora = False
224+
else:
225+
# check with first key if is not in peft format
226+
first_key = next(iter(state_dict.keys()))
227+
if "lora_A" not in first_key:
228+
state_dict = convert_unet_state_dict_to_peft(state_dict)
229+
230+
# Control LoRA from SAI is different from BFL Control LoRA
231+
# https://huggingface.co/stabilityai/control-lora
232+
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
233+
is_sai_sd_control_lora = "lora_controlnet" in state_dict
234+
if is_sai_sd_control_lora:
235+
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
236+
237+
rank = {}
238+
for key, val in state_dict.items():
239+
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
240+
# Bias layers in LoRA only have a single dimension
241+
if "lora_B" in key and val.ndim > 1:
242+
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
243+
# We may run into some ambiguous configuration values when a model has module
244+
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
245+
# for example) and they have different LoRA ranks.
246+
rank[f"^{key}"] = val.shape[1]
247+
248+
if network_alphas is not None and len(network_alphas) >= 1:
249+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
250+
network_alphas = {
251+
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
252+
}
253+
254+
# adapter_name
255+
if adapter_name is None:
256+
adapter_name = get_adapter_name(self)
257+
258+
# create LoraConfig
259+
lora_config = _create_lora_config(
260+
state_dict,
261+
network_alphas,
262+
metadata,
263+
rank,
264+
model_state_dict=self.state_dict(),
265+
adapter_name=adapter_name,
266+
)
258267

259-
# Adjust LoRA config for Control LoRA
260-
if is_sai_sd_control_lora:
261-
lora_config.lora_alpha = lora_config.r
262-
lora_config.alpha_pattern = lora_config.rank_pattern
263-
lora_config.bias = "all"
264-
lora_config.modules_to_save = lora_config.exclude_modules
265-
lora_config.exclude_modules = None
268+
# Adjust LoRA config for Control LoRA
269+
if is_sai_sd_control_lora:
270+
lora_config.lora_alpha = lora_config.r
271+
lora_config.alpha_pattern = lora_config.rank_pattern
272+
lora_config.bias = "all"
273+
lora_config.modules_to_save = lora_config.exclude_modules
274+
lora_config.exclude_modules = None
266275

267276
# <Unsafe code
268277
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype

0 commit comments

Comments
 (0)