File tree Expand file tree Collapse file tree
src/maxdiffusion/models/ltx2 Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -276,7 +276,7 @@ def load_vae_weights(
276276 for pt_key , tensor in tensors .items ():
277277 # Filter keys for combined checkpoint to avoid noise and memory overhead
278278 if filename == "ltx-2.3-22b-dev.safetensors" :
279- if not ( pt_key .startswith ("vae." ) or pt_key . startswith ( "audio_vae." ) ):
279+ if not pt_key .startswith ("vae." ):
280280 continue
281281
282282 # latents_mean and latents_std are nnx.Params and will be loaded correctly.
@@ -289,8 +289,9 @@ def load_vae_weights(
289289 renamed_pt_key = renamed_pt_key .replace ("nin_shortcut" , "conv_shortcut" )
290290
291291 pt_tuple_key = tuple (renamed_pt_key .split ("." ))
292- if needs_vae_prefix and pt_tuple_key [0 ] != "vae" :
293- pt_tuple_key = ("vae" ,) + pt_tuple_key
292+ # Remove 'vae' prefix to match model structure which expects 'encoder'/'decoder' directly
293+ if pt_tuple_key [0 ] == "vae" :
294+ pt_tuple_key = pt_tuple_key [1 :]
294295
295296 pt_list = []
296297 resnet_index = None
You can’t perform that action at this time.
0 commit comments