@@ -274,6 +274,11 @@ def load_vae_weights(
274274 needs_vae_prefix = any (key [0 ] == "vae" for key in random_flax_state_dict )
275275
276276 for pt_key , tensor in tensors .items ():
277+ # Filter keys for combined checkpoint to avoid noise and memory overhead
278+ if filename == "ltx-2.3-22b-dev.safetensors" :
279+ if not (pt_key .startswith ("vae." ) or pt_key .startswith ("audio_vae." )):
280+ continue
281+
277282 # latents_mean and latents_std are nnx.Params and will be loaded correctly.
278283 new_key = pt_key
279284 if filename == "ltx-2.3-22b-dev.safetensors" :
@@ -295,7 +300,7 @@ def load_vae_weights(
295300 name = "_" .join (part .split ("_" )[:- 1 ])
296301 idx = int (part .split ("_" )[- 1 ])
297302
298- if name == "resnets" :
303+ if name == "resnets" or name == "block" :
299304 pt_list .append ("resnets" )
300305 resnet_index = idx
301306 elif name == "upsamplers" :
@@ -322,7 +327,7 @@ def load_vae_weights(
322327
323328 flax_key , flax_tensor = rename_key_and_reshape_tensor (pt_tuple_key , tensor , random_flax_state_dict )
324329 flax_key = _tuple_str_to_int (flax_key )
325- max_logging .log (f"Mapped VAE key: { pt_key } -> { flax_key } " )
330+ max_logging .log (f"Mapped key: { pt_key } -> { flax_key } " )
326331
327332 if resnet_index is not None :
328333 str_flax_key = tuple ([str (x ) for x in flax_key ])
0 commit comments