Skip to content

Commit 8a404dd

Browse files
committed
weight loading
1 parent 0b86f50 commit 8a404dd

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def load_vae_weights(
276276
for pt_key, tensor in tensors.items():
277277
# Filter keys for combined checkpoint to avoid noise and memory overhead
278278
if filename == "ltx-2.3-22b-dev.safetensors":
279-
if not (pt_key.startswith("vae.") or pt_key.startswith("audio_vae.")):
279+
if not pt_key.startswith("vae."):
280280
continue
281281

282282
# latents_mean and latents_std are nnx.Params and will be loaded correctly.
@@ -289,8 +289,9 @@ def load_vae_weights(
289289
renamed_pt_key = renamed_pt_key.replace("nin_shortcut", "conv_shortcut")
290290

291291
pt_tuple_key = tuple(renamed_pt_key.split("."))
292-
if needs_vae_prefix and pt_tuple_key[0] != "vae":
293-
pt_tuple_key = ("vae",) + pt_tuple_key
292+
# Remove 'vae' prefix to match model structure which expects 'encoder'/'decoder' directly
293+
if pt_tuple_key[0] == "vae":
294+
pt_tuple_key = pt_tuple_key[1:]
294295

295296
pt_list = []
296297
resnet_index = None

0 commit comments

Comments
 (0)