Skip to content

Commit fc9c405

Browse files
committed
fix list-type prompt embeds shape error
1 parent 18f5d63 commit fc9c405

1 file changed

Lines changed: 17 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,13 +1399,25 @@ def __call__(
13991399
latents_jax = jnp.concatenate([latents_jax] * 2, axis=0)
14001400
audio_latents_jax = jnp.concatenate([audio_latents_jax] * 2, axis=0)
14011401
def _print_stats_gemma(name, tensor):
1402-
print(
1403-
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 6)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
1404-
)
1402+
if isinstance(tensor, list):
1403+
for idx, t in enumerate(tensor):
1404+
print(
1405+
f"DEBUG {name}_{idx} shape: {t.shape}, mean: {jnp.round(jnp.mean(t), 6)}, min: {jnp.round(jnp.min(t), 4)}, max: {jnp.round(jnp.max(t), 4)}, std: {jnp.round(jnp.std(t), 4)}"
1406+
)
1407+
else:
1408+
print(
1409+
f"DEBUG {name} shape: {tensor.shape}, mean: {jnp.round(jnp.mean(tensor), 6)}, min: {jnp.round(jnp.min(tensor), 4)}, max: {jnp.round(jnp.max(tensor), 4)}, std: {jnp.round(jnp.std(tensor), 4)}"
1410+
)
1411+
14051412
if do_cfg and do_stg:
1406-
_print_stats_gemma("text_encoder_output_flattened", prompt_embeds_jax[:2])
1413+
if isinstance(prompt_embeds_jax, list):
1414+
prompt_embeds_to_print = [x[:2] for x in prompt_embeds_jax]
1415+
else:
1416+
prompt_embeds_to_print = prompt_embeds_jax[:2]
14071417
else:
1408-
_print_stats_gemma("text_encoder_output_flattened", prompt_embeds_jax)
1418+
prompt_embeds_to_print = prompt_embeds_jax
1419+
1420+
_print_stats_gemma("text_encoder_output_flattened", prompt_embeds_to_print)
14091421

14101422
if hasattr(self, "mesh") and self.mesh is not None:
14111423
data_sharding_3d = NamedSharding(self.mesh, P())

0 commit comments

Comments
 (0)