Skip to content

Commit 9a0f5ed

Browse files
author
The gemma Authors
committed
Implement Prefix Tuning for Gemma models.
PiperOrigin-RevId: 899479928
1 parent ae84d95 commit 9a0f5ed

8 files changed

Lines changed: 1101 additions & 2 deletions

File tree

colabs/prefix_finetuning.ipynb

Lines changed: 423 additions & 0 deletions
Large diffs are not rendered by default.

colabs/prefix_sampling.ipynb

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "qKlB5QTDIV6S"
7+
},
8+
"source": [
9+
"# Prefix Tuning (Sampling)\n",
10+
"Example on using Prefix Tuning with Gemma (for inference)."
11+
]
12+
},
13+
{
14+
"metadata": {
15+
"id": "TR-L25KVKT_F"
16+
},
17+
"cell_type": "code",
18+
"source": [
19+
"!pip install -q gemma"
20+
],
21+
"outputs": [],
22+
"execution_count": null
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {
28+
"id": "I6fEKB1tISVW"
29+
},
30+
"outputs": [],
31+
"source": [
32+
"# Common imports\n",
33+
"import os\n",
34+
"import jax\n",
35+
"import jax.numpy as jnp\n",
36+
"import treescope\n",
37+
"\n",
38+
"# Gemma imports\n",
39+
"from gemma import gm\n",
40+
"from gemma import peft"
41+
]
42+
},
43+
{
44+
"metadata": {
45+
"id": "cxGT2XeU4L47"
46+
},
47+
"cell_type": "markdown",
48+
"source": [
49+
"By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
50+
]
51+
},
52+
{
53+
"metadata": {
54+
"id": "o4MidM--4L47"
55+
},
56+
"cell_type": "code",
57+
"source": [
58+
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
59+
],
60+
"outputs": [],
61+
"execution_count": null
62+
},
63+
{
64+
"cell_type": "markdown",
65+
"metadata": {
66+
"id": "-kdAZkvOIryQ"
67+
},
68+
"source": [
69+
"## Initializing the model\n",
70+
"\n",
71+
"To use Gemma with Prefix Tuning, simply wrap any Gemma model in `gm.nn.PrefixTuning.from_model`:"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {
78+
"id": "x-BbrzCVIupV"
79+
},
80+
"outputs": [],
81+
"source": [
82+
"model = gm.nn.PrefixTuning.from_model(\n",
83+
" prefix_length=100,\n",
84+
" global_layers_only=True,\n",
85+
" model=gm.nn.Gemma3_4B(text_only=True),\n",
86+
")"
87+
]
88+
},
89+
{
90+
"cell_type": "markdown",
91+
"metadata": {
92+
"id": "hI3Lg07SJff4"
93+
},
94+
"source": [
95+
"Initialize the weights:"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": null,
101+
"metadata": {
102+
"id": "1shC1DpiJfsw"
103+
},
104+
"outputs": [],
105+
"source": [
106+
"token_ids = jnp.zeros((1, 256,), dtype=jnp.int32) # Create the (batch_size, seq_length)\n",
107+
"\n",
108+
"params = model.init(\n",
109+
" jax.random.key(0),\n",
110+
" token_ids,\n",
111+
")\n",
112+
"\n",
113+
"params = params['params']"
114+
]
115+
},
116+
{
117+
"cell_type": "markdown",
118+
"metadata": {
119+
"id": "T3dWILqKKzG3"
120+
},
121+
"source": [
122+
"Inspect the params shape/structure. We can see Prefix weights have been added."
123+
]
124+
},
125+
{
126+
"cell_type": "code",
127+
"execution_count": null,
128+
"metadata": {
129+
"id": "LMq2Z9nXKcad"
130+
},
131+
"outputs": [],
132+
"source": [
133+
"treescope.show(params)"
134+
]
135+
},
136+
{
137+
"cell_type": "markdown",
138+
"metadata": {
139+
"id": "bGJl5YpKKOf-"
140+
},
141+
"source": [
142+
"Restore the pre-trained params. We use `peft.split_params` and `peft.merge_params` to replace the randomly initialized params with the pre-trained ones.\n",
143+
"\n",
144+
"When using `gm.ckpts.load_params`, make sure to pass the `params=original` kwarg. This ensure that:\n",
145+
"\n",
146+
"* The memory from the old params is released (so only a single copy of the weights stays in memory)\n",
147+
"* The restored params reuse the same sharding as the input (here there's no sharding, so isn't required)"
148+
]
149+
},
150+
{
151+
"cell_type": "code",
152+
"execution_count": null,
153+
"metadata": {
154+
"id": "AcO6oBuLKNjb"
155+
},
156+
"outputs": [],
157+
"source": [
158+
"# Splits the params into non-LoRA and LoRA weights\n",
159+
"original, lora = peft.split_params(params)\n",
160+
"\n",
161+
"# Load the params from the checkpoint\n",
162+
"original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)\n",
163+
"\n",
164+
"# Merge the pretrained params back with LoRA\n",
165+
"params = peft.merge_params(original, lora)"
166+
]
167+
},
168+
{
169+
"metadata": {
170+
"id": "b8y4YAAi9_Sv"
171+
},
172+
"cell_type": "markdown",
173+
"source": [
174+
"## Fine-tuning\n",
175+
"\n",
176+
"See our [finetuning guide](https://gemma-llm.readthedocs.io/en/latest/lora_finetuning.html) for more info.\n",
177+
"\n",
178+
"For a end-to-end finetuning example, refer to [prefix-Tuning](third_party/py/gemma/colabs/prefix_finetuning.ipynb)"
179+
]
180+
},
181+
{
182+
"cell_type": "markdown",
183+
"metadata": {
184+
"id": "MvsQbQM4I4Cs"
185+
},
186+
"source": [
187+
"## Inference\n",
188+
"\n",
189+
"Here's an example of running a single model call:"
190+
]
191+
},
192+
{
193+
"cell_type": "code",
194+
"execution_count": null,
195+
"metadata": {
196+
"id": "eqU7a4eCI5Wr"
197+
},
198+
"outputs": [],
199+
"source": [
200+
"tokenizer = gm.text.Gemma3Tokenizer()\n",
201+
"\n",
202+
"prompt = tokenizer.encode('The capital of France is')\n",
203+
"prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)\n",
204+
"\n",
205+
"\n",
206+
"# Run the model\n",
207+
"out = model.apply(\n",
208+
" {'params': params},\n",
209+
" tokens=prompt,\n",
210+
" return_last_only=True, # Only predict the last token\n",
211+
")\n",
212+
"\n",
213+
"\n",
214+
"# Show the token distribution\n",
215+
"tokenizer.plot_logits(out.logits)"
216+
]
217+
},
218+
{
219+
"cell_type": "markdown",
220+
"metadata": {
221+
"id": "6dOSL9MHuMUa"
222+
},
223+
"source": [
224+
"To sample an entire sentence:"
225+
]
226+
},
227+
{
228+
"cell_type": "code",
229+
"execution_count": null,
230+
"metadata": {
231+
"id": "_ckwREdyqown"
232+
},
233+
"outputs": [],
234+
"source": [
235+
"sampler = gm.text.ChatSampler(\n",
236+
" model=model,\n",
237+
" params=params,\n",
238+
" tokenizer=tokenizer,\n",
239+
")\n",
240+
"\n",
241+
"sampler.chat('The capital of France is?')"
242+
]
243+
}
244+
],
245+
"metadata": {
246+
"colab": {
247+
"last_runtime": {},
248+
"private_outputs": true,
249+
"provenance": []
250+
},
251+
"kernelspec": {
252+
"display_name": "Python 3",
253+
"name": "python3"
254+
},
255+
"language_info": {
256+
"name": "python"
257+
}
258+
},
259+
"nbformat": 4,
260+
"nbformat_minor": 0
261+
}

gemma/gm/ckpts/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
from gemma.gm.ckpts._lora import SkipLoRA
2323
from gemma.gm.ckpts._paths import CheckpointPath
2424
from gemma.gm.ckpts._policy import AnchoredPolicyLoader
25+
SkipPeft = SkipLoRA # Alias for SkipLoRA with a more generic name.

gemma/gm/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
# Wrapper (LoRA, quantization, DPO,...)
5050
# ****************************************************************************
5151
from gemma.gm.nn._lora import LoRA
52+
from gemma.gm.nn._prefix import PrefixTuning
5253
from gemma.gm.nn._quantization import QuantizationAwareWrapper
5354
from gemma.gm.nn._quantization import IntWrapper
5455
from gemma.gm.nn._policy import AnchoredPolicy

0 commit comments

Comments
 (0)