|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | +from transformers import AutoTokenizer, AutoModel |
| 4 | + |
| 5 | +class HFGemma3TextEncoder: |
| 6 | + """ |
| 7 | + A lightweight wrapper around Hugging Face's Gemma 3 model for extracting hidden states. |
| 8 | + This module forces execution on CPU to avoid OOM or XLA collisions when used alongside |
| 9 | + JAX/MaxDiffusion on TPUs. |
| 10 | + """ |
| 11 | + def __init__(self, model_id: str = "google/gemma-3-12b-it", max_length: int = 8192): |
| 12 | + self.model_id = model_id |
| 13 | + self.max_length = max_length |
| 14 | + # Initialize the tokenizer |
| 15 | + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) |
| 16 | + |
| 17 | + # Load the model directly to CPU in bfloat16 to save memory |
| 18 | + print(f"Loading {model_id} onto CPU. This may take a few moments...") |
| 19 | + self.model = AutoModel.from_pretrained( |
| 20 | + self.model_id, |
| 21 | + torch_dtype=torch.bfloat16, |
| 22 | + device_map="cpu", # Force CPU to avoid TPU memory contention with MaxDiffusion |
| 23 | + ) |
| 24 | + self.model.eval() # Set to evaluation mode |
| 25 | + |
| 26 | + def encode(self, text: str | list[str]) -> np.ndarray: |
| 27 | + """ |
| 28 | + Tokenizes the input text, passes it through the HF Gemma 3 model, |
| 29 | + and extracts ALL hidden states. |
| 30 | + |
| 31 | + Args: |
| 32 | + text: A single string or a list of strings to encode. |
| 33 | + |
| 34 | + Returns: |
| 35 | + A numpy array representing the flattened, stacked hidden states |
| 36 | + compatible with GemmaFeaturesExtractorProjLinear. |
| 37 | + Shape: (batch_size, sequence_length, 49 * 3840) |
| 38 | + """ |
| 39 | + # 1. Tokenize input text |
| 40 | + inputs = self.tokenizer( |
| 41 | + text, |
| 42 | + padding="max_length", |
| 43 | + truncation=True, |
| 44 | + max_length=self.max_length, |
| 45 | + return_tensors="pt" |
| 46 | + ) |
| 47 | + |
| 48 | + # Ensure inputs are on the same device as the model (CPU) |
| 49 | + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} |
| 50 | + |
| 51 | + # 2. Forward pass to get hidden states |
| 52 | + # output_hidden_states=True is the key to retrieving all 49 layers |
| 53 | + with torch.no_grad(): |
| 54 | + outputs = self.model(**inputs, output_hidden_states=True) |
| 55 | + |
| 56 | + # 3. Extract and stack hidden states |
| 57 | + # outputs.hidden_states is a tuple of 49 tensors, each shaped (batch, seq_len, 3840) |
| 58 | + all_hidden_states = outputs.hidden_states |
| 59 | + |
| 60 | + # Stack them along a new dimension (dim=0 or dim=-2) |
| 61 | + # We want to format it so it's easy to flatten. |
| 62 | + # Stacked shape: (49, batch, seq_len, 3840) |
| 63 | + stacked_states = torch.stack(all_hidden_states, dim=0) |
| 64 | + |
| 65 | + # Transpose to: (batch, seq_len, 49, 3840) |
| 66 | + transposed_states = stacked_states.permute(1, 2, 0, 3) |
| 67 | + |
| 68 | + # Flatten the last two dimensions to match the Feature Extractor's expectation |
| 69 | + # Shape becomes: (batch, seq_len, 49 * 3840) -> (batch, seq_len, 188160) |
| 70 | + batch_size, seq_len, num_layers, hidden_dim = transposed_states.shape |
| 71 | + flattened_states = transposed_states.reshape(batch_size, seq_len, num_layers * hidden_dim) |
| 72 | + |
| 73 | + # 4. Convert PyTorch Tensor to NumPy Array |
| 74 | + # JAX/Flax can seamlessly accept and convert numpy arrays to JAX Arrays |
| 75 | + numpy_hidden_states = flattened_states.cpu().float().numpy() |
| 76 | + |
| 77 | + return numpy_hidden_states |
0 commit comments