@@ -323,16 +323,26 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
323323 return False
324324
325325
326+ def _filename_suggests_base (name : str ) -> bool :
327+ """Check if a model name/filename suggests it is a Base (undistilled) variant.
328+
329+ Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished
330+ from the state dict. We use the filename as a heuristic: filenames containing "base"
331+ (e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model.
332+ """
333+ return "base" in name .lower ()
334+
335+
326336def _get_flux2_variant (state_dict : dict [str | int , Any ]) -> Flux2VariantType | None :
327337 """Determine FLUX.2 variant from state dict.
328338
329339 Distinguishes between Klein 4B and Klein 9B based on context embedding dimension:
330340 - Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
331341 - Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
332342
333- Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
334- We default to Klein9B (distilled) for all 9B models since GGUF models may not
335- include guidance embedding keys needed to distinguish them .
343+ Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
344+ and cannot be distinguished from the state dict alone. This function defaults to Klein9B
345+ for all 9B models. Callers should use filename heuristics to detect Klein9BBase .
336346
337347 Supports both BFL format (checkpoint) and diffusers format keys:
338348 - BFL format: txt_in.weight (context embedder)
@@ -366,7 +376,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
366376 context_in_dim = shape [1 ]
367377 # Determine variant based on context dimension
368378 if context_in_dim == KLEIN_9B_CONTEXT_DIM :
369- # Default to Klein9B (distilled) - the official/common 9B model
379+ # Default to Klein9B - callers use filename heuristics to detect Klein9BBase
370380 return Flux2VariantType .Klein9B
371381 elif context_in_dim == KLEIN_4B_CONTEXT_DIM :
372382 return Flux2VariantType .Klein4B
@@ -553,6 +563,11 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
553563 if variant is None :
554564 raise NotAMatchError ("unable to determine FLUX.2 model variant from state dict" )
555565
566+ # Klein 9B Base and Klein 9B have identical architectures.
567+ # Use filename heuristic to detect the Base (undistilled) variant.
568+ if variant == Flux2VariantType .Klein9B and _filename_suggests_base (mod .name ):
569+ return Flux2VariantType .Klein9BBase
570+
556571 return variant
557572
558573 @classmethod
@@ -720,6 +735,11 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
720735 if variant is None :
721736 raise NotAMatchError ("unable to determine FLUX.2 model variant from state dict" )
722737
738+ # Klein 9B Base and Klein 9B have identical architectures.
739+ # Use filename heuristic to detect the Base (undistilled) variant.
740+ if variant == Flux2VariantType .Klein9B and _filename_suggests_base (mod .name ):
741+ return Flux2VariantType .Klein9BBase
742+
723743 return variant
724744
725745 @classmethod
@@ -829,30 +849,21 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
829849 - Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
830850 - Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
831851
832- To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
833- we check guidance_embeds:
834- - Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
835- - Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
836-
837- Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
852+ Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
853+ and both have guidance_embeds=False. We use a filename heuristic to detect Base models.
838854 """
839855 KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
840856 KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
841857
842858 transformer_config = get_config_dict_or_raise (mod .path / "transformer" / "config.json" )
843859
844860 joint_attention_dim = transformer_config .get ("joint_attention_dim" , 4096 )
845- guidance_embeds = transformer_config .get ("guidance_embeds" , False )
846861
847862 # Determine variant based on joint_attention_dim
848863 if joint_attention_dim == KLEIN_9B_CONTEXT_DIM :
849- # Check guidance_embeds to distinguish distilled from undistilled
850- # Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
851- # Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
852- if guidance_embeds :
864+ if _filename_suggests_base (mod .name ):
853865 return Flux2VariantType .Klein9BBase
854- else :
855- return Flux2VariantType .Klein9B
866+ return Flux2VariantType .Klein9B
856867 elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM :
857868 return Flux2VariantType .Klein4B
858869 elif joint_attention_dim > 4096 :
0 commit comments