Skip to content

Commit 84aa5ef

Browse files
committed
connectors and feat extractors
1 parent 7f5057f commit 84aa5ef

4 files changed

Lines changed: 22 additions & 6 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
"model.diffusion_model.": "",
1818
"connectors.": "",
1919
"transformer_1d_blocks": "stacked_blocks",
20-
"text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in",
21-
"text_embedding_projection.video_aggregate_embed": "video_text_proj_in",
20+
"text_embedding_projection.audio_aggregate_embed.weight": "feature_extractor.audio_linear.kernel",
21+
"text_embedding_projection.audio_aggregate_embed.bias": "feature_extractor.audio_linear.bias",
22+
"text_embedding_projection.video_aggregate_embed.weight": "feature_extractor.video_linear.kernel",
23+
"text_embedding_projection.video_aggregate_embed.bias": "feature_extractor.video_linear.bias",
2224
"q_norm": "norm_q",
2325
"k_norm": "norm_k",
2426
"norm_q.weight": "norm_q.scale",
@@ -91,6 +93,10 @@ def load_connectors_weights(
9193

9294
accumulated_stacked[base_key][layer_idx] = tensor
9395
else:
96+
# Transpose projection kernels in feature extractor
97+
if "feature_extractor" in segments and segments[-1] == "kernel":
98+
tensor = jnp.transpose(tensor, (1, 0))
99+
94100
flax_key = _tuple_str_to_int(segments)
95101
flax_state_dict[flax_key] = jax.device_put(tensor, device=cpu)
96102

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,9 @@ def rename_for_ltx2_audio_vae(key):
496496
if "upsample.conv.bias" in key:
497497
key = key.replace("upsample.conv.bias", "upsample.conv.conv.bias")
498498

499+
key = key.replace("per_channel_statistics.mean-of-means", "latents_mean")
500+
key = key.replace("per_channel_statistics.std-of-means", "latents_std")
501+
499502
return key
500503

501504

src/maxdiffusion/models/ltx2/text_encoders/feature_extractor_ltx2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,21 @@ def __init__(
104104
rngs: nnx.Rngs = None,
105105
per_modality_projections: bool = False,
106106
use_bias: bool = False,
107+
video_output_dim: Optional[int] = None,
108+
audio_output_dim: Optional[int] = None,
107109
):
108110
"""
109111
Args:
110112
input_dim: Dimension of flattened hidden states (Gemma dim * Num layers).
111-
output_dim: Target dimension for diffusion conditioning.
113+
output_dim: Target dimension for diffusion conditioning (fallback).
112114
"""
113115
self.per_modality_projections = per_modality_projections
114116

115117
if per_modality_projections:
116-
self.video_linear = nnx.Linear(input_dim, output_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
117-
self.audio_linear = nnx.Linear(input_dim, output_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
118+
v_dim = video_output_dim if video_output_dim is not None else output_dim
119+
a_dim = audio_output_dim if audio_output_dim is not None else output_dim
120+
self.video_linear = nnx.Linear(input_dim, v_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
121+
self.audio_linear = nnx.Linear(input_dim, a_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
118122
else:
119123
self.linear = nnx.Linear(input_dim, output_dim, use_bias=use_bias, dtype=dtype, rngs=rngs)
120124

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def __init__(
6565
audio_gated_attn: bool = False,
6666
**kwargs,
6767
):
68-
input_dim = caption_channels * text_proj_in_factor
68+
gemma_dim = 3840 if video_caption_channels is not None else caption_channels
69+
input_dim = gemma_dim * text_proj_in_factor
6970

7071
v_dim = video_caption_channels if video_caption_channels is not None else caption_channels
7172
a_dim = audio_caption_channels if audio_caption_channels is not None else caption_channels
@@ -79,6 +80,8 @@ def __init__(
7980
rngs=rngs,
8081
per_modality_projections=per_modality_projections,
8182
use_bias=proj_bias,
83+
video_output_dim=v_dim,
84+
audio_output_dim=a_dim,
8285
)
8386

8487
# Two independent connectors

0 commit comments

Comments
 (0)