Skip to content

Commit 9a90430

Browse files
PfannkuchensackJPPhotolstein
authored
fix(flux2): remove inert guidance UI, add Klein 4B Base variant, fix metadata recall (#8995)
* Fix(Flux2): Correct guidance_embed, add guidance support for Klein 9B Base, and fix metadata recall Klein 4B and 9B (distilled) have guidance_embeds=False, while Klein 9B Base (undistilled) has guidance_embeds=True. This commit: - Sets guidance_embed=False for Klein 4B/9B and adds Klein9BBase with True - Adds guidance parameter to Flux2DenoiseInvocation (used by Klein 9B Base) - Passes real guidance value instead of hardcoded 1.0 in flux2/denoise.py - Hides guidance slider for distilled Klein models, shows it for Klein 9B Base - Shows Flux scheduler dropdown for all Flux2 Klein models - Passes scheduler to Flux2 denoise node and saves it in metadata - Adds KleinVAEModel and KleinQwen3EncoderModel to recall parameters panel * test(flux2): cover Klein guidance gating, scheduler metadata, and recall dedupe Add a mock-based harness for buildFLUXGraph that locks in the FLUX.2 orchestration: guidance is written to metadata and the flux2_denoise node only for klein_9b_base, distilled variants (klein_9b, klein_4b) omit it, the FLUX scheduler is persisted into both metadata and the denoise node, and separately selected Klein VAE / Qwen3 encoder land in metadata. Add parsing tests for the metadata recall handlers: KleinVAEModel and KleinQwen3EncoderModel only fire when the current main model is FLUX.2, and the generic VAEModel handler now bails out for flux2 / z-image so the metadata viewer no longer renders duplicate VAE rows next to the dedicated Klein / Z-Image handlers. * Chore pnpm fix * Update version to 1.5.0 in flux2_denoise.py * Update condition for rendering ParamFluxScheduler * feat(flux2): add Klein4BBase variant for FLUX.2 Klein Base 4B models Recognize FLUX.2-klein-base-4B on import via filename heuristic. The variant shares Klein4B's architecture (Qwen3-4B encoder, context_in_dim=7680) and reports guidance_embeds=False in its HF config, consistent with Klein 9B Base. UI behavior stays identical to distilled Klein4B until CFG support is wired up in a follow-up. * Change Wrong Comment * refactor(flux2): remove inert guidance UI/metadata for FLUX.2 Klein All current FLUX.2 Klein variants (4B, 4B Base, 9B, 9B Base) report guidance_embeds=false in their HF transformer config (or have zeroed projection weights), so the guidance scalar has no effect on output. The linear UI previously exposed a guidance slider for klein_9b_base and wrote the value into metadata, which misled users into thinking it was steering generation. * Chore typegen * fix test * fix(flux2): skip Guidance metadata recall for legacy FLUX.2 images The generic Guidance metadata handler unconditionally parsed `metadata.guidance` and dispatched `setGuidance(value)` into the shared params slice. For images generated before the Klein guidance cleanup, this still fired — silently writing a stale guidance value into the global state, which then leaked back into FLUX.1 on model switch. Gate the handler on `metadata.model.base`: reject parsing when the image was generated with a FLUX.2 model. The handler is then skipped for both display and recall on legacy FLUX.2 metadata, matching the "silently ignored" contract stated in the PR. - parsing.tsx: check metadata.model.base in Guidance.parse() - parsing.test.tsx: three new cases covering FLUX.2 gating, FLUX.1 pass-through, and back-compat for metadata without a model field --------- Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent f9f2a32 commit 9a90430

17 files changed

Lines changed: 477 additions & 144 deletions

File tree

invokeai/app/invocations/flux2_denoise.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353
"flux2_denoise",
5454
title="FLUX2 Denoise",
5555
tags=["image", "flux", "flux2", "klein", "denoise"],
56-
category="latents",
57-
version="1.4.0",
56+
category="image",
57+
version="1.5.0",
5858
classification=Classification.Prototype,
5959
)
6060
class Flux2DenoiseInvocation(BaseInvocation):
@@ -101,6 +101,14 @@ class Flux2DenoiseInvocation(BaseInvocation):
101101
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
102102
input=Input.Connection,
103103
)
104+
guidance: float = InputField(
105+
default=4.0,
106+
ge=0,
107+
le=20,
108+
description="Guidance strength for distilled guidance-embedding models. "
109+
"Inert for all current FLUX.2 Klein variants (their guidance_embeds weights are absent/zero); "
110+
"kept for node-graph compatibility and future guidance-embedded models.",
111+
)
104112
cfg_scale: float = InputField(
105113
default=1.0,
106114
description=FieldDescriptions.cfg_scale,
@@ -467,6 +475,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
467475
txt_ids=txt_ids,
468476
timesteps=timesteps,
469477
step_callback=self._build_step_callback(context),
478+
guidance=self.guidance,
470479
cfg_scale=cfg_scale_list,
471480
neg_txt=neg_txt,
472481
neg_txt_ids=neg_txt_ids,

invokeai/app/invocations/flux2_klein_model_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def _validate_qwen3_encoder_variant(self, context: InvocationContext, main_confi
207207
flux2_variant = main_config.variant
208208

209209
# Validate the variants match
210-
# Klein4B requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
210+
# Klein4B/Klein4BBase requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
211211
expected_qwen3_variant = None
212-
if flux2_variant == Flux2VariantType.Klein4B:
212+
if flux2_variant in (Flux2VariantType.Klein4B, Flux2VariantType.Klein4BBase):
213213
expected_qwen3_variant = Qwen3VariantType.Qwen3_4B
214214
elif flux2_variant in (Flux2VariantType.Klein9B, Flux2VariantType.Klein9BBase):
215215
expected_qwen3_variant = Qwen3VariantType.Qwen3_8B

invokeai/backend/flux/util.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,24 @@ def get_flux_ae_params() -> AutoEncoderParams:
133133
axes_dim=[16, 56, 56],
134134
theta=10_000,
135135
qkv_bias=True,
136-
guidance_embed=True,
136+
guidance_embed=False,
137+
),
138+
# Flux2 Klein 4B Base is the undistilled foundation model. It shares the same
139+
# architecture as Klein 4B (distilled) and reports guidance_embeds=False in its
140+
# HF transformer config - classical CFG (external negative pass) is the guidance mechanism.
141+
Flux2VariantType.Klein4BBase: FluxParams(
142+
in_channels=64,
143+
vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output)
144+
context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B
145+
hidden_size=3072,
146+
mlp_ratio=4.0,
147+
num_heads=24,
148+
depth=19,
149+
depth_single_blocks=38,
150+
axes_dim=[16, 56, 56],
151+
theta=10_000,
152+
qkv_bias=True,
153+
guidance_embed=False,
137154
),
138155
# Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
139156
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
@@ -149,7 +166,24 @@ def get_flux_ae_params() -> AutoEncoderParams:
149166
axes_dim=[16, 56, 56],
150167
theta=10_000,
151168
qkv_bias=True,
152-
guidance_embed=True,
169+
guidance_embed=False,
170+
),
171+
# Flux2 Klein 9B Base is the undistilled foundation model. It shares the same
172+
# architecture as Klein 9B (distilled) and reports guidance_embeds=False in its
173+
# HF transformer config - the guidance scalar is inert for all Klein variants.
174+
Flux2VariantType.Klein9BBase: FluxParams(
175+
in_channels=64,
176+
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
177+
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
178+
hidden_size=3072,
179+
mlp_ratio=4.0,
180+
num_heads=24,
181+
depth=19,
182+
depth_single_blocks=38,
183+
axes_dim=[16, 56, 56],
184+
theta=10_000,
185+
qkv_bias=True,
186+
guidance_embed=False,
153187
),
154188
}
155189

invokeai/backend/flux2/denoise.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def denoise(
2626
# sampling parameters
2727
timesteps: list[float],
2828
step_callback: Callable[[PipelineIntermediateState], None],
29+
guidance: float,
2930
cfg_scale: list[float],
3031
# Negative conditioning for CFG
3132
neg_txt: torch.Tensor | None = None,
@@ -45,7 +46,10 @@ def denoise(
4546
This is a simplified denoise function for FLUX.2 Klein models that uses
4647
the diffusers Flux2Transformer2DModel interface.
4748
48-
Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
49+
All current FLUX.2 Klein variants (4B, 4B Base, 9B, 9B Base) have guidance_embeds=False
50+
in their HF transformer config (or absent/zeroed projection weights), so the guidance
51+
value is passed but effectively ignored by the model. The argument is retained for
52+
node-graph compatibility and future variants that may ship trained guidance projections.
4953
CFG is applied externally using negative conditioning when cfg_scale != 1.0.
5054
5155
Args:
@@ -56,6 +60,8 @@ def denoise(
5660
txt_ids: Text position IDs tensor.
5761
timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
5862
step_callback: Callback function for progress updates.
63+
guidance: Guidance strength. Inert for all current FLUX.2 Klein variants
64+
(their guidance_embeds projection weights are absent/zero).
5965
cfg_scale: List of CFG scale values per step.
6066
neg_txt: Negative text embeddings for CFG (optional).
6167
neg_txt_ids: Negative text position IDs (optional).
@@ -76,9 +82,10 @@ def denoise(
7682
img = torch.cat([img, img_cond_seq], dim=1)
7783
img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)
7884

79-
# Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
80-
# We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
81-
guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
85+
# The transformer forward() requires a guidance tensor even when guidance_embeds=False,
86+
# because the Flux2TimestepGuidanceEmbeddings forward signature takes it unconditionally.
87+
# All current Klein variants have guidance_embeds=False, so the value is ignored internally.
88+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
8289

8390
# Use scheduler if provided
8491
use_scheduler = scheduler is not None
@@ -121,7 +128,7 @@ def denoise(
121128
timestep=t_vec,
122129
img_ids=img_ids,
123130
txt_ids=txt_ids,
124-
guidance=guidance,
131+
guidance=guidance_vec,
125132
return_dict=False,
126133
)
127134

@@ -141,7 +148,7 @@ def denoise(
141148
timestep=t_vec,
142149
img_ids=img_ids,
143150
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
144-
guidance=guidance,
151+
guidance=guidance_vec,
145152
return_dict=False,
146153
)
147154

@@ -222,7 +229,7 @@ def denoise(
222229
timestep=t_vec,
223230
img_ids=img_ids,
224231
txt_ids=txt_ids,
225-
guidance=guidance,
232+
guidance=guidance_vec,
226233
return_dict=False,
227234
)
228235

@@ -242,7 +249,7 @@ def denoise(
242249
timestep=t_vec,
243250
img_ids=img_ids,
244251
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
245-
guidance=guidance,
252+
guidance=guidance_vec,
246253
return_dict=False,
247254
)
248255

invokeai/backend/model_manager/configs/main.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def from_base(
8181
return cls(steps=35, cfg_scale=4.5, width=1024, height=1024)
8282
case BaseModelType.Flux2:
8383
# Different defaults based on variant
84-
if variant == Flux2VariantType.Klein9BBase:
85-
# Undistilled base model needs more steps
84+
if variant in (Flux2VariantType.Klein4BBase, Flux2VariantType.Klein9BBase):
85+
# Undistilled base models need more steps
8686
return cls(steps=28, cfg_scale=1.0, width=1024, height=1024)
8787
else:
8888
# Distilled models (Klein 4B, Klein 9B) use fewer steps
@@ -389,6 +389,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
389389
# Default to Klein9B - callers use filename heuristics to detect Klein9BBase
390390
return Flux2VariantType.Klein9B
391391
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
392+
# Default to Klein4B - callers use filename heuristics to detect Klein4BBase
392393
return Flux2VariantType.Klein4B
393394
elif context_in_dim > 4096:
394395
# Unknown FLUX.2 variant, default to 4B
@@ -573,10 +574,12 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
573574
if variant is None:
574575
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
575576

576-
# Klein 9B Base and Klein 9B have identical architectures.
577-
# Use filename heuristic to detect the Base (undistilled) variant.
577+
# Base (undistilled) and distilled variants share identical architectures.
578+
# Use filename heuristic to detect the Base variant.
578579
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
579580
return Flux2VariantType.Klein9BBase
581+
if variant == Flux2VariantType.Klein4B and _filename_suggests_base(mod.name):
582+
return Flux2VariantType.Klein4BBase
580583

581584
return variant
582585

@@ -745,10 +748,12 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
745748
if variant is None:
746749
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
747750

748-
# Klein 9B Base and Klein 9B have identical architectures.
749-
# Use filename heuristic to detect the Base (undistilled) variant.
751+
# Base (undistilled) and distilled variants share identical architectures.
752+
# Use filename heuristic to detect the Base variant.
750753
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
751754
return Flux2VariantType.Klein9BBase
755+
if variant == Flux2VariantType.Klein4B and _filename_suggests_base(mod.name):
756+
return Flux2VariantType.Klein4BBase
752757

753758
return variant
754759

@@ -856,11 +861,10 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
856861
"""Determine the FLUX.2 variant from the transformer config.
857862
858863
FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim:
859-
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
864+
- Klein 4B/4B Base: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
860865
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
861866
862-
Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
863-
and both have guidance_embeds=False. We use a filename heuristic to detect Base models.
867+
Distilled and Base variants share identical architectures. We use a filename heuristic to detect Base models.
864868
"""
865869
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
866870
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
@@ -875,6 +879,8 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
875879
return Flux2VariantType.Klein9BBase
876880
return Flux2VariantType.Klein9B
877881
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
882+
if _filename_suggests_base(mod.name):
883+
return Flux2VariantType.Klein4BBase
878884
return Flux2VariantType.Klein4B
879885
elif joint_attention_dim > 4096:
880886
# Unknown FLUX.2 variant, default to 4B

invokeai/backend/model_manager/taxonomy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ class Flux2VariantType(str, Enum):
132132
"""FLUX.2 model variants."""
133133

134134
Klein4B = "klein_4b"
135-
"""Flux2 Klein 4B variant using Qwen3 4B text encoder."""
135+
"""Flux2 Klein 4B variant using Qwen3 4B text encoder (distilled)."""
136+
137+
Klein4BBase = "klein_4b_base"
138+
"""Flux2 Klein 4B Base variant - undistilled foundation model using Qwen3 4B text encoder."""
136139

137140
Klein9B = "klein_9b"
138141
"""Flux2 Klein 9B variant using Qwen3 8B text encoder (distilled)."""

invokeai/frontend/web/openapi.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19924,7 +19924,7 @@
1992419924
},
1992519925
"Flux2VariantType": {
1992619926
"type": "string",
19927-
"enum": ["klein_4b", "klein_9b", "klein_9b_base"],
19927+
"enum": ["klein_4b", "klein_4b_base", "klein_9b", "klein_9b_base"],
1992819928
"title": "Flux2VariantType",
1992919929
"description": "FLUX.2 model variants."
1993019930
},

invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ export const ImageMetadataActions = memo((props: Props) => {
6363
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.QwenImageShift} />
6464
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.CanvasLayers} />
6565
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefImages} />
66+
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.KleinVAEModel} />
67+
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.KleinQwen3EncoderModel} />
6668
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.LoRAs} />
6769
</Flex>
6870
);

0 commit comments

Comments
 (0)