File tree Expand file tree Collapse file tree
invokeai/backend/model_manager/load/model_loaders Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -253,16 +253,35 @@ def _load_from_singlefile(
253253 target_device = TorchDevice .choose_torch_device ()
254254 model_dtype = TorchDevice .choose_bfloat16_safe_dtype (target_device )
255255
256+ # Filter out keys that don't belong to the ZImageTransformer2DModel.
257+ # Merged checkpoints (e.g. LoRA-baked models) may bundle text encoder weights
258+ # (text_encoders.*) or other non-transformer keys alongside the transformer weights.
259+ # Also filter FP8 quantization metadata (scale_weight, scaled_fp8).
260+ valid_prefixes = (
261+ "all_x_embedder." ,
262+ "all_final_layer." ,
263+ "layers." ,
264+ "noise_refiner." ,
265+ "context_refiner." ,
266+ "t_embedder." ,
267+ "cap_embedder." ,
268+ "rope_embedder." ,
269+ )
270+ valid_exact = {"x_pad_token" , "cap_pad_token" }
271+ keys_to_remove = [
272+ k
273+ for k in sd .keys ()
274+ if not (k .startswith (valid_prefixes ) or k in valid_exact )
275+ or k .endswith (".scale_weight" )
276+ or k == "scaled_fp8"
277+ ]
278+ for k in keys_to_remove :
279+ del sd [k ]
280+
256281 # Handle memory management and dtype conversion
257282 new_sd_size = sum ([ten .nelement () * model_dtype .itemsize for ten in sd .values ()])
258283 self ._ram_cache .make_room (new_sd_size )
259284
260- # Filter out FP8 scale_weight and scaled_fp8 metadata keys
261- # These are quantization metadata that shouldn't be loaded into the model
262- keys_to_remove = [k for k in sd .keys () if k .endswith (".scale_weight" ) or k == "scaled_fp8" ]
263- for k in keys_to_remove :
264- del sd [k ]
265-
266285 # Convert to target dtype
267286 for k in sd .keys ():
268287 sd [k ] = sd [k ].to (model_dtype )
You can’t perform that action at this time.
0 commit comments