Skip to content

Commit 7b28885

Browse files
committed
Offload LTX-2 text encoder to TorchAX and resolve lint issues
1 parent 19d4e4d commit 7b28885

6 files changed

Lines changed: 273 additions & 59 deletions

File tree

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ tensorflow-datasets
3535
tensorflow
3636
tokamax
3737
tokenizers
38+
torchax>=0.0.11
3839
transformers<5.0.0
3940

4041
# pinning torch and torchvision to specific versions to avoid

dependencies/requirements/generated_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ toml>=0.10.2
179179
tomlkit>=0.14.0
180180
toolz>=1.1.0
181181
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
182+
torchax>=0.0.11
182183
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
183184
tqdm>=4.67.3
184185
transformers>=4.57.6

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,11 @@ skip_first_n_steps_for_profiler: 0
103103
profiler_steps: 5
104104

105105
replicate_vae: False
106-
107106
use_bwe: False
108107

108+
run_text_encoder_on_tpu: True
109+
# Dynamically disables VAE slicing and distributes the batch dimension to avoid HBM OOM for larger batch sizes.
110+
enable_dynamic_vae_sharding: True
109111
allow_split_physical_axes: False
110112
learning_rate_schedule_steps: -1
111113
max_train_steps: 500
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Tuple
18+
19+
import torch
20+
import jax
21+
from torchax import interop, default_env
22+
23+
# --- Monkeypatch transformers masking_utils to avoid torchax integer tracing bug ---
24+
import transformers.masking_utils
25+
26+
_orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay
27+
28+
29+
def _patched_sliding_window_overlay(sliding_window: int):
30+
# pylint: disable=unused-argument
31+
32+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
33+
# Explicit Sequence Length Assumption:
34+
# This patch assumes that the maximum sequence length used for text prompts (typically <= 1024)
35+
# is strictly less than the sliding window size of Gemma-3 (typically 4096).
36+
# Under this assumption, the sliding window causal constraint `kv_idx > q_idx - sliding_window`
37+
# is mathematically always True for all valid query/key indices (0 <= q_idx, kv_idx < seq_len).
38+
#
39+
# We return a standard boolean tensor `q_idx.new_ones((), dtype=torch.bool)` to guarantee
40+
# Torchax compatibility and prevent any implicit tracing crashes.
41+
# If a future model uses a sequence length exceeding the sliding window, this assumption must be re-evaluated.
42+
return q_idx.new_ones((), dtype=torch.bool)
43+
44+
return inner_mask
45+
46+
47+
class TorchaxGemma3TextEncoder(interop.JittableModule):
48+
"""
49+
A jittable Torchax module for wrapping the HuggingFace PyTorch
50+
Gemma3ForConditionalGeneration text encoder.
51+
"""
52+
53+
def __init__(self, text_encoder):
54+
super().__init__(text_encoder, extra_jit_args={"static_argnames": ["output_hidden_states"]})
55+
56+
def __call__(
57+
self, input_ids: jax.Array, attention_mask: jax.Array, output_hidden_states: bool = True
58+
) -> Tuple[jax.Array, ...]:
59+
# Dynamically patch transformers.masking_utils only during the duration of this call
60+
transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay
61+
try:
62+
with default_env():
63+
input_ids = interop.torch_view(input_ids)
64+
attention_mask = interop.torch_view(attention_mask)
65+
66+
output = self.functional_call(
67+
self._forward_inner,
68+
params=self.params,
69+
buffers=self.buffers,
70+
input_ids=input_ids,
71+
attention_mask=attention_mask,
72+
output_hidden_states=output_hidden_states,
73+
)
74+
return interop.jax_view(output)
75+
finally:
76+
# Restore original behavior to prevent side effects on other potential models in same env
77+
transformers.masking_utils.sliding_window_overlay = _orig_sliding_window_overlay
78+
79+
@staticmethod
80+
def _forward_inner(model, input_ids, attention_mask, output_hidden_states=True):
81+
# We only return hidden states as a tuple of tensors.
82+
# That allows interop.jax_view to convert them into a tuple of jax Arrays
83+
return model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states).hidden_states

0 commit comments

Comments
 (0)