Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
423 changes: 423 additions & 0 deletions colabs/prefix_finetuning.ipynb

Large diffs are not rendered by default.

261 changes: 261 additions & 0 deletions colabs/prefix_sampling.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "qKlB5QTDIV6S"
},
"source": [
"# Prefix Tuning (Sampling)\n",
"Example on using Prefix Tuning with Gemma (for inference)."
]
},
{
"metadata": {
"id": "TR-L25KVKT_F"
},
"cell_type": "code",
"source": [
"!pip install -q gemma"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I6fEKB1tISVW"
},
"outputs": [],
"source": [
"# Common imports\n",
"import os\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import treescope\n",
"\n",
"# Gemma imports\n",
"from gemma import gm\n",
"from gemma import peft"
]
},
{
"metadata": {
"id": "cxGT2XeU4L47"
},
"cell_type": "markdown",
"source": [
"By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
]
},
{
"metadata": {
"id": "o4MidM--4L47"
},
"cell_type": "code",
"source": [
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {
"id": "-kdAZkvOIryQ"
},
"source": [
"## Initializing the model\n",
"\n",
"To use Gemma with Prefix Tuning, simply wrap any Gemma model in `gm.nn.PrefixTuning.from_model`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x-BbrzCVIupV"
},
"outputs": [],
"source": [
"model = gm.nn.PrefixTuning.from_model(\n",
" prefix_length=100,\n",
" global_layers_only=True,\n",
" model=gm.nn.Gemma3_4B(text_only=True),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hI3Lg07SJff4"
},
"source": [
"Initialize the weights:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1shC1DpiJfsw"
},
"outputs": [],
"source": [
"token_ids = jnp.zeros((1, 256,), dtype=jnp.int32) # Create the (batch_size, seq_length)\n",
"\n",
"params = model.init(\n",
" jax.random.key(0),\n",
" token_ids,\n",
")\n",
"\n",
"params = params['params']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T3dWILqKKzG3"
},
"source": [
"Inspect the params shape/structure. We can see Prefix weights have been added."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LMq2Z9nXKcad"
},
"outputs": [],
"source": [
"treescope.show(params)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bGJl5YpKKOf-"
},
"source": [
"Restore the pre-trained params. We use `peft.split_params` and `peft.merge_params` to replace the randomly initialized params with the pre-trained ones.\n",
"\n",
"When using `gm.ckpts.load_params`, make sure to pass the `params=original` kwarg. This ensure that:\n",
"\n",
"* The memory from the old params is released (so only a single copy of the weights stays in memory)\n",
"* The restored params reuse the same sharding as the input (here there's no sharding, so isn't required)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AcO6oBuLKNjb"
},
"outputs": [],
"source": [
"# Splits the params into non-LoRA and LoRA weights\n",
"original, lora = peft.split_params(params)\n",
"\n",
"# Load the params from the checkpoint\n",
"original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)\n",
"\n",
"# Merge the pretrained params back with LoRA\n",
"params = peft.merge_params(original, lora)"
]
},
{
"metadata": {
"id": "b8y4YAAi9_Sv"
},
"cell_type": "markdown",
"source": [
"## Fine-tuning\n",
"\n",
"See our [finetuning guide](https://gemma-llm.readthedocs.io/en/latest/lora_finetuning.html) for more info.\n",
"\n",
"For a end-to-end finetuning example, refer to [prefix-Tuning](third_party/py/gemma/colabs/prefix_finetuning.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MvsQbQM4I4Cs"
},
"source": [
"## Inference\n",
"\n",
"Here's an example of running a single model call:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eqU7a4eCI5Wr"
},
"outputs": [],
"source": [
"tokenizer = gm.text.Gemma3Tokenizer()\n",
"\n",
"prompt = tokenizer.encode('The capital of France is')\n",
"prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)\n",
"\n",
"\n",
"# Run the model\n",
"out = model.apply(\n",
" {'params': params},\n",
" tokens=prompt,\n",
" return_last_only=True, # Only predict the last token\n",
")\n",
"\n",
"\n",
"# Show the token distribution\n",
"tokenizer.plot_logits(out.logits)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6dOSL9MHuMUa"
},
"source": [
"To sample an entire sentence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_ckwREdyqown"
},
"outputs": [],
"source": [
"sampler = gm.text.ChatSampler(\n",
" model=model,\n",
" params=params,\n",
" tokenizer=tokenizer,\n",
")\n",
"\n",
"sampler.chat('The capital of France is?')"
]
}
],
"metadata": {
"colab": {
"last_runtime": {},
"private_outputs": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
1 change: 1 addition & 0 deletions gemma/gm/ckpts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from gemma.gm.ckpts._lora import SkipLoRA
from gemma.gm.ckpts._paths import CheckpointPath
from gemma.gm.ckpts._policy import AnchoredPolicyLoader
SkipPeft = SkipLoRA # Alias for SkipLoRA with a more generic name.
1 change: 1 addition & 0 deletions gemma/gm/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
# Wrapper (LoRA, quantization, DPO,...)
# ****************************************************************************
from gemma.gm.nn._lora import LoRA
from gemma.gm.nn._prefix import PrefixTuning
from gemma.gm.nn._quantization import QuantizationAwareWrapper
from gemma.gm.nn._quantization import IntWrapper
from gemma.gm.nn._policy import AnchoredPolicy
Expand Down
Loading
Loading