Skip to content

Commit 7707c3d

Browse files
committed
Offload LTX-2 text encoder to TorchAX and resolve lint issues
1 parent 71b4138 commit 7707c3d

4 files changed

Lines changed: 176 additions & 44 deletions

File tree

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ tensorflow-datasets
3535
tensorflow
3636
tokamax
3737
tokenizers
38+
torchax>=0.0.11
3839
transformers<5.0.0
3940

4041
# pinning torch and torchvision to specific versions to avoid

dependencies/requirements/generated_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ toml>=0.10.2
179179
tomlkit>=0.14.0
180180
toolz>=1.1.0
181181
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
182+
torchax>=0.0.11
182183
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
183184
tqdm>=4.67.3
184185
transformers>=4.57.6
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Tuple
18+
19+
import torch
20+
import jax
21+
from torchax import interop, default_env
22+
23+
# --- Monkeypatch transformers masking_utils to avoid torchax integer tracing bug ---
24+
import transformers.masking_utils
25+
26+
_orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay
27+
28+
29+
def _patched_sliding_window_overlay(sliding_window: int):
30+
# pylint: disable=unused-argument
31+
32+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
33+
# Since sequence length < sliding window (e.g. 256 < 4096), this mask is always True.
34+
# We return a standard boolean tensor using new_ones to guarantee Torchax compatibility
35+
# and prevent any implicit tracing crashes.
36+
return q_idx.new_ones((), dtype=torch.bool)
37+
38+
return inner_mask
39+
40+
41+
transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay
42+
# -----------------------------------------------------------------------------------
43+
44+
45+
class TorchaxGemma3TextEncoder(interop.JittableModule):
46+
"""
47+
A jittable Torchax module for wrapping the HuggingFace PyTorch
48+
Gemma3ForConditionalGeneration text encoder.
49+
"""
50+
51+
def __init__(self, text_encoder):
52+
super().__init__(text_encoder, extra_jit_args={"static_argnames": ["output_hidden_states"]})
53+
54+
def __call__(
55+
self, input_ids: jax.Array, attention_mask: jax.Array, output_hidden_states: bool = True
56+
) -> Tuple[jax.Array, ...]:
57+
with default_env():
58+
input_ids = interop.torch_view(input_ids)
59+
attention_mask = interop.torch_view(attention_mask)
60+
61+
output = self.functional_call(
62+
self._forward_inner,
63+
params=self.params,
64+
buffers=self.buffers,
65+
input_ids=input_ids,
66+
attention_mask=attention_mask,
67+
output_hidden_states=output_hidden_states,
68+
)
69+
return interop.jax_view(output)
70+
71+
@staticmethod
72+
def _forward_inner(model, input_ids, attention_mask, output_hidden_states=True):
73+
# We only return hidden states as a tuple of tensors.
74+
# That allows interop.jax_view to convert them into a tuple of jax Arrays
75+
return model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states).hidden_states

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 99 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,20 @@
2222
import jax
2323
import jax.numpy as jnp
2424
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
25+
from torchax import default_env
26+
from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder
27+
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
28+
from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs
29+
import contextlib
2530
import flax
2631
import flax.linen as nn
2732
import flax.traverse_util
2833
from flax import nnx
2934
from flax.linen import partitioning as nn_partitioning
3035
from transformers import AutoTokenizer, GemmaTokenizer, GemmaTokenizerFast, Gemma3ForConditionalGeneration
31-
from maxdiffusion.tpu_utils import get_tpu_type, TpuType
3236
import qwix
3337
from ...utils import logging
34-
from ...schedulers import FlaxFlowMatchScheduler
38+
from ...schedulers import FlaxFlowMatchScheduler # pylint: disable=no-name-in-module
3539
from ...models.ltx2.autoencoder_kl_ltx2 import LTX2VideoAutoencoderKL
3640
from ...models.ltx2.autoencoder_kl_ltx2_audio import FlaxAutoencoderKLLTX2Audio
3741
from ...models.ltx2.vocoder_ltx2 import LTX2Vocoder
@@ -53,7 +57,6 @@
5357
from ... import max_logging
5458
from ... import max_utils
5559
from ...max_utils import get_precision, device_put_replicated, get_flash_block_sizes
56-
from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs
5760

5861

5962
@flax.struct.dataclass
@@ -65,7 +68,8 @@ class LTX2PipelineOutput:
6568
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
6669
"""
6770
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure.
68-
Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).
71+
Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed]
72+
(https://huggingface.co/papers/2305.08891).
6973
"""
7074
std_text = jnp.std(noise_pred_text, axis=list(range(1, noise_pred_text.ndim)), keepdims=True)
7175
std_cfg = jnp.std(noise_cfg, axis=list(range(1, noise_cfg.ndim)), keepdims=True)
@@ -110,6 +114,9 @@ def create_sharded_logical_transformer(
110114
restored_checkpoint=None,
111115
subfolder: str = "",
112116
):
117+
"""Creates a sharded logical transformer."""
118+
119+
# pylint: disable=too-many-positional-arguments,unused-argument
113120
def create_model(rngs: nnx.Rngs, ltx2_config: dict):
114121
transformer = LTX2VideoTransformer3DModel(**ltx2_config, rngs=rngs)
115122
return transformer
@@ -186,6 +193,8 @@ def calculate_shift(
186193
base_shift: float = 0.5,
187194
max_shift: float = 1.15,
188195
):
196+
"""Calculates the shift for the timestep schedule."""
197+
# pylint: disable=too-many-positional-arguments
189198
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
190199
b = base_shift - m * base_seq_len
191200
mu = image_seq_len * m + b
@@ -200,6 +209,8 @@ def retrieve_timesteps(
200209
sigmas: Optional[List[float]] = None,
201210
**kwargs,
202211
):
212+
"""Retrieves timesteps for the scheduler."""
213+
# pylint: disable=too-many-positional-arguments
203214
if timesteps is not None and sigmas is not None:
204215
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
205216

@@ -222,6 +233,8 @@ class LTX2Pipeline:
222233
Pipeline for LTX-2.
223234
"""
224235

236+
# pylint: disable=missing-function-docstring,too-many-positional-arguments,unused-argument
237+
225238
def __init__(
226239
self,
227240
scheduler: FlaxFlowMatchScheduler,
@@ -245,6 +258,8 @@ def __init__(
245258
self.transformer = transformer
246259
self.latent_upsampler = latent_upsampler
247260
self.latent_upsampler_params = latent_upsampler_params
261+
self.mesh = None
262+
self.config = None
248263

249264
# VAE compression ratios
250265
self.vae_spatial_compression_ratio = getattr(self.vae, "spatial_compression_ratio", 32)
@@ -316,6 +331,11 @@ def load_text_encoder(cls, config: HyperParameters):
316331
torch_dtype=torch.bfloat16,
317332
)
318333
text_encoder.eval()
334+
335+
with default_env():
336+
text_encoder = text_encoder.to("jax")
337+
text_encoder = TorchaxGemma3TextEncoder(text_encoder)
338+
319339
return text_encoder
320340

321341
@classmethod
@@ -396,7 +416,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
396416
sharding = sharding.get_value()
397417
try:
398418
replicate_vae = config.replicate_vae
399-
except ValueError:
419+
except Exception: # pylint: disable=broad-exception-caught
400420
replicate_vae = False
401421
if replicate_vae:
402422
sharding = NamedSharding(mesh, P())
@@ -444,7 +464,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
444464
sharding = sharding.get_value()
445465
try:
446466
replicate_vae = config.replicate_vae
447-
except ValueError:
467+
except Exception: # pylint: disable=broad-exception-caught
448468
replicate_vae = False
449469
if replicate_vae:
450470
sharding = NamedSharding(mesh, P())
@@ -750,39 +770,48 @@ def _get_gemma_prompt_embeds(
750770
prompt = [p.strip() for p in prompt]
751771

752772
if self.text_encoder is not None:
753-
# PyTorch Text Encoder
773+
# Torchax Text Encoder
754774
text_inputs = self.tokenizer(
755775
prompt,
756776
padding="max_length",
757777
max_length=max_sequence_length,
758778
truncation=True,
759779
add_special_tokens=True,
760-
return_tensors="pt",
780+
return_tensors="np",
781+
)
782+
text_input_ids = jnp.array(text_inputs.input_ids)
783+
prompt_attention_mask = jnp.array(text_inputs.attention_mask)
784+
785+
# Distribute the batch dimension across available TPUs to prevent Softmax OOM
786+
# (reduces 512MB allocation down to 64MB per TPU for batch size 16)
787+
devices = np.array(jax.devices())
788+
num_shards = 1
789+
for i in range(len(devices), 0, -1):
790+
if text_input_ids.shape[0] % i == 0:
791+
num_shards = i
792+
break
793+
794+
if num_shards > 1:
795+
mesh = Mesh(devices[:num_shards], axis_names=("batch",))
796+
sharding = NamedSharding(mesh, P("batch"))
797+
text_input_ids = jax.device_put(text_input_ids, sharding)
798+
prompt_attention_mask = jax.device_put(prompt_attention_mask, sharding)
799+
800+
# Torchax wrapper returns tuple of hidden states natively
801+
text_encoder_hidden_states = self.text_encoder(
802+
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
761803
)
762-
text_input_ids = text_inputs.input_ids
763-
prompt_attention_mask = text_inputs.attention_mask
764-
765-
text_input_ids = text_input_ids.to(self.text_encoder.device)
766-
prompt_attention_mask = prompt_attention_mask.to(self.text_encoder.device)
767-
768-
with torch.no_grad():
769-
text_encoder_outputs = self.text_encoder(
770-
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
771-
)
772-
773-
text_encoder_hidden_states = text_encoder_outputs.hidden_states
774-
del text_encoder_outputs # Free memory
775804

776805
prompt_embeds_list = []
777806
# Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT
778807
for state in text_encoder_hidden_states:
779-
state_np = state.cpu().to(torch.float32).numpy()
780-
prompt_embeds_list.append(jnp.array(state_np, dtype=jnp.bfloat16))
808+
state = jax.device_put(state, jax.devices()[0])
809+
prompt_embeds_list.append(state.astype(jnp.bfloat16))
781810

782811
prompt_embeds = prompt_embeds_list
783-
del text_encoder_hidden_states # Free PyTorch tensor memory
812+
del text_encoder_hidden_states # Free memory
784813

785-
prompt_attention_mask = jnp.array(prompt_attention_mask.cpu().to(torch.float32).numpy(), dtype=jnp.bool_)
814+
prompt_attention_mask = prompt_attention_mask.astype(jnp.bool_)
786815
else:
787816
raise ValueError("`text_encoder` is required to encode prompts.")
788817

@@ -939,7 +968,7 @@ def check_inputs(
939968

940969
@staticmethod
941970
def _pack_latents(latents: jax.Array, patch_size: int = 1, patch_size_t: int = 1) -> jax.Array:
942-
batch_size, num_channels, num_frames, height, width = latents.shape
971+
batch_size, _, num_frames, height, width = latents.shape
943972
post_patch_num_frames = num_frames // patch_size_t
944973
post_patch_height = height // patch_size
945974
post_patch_width = width // patch_size
@@ -1028,7 +1057,7 @@ def _pack_audio_latents(
10281057
latents: jax.Array, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
10291058
) -> jax.Array:
10301059
if patch_size is not None and patch_size_t is not None:
1031-
batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
1060+
batch_size, _, latent_length, latent_mel_bins = latents.shape
10321061
post_patch_latent_length = latent_length // patch_size_t
10331062
post_patch_mel_bins = latent_mel_bins // patch_size
10341063
latents = latents.reshape(batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size)
@@ -1327,7 +1356,6 @@ def __call__(
13271356
graphdef, state = nnx.split(self.transformer)
13281357

13291358
# 7. Denoising Loop
1330-
import contextlib
13311359

13321360
context_manager = self.mesh if hasattr(self, "mesh") and self.mesh is not None else contextlib.nullcontext()
13331361
axis_rules_context = (
@@ -1380,9 +1408,7 @@ def __call__(
13801408
)
13811409
else:
13821410
# Old Python loop path
1383-
for i in range(len(timesteps_jax)):
1384-
t = timesteps_jax[i]
1385-
1411+
for _, t in enumerate(timesteps_jax):
13861412
# Isolate input sharding to scan_layers=False to avoid affecting the standard path
13871413
latents_jax_sharded = latents_jax
13881414
audio_latents_jax_sharded = audio_latents_jax
@@ -1503,20 +1529,45 @@ def __call__(
15031529
if output_type == "latent":
15041530
return LTX2PipelineOutput(frames=latents, audio=audio_latents)
15051531

1506-
# Force latents and VAE weights to be fully replicated using with_sharding_constraint, this speeds up single video latency ~3x
1507-
try:
1508-
mesh = latents.sharding.mesh
1509-
replicated_sharding = NamedSharding(mesh, P())
1510-
latents = jax.lax.with_sharding_constraint(latents, replicated_sharding)
1511-
1512-
# Replicate VAE weights
1513-
graphdef, state = nnx.split(self.vae)
1514-
state = jax.tree_util.tree_map(
1515-
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state
1532+
# Force latents and VAE weights to be fully replicated using with_sharding_constraint,
1533+
# this speeds up single video latency ~3x
1534+
if batch_size <= 2:
1535+
try:
1536+
mesh = latents.sharding.mesh
1537+
replicated_sharding = NamedSharding(mesh, P())
1538+
latents = jax.lax.with_sharding_constraint(latents, replicated_sharding)
1539+
1540+
# Replicate VAE weights
1541+
graphdef, state = nnx.split(self.vae)
1542+
state = jax.tree_util.tree_map(
1543+
lambda x: jax.lax.with_sharding_constraint(x, replicated_sharding) if isinstance(x, jax.Array) else x, state
1544+
)
1545+
self.vae = nnx.merge(graphdef, state)
1546+
except Exception: # pylint: disable=broad-exception-caught
1547+
max_logging.log("[Tuning] Failed to apply sharding constraint")
1548+
else:
1549+
max_logging.log(
1550+
f"[Tuning] Skipping VAE replication and disabling slicing to prevent HBM OOM for batch_size {batch_size} > 2"
15161551
)
1517-
self.vae = nnx.merge(graphdef, state)
1518-
except Exception as e:
1519-
max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}")
1552+
try:
1553+
# Disable sequential slicing to avoid JAX concatenating 17GB arrays on the TPU
1554+
self.vae.use_slicing = False
1555+
1556+
# Distribute the batch dimension across the existing mesh to ensure topological compatibility
1557+
mesh = latents.sharding.mesh
1558+
active_axes = []
1559+
current_shards = 1
1560+
1561+
for axis_name, size in mesh.shape.items():
1562+
if size > 1 and batch_size % (current_shards * size) == 0:
1563+
active_axes.append(axis_name)
1564+
current_shards *= size
1565+
1566+
if active_axes:
1567+
batch_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(tuple(active_axes)))
1568+
latents = jax.lax.with_sharding_constraint(latents, batch_sharding)
1569+
except Exception: # pylint: disable=broad-exception-caught
1570+
max_logging.log("[Tuning] Failed to apply batch sharding constraint to VAE")
15201571

15211572
if getattr(self.vae.config, "timestep_conditioning", False):
15221573
noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype)
@@ -1587,6 +1638,8 @@ def transformer_forward_pass(
15871638
audio_num_frames,
15881639
fps,
15891640
):
1641+
"""Forward pass for the transformer."""
1642+
# pylint: disable=too-many-positional-arguments,unused-argument
15901643
transformer = nnx.merge(graphdef, state)
15911644

15921645
# Expand timestep to batch size
@@ -1647,6 +1700,8 @@ def run_diffusion_loop(
16471700
scheduler_step,
16481701
logical_axis_rules,
16491702
):
1703+
"""Runs the diffusion loop."""
1704+
# pylint: disable=too-many-positional-arguments
16501705
latents_jax = latents_jax.astype(jnp.float32)
16511706
audio_latents_jax = audio_latents_jax.astype(jnp.float32)
16521707
transformer = nnx.merge(graphdef, state)

0 commit comments

Comments
 (0)