Skip to content

Commit 58330ad

Browse files
committed
Fixing issues with Multimodal text generation
1 parent 52e4f6d commit 58330ad

4 files changed

Lines changed: 485 additions & 2 deletions

File tree

demos/Gemma3_Multimodal.ipynb

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Gemma3_Multimodal.ipynb\">\n",
8+
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
9+
"</a>"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"### Gemma 3 Multimodal Demo with TransformerBridge\n",
17+
"\n",
18+
"This notebook demonstrates how to use TransformerBridge with `Gemma3ForConditionalGeneration`,\n",
19+
"the vision-language variant of Gemma 3. The model pairs a SigLIP vision encoder with the\n",
20+
"Gemma 3 language model and is the same architecture used by MedGemma.\n",
21+
"\n",
22+
"We demonstrate:\n",
23+
"1. Loading Gemma 3 (4B-it) through TransformerBridge\n",
24+
"2. Multimodal generation from an image + text prompt\n",
25+
"3. Capturing vision-language activations with `run_with_cache()`\n",
26+
"\n",
27+
"> **Gated model.** The `google/gemma-3-*` checkpoints are gated on Hugging Face. Accept\n",
28+
"> the license at https://huggingface.co/google/gemma-3-4b-it and run `huggingface-cli login`\n",
29+
"> (or set `HF_TOKEN`) before executing this notebook."
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": null,
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"# Detect Colab and install dependencies if needed\n",
39+
"DEVELOPMENT_MODE = False\n",
40+
"try:\n",
41+
" import google.colab\n",
42+
" IN_COLAB = True\n",
43+
" print(\"Running as a Colab notebook\")\n",
44+
" %pip install transformer_lens\n",
45+
" %pip install circuitsvis\n",
46+
"except:\n",
47+
" IN_COLAB = False"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": null,
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"# NBVAL_IGNORE_OUTPUT\n",
57+
"import torch\n",
58+
"from PIL import Image\n",
59+
"import requests\n",
60+
"from io import BytesIO\n",
61+
"\n",
62+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
63+
"\n",
64+
"import matplotlib.pyplot as plt\n",
65+
"%matplotlib inline\n",
66+
"\n",
67+
"from transformer_lens.model_bridge import TransformerBridge\n",
68+
"\n",
69+
"try:\n",
70+
" import circuitsvis as cv\n",
71+
"except ImportError:\n",
72+
" print('circuitsvis not installed, attention visualization will not work')\n",
73+
" cv = None"
74+
]
75+
},
76+
{
77+
"cell_type": "markdown",
78+
"metadata": {},
79+
"source": [
80+
"## Load Gemma 3 through TransformerBridge\n",
81+
"\n",
82+
"TransformerBridge maps `Gemma3ForConditionalGeneration` to its multimodal adapter, which\n",
83+
"wraps the SigLIP vision tower, the multimodal projector, and the Gemma 3 language model\n",
84+
"into a single hooked model.\n",
85+
"\n",
86+
"We use **bfloat16** here \u2014 Gemma 3 is trained in bf16 and fp16 can produce unstable activations.\n",
87+
"The 4B-it variant is the smallest multimodal Gemma 3 (the 270m and 1B checkpoints are text-only)."
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"# NBVAL_IGNORE_OUTPUT\n",
97+
"model = TransformerBridge.boot_transformers(\n",
98+
" \"google/gemma-3-4b-it\",\n",
99+
" device=device,\n",
100+
" dtype=torch.bfloat16,\n",
101+
")\n",
102+
"\n",
103+
"for param in model.parameters():\n",
104+
" param.requires_grad = False\n",
105+
"\n",
106+
"print(f\"Model loaded on {device}\")\n",
107+
"print(f\"Multimodal: {getattr(model.cfg, 'is_multimodal', False)}\")\n",
108+
"print(f\"Layers: {model.cfg.n_layers}, Heads: {model.cfg.n_heads}\")\n",
109+
"print(f\"Vision tokens per image: {getattr(model.cfg, 'mm_tokens_per_image', None)}\")"
110+
]
111+
},
112+
{
113+
"cell_type": "markdown",
114+
"metadata": {},
115+
"source": [
116+
"## Load a test image\n",
117+
"\n",
118+
"We'll use a stop-sign photo from Australia to test the model's visual understanding."
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": null,
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"# NBVAL_IGNORE_OUTPUT\n",
128+
"image_url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
129+
"response = requests.get(image_url)\n",
130+
"image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
131+
"plt.imshow(image)\n",
132+
"plt.axis('off')\n",
133+
"plt.title('Test Image')\n",
134+
"plt.show()"
135+
]
136+
},
137+
{
138+
"cell_type": "markdown",
139+
"metadata": {},
140+
"source": [
141+
"## Multimodal Generation\n",
142+
"\n",
143+
"Gemma 3 instruction-tuned models expect the chat template format with\n",
144+
"`<start_of_turn>` / `<end_of_turn>` markers, and use `<start_of_image>` as the image\n",
145+
"placeholder (rather than LLaVA's `<image>`). The processor expands `<start_of_image>`\n",
146+
"into the appropriate number of vision tokens (256 by default for Gemma 3 4B).\n",
147+
"\n",
148+
"We call `prepare_multimodal_inputs()` to run the processor on text + image, then pass\n",
149+
"`pixel_values` to `generate()`. The bridge's `generate()` keeps a KV cache\n",
150+
"(`use_past_kv_cache=True` by default) for efficient autoregressive decoding while\n",
151+
"preserving full hook access."
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": null,
157+
"metadata": {},
158+
"outputs": [],
159+
"source": [
160+
"# NBVAL_IGNORE_OUTPUT\n",
161+
"question = \"What do you see in this photo?\"\n",
162+
"prompt = (\n",
163+
" \"<start_of_turn>user\\n\"\n",
164+
" f\"<start_of_image>{question}<end_of_turn>\\n\"\n",
165+
" \"<start_of_turn>model\\n\"\n",
166+
")\n",
167+
"\n",
168+
"# Prepare multimodal inputs (handles image processing + tokenization)\n",
169+
"inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n",
170+
"input_ids = inputs['input_ids']\n",
171+
"pixel_values = inputs['pixel_values']\n",
172+
"\n",
173+
"# Pass any extra processor outputs (e.g. token_type_ids for Gemma 3)\n",
174+
"extra_kwargs = {k: v for k, v in inputs.items()\n",
175+
" if k not in ('input_ids', 'pixel_values')}\n",
176+
"\n",
177+
"generated_text = model.generate(\n",
178+
" input_ids,\n",
179+
" pixel_values=pixel_values,\n",
180+
" max_new_tokens=80,\n",
181+
" do_sample=False,\n",
182+
" use_past_kv_cache=True,\n",
183+
" return_type=\"str\",\n",
184+
" **extra_kwargs,\n",
185+
")\n",
186+
"\n",
187+
"print('Generated text:', generated_text)"
188+
]
189+
},
190+
{
191+
"cell_type": "markdown",
192+
"metadata": {},
193+
"source": [
194+
"Let's try a second image to confirm the model adapts its description:"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"metadata": {},
201+
"outputs": [],
202+
"source": [
203+
"# NBVAL_IGNORE_OUTPUT\n",
204+
"image_url_2 = \"https://github.com/zazamrykh/PicFinder/blob/main/images/doge.jpg?raw=true\"\n",
205+
"response = requests.get(image_url_2)\n",
206+
"image_2 = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
207+
"plt.imshow(image_2)\n",
208+
"plt.axis('off')\n",
209+
"plt.show()\n",
210+
"\n",
211+
"inputs = model.prepare_multimodal_inputs(text=prompt, images=image_2)\n",
212+
"input_ids = inputs['input_ids']\n",
213+
"pixel_values = inputs['pixel_values']\n",
214+
"extra_kwargs = {k: v for k, v in inputs.items()\n",
215+
" if k not in ('input_ids', 'pixel_values')}\n",
216+
"\n",
217+
"generated_text = model.generate(\n",
218+
" input_ids,\n",
219+
" pixel_values=pixel_values,\n",
220+
" max_new_tokens=80,\n",
221+
" do_sample=False,\n",
222+
" use_past_kv_cache=True,\n",
223+
" return_type=\"str\",\n",
224+
" **extra_kwargs,\n",
225+
")\n",
226+
"print('Generated text:', generated_text)"
227+
]
228+
},
229+
{
230+
"cell_type": "markdown",
231+
"metadata": {},
232+
"source": [
233+
"## Inspecting Vision-Language Activations\n",
234+
"\n",
235+
"`run_with_cache()` accepts the same `pixel_values` argument and captures activations from\n",
236+
"the vision encoder, the multimodal projector, and every transformer block in the language\n",
237+
"model. This lets us inspect how the language tokens attend to image tokens during\n",
238+
"multimodal processing."
239+
]
240+
},
241+
{
242+
"cell_type": "code",
243+
"execution_count": null,
244+
"metadata": {},
245+
"outputs": [],
246+
"source": [
247+
"# NBVAL_IGNORE_OUTPUT\n",
248+
"inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n",
249+
"extra_kwargs = {k: v for k, v in inputs.items()\n",
250+
" if k not in ('input_ids', 'pixel_values')}\n",
251+
"\n",
252+
"with torch.no_grad():\n",
253+
" logits, cache = model.run_with_cache(\n",
254+
" inputs['input_ids'],\n",
255+
" pixel_values=inputs['pixel_values'],\n",
256+
" **extra_kwargs,\n",
257+
" )\n",
258+
"\n",
259+
"print(f'Logits shape: {logits.shape}')\n",
260+
"print(f'Cache entries: {len(cache)}')\n",
261+
"vision_keys = [k for k in cache.keys() if 'vision' in k.lower()]\n",
262+
"print(f'Vision-related cache entries: {len(vision_keys)}')\n",
263+
"print(f'Sample vision keys: {vision_keys[:5]}')"
264+
]
265+
},
266+
{
267+
"cell_type": "code",
268+
"execution_count": null,
269+
"metadata": {},
270+
"outputs": [],
271+
"source": [
272+
"# NBVAL_IGNORE_OUTPUT\n",
273+
"if cv is not None:\n",
274+
" layer_to_visualize = 16\n",
275+
" tokens_to_show = 30\n",
276+
"\n",
277+
" pattern_keys = [k for k in cache.keys() if f'blocks.{layer_to_visualize}' in k and 'pattern' in k]\n",
278+
" if pattern_keys:\n",
279+
" attention_pattern = cache[pattern_keys[0]]\n",
280+
" if attention_pattern.ndim == 4:\n",
281+
" attention_pattern = attention_pattern[0]\n",
282+
"\n",
283+
" token_ids = inputs['input_ids'][0].cpu()\n",
284+
" str_tokens = model.tokenizer.convert_ids_to_tokens(token_ids)\n",
285+
"\n",
286+
" print(f'Layer {layer_to_visualize} Head Attention Patterns (last {tokens_to_show} tokens):')\n",
287+
" display(cv.attention.attention_patterns(\n",
288+
" tokens=str_tokens[-tokens_to_show:],\n",
289+
" attention=attention_pattern[:, -tokens_to_show:, -tokens_to_show:].float().cpu(),\n",
290+
" ))\n",
291+
" else:\n",
292+
" print(f'No attention pattern found for layer {layer_to_visualize}')\n",
293+
" print(f'Available attention-related keys: {[k for k in cache.keys() if \"attn\" in k][:10]}')\n",
294+
"else:\n",
295+
" print('circuitsvis not available \u2014 skipping visualization')"
296+
]
297+
},
298+
{
299+
"cell_type": "markdown",
300+
"metadata": {},
301+
"source": [
302+
"## Summary\n",
303+
"\n",
304+
"TransformerBridge provides native multimodal support for `Gemma3ForConditionalGeneration`:\n",
305+
"\n",
306+
"- **`boot_transformers(\"google/gemma-3-4b-it\")`** loads the full vision + projector + language pipeline\n",
307+
"- **`prepare_multimodal_inputs(text=..., images=...)`** handles image processing and tokenization\n",
308+
"- **`generate(input_ids, pixel_values=...)`** runs multimodal generation with KV cache and hooks\n",
309+
"- **`run_with_cache(input_ids, pixel_values=...)`** captures activations including SigLIP vision tokens\n",
310+
"\n",
311+
"A few Gemma 3 specifics worth noting:\n",
312+
"\n",
313+
"- Use the chat-template format (`<start_of_turn>user ... <end_of_turn><start_of_turn>model`) for instruction-tuned variants\n",
314+
"- The image placeholder is `<start_of_image>`, not `<image>`\n",
315+
"- Gemma 3 is trained in bf16 \u2014 prefer `torch.bfloat16` over `torch.float16`\n",
316+
"- The same code path works for MedGemma (`google/medgemma-4b-it`, `google/medgemma-27b-it`) and any other `Gemma3ForConditionalGeneration` checkpoint"
317+
]
318+
}
319+
],
320+
"metadata": {
321+
"kernelspec": {
322+
"display_name": "transformer-lens",
323+
"language": "python",
324+
"name": "python3"
325+
},
326+
"language_info": {
327+
"codemirror_mode": {
328+
"name": "ipython",
329+
"version": 3
330+
},
331+
"file_extension": ".py",
332+
"mimetype": "text/x-python",
333+
"name": "python",
334+
"nbconvert_exporter": "python",
335+
"pygments_lexer": "ipython3",
336+
"version": "3.12.12"
337+
}
338+
},
339+
"nbformat": 4,
340+
"nbformat_minor": 5
341+
}

0 commit comments

Comments
 (0)