Skip to content

Commit 45bb025

Browse files
committed
feat: Add HF-based Gemma 3 text encoder for LTX-2 CPU feature extraction
1 parent 6e3b58b commit 45bb025

2 files changed

Lines changed: 111 additions & 0 deletions

File tree

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import numpy as np
3+
from transformers import AutoTokenizer, AutoModel
4+
5+
class HFGemma3TextEncoder:
6+
"""
7+
A lightweight wrapper around Hugging Face's Gemma 3 model for extracting hidden states.
8+
This module forces execution on CPU to avoid OOM or XLA collisions when used alongside
9+
JAX/MaxDiffusion on TPUs.
10+
"""
11+
def __init__(self, model_id: str = "google/gemma-3-12b-it", max_length: int = 8192):
12+
self.model_id = model_id
13+
self.max_length = max_length
14+
# Initialize the tokenizer
15+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
16+
17+
# Load the model directly to CPU in bfloat16 to save memory
18+
print(f"Loading {model_id} onto CPU. This may take a few moments...")
19+
self.model = AutoModel.from_pretrained(
20+
self.model_id,
21+
torch_dtype=torch.bfloat16,
22+
device_map="cpu", # Force CPU to avoid TPU memory contention with MaxDiffusion
23+
)
24+
self.model.eval() # Set to evaluation mode
25+
26+
def encode(self, text: str | list[str]) -> np.ndarray:
27+
"""
28+
Tokenizes the input text, passes it through the HF Gemma 3 model,
29+
and extracts ALL hidden states.
30+
31+
Args:
32+
text: A single string or a list of strings to encode.
33+
34+
Returns:
35+
A numpy array representing the flattened, stacked hidden states
36+
compatible with GemmaFeaturesExtractorProjLinear.
37+
Shape: (batch_size, sequence_length, 49 * 3840)
38+
"""
39+
# 1. Tokenize input text
40+
inputs = self.tokenizer(
41+
text,
42+
padding="max_length",
43+
truncation=True,
44+
max_length=self.max_length,
45+
return_tensors="pt"
46+
)
47+
48+
# Ensure inputs are on the same device as the model (CPU)
49+
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
50+
51+
# 2. Forward pass to get hidden states
52+
# output_hidden_states=True is the key to retrieving all 49 layers
53+
with torch.no_grad():
54+
outputs = self.model(**inputs, output_hidden_states=True)
55+
56+
# 3. Extract and stack hidden states
57+
# outputs.hidden_states is a tuple of 49 tensors, each shaped (batch, seq_len, 3840)
58+
all_hidden_states = outputs.hidden_states
59+
60+
# Stack them along a new dimension (dim=0 or dim=-2)
61+
# We want to format it so it's easy to flatten.
62+
# Stacked shape: (49, batch, seq_len, 3840)
63+
stacked_states = torch.stack(all_hidden_states, dim=0)
64+
65+
# Transpose to: (batch, seq_len, 49, 3840)
66+
transposed_states = stacked_states.permute(1, 2, 0, 3)
67+
68+
# Flatten the last two dimensions to match the Feature Extractor's expectation
69+
# Shape becomes: (batch, seq_len, 49 * 3840) -> (batch, seq_len, 188160)
70+
batch_size, seq_len, num_layers, hidden_dim = transposed_states.shape
71+
flattened_states = transposed_states.reshape(batch_size, seq_len, num_layers * hidden_dim)
72+
73+
# 4. Convert PyTorch Tensor to NumPy Array
74+
# JAX/Flax can seamlessly accept and convert numpy arrays to JAX Arrays
75+
numpy_hidden_states = flattened_states.cpu().float().numpy()
76+
77+
return numpy_hidden_states
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
import numpy as np
3+
4+
from maxdiffusion.models.ltx2.text_encoders.hf_gemma3_encoder import HFGemma3TextEncoder
5+
6+
class TestHFGemma3TextEncoder:
7+
"""Test suite for the Hugging Face CPU-based Gemma 3 Text Encoder."""
8+
9+
@pytest.fixture(scope="class")
10+
def encoder(self):
11+
"""Initialize the encoder. We use a small max_length to save memory and time."""
12+
print("Initializing HFGemma3TextEncoder on CPU...")
13+
# Note: Depending on your system memory, loading 12B on CPU might take ~25GB RAM.
14+
# Ensure the test node has enough CPU RAM.
15+
encoder = HFGemma3TextEncoder("google/gemma-3-12b-it", max_length=16)
16+
return encoder
17+
18+
def test_encode_output_shape(self, encoder):
19+
"""Verify that the encode method returns the correctly flattened numpy array."""
20+
prompt = "A test prompt for HF Gemma 3"
21+
22+
# Run encode
23+
print("Running encode forward pass on CPU...")
24+
output_array = encoder.encode(prompt)
25+
26+
# Verify it's a numpy array
27+
assert isinstance(output_array, np.ndarray), "Output must be a numpy array for JAX integration."
28+
29+
# Verify shape
30+
# Expected: (batch_size, sequence_length, 49 * 3840) -> (1, 16, 188160)
31+
expected_shape = (1, 16, 49 * 3840)
32+
assert output_array.shape == expected_shape, f"Expected shape {expected_shape}, got {output_array.shape}"
33+
34+
print(f"✅ Output successfully shaped for GemmaFeaturesExtractorProjLinear: {output_array.shape}")

0 commit comments

Comments
 (0)