Skip to content

Commit 4851a8b

Browse files
committed
feat(ltx2): make run_text_encoder_on_tpu default False and dynamically load torchax
1 parent cfff435 commit 4851a8b

3 files changed

Lines changed: 10 additions & 5 deletions

File tree

src/maxdiffusion/configs/ltx2_3_video.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ profiler_steps: 5
108108

109109
replicate_vae: False
110110

111+
run_text_encoder_on_tpu: False
112+
# Dynamically disables VAE slicing and distributes the batch dimension to avoid HBM OOM for larger batch sizes.
113+
enable_dynamic_vae_sharding: True
114+
111115
allow_split_physical_axes: False
112116
learning_rate_schedule_steps: -1
113117
max_train_steps: 500

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ profiler_steps: 5
114114
replicate_vae: False
115115
use_bwe: False
116116

117-
run_text_encoder_on_tpu: True
117+
run_text_encoder_on_tpu: False
118118
# Dynamically disables VAE slicing and distributes the batch dimension to avoid HBM OOM for larger batch sizes.
119119
enable_dynamic_vae_sharding: True
120120
allow_split_physical_axes: False

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
import jax
2424
import jax.numpy as jnp
2525
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
26-
from torchax import default_env
27-
from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder
2826
from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs
2927
import contextlib
3028
import flax
@@ -352,7 +350,10 @@ def load_text_encoder(cls, config: HyperParameters):
352350
)
353351
text_encoder.eval()
354352

355-
if getattr(config, "run_text_encoder_on_tpu", True):
353+
if getattr(config, "run_text_encoder_on_tpu", False):
354+
from torchax import default_env
355+
from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder
356+
356357
with default_env():
357358
text_encoder = text_encoder.to("jax")
358359
text_encoder = TorchaxGemma3TextEncoder(text_encoder)
@@ -855,7 +856,7 @@ def _get_gemma_prompt_embeds(
855856
prompt = [p.strip() for p in prompt]
856857

857858
if self.text_encoder is not None:
858-
run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", True) if hasattr(self, "config") else True
859+
run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", False) if hasattr(self, "config") else False
859860
if run_text_encoder_on_tpu:
860861
# Torchax Text Encoder
861862
text_inputs = self.tokenizer(

0 commit comments

Comments
 (0)