1+ {
2+ "cells" : [
3+ {
4+ "cell_type" : " markdown" ,
5+ "metadata" : {},
6+ "source" : [
7+ " <a target=\" _blank\" href=\" https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Gemma3_Multimodal.ipynb\" >\n " ,
8+ " <img src=\" https://colab.research.google.com/assets/colab-badge.svg\" alt=\" Open In Colab\" />\n " ,
9+ " </a>"
10+ ]
11+ },
12+ {
13+ "cell_type" : " markdown" ,
14+ "metadata" : {},
15+ "source" : [
16+ " ### Gemma 3 Multimodal Demo with TransformerBridge\n " ,
17+ " \n " ,
18+ " This notebook demonstrates how to use TransformerBridge with `Gemma3ForConditionalGeneration`,\n " ,
19+ " the vision-language variant of Gemma 3. The model pairs a SigLIP vision encoder with the\n " ,
20+ " Gemma 3 language model and is the same architecture used by MedGemma.\n " ,
21+ " \n " ,
22+ " We demonstrate:\n " ,
23+ " 1. Loading Gemma 3 (4B-it) through TransformerBridge\n " ,
24+ " 2. Multimodal generation from an image + text prompt\n " ,
25+ " 3. Capturing vision-language activations with `run_with_cache()`\n " ,
26+ " \n " ,
27+ " > **Gated model.** The `google/gemma-3-*` checkpoints are gated on Hugging Face. Accept\n " ,
28+ " > the license at https://huggingface.co/google/gemma-3-4b-it and run `huggingface-cli login`\n " ,
29+ " > (or set `HF_TOKEN`) before executing this notebook."
30+ ]
31+ },
32+ {
33+ "cell_type" : " code" ,
34+ "execution_count" : null ,
35+ "metadata" : {},
36+ "outputs" : [],
37+ "source" : [
38+ " # Detect Colab and install dependencies if needed\n " ,
39+ " DEVELOPMENT_MODE = False\n " ,
40+ " try:\n " ,
41+ " import google.colab\n " ,
42+ " IN_COLAB = True\n " ,
43+ " print(\" Running as a Colab notebook\" )\n " ,
44+ " %pip install transformer_lens\n " ,
45+ " %pip install circuitsvis\n " ,
46+ " except:\n " ,
47+ " IN_COLAB = False"
48+ ]
49+ },
50+ {
51+ "cell_type" : " code" ,
52+ "execution_count" : null ,
53+ "metadata" : {},
54+ "outputs" : [],
55+ "source" : [
56+ " # NBVAL_IGNORE_OUTPUT\n " ,
57+ " import torch\n " ,
58+ " from PIL import Image\n " ,
59+ " import requests\n " ,
60+ " from io import BytesIO\n " ,
61+ " \n " ,
62+ " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n " ,
63+ " \n " ,
64+ " import matplotlib.pyplot as plt\n " ,
65+ " %matplotlib inline\n " ,
66+ " \n " ,
67+ " from transformer_lens.model_bridge import TransformerBridge\n " ,
68+ " \n " ,
69+ " try:\n " ,
70+ " import circuitsvis as cv\n " ,
71+ " except ImportError:\n " ,
72+ " print('circuitsvis not installed, attention visualization will not work')\n " ,
73+ " cv = None"
74+ ]
75+ },
76+ {
77+ "cell_type" : " markdown" ,
78+ "metadata" : {},
79+ "source" : [
80+ " ## Load Gemma 3 through TransformerBridge\n " ,
81+ " \n " ,
82+ " TransformerBridge maps `Gemma3ForConditionalGeneration` to its multimodal adapter, which\n " ,
83+ " wraps the SigLIP vision tower, the multimodal projector, and the Gemma 3 language model\n " ,
84+ " into a single hooked model.\n " ,
85+ " \n " ,
86+ " We use **bfloat16** here \u2014 Gemma 3 is trained in bf16 and fp16 can produce unstable activations.\n " ,
87+ " The 4B-it variant is the smallest multimodal Gemma 3 (the 270m and 1B checkpoints are text-only)."
88+ ]
89+ },
90+ {
91+ "cell_type" : " code" ,
92+ "execution_count" : null ,
93+ "metadata" : {},
94+ "outputs" : [],
95+ "source" : [
96+ " # NBVAL_IGNORE_OUTPUT\n " ,
97+ " model = TransformerBridge.boot_transformers(\n " ,
98+ " \" google/gemma-3-4b-it\" ,\n " ,
99+ " device=device,\n " ,
100+ " dtype=torch.bfloat16,\n " ,
101+ " )\n " ,
102+ " \n " ,
103+ " for param in model.parameters():\n " ,
104+ " param.requires_grad = False\n " ,
105+ " \n " ,
106+ " print(f\" Model loaded on {device}\" )\n " ,
107+ " print(f\" Multimodal: {getattr(model.cfg, 'is_multimodal', False)}\" )\n " ,
108+ " print(f\" Layers: {model.cfg.n_layers}, Heads: {model.cfg.n_heads}\" )\n " ,
109+ " print(f\" Vision tokens per image: {getattr(model.cfg, 'mm_tokens_per_image', None)}\" )"
110+ ]
111+ },
112+ {
113+ "cell_type" : " markdown" ,
114+ "metadata" : {},
115+ "source" : [
116+ " ## Load a test image\n " ,
117+ " \n " ,
118+ " We'll use a stop-sign photo from Australia to test the model's visual understanding."
119+ ]
120+ },
121+ {
122+ "cell_type" : " code" ,
123+ "execution_count" : null ,
124+ "metadata" : {},
125+ "outputs" : [],
126+ "source" : [
127+ " # NBVAL_IGNORE_OUTPUT\n " ,
128+ " image_url = \" https://www.ilankelman.org/stopsigns/australia.jpg\"\n " ,
129+ " response = requests.get(image_url)\n " ,
130+ " image = Image.open(BytesIO(response.content)).convert(\" RGB\" )\n " ,
131+ " plt.imshow(image)\n " ,
132+ " plt.axis('off')\n " ,
133+ " plt.title('Test Image')\n " ,
134+ " plt.show()"
135+ ]
136+ },
137+ {
138+ "cell_type" : " markdown" ,
139+ "metadata" : {},
140+ "source" : [
141+ " ## Multimodal Generation\n " ,
142+ " \n " ,
143+ " Gemma 3 instruction-tuned models expect the chat template format with\n " ,
144+ " `<start_of_turn>` / `<end_of_turn>` markers, and use `<start_of_image>` as the image\n " ,
145+ " placeholder (rather than LLaVA's `<image>`). The processor expands `<start_of_image>`\n " ,
146+ " into the appropriate number of vision tokens (256 by default for Gemma 3 4B).\n " ,
147+ " \n " ,
148+ " We call `prepare_multimodal_inputs()` to run the processor on text + image, then pass\n " ,
149+ " `pixel_values` to `generate()`. The bridge's `generate()` keeps a KV cache\n " ,
150+ " (`use_past_kv_cache=True` by default) for efficient autoregressive decoding while\n " ,
151+ " preserving full hook access."
152+ ]
153+ },
154+ {
155+ "cell_type" : " code" ,
156+ "execution_count" : null ,
157+ "metadata" : {},
158+ "outputs" : [],
159+ "source" : [
160+ " # NBVAL_IGNORE_OUTPUT\n " ,
161+ " question = \" What do you see in this photo?\"\n " ,
162+ " prompt = (\n " ,
163+ " \" <start_of_turn>user\\ n\"\n " ,
164+ " f\" <start_of_image>{question}<end_of_turn>\\ n\"\n " ,
165+ " \" <start_of_turn>model\\ n\"\n " ,
166+ " )\n " ,
167+ " \n " ,
168+ " # Prepare multimodal inputs (handles image processing + tokenization)\n " ,
169+ " inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n " ,
170+ " input_ids = inputs['input_ids']\n " ,
171+ " pixel_values = inputs['pixel_values']\n " ,
172+ " \n " ,
173+ " # Pass any extra processor outputs (e.g. token_type_ids for Gemma 3)\n " ,
174+ " extra_kwargs = {k: v for k, v in inputs.items()\n " ,
175+ " if k not in ('input_ids', 'pixel_values')}\n " ,
176+ " \n " ,
177+ " generated_text = model.generate(\n " ,
178+ " input_ids,\n " ,
179+ " pixel_values=pixel_values,\n " ,
180+ " max_new_tokens=80,\n " ,
181+ " do_sample=False,\n " ,
182+ " use_past_kv_cache=True,\n " ,
183+ " return_type=\" str\" ,\n " ,
184+ " **extra_kwargs,\n " ,
185+ " )\n " ,
186+ " \n " ,
187+ " print('Generated text:', generated_text)"
188+ ]
189+ },
190+ {
191+ "cell_type" : " markdown" ,
192+ "metadata" : {},
193+ "source" : [
194+ " Let's try a second image to confirm the model adapts its description:"
195+ ]
196+ },
197+ {
198+ "cell_type" : " code" ,
199+ "execution_count" : null ,
200+ "metadata" : {},
201+ "outputs" : [],
202+ "source" : [
203+ " # NBVAL_IGNORE_OUTPUT\n " ,
204+ " image_url_2 = \" https://github.com/zazamrykh/PicFinder/blob/main/images/doge.jpg?raw=true\"\n " ,
205+ " response = requests.get(image_url_2)\n " ,
206+ " image_2 = Image.open(BytesIO(response.content)).convert(\" RGB\" )\n " ,
207+ " plt.imshow(image_2)\n " ,
208+ " plt.axis('off')\n " ,
209+ " plt.show()\n " ,
210+ " \n " ,
211+ " inputs = model.prepare_multimodal_inputs(text=prompt, images=image_2)\n " ,
212+ " input_ids = inputs['input_ids']\n " ,
213+ " pixel_values = inputs['pixel_values']\n " ,
214+ " extra_kwargs = {k: v for k, v in inputs.items()\n " ,
215+ " if k not in ('input_ids', 'pixel_values')}\n " ,
216+ " \n " ,
217+ " generated_text = model.generate(\n " ,
218+ " input_ids,\n " ,
219+ " pixel_values=pixel_values,\n " ,
220+ " max_new_tokens=80,\n " ,
221+ " do_sample=False,\n " ,
222+ " use_past_kv_cache=True,\n " ,
223+ " return_type=\" str\" ,\n " ,
224+ " **extra_kwargs,\n " ,
225+ " )\n " ,
226+ " print('Generated text:', generated_text)"
227+ ]
228+ },
229+ {
230+ "cell_type" : " markdown" ,
231+ "metadata" : {},
232+ "source" : [
233+ " ## Inspecting Vision-Language Activations\n " ,
234+ " \n " ,
235+ " `run_with_cache()` accepts the same `pixel_values` argument and captures activations from\n " ,
236+ " the vision encoder, the multimodal projector, and every transformer block in the language\n " ,
237+ " model. This lets us inspect how the language tokens attend to image tokens during\n " ,
238+ " multimodal processing."
239+ ]
240+ },
241+ {
242+ "cell_type" : " code" ,
243+ "execution_count" : null ,
244+ "metadata" : {},
245+ "outputs" : [],
246+ "source" : [
247+ " # NBVAL_IGNORE_OUTPUT\n " ,
248+ " inputs = model.prepare_multimodal_inputs(text=prompt, images=image)\n " ,
249+ " extra_kwargs = {k: v for k, v in inputs.items()\n " ,
250+ " if k not in ('input_ids', 'pixel_values')}\n " ,
251+ " \n " ,
252+ " with torch.no_grad():\n " ,
253+ " logits, cache = model.run_with_cache(\n " ,
254+ " inputs['input_ids'],\n " ,
255+ " pixel_values=inputs['pixel_values'],\n " ,
256+ " **extra_kwargs,\n " ,
257+ " )\n " ,
258+ " \n " ,
259+ " print(f'Logits shape: {logits.shape}')\n " ,
260+ " print(f'Cache entries: {len(cache)}')\n " ,
261+ " vision_keys = [k for k in cache.keys() if 'vision' in k.lower()]\n " ,
262+ " print(f'Vision-related cache entries: {len(vision_keys)}')\n " ,
263+ " print(f'Sample vision keys: {vision_keys[:5]}')"
264+ ]
265+ },
266+ {
267+ "cell_type" : " code" ,
268+ "execution_count" : null ,
269+ "metadata" : {},
270+ "outputs" : [],
271+ "source" : [
272+ " # NBVAL_IGNORE_OUTPUT\n " ,
273+ " if cv is not None:\n " ,
274+ " layer_to_visualize = 16\n " ,
275+ " tokens_to_show = 30\n " ,
276+ " \n " ,
277+ " pattern_keys = [k for k in cache.keys() if f'blocks.{layer_to_visualize}' in k and 'pattern' in k]\n " ,
278+ " if pattern_keys:\n " ,
279+ " attention_pattern = cache[pattern_keys[0]]\n " ,
280+ " if attention_pattern.ndim == 4:\n " ,
281+ " attention_pattern = attention_pattern[0]\n " ,
282+ " \n " ,
283+ " token_ids = inputs['input_ids'][0].cpu()\n " ,
284+ " str_tokens = model.tokenizer.convert_ids_to_tokens(token_ids)\n " ,
285+ " \n " ,
286+ " print(f'Layer {layer_to_visualize} Head Attention Patterns (last {tokens_to_show} tokens):')\n " ,
287+ " display(cv.attention.attention_patterns(\n " ,
288+ " tokens=str_tokens[-tokens_to_show:],\n " ,
289+ " attention=attention_pattern[:, -tokens_to_show:, -tokens_to_show:].float().cpu(),\n " ,
290+ " ))\n " ,
291+ " else:\n " ,
292+ " print(f'No attention pattern found for layer {layer_to_visualize}')\n " ,
293+ " print(f'Available attention-related keys: {[k for k in cache.keys() if \" attn\" in k][:10]}')\n " ,
294+ " else:\n " ,
295+ " print('circuitsvis not available \u2014 skipping visualization')"
296+ ]
297+ },
298+ {
299+ "cell_type" : " markdown" ,
300+ "metadata" : {},
301+ "source" : [
302+ " ## Summary\n " ,
303+ " \n " ,
304+ " TransformerBridge provides native multimodal support for `Gemma3ForConditionalGeneration`:\n " ,
305+ " \n " ,
306+ " - **`boot_transformers(\" google/gemma-3-4b-it\" )`** loads the full vision + projector + language pipeline\n " ,
307+ " - **`prepare_multimodal_inputs(text=..., images=...)`** handles image processing and tokenization\n " ,
308+ " - **`generate(input_ids, pixel_values=...)`** runs multimodal generation with KV cache and hooks\n " ,
309+ " - **`run_with_cache(input_ids, pixel_values=...)`** captures activations including SigLIP vision tokens\n " ,
310+ " \n " ,
311+ " A few Gemma 3 specifics worth noting:\n " ,
312+ " \n " ,
313+ " - Use the chat-template format (`<start_of_turn>user ... <end_of_turn><start_of_turn>model`) for instruction-tuned variants\n " ,
314+ " - The image placeholder is `<start_of_image>`, not `<image>`\n " ,
315+ " - Gemma 3 is trained in bf16 \u2014 prefer `torch.bfloat16` over `torch.float16`\n " ,
316+ " - The same code path works for MedGemma (`google/medgemma-4b-it`, `google/medgemma-27b-it`) and any other `Gemma3ForConditionalGeneration` checkpoint"
317+ ]
318+ }
319+ ],
320+ "metadata" : {
321+ "kernelspec" : {
322+ "display_name" : " transformer-lens" ,
323+ "language" : " python" ,
324+ "name" : " python3"
325+ },
326+ "language_info" : {
327+ "codemirror_mode" : {
328+ "name" : " ipython" ,
329+ "version" : 3
330+ },
331+ "file_extension" : " .py" ,
332+ "mimetype" : " text/x-python" ,
333+ "name" : " python" ,
334+ "nbconvert_exporter" : " python" ,
335+ "pygments_lexer" : " ipython3" ,
336+ "version" : " 3.12.12"
337+ }
338+ },
339+ "nbformat" : 4 ,
340+ "nbformat_minor" : 5
341+ }
0 commit comments