Skip to content

Commit dbbf289

Browse files
fix: detect FLUX.2 Klein 9B Base variant via filename heuristic (#9011)
Klein 9B Base (undistilled) and Klein 9B (distilled) have identical architectures and cannot be distinguished from the state dict alone. Use a filename heuristic ("base" in the name) to detect the Base variant for checkpoint, GGUF, and diffusers format models. Also fixes the incorrect guidance_embeds-based detection for diffusers format, since both variants have guidance_embeds=False.
1 parent f08b802 commit dbbf289

File tree

1 file changed

+28
-17
lines changed
  • invokeai/backend/model_manager/configs

1 file changed

+28
-17
lines changed

invokeai/backend/model_manager/configs/main.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -323,16 +323,26 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
323323
return False
324324

325325

326+
def _filename_suggests_base(name: str) -> bool:
327+
"""Check if a model name/filename suggests it is a Base (undistilled) variant.
328+
329+
Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished
330+
from the state dict. We use the filename as a heuristic: filenames containing "base"
331+
(e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model.
332+
"""
333+
return "base" in name.lower()
334+
335+
326336
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
327337
"""Determine FLUX.2 variant from state dict.
328338
329339
Distinguishes between Klein 4B and Klein 9B based on context embedding dimension:
330340
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
331341
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
332342
333-
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
334-
We default to Klein9B (distilled) for all 9B models since GGUF models may not
335-
include guidance embedding keys needed to distinguish them.
343+
Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
344+
and cannot be distinguished from the state dict alone. This function defaults to Klein9B
345+
for all 9B models. Callers should use filename heuristics to detect Klein9BBase.
336346
337347
Supports both BFL format (checkpoint) and diffusers format keys:
338348
- BFL format: txt_in.weight (context embedder)
@@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
366376
context_in_dim = shape[1]
367377
# Determine variant based on context dimension
368378
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
369-
# Default to Klein9B (distilled) - the official/common 9B model
379+
# Default to Klein9B - callers use filename heuristics to detect Klein9BBase
370380
return Flux2VariantType.Klein9B
371381
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
372382
return Flux2VariantType.Klein4B
@@ -553,6 +563,11 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
553563
if variant is None:
554564
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
555565

566+
# Klein 9B Base and Klein 9B have identical architectures.
567+
# Use filename heuristic to detect the Base (undistilled) variant.
568+
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
569+
return Flux2VariantType.Klein9BBase
570+
556571
return variant
557572

558573
@classmethod
@@ -720,6 +735,11 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
720735
if variant is None:
721736
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
722737

738+
# Klein 9B Base and Klein 9B have identical architectures.
739+
# Use filename heuristic to detect the Base (undistilled) variant.
740+
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
741+
return Flux2VariantType.Klein9BBase
742+
723743
return variant
724744

725745
@classmethod
@@ -829,30 +849,21 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
829849
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
830850
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
831851
832-
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
833-
we check guidance_embeds:
834-
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
835-
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
836-
837-
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
852+
Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
853+
and both have guidance_embeds=False. We use a filename heuristic to detect Base models.
838854
"""
839855
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
840856
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
841857

842858
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
843859

844860
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
845-
guidance_embeds = transformer_config.get("guidance_embeds", False)
846861

847862
# Determine variant based on joint_attention_dim
848863
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
849-
# Check guidance_embeds to distinguish distilled from undistilled
850-
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
851-
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
852-
if guidance_embeds:
864+
if _filename_suggests_base(mod.name):
853865
return Flux2VariantType.Klein9BBase
854-
else:
855-
return Flux2VariantType.Klein9B
866+
return Flux2VariantType.Klein9B
856867
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
857868
return Flux2VariantType.Klein4B
858869
elif joint_attention_dim > 4096:

0 commit comments

Comments
 (0)