Skip to content

Commit b646b2c

Browse files
Copilotlstein
andcommitted
Add FLUX.2 LOKR model support (detection and loading)
Co-authored-by: lstein <111189+lstein@users.noreply.github.com> Fix BFL LOKR models being misidentified as AIToolkit format Co-authored-by: lstein <111189+lstein@users.noreply.github.com> Fix alpha key warning in LOKR QKV split layers Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
1 parent afbd45a commit b646b2c

13 files changed

Lines changed: 375 additions & 34 deletions

File tree

invokeai/backend/model_manager/configs/lora.py

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,32 @@ def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
7979
_FLUX1_MLP_RATIO = 4
8080

8181

82+
def _lokr_in_dim(state_dict: dict[str | int, Any], key_prefix: str) -> int | None:
83+
"""Compute the input dimension of a LOKR layer: w1.shape[1] * w2.shape[1].
84+
85+
Supports both full LOKR (lokr_w1/lokr_w2) and factorized LOKR (lokr_w1_b/lokr_w2_b).
86+
Returns None if the required keys are not present.
87+
"""
88+
if f"{key_prefix}.lokr_w1" in state_dict and f"{key_prefix}.lokr_w2" in state_dict:
89+
return state_dict[f"{key_prefix}.lokr_w1"].shape[1] * state_dict[f"{key_prefix}.lokr_w2"].shape[1]
90+
elif f"{key_prefix}.lokr_w1_b" in state_dict and f"{key_prefix}.lokr_w2_b" in state_dict:
91+
return state_dict[f"{key_prefix}.lokr_w1_b"].shape[1] * state_dict[f"{key_prefix}.lokr_w2_b"].shape[1]
92+
return None
93+
94+
95+
def _lokr_out_dim(state_dict: dict[str | int, Any], key_prefix: str) -> int | None:
96+
"""Compute the output dimension of a LOKR layer: w1.shape[0] * w2.shape[0].
97+
98+
Supports both full LOKR (lokr_w1/lokr_w2) and factorized LOKR (lokr_w1_a/lokr_w2_a).
99+
Returns None if the required keys are not present.
100+
"""
101+
if f"{key_prefix}.lokr_w1" in state_dict and f"{key_prefix}.lokr_w2" in state_dict:
102+
return state_dict[f"{key_prefix}.lokr_w1"].shape[0] * state_dict[f"{key_prefix}.lokr_w2"].shape[0]
103+
elif f"{key_prefix}.lokr_w1_a" in state_dict and f"{key_prefix}.lokr_w2_a" in state_dict:
104+
return state_dict[f"{key_prefix}.lokr_w1_a"].shape[0] * state_dict[f"{key_prefix}.lokr_w2_a"].shape[0]
105+
return None
106+
107+
82108
def _is_flux2_lora(mod: ModelOnDisk) -> bool:
83109
"""Check if a FLUX-format LoRA is specifically for FLUX.2 (Klein) rather than FLUX.1.
84110
@@ -147,7 +173,30 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
147173
elif "vector_in" in key and key.endswith("lora_A.weight"):
148174
return state_dict[key].shape[1] in _FLUX2_VEC_IN_DIMS
149175

150-
# BFL PEFT: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
176+
# BFL LyCORIS (LoKR/LoHA): attention projection → check hidden_size via product of dims
177+
elif key.endswith((".img_attn.proj.lokr_w1", ".img_attn.proj.lokr_w1_b")):
178+
layer_prefix = key.rsplit(".", 1)[0]
179+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
180+
if in_dim is not None:
181+
if in_dim != _FLUX1_HIDDEN_SIZE:
182+
return True
183+
bfl_hidden_size = in_dim # ambiguous, keep checking
184+
185+
# BFL LyCORIS: context_embedder/txt_in
186+
elif "txt_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
187+
layer_prefix = key.rsplit(".", 1)[0]
188+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
189+
if in_dim is not None:
190+
return in_dim in _FLUX2_CONTEXT_IN_DIMS
191+
192+
# BFL LyCORIS: vector_in
193+
elif "vector_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
194+
layer_prefix = key.rsplit(".", 1)[0]
195+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
196+
if in_dim is not None:
197+
return in_dim in _FLUX2_VEC_IN_DIMS
198+
199+
# BFL PEFT/LyCORIS: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
151200
# Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288).
152201
if bfl_hidden_size == _FLUX1_HIDDEN_SIZE:
153202
for key in state_dict:
@@ -158,6 +207,13 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
158207
if ffn_dim != bfl_hidden_size * _FLUX1_MLP_RATIO:
159208
return True
160209
break
210+
# BFL LyCORIS: check output dim of img_mlp.0 via product of dims
211+
if key.startswith(_bfl_prefixes) and key.endswith((".img_mlp.0.lokr_w1", ".img_mlp.0.lokr_w1_a")):
212+
layer_prefix = key.rsplit(".", 1)[0]
213+
out_dim = _lokr_out_dim(state_dict, layer_prefix)
214+
if out_dim is not None and out_dim != bfl_hidden_size * _FLUX1_MLP_RATIO:
215+
return True
216+
break
161217

162218
# Check kohya format: look for context_embedder or vector_in keys
163219
# Kohya format uses lora_unet_ prefix with underscores instead of dots
@@ -167,9 +223,21 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
167223
if key.startswith("lora_unet_txt_in.") or key.startswith("lora_unet_context_embedder."):
168224
if key.endswith("lora_down.weight"):
169225
return state_dict[key].shape[1] in _FLUX2_CONTEXT_IN_DIMS
226+
# Kohya LyCORIS (LoKR)
227+
elif key.endswith((".lokr_w1", ".lokr_w1_b")):
228+
layer_prefix = key.rsplit(".", 1)[0]
229+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
230+
if in_dim is not None:
231+
return in_dim in _FLUX2_CONTEXT_IN_DIMS
170232
if key.startswith("lora_unet_vector_in.") or key.startswith("lora_unet_time_text_embed_text_embedder_"):
171233
if key.endswith("lora_down.weight"):
172234
return state_dict[key].shape[1] in _FLUX2_VEC_IN_DIMS
235+
# Kohya LyCORIS (LoKR)
236+
elif key.endswith((".lokr_w1", ".lokr_w1_b")):
237+
layer_prefix = key.rsplit(".", 1)[0]
238+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
239+
if in_dim is not None:
240+
return in_dim in _FLUX2_VEC_IN_DIMS
173241

174242
return False
175243

@@ -244,7 +312,7 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
244312
return Flux2VariantType.Klein9B
245313
return None
246314

247-
# Check BFL PEFT format (diffusion_model.* or base_model.model.* prefix with BFL names)
315+
# Check BFL PEFT/LyCORIS format (diffusion_model.* or base_model.model.* prefix with BFL names)
248316
_bfl_prefixes = ("diffusion_model.", "base_model.model.")
249317
for key in state_dict:
250318
if not isinstance(key, str):
@@ -279,6 +347,39 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
279347
return Flux2VariantType.Klein9B
280348
return None
281349

350+
# BFL LyCORIS (LoKR): context embedder (txt_in)
351+
if "txt_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
352+
layer_prefix = key.rsplit(".", 1)[0]
353+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
354+
if in_dim is not None:
355+
if in_dim == KLEIN_4B_CONTEXT_DIM:
356+
return Flux2VariantType.Klein4B
357+
if in_dim == KLEIN_9B_CONTEXT_DIM:
358+
return Flux2VariantType.Klein9B
359+
return None
360+
361+
# BFL LyCORIS (LoKR): vector embedder (vector_in)
362+
if "vector_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
363+
layer_prefix = key.rsplit(".", 1)[0]
364+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
365+
if in_dim is not None:
366+
if in_dim == KLEIN_4B_VEC_DIM:
367+
return Flux2VariantType.Klein4B
368+
if in_dim == KLEIN_9B_VEC_DIM:
369+
return Flux2VariantType.Klein9B
370+
return None
371+
372+
# BFL LyCORIS (LoKR): attention projection
373+
if key.endswith((".img_attn.proj.lokr_w1", ".img_attn.proj.lokr_w1_b")):
374+
layer_prefix = key.rsplit(".", 1)[0]
375+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
376+
if in_dim is not None:
377+
if in_dim == KLEIN_4B_HIDDEN_SIZE:
378+
return Flux2VariantType.Klein4B
379+
if in_dim == KLEIN_9B_HIDDEN_SIZE:
380+
return Flux2VariantType.Klein9B
381+
return None
382+
282383
# Check kohya format
283384
for key in state_dict:
284385
if not isinstance(key, str):
@@ -291,6 +392,16 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
291392
if dim == KLEIN_9B_CONTEXT_DIM:
292393
return Flux2VariantType.Klein9B
293394
return None
395+
# Kohya LyCORIS (LoKR)
396+
elif key.endswith((".lokr_w1", ".lokr_w1_b")):
397+
layer_prefix = key.rsplit(".", 1)[0]
398+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
399+
if in_dim is not None:
400+
if in_dim == KLEIN_4B_CONTEXT_DIM:
401+
return Flux2VariantType.Klein4B
402+
if in_dim == KLEIN_9B_CONTEXT_DIM:
403+
return Flux2VariantType.Klein9B
404+
return None
294405
if key.startswith("lora_unet_vector_in.") or key.startswith("lora_unet_time_text_embed_text_embedder_"):
295406
if key.endswith("lora_down.weight"):
296407
dim = state_dict[key].shape[1]
@@ -299,6 +410,16 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
299410
if dim == KLEIN_9B_VEC_DIM:
300411
return Flux2VariantType.Klein9B
301412
return None
413+
# Kohya LyCORIS (LoKR)
414+
elif key.endswith((".lokr_w1", ".lokr_w1_b")):
415+
layer_prefix = key.rsplit(".", 1)[0]
416+
in_dim = _lokr_in_dim(state_dict, layer_prefix)
417+
if in_dim is not None:
418+
if in_dim == KLEIN_4B_VEC_DIM:
419+
return Flux2VariantType.Klein4B
420+
if in_dim == KLEIN_9B_VEC_DIM:
421+
return Flux2VariantType.Klein9B
422+
return None
302423

303424
return None
304425

@@ -423,6 +544,12 @@ def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
423544
"to_q_lora.down.weight",
424545
"lora_A.weight",
425546
"lora_B.weight",
547+
# LyCORIS LoKR suffixes
548+
"lokr_w1",
549+
"lokr_w2",
550+
# LyCORIS LoHA suffixes
551+
"hada_w1_a",
552+
"hada_w2_a",
426553
},
427554
)
428555

invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,27 @@ def is_state_dict_likely_in_flux_aitoolkit_format(
3030
if not _has_flux_layer_structure(state_dict):
3131
return False
3232

33+
# AIToolkit only produces standard PEFT LoRA (lora_A.weight / lora_B.weight).
34+
# Exclude LyCORIS algorithm variants (LoKR, LoHA, etc.) which use different weight key suffixes.
35+
# These are handled by the BFL PEFT converter instead.
36+
_LYCORIS_SUFFIXES = (
37+
"lokr_w1",
38+
"lokr_w2",
39+
"lokr_w1_a",
40+
"lokr_w1_b",
41+
"lokr_w2_a",
42+
"lokr_w2_b",
43+
"lokr_t2",
44+
"hada_w1_a",
45+
"hada_w1_b",
46+
"hada_w2_a",
47+
"hada_w2_b",
48+
"hada_t1",
49+
"hada_t2",
50+
)
51+
if any(k.endswith(_LYCORIS_SUFFIXES) for k in state_dict.keys() if isinstance(k, str)):
52+
return False
53+
3354
if metadata:
3455
try:
3556
software = json.loads(metadata.get("software", "{}"))

0 commit comments

Comments
 (0)