Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions invokeai/app/invocations/flux2_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
"flux2_denoise",
title="FLUX2 Denoise",
tags=["image", "flux", "flux2", "klein", "denoise"],
category="latents",
version="1.4.0",
category="image",
version="1.5.0",
classification=Classification.Prototype,
)
class Flux2DenoiseInvocation(BaseInvocation):
Expand Down Expand Up @@ -101,6 +101,13 @@ class Flux2DenoiseInvocation(BaseInvocation):
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
guidance: float = InputField(
default=4.0,
ge=0,
le=20,
description="The guidance strength. Only used by undistilled models (Klein 9B Base). "
"Ignored by distilled models (Klein 4B, Klein 9B).",
)
cfg_scale: float = InputField(
default=1.0,
description=FieldDescriptions.cfg_scale,
Expand Down Expand Up @@ -467,6 +474,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
txt_ids=txt_ids,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
cfg_scale=cfg_scale_list,
neg_txt=neg_txt,
neg_txt_ids=neg_txt_ids,
Expand Down
17 changes: 16 additions & 1 deletion invokeai/backend/flux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,26 @@ def get_flux_ae_params() -> AutoEncoderParams:
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
guidance_embed=False,
),
# Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
Flux2VariantType.Klein9B: FluxParams(
in_channels=64,
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
# Flux2 Klein 9B Base is the undistilled foundation model with guidance_embeds=True
Flux2VariantType.Klein9BBase: FluxParams(
in_channels=64,
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
Expand Down
22 changes: 14 additions & 8 deletions invokeai/backend/flux2/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def denoise(
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
cfg_scale: list[float],
# Negative conditioning for CFG
neg_txt: torch.Tensor | None = None,
Expand All @@ -45,7 +46,9 @@ def denoise(
This is a simplified denoise function for FLUX.2 Klein models that uses
the diffusers Flux2Transformer2DModel interface.

Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
Distilled models (Klein 4B, Klein 9B) have guidance_embeds=False, so the guidance
value is passed but ignored by the model. Undistilled models (Klein 9B Base) have
guidance_embeds=True and use the guidance value for generation.
CFG is applied externally using negative conditioning when cfg_scale != 1.0.

Args:
Expand All @@ -56,6 +59,8 @@ def denoise(
txt_ids: Text position IDs tensor.
timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
step_callback: Callback function for progress updates.
guidance: Guidance strength. Used by undistilled models (Klein 9B Base),
ignored by distilled models (Klein 4B, Klein 9B).
cfg_scale: List of CFG scale values per step.
neg_txt: Negative text embeddings for CFG (optional).
neg_txt_ids: Negative text position IDs (optional).
Expand All @@ -76,9 +81,10 @@ def denoise(
img = torch.cat([img, img_cond_seq], dim=1)
img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)

# Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
# We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
# The transformer forward() requires a guidance tensor.
# For distilled models (guidance_embeds=False), this value is ignored by the model.
# For undistilled models (Klein 9B Base, guidance_embeds=True), it controls guidance strength.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

# Use scheduler if provided
use_scheduler = scheduler is not None
Expand Down Expand Up @@ -121,7 +127,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand All @@ -141,7 +147,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand Down Expand Up @@ -222,7 +228,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand All @@ -242,7 +248,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ export const ImageMetadataActions = memo((props: Props) => {
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.QwenImageShift} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.CanvasLayers} />
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefImages} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.KleinVAEModel} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.KleinQwen3EncoderModel} />
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.LoRAs} />
</Flex>
);
Expand Down
128 changes: 128 additions & 0 deletions invokeai/frontend/web/src/features/metadata/parsing.test.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import type { AppStore } from 'app/store/store';
import type * as paramsSliceModule from 'features/controlLayers/store/paramsSlice';
import { ImageMetadataHandlers } from 'features/metadata/parsing';
import type * as modelsApiModule from 'services/api/endpoints/models';
import { beforeEach, describe, expect, it, vi } from 'vitest';

// ---------------------------------------------------------------------------
// Module mocks
//
// We are testing only the *gating* logic of the model-related metadata
// handlers (`VAEModel`, `KleinVAEModel`, `KleinQwen3EncoderModel`). The actual
// model lookup goes through `parseModelIdentifier`, which dispatches RTK
// Query thunks. We stub the models endpoint so that any lookup resolves to a
// canned model identifier — the parse step then succeeds and the assertions
// inside each handler become observable.
// ---------------------------------------------------------------------------

let currentBase: string | null = 'flux2';

vi.mock('features/controlLayers/store/paramsSlice', async (importOriginal) => {
const mod = await importOriginal<typeof paramsSliceModule>();
return { ...mod, selectBase: () => currentBase };
});

const fakeModel = (type: 'vae' | 'qwen3_encoder', base: string) => ({
key: `${type}-key`,
hash: 'hash',
name: `Some ${type}`,
base,
type,
});

let nextResolved: ReturnType<typeof fakeModel> = fakeModel('vae', 'flux2');

vi.mock('services/api/endpoints/models', async (importOriginal) => {
const mod = await importOriginal<typeof modelsApiModule>();
return {
...mod,
modelsApi: {
...mod.modelsApi,
endpoints: {
...mod.modelsApi.endpoints,
getModelConfig: { initiate: (key: string) => ({ type: 'rtkq/initiate', key }) },
},
},
};
});

const makeStore = (): AppStore =>
({
dispatch: vi.fn(() => ({
unwrap: () => Promise.resolve(nextResolved),
})),
getState: () => ({}),
}) as unknown as AppStore;

beforeEach(() => {
currentBase = 'flux2';
nextResolved = fakeModel('vae', 'flux2');
});

describe('ImageMetadataHandlers — Klein recall gating', () => {
describe('KleinVAEModel', () => {
it('parses metadata.vae when the current main model is FLUX.2 Klein', async () => {
currentBase = 'flux2';
nextResolved = fakeModel('vae', 'flux2');
const store = makeStore();

const parsed = await ImageMetadataHandlers.KleinVAEModel.parse({ vae: nextResolved }, store);

expect(parsed.key).toBe('vae-key');
expect(parsed.type).toBe('vae');
});

it('rejects parsing when the current main model is not FLUX.2 Klein', async () => {
currentBase = 'sdxl';
nextResolved = fakeModel('vae', 'flux2');
const store = makeStore();

await expect(ImageMetadataHandlers.KleinVAEModel.parse({ vae: nextResolved }, store)).rejects.toThrow();
});
});

describe('KleinQwen3EncoderModel', () => {
it('parses metadata.qwen3_encoder when the current main model is FLUX.2 Klein', async () => {
currentBase = 'flux2';
nextResolved = fakeModel('qwen3_encoder', 'flux2');
const store = makeStore();

const parsed = await ImageMetadataHandlers.KleinQwen3EncoderModel.parse({ qwen3_encoder: nextResolved }, store);

expect(parsed.key).toBe('qwen3_encoder-key');
expect(parsed.type).toBe('qwen3_encoder');
});

it('rejects parsing when the current main model is not FLUX.2 Klein', async () => {
currentBase = 'sdxl';
nextResolved = fakeModel('qwen3_encoder', 'flux2');
const store = makeStore();

await expect(
ImageMetadataHandlers.KleinQwen3EncoderModel.parse({ qwen3_encoder: nextResolved }, store)
).rejects.toThrow();
});
});

describe('VAEModel (generic)', () => {
// The generic VAEModel handler must NOT also fire for FLUX.2 / Z-Image
// images, otherwise the metadata viewer renders duplicate VAE rows next
// to the dedicated KleinVAEModel / ZImageVAEModel handlers.
it.each(['flux2', 'z-image'])('rejects parsing when current base is %s', async (base) => {
currentBase = base;
nextResolved = fakeModel('vae', base);
const store = makeStore();

await expect(ImageMetadataHandlers.VAEModel.parse({ vae: nextResolved }, store)).rejects.toThrow();
});

it('parses successfully for non-Klein, non-Z-Image bases', async () => {
currentBase = 'sdxl';
nextResolved = fakeModel('vae', 'sdxl');
const store = makeStore();

const parsed = await ImageMetadataHandlers.VAEModel.parse({ vae: nextResolved }, store);
expect(parsed.key).toBe('vae-key');
});
});
});
3 changes: 3 additions & 0 deletions invokeai/frontend/web/src/features/metadata/parsing.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,9 @@ const VAEModel: SingleMetadataHandler<ParameterVAEModel> = {
const parsed = await parseModelIdentifier(raw, store, 'vae');
assert(parsed.type === 'vae');
assert(isCompatibleWithMainModel(parsed, store));
// Z-Image and FLUX.2 Klein have dedicated VAE handlers; avoid rendering a duplicate row.
const base = selectBase(store.getState());
assert(base !== 'z-image' && base !== 'flux2', 'VAEModel handler does not apply to Z-Image or FLUX.2 Klein');
return Promise.resolve(parsed);
},
recall: (value, store) => {
Expand Down
Loading
Loading