Skip to content

Commit 01b9702

Browse files
committed
force 4way
1 parent f122c7c commit 01b9702

1 file changed

Lines changed: 11 additions & 10 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,8 +1370,9 @@ def __call__(
13701370

13711371
do_cfg = guidance_scale > 1.0
13721372
do_stg = stg_scale > 0.0
1373+
force_4way = getattr(self.config, "model_name", "") == "ltx2.3"
13731374

1374-
if do_cfg and do_stg:
1375+
if force_4way or (do_cfg and do_stg):
13751376
negative_prompt_embeds_jax = negative_prompt_embeds
13761377
negative_prompt_attention_mask_jax = negative_prompt_attention_mask
13771378
if isinstance(prompt_embeds_jax, list):
@@ -1473,6 +1474,7 @@ def __call__(
14731474
self.scheduler.step,
14741475
tuple(tuple(rule) if isinstance(rule, list) else rule for rule in self.config.logical_axis_rules),
14751476
use_cross_timestep=use_cross_timestep,
1477+
force_4way=force_4way,
14761478
)
14771479
else:
14781480
# Old Python loop path
@@ -1767,7 +1769,7 @@ def transformer_forward_pass(
17671769
sigma=None,
17681770
audio_sigma=None,
17691771
use_cross_timestep=False,
1770-
is_cfg_stg_mode: bool = False,
1772+
is_4way: bool = False,
17711773
):
17721774
transformer = nnx.merge(graphdef, state)
17731775

@@ -1784,7 +1786,7 @@ def transformer_forward_pass(
17841786
else:
17851787
audio_sigma = jnp.expand_dims(audio_sigma, 0).repeat(latents.shape[0])
17861788

1787-
if is_cfg_stg_mode:
1789+
if is_4way:
17881790
# 4-way split layout: [Uncond, Cond, Perturb, Isolated]
17891791
ones_mask = jnp.ones((global_batch_size, 1, 1), dtype=latents.dtype)
17901792
zeros_mask = jnp.zeros((global_batch_size, 1, 1), dtype=latents.dtype)
@@ -1874,13 +1876,15 @@ def run_diffusion_loop(
18741876
logical_axis_rules,
18751877
perturbation_mask=None,
18761878
use_cross_timestep=False,
1879+
force_4way=False,
18771880
):
18781881
latents_jax = latents_jax.astype(jnp.float32)
18791882
audio_latents_jax = audio_latents_jax.astype(jnp.float32)
18801883
transformer = nnx.merge(graphdef, state)
18811884

18821885
do_cfg = guidance_scale > 1.0
18831886
do_stg = stg_scale > 0.0
1887+
use_4way = force_4way or (do_cfg and do_stg)
18841888

18851889
# Helper functions matching Diffusers Delta formulation
18861890
def convert_to_x0(lat, vel, sigma_t):
@@ -1924,25 +1928,22 @@ def scan_body(carry, inputs):
19241928
sigma=sigma_t,
19251929
audio_sigma=sigma_t,
19261930
use_cross_timestep=use_cross_timestep,
1927-
is_cfg_stg_mode=do_cfg and do_stg,
1931+
is_4way=use_4way,
19281932
)
19291933

19301934
# Extract latents_step based on stacking strategy
1931-
if do_cfg and do_stg:
1935+
if use_4way:
19321936
latents_step = latents[batch_size : 2 * batch_size]
19331937
audio_latents_step = audio_latents[batch_size : 2 * batch_size]
19341938
elif do_cfg:
19351939
latents_step = latents[batch_size:]
19361940
audio_latents_step = audio_latents[batch_size:]
1937-
elif do_stg:
1938-
latents_step = latents[:batch_size]
1939-
audio_latents_step = audio_latents[:batch_size]
19401941
else:
19411942
latents_step = latents
19421943
audio_latents_step = audio_latents
19431944

19441945
# Apply Diffusers STG + CFG + Modality Delta Logic
1945-
if do_cfg and do_stg:
1946+
if use_4way:
19461947
noise_pred_uncond, noise_pred_text, noise_pred_perturb, noise_pred_isolated = jnp.split(noise_pred, 4, axis=0)
19471948

19481949
# Convert to x0
@@ -1997,7 +1998,7 @@ def scan_body(carry, inputs):
19971998
audio_latents_step, _ = scheduler_step(s_state, noise_pred_audio, t, audio_latents_step, return_dict=False)
19981999
audio_latents_step = audio_latents_step.astype(audio_latents.dtype)
19992000

2000-
if do_cfg and do_stg:
2001+
if use_4way:
20012002
latents_next = jnp.concatenate([latents_step] * 4, axis=0)
20022003
audio_latents_next = jnp.concatenate([audio_latents_step] * 4, axis=0)
20032004
elif do_cfg:

0 commit comments

Comments
 (0)