Skip to content

Commit 8cb0b7b

Browse files
Add Flux2 LoKR adapter support with dual conversion paths
- Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV) - Generic lossy path: optional SVD conversion via peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints - Re-fuse LoRA keys when model QKV is fused from prior LoKR load
1 parent b757035 commit 8cb0b7b

File tree

4 files changed

+314
-58
lines changed

4 files changed

+314
-58
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23312331
temp_state_dict[new_key] = v
23322332
original_state_dict = temp_state_dict
23332333

2334+
# Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed.
2335+
# Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1.
2336+
alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")]
2337+
for alpha_key in alpha_keys:
2338+
alpha = original_state_dict.pop(alpha_key).item()
2339+
module_path = alpha_key[: -len(".alpha")]
2340+
lora_a_key = f"{module_path}.lora_A.weight"
2341+
if lora_a_key in original_state_dict:
2342+
rank = original_state_dict[lora_a_key].shape[0]
2343+
scale = alpha / rank
2344+
original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale
2345+
23342346
num_double_layers = 0
23352347
num_single_layers = 0
23362348
for key in original_state_dict.keys():
@@ -2628,6 +2640,105 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
26282640
return ait_sd
26292641

26302642

2643+
def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict):
2644+
"""Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
2645+
2646+
Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by
2647+
`fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's
2648+
QKV projections before injecting the adapter.
2649+
"""
2650+
converted_state_dict = {}
2651+
2652+
prefix = "diffusion_model."
2653+
original_state_dict = {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}
2654+
2655+
num_double_layers = 0
2656+
num_single_layers = 0
2657+
for key in original_state_dict:
2658+
if key.startswith("single_blocks."):
2659+
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
2660+
elif key.startswith("double_blocks."):
2661+
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
2662+
2663+
lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2")
2664+
2665+
def _remap_lokr_module(bfl_path, diff_path):
2666+
"""Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path."""
2667+
alpha_key = f"{bfl_path}.alpha"
2668+
alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None
2669+
2670+
for suffix in lokr_suffixes:
2671+
src_key = f"{bfl_path}.{suffix}"
2672+
if src_key not in original_state_dict:
2673+
continue
2674+
2675+
weight = original_state_dict.pop(src_key)
2676+
2677+
# Bake alpha/rank scaling into the first w1 tensor encountered for this module.
2678+
# After baking, peft's config uses alpha=r so its runtime scaling is 1.0.
2679+
if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"):
2680+
w2a_key = f"{bfl_path}.lokr_w2_a"
2681+
w1a_key = f"{bfl_path}.lokr_w1_a"
2682+
if w2a_key in original_state_dict:
2683+
r_eff = original_state_dict[w2a_key].shape[1]
2684+
elif w1a_key in original_state_dict:
2685+
r_eff = original_state_dict[w1a_key].shape[1]
2686+
else:
2687+
r_eff = alpha
2688+
scale = alpha / r_eff
2689+
weight = weight * scale
2690+
alpha = None # only bake once per module
2691+
2692+
converted_state_dict[f"{diff_path}.{suffix}"] = weight
2693+
2694+
# --- Single blocks ---
2695+
for sl in range(num_single_layers):
2696+
_remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj")
2697+
_remap_lokr_module(f"single_blocks.{sl}.linear2", f"single_transformer_blocks.{sl}.attn.to_out")
2698+
2699+
# --- Double blocks ---
2700+
for dl in range(num_double_layers):
2701+
tb = f"transformer_blocks.{dl}"
2702+
db = f"double_blocks.{dl}"
2703+
2704+
# QKV -> fused to_qkv / to_added_qkv (model must be fused before injection)
2705+
_remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv")
2706+
_remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv")
2707+
2708+
# Projections
2709+
_remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0")
2710+
_remap_lokr_module(f"{db}.txt_attn.proj", f"{tb}.attn.to_add_out")
2711+
2712+
# MLPs
2713+
_remap_lokr_module(f"{db}.img_mlp.0", f"{tb}.ff.linear_in")
2714+
_remap_lokr_module(f"{db}.img_mlp.2", f"{tb}.ff.linear_out")
2715+
_remap_lokr_module(f"{db}.txt_mlp.0", f"{tb}.ff_context.linear_in")
2716+
_remap_lokr_module(f"{db}.txt_mlp.2", f"{tb}.ff_context.linear_out")
2717+
2718+
# --- Extra mappings (embedders, modulation, final layer) ---
2719+
extra_mappings = {
2720+
"img_in": "x_embedder",
2721+
"txt_in": "context_embedder",
2722+
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
2723+
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
2724+
"final_layer.linear": "proj_out",
2725+
"final_layer.adaLN_modulation.1": "norm_out.linear",
2726+
"single_stream_modulation.lin": "single_stream_modulation.linear",
2727+
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
2728+
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
2729+
}
2730+
for bfl_key, diff_key in extra_mappings.items():
2731+
_remap_lokr_module(bfl_key, diff_key)
2732+
2733+
if len(original_state_dict) > 0:
2734+
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
2735+
2736+
for key in list(converted_state_dict.keys()):
2737+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
2738+
2739+
return converted_state_dict
2740+
2741+
26312742
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
26322743
"""
26332744
Convert non-diffusers ZImage LoRA state dict to diffusers format.
@@ -2785,14 +2896,14 @@ def get_alpha_scales(down_weight, alpha_key):
27852896

27862897
base = k[: -len(lora_dot_down_key)]
27872898

2788-
# Skip combined "qkv" projection individual to.q/k/v keys are also present.
2899+
# Skip combined "qkv" projection - individual to.q/k/v keys are also present.
27892900
if base.endswith(".qkv"):
27902901
state_dict.pop(k)
27912902
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
27922903
state_dict.pop(base + ".alpha", None)
27932904
continue
27942905

2795-
# Skip bare "out.lora.*" "to_out.0.lora.*" covers the same projection.
2906+
# Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection.
27962907
if re.search(r"\.out$", base) and ".to_out" not in base:
27972908
state_dict.pop(k)
27982909
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
_convert_kohya_flux2_lora_to_diffusers,
4747
_convert_kohya_flux_lora_to_diffusers,
4848
_convert_musubi_wan_lora_to_diffusers,
49+
_convert_non_diffusers_flux2_lokr_to_diffusers,
4950
_convert_non_diffusers_flux2_lora_to_diffusers,
5051
_convert_non_diffusers_hidream_lora_to_diffusers,
5152
_convert_non_diffusers_lora_to_diffusers,
@@ -57,6 +58,7 @@
5758
_convert_non_diffusers_z_image_lora_to_diffusers,
5859
_convert_xlabs_flux_lora_to_diffusers,
5960
_maybe_map_sgm_blocks_to_diffusers,
61+
_refuse_flux2_lora_state_dict,
6062
)
6163

6264

@@ -5687,12 +5689,18 @@ def lora_state_dict(
56875689

56885690
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
56895691
if is_ai_toolkit:
5690-
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
5692+
is_lokr = any("lokr_" in k for k in state_dict)
5693+
if is_lokr:
5694+
state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict)
5695+
if metadata is None:
5696+
metadata = {}
5697+
metadata["is_lokr"] = "true"
5698+
else:
5699+
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
56915700

56925701
out = (state_dict, metadata) if return_lora_metadata else state_dict
56935702
return out
56945703

5695-
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
56965704
def load_lora_weights(
56975705
self,
56985706
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
@@ -5720,13 +5728,26 @@ def load_lora_weights(
57205728
kwargs["return_lora_metadata"] = True
57215729
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
57225730

5723-
is_correct_format = all("lora" in key for key in state_dict.keys())
5731+
is_correct_format = all("lora" in key or "lokr" in key for key in state_dict.keys())
57245732
if not is_correct_format:
5725-
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
5733+
raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.")
5734+
5735+
# For LoKR adapters, fuse QKV projections so peft can target the fused modules directly.
5736+
is_lokr = metadata is not None and metadata.get("is_lokr") == "true"
5737+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
5738+
if is_lokr:
5739+
transformer.fuse_qkv_projections()
5740+
elif (
5741+
hasattr(transformer, "transformer_blocks")
5742+
and len(transformer.transformer_blocks) > 0
5743+
and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False)
5744+
):
5745+
# Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match.
5746+
state_dict = _refuse_flux2_lora_state_dict(state_dict)
57265747

57275748
self.load_lora_into_transformer(
57285749
state_dict,
5729-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5750+
transformer=transformer,
57305751
adapter_name=adapter_name,
57315752
metadata=metadata,
57325753
_pipeline=self,

src/diffusers/loaders/peft.py

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@
3838
set_adapter_layers,
3939
set_weights_and_activate_adapters,
4040
)
41-
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
41+
from ..utils.peft_utils import _create_lokr_config, _create_lora_config, _maybe_warn_for_unhandled_keys
4242
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
4343
from .unet_loader_utils import _maybe_expand_lora_scales
4444

4545

4646
logger = logging.get_logger(__name__)
4747

4848
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
49-
lambda: (lambda model_cls, weights: weights),
49+
lambda: lambda model_cls, weights: weights,
5050
{
5151
"UNet2DConditionModel": _maybe_expand_lora_scales,
5252
"UNetMotionModel": _maybe_expand_lora_scales,
@@ -213,56 +213,65 @@ def load_lora_adapter(
213213
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
214214
)
215215

216-
# check with first key if is not in peft format
217-
first_key = next(iter(state_dict.keys()))
218-
if "lora_A" not in first_key:
219-
state_dict = convert_unet_state_dict_to_peft(state_dict)
220-
221-
# Control LoRA from SAI is different from BFL Control LoRA
222-
# https://huggingface.co/stabilityai/control-lora
223-
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
224-
is_sai_sd_control_lora = "lora_controlnet" in state_dict
225-
if is_sai_sd_control_lora:
226-
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
227-
228-
rank = {}
229-
for key, val in state_dict.items():
230-
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
231-
# Bias layers in LoRA only have a single dimension
232-
if "lora_B" in key and val.ndim > 1:
233-
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
234-
# We may run into some ambiguous configuration values when a model has module
235-
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
236-
# for example) and they have different LoRA ranks.
237-
rank[f"^{key}"] = val.shape[1]
238-
239-
if network_alphas is not None and len(network_alphas) >= 1:
240-
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
241-
network_alphas = {
242-
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
243-
}
244-
245-
# adapter_name
246-
if adapter_name is None:
247-
adapter_name = get_adapter_name(self)
248-
249-
# create LoraConfig
250-
lora_config = _create_lora_config(
251-
state_dict,
252-
network_alphas,
253-
metadata,
254-
rank,
255-
model_state_dict=self.state_dict(),
256-
adapter_name=adapter_name,
257-
)
216+
# Detect whether this is a LoKR adapter (Kronecker product, not low-rank)
217+
is_lokr = any("lokr_" in k for k in state_dict)
218+
219+
if is_lokr:
220+
if adapter_name is None:
221+
adapter_name = get_adapter_name(self)
222+
lora_config = _create_lokr_config(state_dict)
223+
is_sai_sd_control_lora = False
224+
else:
225+
# check with first key if is not in peft format
226+
first_key = next(iter(state_dict.keys()))
227+
if "lora_A" not in first_key:
228+
state_dict = convert_unet_state_dict_to_peft(state_dict)
229+
230+
# Control LoRA from SAI is different from BFL Control LoRA
231+
# https://huggingface.co/stabilityai/control-lora
232+
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
233+
is_sai_sd_control_lora = "lora_controlnet" in state_dict
234+
if is_sai_sd_control_lora:
235+
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
236+
237+
rank = {}
238+
for key, val in state_dict.items():
239+
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
240+
# Bias layers in LoRA only have a single dimension
241+
if "lora_B" in key and val.ndim > 1:
242+
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
243+
# We may run into some ambiguous configuration values when a model has module
244+
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
245+
# for example) and they have different LoRA ranks.
246+
rank[f"^{key}"] = val.shape[1]
247+
248+
if network_alphas is not None and len(network_alphas) >= 1:
249+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
250+
network_alphas = {
251+
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
252+
}
253+
254+
# adapter_name
255+
if adapter_name is None:
256+
adapter_name = get_adapter_name(self)
257+
258+
# create LoraConfig
259+
lora_config = _create_lora_config(
260+
state_dict,
261+
network_alphas,
262+
metadata,
263+
rank,
264+
model_state_dict=self.state_dict(),
265+
adapter_name=adapter_name,
266+
)
258267

259-
# Adjust LoRA config for Control LoRA
260-
if is_sai_sd_control_lora:
261-
lora_config.lora_alpha = lora_config.r
262-
lora_config.alpha_pattern = lora_config.rank_pattern
263-
lora_config.bias = "all"
264-
lora_config.modules_to_save = lora_config.exclude_modules
265-
lora_config.exclude_modules = None
268+
# Adjust LoRA config for Control LoRA
269+
if is_sai_sd_control_lora:
270+
lora_config.lora_alpha = lora_config.r
271+
lora_config.alpha_pattern = lora_config.rank_pattern
272+
lora_config.bias = "all"
273+
lora_config.modules_to_save = lora_config.exclude_modules
274+
lora_config.exclude_modules = None
266275

267276
# <Unsafe code
268277
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype

0 commit comments

Comments
 (0)