@@ -1370,8 +1370,9 @@ def __call__(
13701370
13711371 do_cfg = guidance_scale > 1.0
13721372 do_stg = stg_scale > 0.0
1373+ force_4way = getattr (self .config , "model_name" , "" ) == "ltx2.3"
13731374
1374- if do_cfg and do_stg :
1375+ if force_4way or ( do_cfg and do_stg ) :
13751376 negative_prompt_embeds_jax = negative_prompt_embeds
13761377 negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13771378 if isinstance (prompt_embeds_jax , list ):
@@ -1473,6 +1474,7 @@ def __call__(
14731474 self .scheduler .step ,
14741475 tuple (tuple (rule ) if isinstance (rule , list ) else rule for rule in self .config .logical_axis_rules ),
14751476 use_cross_timestep = use_cross_timestep ,
1477+ force_4way = force_4way ,
14761478 )
14771479 else :
14781480 # Old Python loop path
@@ -1767,7 +1769,7 @@ def transformer_forward_pass(
17671769 sigma = None ,
17681770 audio_sigma = None ,
17691771 use_cross_timestep = False ,
1770- is_cfg_stg_mode : bool = False ,
1772+ is_4way : bool = False ,
17711773):
17721774 transformer = nnx .merge (graphdef , state )
17731775
@@ -1784,7 +1786,7 @@ def transformer_forward_pass(
17841786 else :
17851787 audio_sigma = jnp .expand_dims (audio_sigma , 0 ).repeat (latents .shape [0 ])
17861788
1787- if is_cfg_stg_mode :
1789+ if is_4way :
17881790 # 4-way split layout: [Uncond, Cond, Perturb, Isolated]
17891791 ones_mask = jnp .ones ((global_batch_size , 1 , 1 ), dtype = latents .dtype )
17901792 zeros_mask = jnp .zeros ((global_batch_size , 1 , 1 ), dtype = latents .dtype )
@@ -1874,13 +1876,15 @@ def run_diffusion_loop(
18741876 logical_axis_rules ,
18751877 perturbation_mask = None ,
18761878 use_cross_timestep = False ,
1879+ force_4way = False ,
18771880):
18781881 latents_jax = latents_jax .astype (jnp .float32 )
18791882 audio_latents_jax = audio_latents_jax .astype (jnp .float32 )
18801883 transformer = nnx .merge (graphdef , state )
18811884
18821885 do_cfg = guidance_scale > 1.0
18831886 do_stg = stg_scale > 0.0
1887+ use_4way = force_4way or (do_cfg and do_stg )
18841888
18851889 # Helper functions matching Diffusers Delta formulation
18861890 def convert_to_x0 (lat , vel , sigma_t ):
@@ -1924,25 +1928,22 @@ def scan_body(carry, inputs):
19241928 sigma = sigma_t ,
19251929 audio_sigma = sigma_t ,
19261930 use_cross_timestep = use_cross_timestep ,
1927- is_cfg_stg_mode = do_cfg and do_stg ,
1931+ is_4way = use_4way ,
19281932 )
19291933
19301934 # Extract latents_step based on stacking strategy
1931- if do_cfg and do_stg :
1935+ if use_4way :
19321936 latents_step = latents [batch_size : 2 * batch_size ]
19331937 audio_latents_step = audio_latents [batch_size : 2 * batch_size ]
19341938 elif do_cfg :
19351939 latents_step = latents [batch_size :]
19361940 audio_latents_step = audio_latents [batch_size :]
1937- elif do_stg :
1938- latents_step = latents [:batch_size ]
1939- audio_latents_step = audio_latents [:batch_size ]
19401941 else :
19411942 latents_step = latents
19421943 audio_latents_step = audio_latents
19431944
19441945 # Apply Diffusers STG + CFG + Modality Delta Logic
1945- if do_cfg and do_stg :
1946+ if use_4way :
19461947 noise_pred_uncond , noise_pred_text , noise_pred_perturb , noise_pred_isolated = jnp .split (noise_pred , 4 , axis = 0 )
19471948
19481949 # Convert to x0
@@ -1997,7 +1998,7 @@ def scan_body(carry, inputs):
19971998 audio_latents_step , _ = scheduler_step (s_state , noise_pred_audio , t , audio_latents_step , return_dict = False )
19981999 audio_latents_step = audio_latents_step .astype (audio_latents .dtype )
19992000
2000- if do_cfg and do_stg :
2001+ if use_4way :
20012002 latents_next = jnp .concatenate ([latents_step ] * 4 , axis = 0 )
20022003 audio_latents_next = jnp .concatenate ([audio_latents_step ] * 4 , axis = 0 )
20032004 elif do_cfg :
0 commit comments