Skip to content

Commit afbd45a

Browse files
Pfannkuchensackclaudelstein
authored
Feature: flux2 klein lora support (invoke-ai#8862)
* WIP: Add FLUX.2 Klein LoRA support (BFL PEFT format) Initial implementation for loading and applying LoRA models trained with BFL's PEFT format for FLUX.2 Klein transformers. Changes: - Add LoRA_Diffusers_Flux2_Config and LoRA_LyCORIS_Flux2_Config - Add BflPeft format to FluxLoRAFormat taxonomy - Add flux_bfl_peft_lora_conversion_utils for weight conversion - Add Flux2KleinLoraLoaderInvocation node Status: Work in progress - not yet fully tested Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * feat(flux2): add LoRA support for FLUX.2 Klein models Add BFL PEFT LoRA support for FLUX.2 Klein, including runtime conversion of BFL-format keys to diffusers format with fused QKV splitting, improved detection of Klein 4B LoRAs via MLP ratio check, and frontend graph wiring. * feat(flux2): detect Klein LoRA variant (4B/9B) and filter by compatibility Auto-detect FLUX.2 Klein LoRA variant from tensor dimensions during model probe, warn on variant mismatch at load time, and filter the LoRA picker to only show variant-compatible LoRAs. * Chore Ruff * Chore pnpm * Fix detection and loading of 3 unrecognized Flux.2 Klein LoRA formats Three Flux.2 Klein LoRAs were either unrecognized or misclassified due to format detection gaps: 1. PEFT-wrapped BFL format (base_model.model.* prefix) was not recognized because the detector only accepted the diffusion_model.* prefix. 2. Klein 4B LoRAs with hidden_size=3072 were misidentified as Flux.1 due to a break statement exiting the detection loop before txt_in/vector_in dimensions could be checked. 3. Flux2 native diffusers format (to_qkv_mlp_proj, ff.linear_in) was not detected because the detector only checked for Flux.1 diffusers keys. Also handles mixed PEFT/standard LoRA suffix formats within the same file. --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent b9f9015 commit afbd45a

13 files changed

Lines changed: 1396 additions & 35 deletions

File tree

invokeai/app/invocations/flux2_denoise.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
)
3939
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
4040
from invokeai.backend.patches.layer_patcher import LayerPatcher
41+
from invokeai.backend.patches.lora_conversions.flux_bfl_peft_lora_conversion_utils import (
42+
convert_bfl_lora_patch_to_diffusers,
43+
)
4144
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
4245
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
4346
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
@@ -503,11 +506,17 @@ def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor)
503506
return mask.expand_as(latents)
504507

505508
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
506-
"""Iterate over LoRA models to apply."""
509+
"""Iterate over LoRA models to apply.
510+
511+
Converts BFL-format LoRA keys to diffusers format if needed, since FLUX.2 Klein
512+
uses Flux2Transformer2DModel (diffusers naming) but LoRAs may have been loaded
513+
with BFL naming (e.g. when a Klein 4B LoRA is misidentified as FLUX.1).
514+
"""
507515
for lora in self.transformer.loras:
508516
lora_info = context.models.load(lora.lora)
509517
assert isinstance(lora_info.model, ModelPatchRaw)
510-
yield (lora_info.model, lora.weight)
518+
converted = convert_bfl_lora_patch_to_diffusers(lora_info.model)
519+
yield (converted, lora.weight)
511520
del lora_info
512521

513522
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""FLUX.2 Klein LoRA Loader Invocation.
2+
3+
Applies LoRA models to a FLUX.2 Klein transformer and/or Qwen3 text encoder.
4+
Unlike standard FLUX which uses CLIP+T5, Klein uses only Qwen3 for text encoding.
5+
"""
6+
7+
from typing import Optional
8+
9+
from invokeai.app.invocations.baseinvocation import (
10+
BaseInvocation,
11+
BaseInvocationOutput,
12+
Classification,
13+
invocation,
14+
invocation_output,
15+
)
16+
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
17+
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
18+
from invokeai.app.services.shared.invocation_context import InvocationContext
19+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
20+
21+
22+
@invocation_output("flux2_klein_lora_loader_output")
23+
class Flux2KleinLoRALoaderOutput(BaseInvocationOutput):
24+
"""FLUX.2 Klein LoRA Loader Output"""
25+
26+
transformer: Optional[TransformerField] = OutputField(
27+
default=None, description=FieldDescriptions.transformer, title="Transformer"
28+
)
29+
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
30+
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
31+
)
32+
33+
34+
@invocation(
35+
"flux2_klein_lora_loader",
36+
title="Apply LoRA - Flux2 Klein",
37+
tags=["lora", "model", "flux", "klein", "flux2"],
38+
category="model",
39+
version="1.0.0",
40+
classification=Classification.Prototype,
41+
)
42+
class Flux2KleinLoRALoaderInvocation(BaseInvocation):
43+
"""Apply a LoRA model to a FLUX.2 Klein transformer and/or Qwen3 text encoder."""
44+
45+
lora: ModelIdentifierField = InputField(
46+
description=FieldDescriptions.lora_model,
47+
title="LoRA",
48+
ui_model_base=BaseModelType.Flux2,
49+
ui_model_type=ModelType.LoRA,
50+
)
51+
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
52+
transformer: TransformerField | None = InputField(
53+
default=None,
54+
description=FieldDescriptions.transformer,
55+
input=Input.Connection,
56+
title="Transformer",
57+
)
58+
qwen3_encoder: Qwen3EncoderField | None = InputField(
59+
default=None,
60+
title="Qwen3 Encoder",
61+
description=FieldDescriptions.qwen3_encoder,
62+
input=Input.Connection,
63+
)
64+
65+
def invoke(self, context: InvocationContext) -> Flux2KleinLoRALoaderOutput:
66+
lora_key = self.lora.key
67+
68+
if not context.models.exists(lora_key):
69+
raise ValueError(f"Unknown lora: {lora_key}!")
70+
71+
# Warn if LoRA variant doesn't match transformer variant
72+
lora_config = context.models.get_config(lora_key)
73+
lora_variant = getattr(lora_config, "variant", None)
74+
if lora_variant and self.transformer is not None:
75+
transformer_config = context.models.get_config(self.transformer.transformer.key)
76+
transformer_variant = getattr(transformer_config, "variant", None)
77+
if transformer_variant and lora_variant != transformer_variant:
78+
context.logger.warning(
79+
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
80+
f"but transformer is {transformer_variant.value}. This may cause shape errors."
81+
)
82+
83+
# Check for existing LoRAs with the same key.
84+
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
85+
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
86+
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
87+
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')
88+
89+
output = Flux2KleinLoRALoaderOutput()
90+
91+
# Attach LoRA layers to the models.
92+
if self.transformer is not None:
93+
output.transformer = self.transformer.model_copy(deep=True)
94+
output.transformer.loras.append(
95+
LoRAField(
96+
lora=self.lora,
97+
weight=self.weight,
98+
)
99+
)
100+
if self.qwen3_encoder is not None:
101+
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
102+
output.qwen3_encoder.loras.append(
103+
LoRAField(
104+
lora=self.lora,
105+
weight=self.weight,
106+
)
107+
)
108+
109+
return output
110+
111+
112+
@invocation(
113+
"flux2_klein_lora_collection_loader",
114+
title="Apply LoRA Collection - Flux2 Klein",
115+
tags=["lora", "model", "flux", "klein", "flux2"],
116+
category="model",
117+
version="1.0.0",
118+
classification=Classification.Prototype,
119+
)
120+
class Flux2KleinLoRACollectionLoader(BaseInvocation):
121+
"""Applies a collection of LoRAs to a FLUX.2 Klein transformer and/or Qwen3 text encoder."""
122+
123+
loras: Optional[LoRAField | list[LoRAField]] = InputField(
124+
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
125+
)
126+
127+
transformer: Optional[TransformerField] = InputField(
128+
default=None,
129+
description=FieldDescriptions.transformer,
130+
input=Input.Connection,
131+
title="Transformer",
132+
)
133+
qwen3_encoder: Qwen3EncoderField | None = InputField(
134+
default=None,
135+
title="Qwen3 Encoder",
136+
description=FieldDescriptions.qwen3_encoder,
137+
input=Input.Connection,
138+
)
139+
140+
def invoke(self, context: InvocationContext) -> Flux2KleinLoRALoaderOutput:
141+
output = Flux2KleinLoRALoaderOutput()
142+
loras = self.loras if isinstance(self.loras, list) else [self.loras]
143+
added_loras: list[str] = []
144+
145+
if self.transformer is not None:
146+
output.transformer = self.transformer.model_copy(deep=True)
147+
148+
if self.qwen3_encoder is not None:
149+
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
150+
151+
for lora in loras:
152+
if lora is None:
153+
continue
154+
if lora.lora.key in added_loras:
155+
continue
156+
157+
if not context.models.exists(lora.lora.key):
158+
raise Exception(f"Unknown lora: {lora.lora.key}!")
159+
160+
assert lora.lora.base in (BaseModelType.Flux, BaseModelType.Flux2)
161+
162+
# Warn if LoRA variant doesn't match transformer variant
163+
lora_config = context.models.get_config(lora.lora.key)
164+
lora_variant = getattr(lora_config, "variant", None)
165+
if lora_variant and self.transformer is not None:
166+
transformer_config = context.models.get_config(self.transformer.transformer.key)
167+
transformer_variant = getattr(transformer_config, "variant", None)
168+
if transformer_variant and lora_variant != transformer_variant:
169+
context.logger.warning(
170+
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
171+
f"but transformer is {transformer_variant.value}. This may cause shape errors."
172+
)
173+
174+
added_loras.append(lora.lora.key)
175+
176+
if self.transformer is not None and output.transformer is not None:
177+
output.transformer.loras.append(lora)
178+
179+
if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
180+
output.qwen3_encoder.loras.append(lora)
181+
182+
return output

invokeai/backend/model_manager/configs/factory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@
4040
from invokeai.backend.model_manager.configs.llava_onevision import LlavaOnevision_Diffusers_Config
4141
from invokeai.backend.model_manager.configs.lora import (
4242
ControlLoRA_LyCORIS_FLUX_Config,
43+
LoRA_Diffusers_Flux2_Config,
4344
LoRA_Diffusers_FLUX_Config,
4445
LoRA_Diffusers_SD1_Config,
4546
LoRA_Diffusers_SD2_Config,
4647
LoRA_Diffusers_SDXL_Config,
4748
LoRA_Diffusers_ZImage_Config,
49+
LoRA_LyCORIS_Flux2_Config,
4850
LoRA_LyCORIS_FLUX_Config,
4951
LoRA_LyCORIS_SD1_Config,
5052
LoRA_LyCORIS_SD2_Config,
@@ -197,18 +199,24 @@
197199
Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()],
198200
Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()],
199201
# LoRA - LyCORIS format
202+
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
203+
# that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
200204
Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()],
201205
Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()],
202206
Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()],
207+
Annotated[LoRA_LyCORIS_Flux2_Config, LoRA_LyCORIS_Flux2_Config.get_tag()],
203208
Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()],
204209
Annotated[LoRA_LyCORIS_ZImage_Config, LoRA_LyCORIS_ZImage_Config.get_tag()],
205210
# LoRA - OMI format
206211
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
207212
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
208213
# LoRA - diffusers format
214+
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
215+
# that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
209216
Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()],
210217
Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()],
211218
Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()],
219+
Annotated[LoRA_Diffusers_Flux2_Config, LoRA_Diffusers_Flux2_Config.get_tag()],
212220
Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()],
213221
Annotated[LoRA_Diffusers_ZImage_Config, LoRA_Diffusers_ZImage_Config.get_tag()],
214222
# ControlLoRA - diffusers format

0 commit comments

Comments
 (0)