@@ -319,7 +319,7 @@ def __init__(
319319 )
320320
321321 key = rngs .params ()
322- k1 , k2 , k3 , k4 = jax .random .split (key , 4 )
322+ k1 , k2 , k3 , k4 , k5 , k6 = jax .random .split (key , 6 )
323323
324324 self .cross_attn_mod = cross_attn_mod
325325 table_size = 9 if cross_attn_mod else 6
@@ -339,6 +339,15 @@ def __init__(
339339 jax .random .normal (k4 , (5 , audio_dim ), dtype = weights_dtype ),
340340 kernel_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
341341 )
342+ if self .cross_attn_mod :
343+ self .prompt_scale_shift_table = nnx .Param (
344+ jax .random .normal (k5 , (2 , self .dim ), dtype = weights_dtype ) / jnp .sqrt (self .dim ),
345+ kernel_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
346+ )
347+ self .audio_prompt_scale_shift_table = nnx .Param (
348+ jax .random .normal (k6 , (2 , audio_dim ), dtype = weights_dtype ) / jnp .sqrt (audio_dim ),
349+ kernel_init = nnx .with_partitioning (nnx .initializers .zeros , (None , "embed" )),
350+ )
342351
343352 def __call__ (
344353 self ,
@@ -353,6 +362,8 @@ def __call__(
353362 temb_ca_audio_scale_shift : jax .Array ,
354363 temb_ca_gate : jax .Array ,
355364 temb_ca_audio_gate : jax .Array ,
365+ temb_prompt : Optional [jax .Array ] = None ,
366+ temb_prompt_audio : Optional [jax .Array ] = None ,
356367 # RoPE
357368 video_rotary_emb : Optional [Tuple [jax .Array , jax .Array ]] = None ,
358369 audio_rotary_emb : Optional [Tuple [jax .Array , jax .Array ]] = None ,
@@ -445,6 +456,14 @@ def __call__(
445456 if getattr (self , "cross_attn_mod" , False ):
446457 norm_hidden_states = norm_hidden_states * (1 + scale_q ) + shift_q
447458
459+ if getattr (self , "cross_attn_mod" , False ) and temb_prompt is not None :
460+ prompt_table_reshaped = jnp .expand_dims (self .prompt_scale_shift_table , axis = (0 , 1 ))
461+ temb_prompt_reshaped = temb_prompt .reshape (batch_size , 1 , 2 , - 1 )
462+ prompt_ada_values = prompt_table_reshaped + temb_prompt_reshaped
463+ shift_text_kv = prompt_ada_values [:, :, 0 , :]
464+ scale_text_kv = prompt_ada_values [:, :, 1 , :]
465+ encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv ) + shift_text_kv
466+
448467 attn_hidden_states = self .attn2 (
449468 norm_hidden_states ,
450469 encoder_hidden_states = encoder_hidden_states ,
@@ -461,6 +480,14 @@ def __call__(
461480 if getattr (self , "cross_attn_mod" , False ):
462481 norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_q ) + audio_shift_q
463482
483+ if getattr (self , "cross_attn_mod" , False ) and temb_prompt_audio is not None :
484+ audio_prompt_table_reshaped = jnp .expand_dims (self .audio_prompt_scale_shift_table , axis = (0 , 1 ))
485+ temb_prompt_audio_reshaped = temb_prompt_audio .reshape (batch_size , 1 , 2 , - 1 )
486+ audio_prompt_ada_values = audio_prompt_table_reshaped + temb_prompt_audio_reshaped
487+ audio_shift_text_kv = audio_prompt_ada_values [:, :, 0 , :]
488+ audio_scale_text_kv = audio_prompt_ada_values [:, :, 1 , :]
489+ audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv ) + audio_shift_text_kv
490+
464491 attn_audio_hidden_states = self .audio_attn2 (
465492 norm_audio_hidden_states ,
466493 encoder_hidden_states = audio_encoder_hidden_states ,
@@ -785,6 +812,25 @@ def __init__(
785812 weights_dtype = self .weights_dtype ,
786813 )
787814
815+ # 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3)
816+ if self .cross_attn_mod :
817+ self .prompt_adaln = LTX2AdaLayerNormSingle (
818+ rngs = rngs ,
819+ embedding_dim = inner_dim ,
820+ num_mod_params = 2 ,
821+ use_additional_conditions = False ,
822+ dtype = self .dtype ,
823+ weights_dtype = self .weights_dtype ,
824+ )
825+ self .audio_prompt_adaln = LTX2AdaLayerNormSingle (
826+ rngs = rngs ,
827+ embedding_dim = audio_inner_dim ,
828+ num_mod_params = 2 ,
829+ use_additional_conditions = False ,
830+ dtype = self .dtype ,
831+ weights_dtype = self .weights_dtype ,
832+ )
833+
788834 # 3. Output Layer Scale/Shift Modulation parameters
789835 param_rng = rngs .params ()
790836 self .scale_shift_table = nnx .Param (
@@ -969,6 +1015,8 @@ def __call__(
9691015 audio_encoder_hidden_states : jax .Array ,
9701016 timestep : jax .Array ,
9711017 audio_timestep : Optional [jax .Array ] = None ,
1018+ sigma : Optional [jax .Array ] = None ,
1019+ audio_sigma : Optional [jax .Array ] = None ,
9721020 encoder_attention_mask : Optional [jax .Array ] = None ,
9731021 audio_encoder_attention_mask : Optional [jax .Array ] = None ,
9741022 num_frames : Optional [int ] = None ,
@@ -1032,6 +1080,22 @@ def __call__(
10321080 temb_audio = temb_audio .reshape (batch_size , - 1 , temb_audio .shape [- 1 ])
10331081 audio_embedded_timestep = audio_embedded_timestep .reshape (batch_size , - 1 , audio_embedded_timestep .shape [- 1 ])
10341082
1083+ if self .cross_attn_mod and sigma is not None :
1084+ audio_sigma = audio_sigma if audio_sigma is not None else sigma
1085+ temb_prompt , _ = self .prompt_adaln (
1086+ sigma .flatten (),
1087+ hidden_dtype = hidden_states .dtype ,
1088+ )
1089+ temb_prompt_audio , _ = self .audio_prompt_adaln (
1090+ audio_sigma .flatten (),
1091+ hidden_dtype = audio_hidden_states .dtype ,
1092+ )
1093+ temb_prompt = temb_prompt .reshape (batch_size , - 1 , temb_prompt .shape [- 1 ])
1094+ temb_prompt_audio = temb_prompt_audio .reshape (batch_size , - 1 , temb_prompt_audio .shape [- 1 ])
1095+ else :
1096+ temb_prompt = None
1097+ temb_prompt_audio = None
1098+
10351099 video_cross_attn_scale_shift , _ = self .av_cross_attn_video_scale_shift (
10361100 timestep .flatten (),
10371101 hidden_dtype = hidden_states .dtype ,
@@ -1094,6 +1158,8 @@ def scan_fn(carry, block_and_mask):
10941158 temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
10951159 temb_ca_gate = video_cross_attn_a2v_gate ,
10961160 temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1161+ temb_prompt = temb_prompt ,
1162+ temb_prompt_audio = temb_prompt_audio ,
10971163 video_rotary_emb = video_rotary_emb ,
10981164 audio_rotary_emb = audio_rotary_emb ,
10991165 ca_video_rotary_emb = video_cross_attn_rotary_emb ,
@@ -1135,6 +1201,8 @@ def scan_fn(carry, block_and_mask):
11351201 temb_ca_audio_scale_shift = audio_cross_attn_scale_shift ,
11361202 temb_ca_gate = video_cross_attn_a2v_gate ,
11371203 temb_ca_audio_gate = audio_cross_attn_v2a_gate ,
1204+ temb_prompt = temb_prompt ,
1205+ temb_prompt_audio = temb_prompt_audio ,
11381206 video_rotary_emb = video_rotary_emb ,
11391207 audio_rotary_emb = audio_rotary_emb ,
11401208 ca_video_rotary_emb = video_cross_attn_rotary_emb ,
0 commit comments