|
23 | 23 | import jax |
24 | 24 | import jax.numpy as jnp |
25 | 25 | from jax.sharding import Mesh, NamedSharding, PartitionSpec as P |
26 | | -from torchax import default_env |
27 | | -from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder |
28 | 26 | from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs |
29 | 27 | import contextlib |
30 | 28 | import flax |
@@ -352,7 +350,10 @@ def load_text_encoder(cls, config: HyperParameters): |
352 | 350 | ) |
353 | 351 | text_encoder.eval() |
354 | 352 |
|
355 | | - if getattr(config, "run_text_encoder_on_tpu", True): |
| 353 | + if getattr(config, "run_text_encoder_on_tpu", False): |
| 354 | + from torchax import default_env |
| 355 | + from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder |
| 356 | + |
356 | 357 | with default_env(): |
357 | 358 | text_encoder = text_encoder.to("jax") |
358 | 359 | text_encoder = TorchaxGemma3TextEncoder(text_encoder) |
@@ -855,7 +856,7 @@ def _get_gemma_prompt_embeds( |
855 | 856 | prompt = [p.strip() for p in prompt] |
856 | 857 |
|
857 | 858 | if self.text_encoder is not None: |
858 | | - run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", True) if hasattr(self, "config") else True |
| 859 | + run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", False) if hasattr(self, "config") else False |
859 | 860 | if run_text_encoder_on_tpu: |
860 | 861 | # Torchax Text Encoder |
861 | 862 | text_inputs = self.tokenizer( |
|
0 commit comments