@@ -79,6 +79,32 @@ def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
7979_FLUX1_MLP_RATIO = 4
8080
8181
82+ def _lokr_in_dim (state_dict : dict [str | int , Any ], key_prefix : str ) -> int | None :
83+ """Compute the input dimension of a LOKR layer: w1.shape[1] * w2.shape[1].
84+
85+ Supports both full LOKR (lokr_w1/lokr_w2) and factorized LOKR (lokr_w1_b/lokr_w2_b).
86+ Returns None if the required keys are not present.
87+ """
88+ if f"{ key_prefix } .lokr_w1" in state_dict and f"{ key_prefix } .lokr_w2" in state_dict :
89+ return state_dict [f"{ key_prefix } .lokr_w1" ].shape [1 ] * state_dict [f"{ key_prefix } .lokr_w2" ].shape [1 ]
90+ elif f"{ key_prefix } .lokr_w1_b" in state_dict and f"{ key_prefix } .lokr_w2_b" in state_dict :
91+ return state_dict [f"{ key_prefix } .lokr_w1_b" ].shape [1 ] * state_dict [f"{ key_prefix } .lokr_w2_b" ].shape [1 ]
92+ return None
93+
94+
95+ def _lokr_out_dim (state_dict : dict [str | int , Any ], key_prefix : str ) -> int | None :
96+ """Compute the output dimension of a LOKR layer: w1.shape[0] * w2.shape[0].
97+
98+ Supports both full LOKR (lokr_w1/lokr_w2) and factorized LOKR (lokr_w1_a/lokr_w2_a).
99+ Returns None if the required keys are not present.
100+ """
101+ if f"{ key_prefix } .lokr_w1" in state_dict and f"{ key_prefix } .lokr_w2" in state_dict :
102+ return state_dict [f"{ key_prefix } .lokr_w1" ].shape [0 ] * state_dict [f"{ key_prefix } .lokr_w2" ].shape [0 ]
103+ elif f"{ key_prefix } .lokr_w1_a" in state_dict and f"{ key_prefix } .lokr_w2_a" in state_dict :
104+ return state_dict [f"{ key_prefix } .lokr_w1_a" ].shape [0 ] * state_dict [f"{ key_prefix } .lokr_w2_a" ].shape [0 ]
105+ return None
106+
107+
82108def _is_flux2_lora (mod : ModelOnDisk ) -> bool :
83109 """Check if a FLUX-format LoRA is specifically for FLUX.2 (Klein) rather than FLUX.1.
84110
@@ -147,7 +173,30 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
147173 elif "vector_in" in key and key .endswith ("lora_A.weight" ):
148174 return state_dict [key ].shape [1 ] in _FLUX2_VEC_IN_DIMS
149175
150- # BFL PEFT: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
176+ # BFL LyCORIS (LoKR/LoHA): attention projection → check hidden_size via product of dims
177+ elif key .endswith ((".img_attn.proj.lokr_w1" , ".img_attn.proj.lokr_w1_b" )):
178+ layer_prefix = key .rsplit ("." , 1 )[0 ]
179+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
180+ if in_dim is not None :
181+ if in_dim != _FLUX1_HIDDEN_SIZE :
182+ return True
183+ bfl_hidden_size = in_dim # ambiguous, keep checking
184+
185+ # BFL LyCORIS: context_embedder/txt_in
186+ elif "txt_in" in key and key .endswith ((".lokr_w1" , ".lokr_w1_b" )):
187+ layer_prefix = key .rsplit ("." , 1 )[0 ]
188+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
189+ if in_dim is not None :
190+ return in_dim in _FLUX2_CONTEXT_IN_DIMS
191+
192+ # BFL LyCORIS: vector_in
193+ elif "vector_in" in key and key .endswith ((".lokr_w1" , ".lokr_w1_b" )):
194+ layer_prefix = key .rsplit ("." , 1 )[0 ]
195+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
196+ if in_dim is not None :
197+ return in_dim in _FLUX2_VEC_IN_DIMS
198+
199+ # BFL PEFT/LyCORIS: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
151200 # Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288).
152201 if bfl_hidden_size == _FLUX1_HIDDEN_SIZE :
153202 for key in state_dict :
@@ -158,6 +207,13 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
158207 if ffn_dim != bfl_hidden_size * _FLUX1_MLP_RATIO :
159208 return True
160209 break
210+ # BFL LyCORIS: check output dim of img_mlp.0 via product of dims
211+ if key .startswith (_bfl_prefixes ) and key .endswith ((".img_mlp.0.lokr_w1" , ".img_mlp.0.lokr_w1_a" )):
212+ layer_prefix = key .rsplit ("." , 1 )[0 ]
213+ out_dim = _lokr_out_dim (state_dict , layer_prefix )
214+ if out_dim is not None and out_dim != bfl_hidden_size * _FLUX1_MLP_RATIO :
215+ return True
216+ break
161217
162218 # Check kohya format: look for context_embedder or vector_in keys
163219 # Kohya format uses lora_unet_ prefix with underscores instead of dots
@@ -167,9 +223,21 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
167223 if key .startswith ("lora_unet_txt_in." ) or key .startswith ("lora_unet_context_embedder." ):
168224 if key .endswith ("lora_down.weight" ):
169225 return state_dict [key ].shape [1 ] in _FLUX2_CONTEXT_IN_DIMS
226+ # Kohya LyCORIS (LoKR)
227+ elif key .endswith ((".lokr_w1" , ".lokr_w1_b" )):
228+ layer_prefix = key .rsplit ("." , 1 )[0 ]
229+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
230+ if in_dim is not None :
231+ return in_dim in _FLUX2_CONTEXT_IN_DIMS
170232 if key .startswith ("lora_unet_vector_in." ) or key .startswith ("lora_unet_time_text_embed_text_embedder_" ):
171233 if key .endswith ("lora_down.weight" ):
172234 return state_dict [key ].shape [1 ] in _FLUX2_VEC_IN_DIMS
235+ # Kohya LyCORIS (LoKR)
236+ elif key .endswith ((".lokr_w1" , ".lokr_w1_b" )):
237+ layer_prefix = key .rsplit ("." , 1 )[0 ]
238+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
239+ if in_dim is not None :
240+ return in_dim in _FLUX2_VEC_IN_DIMS
173241
174242 return False
175243
@@ -244,7 +312,7 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
244312 return Flux2VariantType .Klein9B
245313 return None
246314
247- # Check BFL PEFT format (diffusion_model.* or base_model.model.* prefix with BFL names)
315+ # Check BFL PEFT/LyCORIS format (diffusion_model.* or base_model.model.* prefix with BFL names)
248316 _bfl_prefixes = ("diffusion_model." , "base_model.model." )
249317 for key in state_dict :
250318 if not isinstance (key , str ):
@@ -279,6 +347,39 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
279347 return Flux2VariantType .Klein9B
280348 return None
281349
350+ # BFL LyCORIS (LoKR): context embedder (txt_in)
351+ if "txt_in" in key and key .endswith ((".lokr_w1" , ".lokr_w1_b" )):
352+ layer_prefix = key .rsplit ("." , 1 )[0 ]
353+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
354+ if in_dim is not None :
355+ if in_dim == KLEIN_4B_CONTEXT_DIM :
356+ return Flux2VariantType .Klein4B
357+ if in_dim == KLEIN_9B_CONTEXT_DIM :
358+ return Flux2VariantType .Klein9B
359+ return None
360+
361+ # BFL LyCORIS (LoKR): vector embedder (vector_in)
362+ if "vector_in" in key and key .endswith ((".lokr_w1" , ".lokr_w1_b" )):
363+ layer_prefix = key .rsplit ("." , 1 )[0 ]
364+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
365+ if in_dim is not None :
366+ if in_dim == KLEIN_4B_VEC_DIM :
367+ return Flux2VariantType .Klein4B
368+ if in_dim == KLEIN_9B_VEC_DIM :
369+ return Flux2VariantType .Klein9B
370+ return None
371+
372+ # BFL LyCORIS (LoKR): attention projection
373+ if key .endswith ((".img_attn.proj.lokr_w1" , ".img_attn.proj.lokr_w1_b" )):
374+ layer_prefix = key .rsplit ("." , 1 )[0 ]
375+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
376+ if in_dim is not None :
377+ if in_dim == KLEIN_4B_HIDDEN_SIZE :
378+ return Flux2VariantType .Klein4B
379+ if in_dim == KLEIN_9B_HIDDEN_SIZE :
380+ return Flux2VariantType .Klein9B
381+ return None
382+
282383 # Check kohya format
283384 for key in state_dict :
284385 if not isinstance (key , str ):
@@ -291,6 +392,16 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
291392 if dim == KLEIN_9B_CONTEXT_DIM :
292393 return Flux2VariantType .Klein9B
293394 return None
395+ # Kohya LyCORIS (LoKR)
396+ elif key .endswith ((".lokr_w1" , ".lokr_w1_b" )):
397+ layer_prefix = key .rsplit ("." , 1 )[0 ]
398+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
399+ if in_dim is not None :
400+ if in_dim == KLEIN_4B_CONTEXT_DIM :
401+ return Flux2VariantType .Klein4B
402+ if in_dim == KLEIN_9B_CONTEXT_DIM :
403+ return Flux2VariantType .Klein9B
404+ return None
294405 if key .startswith ("lora_unet_vector_in." ) or key .startswith ("lora_unet_time_text_embed_text_embedder_" ):
295406 if key .endswith ("lora_down.weight" ):
296407 dim = state_dict [key ].shape [1 ]
@@ -299,6 +410,16 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
299410 if dim == KLEIN_9B_VEC_DIM :
300411 return Flux2VariantType .Klein9B
301412 return None
413+ # Kohya LyCORIS (LoKR)
414+ elif key .endswith ((".lokr_w1" , ".lokr_w1_b" )):
415+ layer_prefix = key .rsplit ("." , 1 )[0 ]
416+ in_dim = _lokr_in_dim (state_dict , layer_prefix )
417+ if in_dim is not None :
418+ if in_dim == KLEIN_4B_VEC_DIM :
419+ return Flux2VariantType .Klein4B
420+ if in_dim == KLEIN_9B_VEC_DIM :
421+ return Flux2VariantType .Klein9B
422+ return None
302423
303424 return None
304425
@@ -423,6 +544,12 @@ def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
423544 "to_q_lora.down.weight" ,
424545 "lora_A.weight" ,
425546 "lora_B.weight" ,
547+ # LyCORIS LoKR suffixes
548+ "lokr_w1" ,
549+ "lokr_w2" ,
550+ # LyCORIS LoHA suffixes
551+ "hada_w1_a" ,
552+ "hada_w2_a" ,
426553 },
427554 )
428555
0 commit comments