Skip to content

Commit a449a5c

Browse files
committed
Offload LTX-2 text encoder to TorchAX and resolve lint issues
1 parent 71b4138 commit a449a5c

5 files changed

Lines changed: 227 additions & 53 deletions

File tree

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ tensorflow-datasets
3535
tensorflow
3636
tokamax
3737
tokenizers
38+
torchax>=0.0.11
3839
transformers<5.0.0
3940

4041
# pinning torch and torchvision to specific versions to avoid

dependencies/requirements/generated_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ toml>=0.10.2
179179
tomlkit>=0.14.0
180180
toolz>=1.1.0
181181
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
182+
torchax>=0.0.11
182183
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
183184
tqdm>=4.67.3
184185
transformers>=4.57.6
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Tuple
18+
19+
import torch
20+
import jax
21+
from torchax import interop, default_env
22+
23+
# --- Monkeypatch transformers masking_utils to avoid torchax integer tracing bug ---
24+
import transformers.masking_utils
25+
26+
_orig_sliding_window_overlay = transformers.masking_utils.sliding_window_overlay
27+
28+
29+
def _patched_sliding_window_overlay(sliding_window: int):
30+
# pylint: disable=unused-argument
31+
32+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
33+
# Since sequence length < sliding window (e.g. 256 < 4096), this mask is always True.
34+
# We return a standard boolean tensor using new_ones to guarantee Torchax compatibility
35+
# and prevent any implicit tracing crashes.
36+
return q_idx.new_ones((), dtype=torch.bool)
37+
38+
return inner_mask
39+
40+
41+
transformers.masking_utils.sliding_window_overlay = _patched_sliding_window_overlay
42+
# -----------------------------------------------------------------------------------
43+
44+
45+
class TorchaxGemma3TextEncoder(interop.JittableModule):
46+
"""
47+
A jittable Torchax module for wrapping the HuggingFace PyTorch
48+
Gemma3ForConditionalGeneration text encoder.
49+
"""
50+
51+
def __init__(self, text_encoder):
52+
super().__init__(text_encoder, extra_jit_args={"static_argnames": ["output_hidden_states"]})
53+
54+
def __call__(
55+
self, input_ids: jax.Array, attention_mask: jax.Array, output_hidden_states: bool = True
56+
) -> Tuple[jax.Array, ...]:
57+
with default_env():
58+
input_ids = interop.torch_view(input_ids)
59+
attention_mask = interop.torch_view(attention_mask)
60+
61+
output = self.functional_call(
62+
self._forward_inner,
63+
params=self.params,
64+
buffers=self.buffers,
65+
input_ids=input_ids,
66+
attention_mask=attention_mask,
67+
output_hidden_states=output_hidden_states,
68+
)
69+
return interop.jax_view(output)
70+
71+
@staticmethod
72+
def _forward_inner(model, input_ids, attention_mask, output_hidden_states=True):
73+
# We only return hidden states as a tuple of tensors.
74+
# That allows interop.jax_view to convert them into a tuple of jax Arrays
75+
return model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states).hidden_states

0 commit comments

Comments
 (0)