Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 145 additions & 7 deletions gemma/gm/ckpts/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def make_tree_for_params(
if self.has_mm_params and not params.has_mm_params:
ckpt_params = _add_skip_mm_params(ckpt_params, metadata)

# Reconcile known structural mismatches between model-init and
# checkpoint (e.g. Gemma4 LoRA: empty wrapper stubs from split_params,
# leaf-vs-dict format from nn.share_scope). Only triggers when
# mismatches are detected; Gemma3/legacy paths are unchanged.
if _needs_reconciliation(ckpt_params, self.nested_tree):
ckpt_params = _reconcile_tree(ckpt_params, self.nested_tree)

# 2. Reformat the nested tree to match the checkpoint structure.
if self.type == _CheckpointType.NESTED:
target_params = ckpt_params # No need to reformat
Expand All @@ -170,7 +177,11 @@ def make_tree_for_params(

@functools.cached_property
def has_mm_params(self) -> bool:
return 'vision_encoder' in self.nested_tree
# Check for any known multimodal encoder (vision or audio).
return (
'vision_encoder' in self.nested_tree
or 'audio_encoder' in self.nested_tree
)

@functools.cached_property
def has_audio_input_embedding(self) -> bool:
Expand Down Expand Up @@ -395,10 +406,22 @@ def _remove_mm_params(params):
# TODO(epot): Once orbax supports partial restore, we would not need to
# load those extra params in the first place.

del params['vision_encoder']
for k in ('mm_input_projection', 'mm_soft_embedding_norm'):
# Vision params
if 'vision_encoder' in params:
del params['vision_encoder']
for k in ('mm_input_projection', 'mm_soft_embedding_norm',
'mm_pre_projection_norm', 'mm_input_embedding_extra'):
if k in params.get('embedder', {}):
del params['embedder'][k]

# Audio params (Gemma4)
if 'audio_encoder' in params:
del params['audio_encoder']
for k in ('audio_input_projection', 'audio_soft_embedding_norm',
'audio_input_embedding', 'audio_input_embedding_extra'):
if k in params.get('embedder', {}):
del params['embedder'][k]

return params


Expand All @@ -407,14 +430,129 @@ def _add_skip_mm_params(params: Params, metadata: _CheckpointTree) -> Params:
params = etree.copy(params)
params_with_mm = metadata.nested_tree

params['vision_encoder'] = params_with_mm['vision_encoder']
for k in ('mm_input_projection', 'mm_soft_embedding_norm'):
if k in params_with_mm.get('embedder', {}):
params['embedder'][k] = params_with_mm['embedder'][k]
# Known top-level multimodal encoder keys.
_MM_TOP_LEVEL_KEYS = ('vision_encoder', 'audio_encoder')
# Known embedder-level multimodal projection/norm keys.
_MM_EMBEDDER_KEYS = (
# Vision
'mm_input_projection',
'mm_soft_embedding_norm',
'mm_pre_projection_norm',
'mm_input_embedding_extra',
# Audio
'audio_input_projection',
'audio_soft_embedding_norm',
'audio_input_embedding',
'audio_input_embedding_extra',
)

for k in _MM_TOP_LEVEL_KEYS:
if k in params_with_mm and k not in params:
params[k] = params_with_mm[k]

embedder_mm = params_with_mm.get('embedder', {})
for k in _MM_EMBEDDER_KEYS:
if k in embedder_mm and k not in params.get('embedder', {}):
params['embedder'][k] = embedder_mm[k]

return params


def _needs_reconciliation(params: Params, metadata_tree: Params) -> bool:
"""Returns True if the model params tree has known structural mismatches.

Detects two patterns that arise when LoRA interceptors interact with
models using ``nn.share_scope`` (e.g. Gemma4 FeedForward):

1. **Empty ``{}`` stubs** left by ``peft.split_params`` at LoRA wrapper
scopes (e.g. ``_LoRAEinsum_0``). These keys exist in the model tree
but not in the checkpoint.
2. **Leaf-vs-dict format**: ``nn.share_scope`` flattens ``{'w': array}``
to bare ``ArrayImpl`` in the model-init tree, while the checkpoint
keeps the dict format.

This check is intentionally conservative — it returns ``False`` for
Gemma3 and legacy checkpoints, so their restore paths are unchanged.
"""
if not isinstance(params, dict) or not isinstance(metadata_tree, dict):
return False

for k, p_val in params.items():
if k not in metadata_tree:
# Key in model but not in checkpoint (e.g. LoRA stub).
if isinstance(p_val, dict) and not p_val:
return True
continue
m_val = metadata_tree[k]
# Leaf-vs-dict mismatch.
if not isinstance(p_val, dict) and isinstance(m_val, dict):
return True
# Recurse into sub-dicts.
if isinstance(p_val, dict) and isinstance(m_val, dict):
if _needs_reconciliation(p_val, m_val):
return True

return False


def _reconcile_tree(params: Params, metadata_tree: Params) -> Params:
"""Align model-init params tree to match checkpoint metadata structure.

Only called when :func:`_needs_reconciliation` returns ``True``.

Handles two known mismatches between ``model.init()`` and on-disk
checkpoints:

1. **Empty stubs**: LoRA wrappers (or other interceptors) may leave
empty dict scopes in the params tree that don't exist in the
checkpoint. These are dropped.
2. **Leaf-vs-dict format**: ``nn.share_scope`` in Gemma4 ``FeedForward``
flattens ``{'w': array}`` to bare ``ArrayImpl`` during model init.
When the checkpoint stores ``{'w': array}``, the leaf is wrapped to
match.

Args:
params: The model-init params tree (may contain stubs / format
mismatches).
metadata_tree: The checkpoint metadata tree (ground-truth structure).

Returns:
A new params tree aligned to the checkpoint metadata structure.
"""
if not isinstance(params, dict) or not isinstance(metadata_tree, dict):
return params

result = {}
for k in metadata_tree:
if k not in params:
# Key in checkpoint but not in model (e.g. MM params handled by
# _add_skip_mm_params separately) — skip.
continue
p_val = params[k]
m_val = metadata_tree[k]

if isinstance(p_val, dict) and isinstance(m_val, dict):
# Both dicts — recurse.
reconciled = _reconcile_tree(p_val, m_val)
if reconciled: # Drop if empty after reconciliation.
result[k] = reconciled
elif not isinstance(p_val, dict) and isinstance(m_val, dict):
# Model has leaf (ArrayImpl), checkpoint has dict ({'w': ...}).
# Wrap the leaf to match checkpoint format.
if len(m_val) == 1:
inner_key = next(iter(m_val))
result[k] = {inner_key: p_val}
else:
result[k] = p_val # Fallback: keep as-is.
else:
# Both leaves, or model has dict but checkpoint has leaf.
result[k] = p_val

# Keys in params but NOT in metadata are intentionally dropped.
# This strips LoRA wrapper stubs (_LoRAEinsum_0, etc.).
return result


def _is_flat_layout(params: Params) -> bool:
"""Returns True is the structure is the legacy one."""
return (not _is_stacked_layout(params)) and all(
Expand Down
49 changes: 39 additions & 10 deletions gemma/gm/data/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,42 @@
import einops
from etils.etree import jax as etree # pylint: disable=g-importing-member
from gemma.gm.data import _functional
from gemma.gm.text import _template
from gemma.gm.text import _tokenizer
from grain import python as grain
import jax
from kauldron import kd
import numpy as np

# Turn tag strings indexed by tokenizer FORMAT.
# Gemma4 uses '<|turn>' / '<turn|>', all others use '<start_of_turn>' /
# '<end_of_turn>'. Importing the `dialog` library is intentionally avoided
# to keep the dep footprint small; the two known format strings are inlined.
_TURN_TAGS: dict[str, tuple[str, str]] = {}


def _get_turn_tags(
tokenizer: _tokenizer.Tokenizer,
) -> tuple[str, str]:
"""Returns (start_of_turn, end_of_turn) tag strings for *tokenizer*."""
fmt = getattr(tokenizer, 'FORMAT', None)
# dialog.Format.GEMMA4 has value 'gemma4' (StrEnum).
if fmt is not None and str(fmt).lower() == 'gemma4':
return ('<|turn>', '<turn|>')
# Default: Gemma3 / legacy format.
return ('<start_of_turn>', '<end_of_turn>')


def _format_prompt(prompt: str, tokenizer: _tokenizer.Tokenizer) -> str:
"""Formats *prompt* with the correct turn tags for *tokenizer*."""
sot, eot = _get_turn_tags(tokenizer)
return f'{sot}user\n{prompt}{eot}\n{sot}model\n'


def _format_answer(response: str, tokenizer: _tokenizer.Tokenizer) -> str:
"""Formats *response* with the correct turn tags for *tokenizer*."""
_, eot = _get_turn_tags(tokenizer)
return f'{response}{eot}'


@dataclasses.dataclass(kw_only=True, frozen=True)
class Seq2SeqTask(grain.MapTransform):
Expand Down Expand Up @@ -115,10 +144,10 @@ def map(self, element):
prompt = _decode_bytes(prompt)
response = _decode_bytes(response)

# Format the input to match the expected dialog template.
# TODO(epot): Add a `template` protocol to allow customizing this.
prompt = _template.PROMPT.format(prompt)
response = _template.ANSWER.format(response)
# Format the input using tokenizer-aware turn tags.
# TODO(epot): Add a `template` protocol for full customization.
prompt = _format_prompt(prompt, self.tokenizer)
response = _format_answer(response, self.tokenizer)

# For sampling, we don't need to tokenize the input.
if self.sampling:
Expand Down Expand Up @@ -219,11 +248,11 @@ def map(self, element):
chosen = _decode_bytes(chosen)
rejected = _decode_bytes(rejected)

# Format the input to match the expected dialog template.
# TODO(epot): Move this in a separate FormatDialog transform.
prompt = _template.PROMPT.format(prompt)
chosen = _template.ANSWER.format(chosen)
rejected = _template.ANSWER.format(rejected)
# Format the input using tokenizer-aware turn tags.
# TODO(epot): Extract into a standalone FormatDialog transform.
prompt = _format_prompt(prompt, self.tokenizer)
chosen = _format_answer(chosen, self.tokenizer)
rejected = _format_answer(rejected, self.tokenizer)

# Tokenize the input and the responses.
# Note: Input should start with begin-of-sequence token.
Expand Down
29 changes: 21 additions & 8 deletions gemma/gm/nn/_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,26 @@
from flax import linen as nn
from gemma import peft
from gemma.gm.nn import _layers
from gemma.gm.nn.gemma3n import _layers as _gemma3n_layers
from gemma.gm.nn.gemma4 import _layers as _gemma4_layers
import jax
import jax.numpy as jnp
from kauldron import kontext
import numpy as np


_SUPPORTED_MODULES = (nn.Dense, nn.Einsum, nn.DenseGeneral, _layers.Einsum)
_SUPPORTED_MODULES = (
nn.Dense,
nn.Einsum,
nn.DenseGeneral,
_layers.Einsum,
_gemma4_layers.Einsum,
_gemma4_layers.ClippedEinsum,
_gemma3n_layers.Einsum,
# NOTE: nano._layers.Einsum is excluded because nano:nano depends on
# //third_party/py/gemma/gm, creating a circular BUILD dependency.
# To add it, nano._layers needs its own fine-grained BUILD target.
)


class LoRA(nn.Module):
Expand Down Expand Up @@ -107,19 +120,19 @@ def _replace_by_lora(
if debug_str:
logging.info(debug_str)

# TODO(epot): Replace by generic LoRA wrapper ?
# TODO(epot): Replace by generic LoRA wrapper ?
match module:
case nn.Dense():
return peft.LoRADense(rank=rank, dtype=dtype, wrapped=module)
case nn.Einsum():
return peft.LoRAEinsum(rank=rank, dtype=dtype, wrapped=module)
case nn.DenseGeneral():
return peft.LoRADenseGeneral(rank=rank, dtype=dtype, wrapped=module)
case _layers.Einsum():
# This hack is required because the FeedForward layer call two different
# Einsum with using `nn.share_scope`, so the two wrappers need a different
# name.
# This seems to be a bug in flax interceptor.
case _ if isinstance(module, _SUPPORTED_MODULES):
# All custom Einsum variants (gm.nn, gemma4, gemma3n, nano, etc.)
# use `_LoRAEinsum` wrapper. The name hack is required because
# FeedForward uses `nn.share_scope` to flatten two Einsum modules
# into the same param scope — the two wrappers need distinct names.
if module.weight_name != 'w':
name = f'_LoRAEinsum_{module.weight_name}'
else:
Expand All @@ -135,7 +148,7 @@ class _LoRAEinsum(nn.Module):
_: dataclasses.KW_ONLY
rank: int
dtype: np.dtype
wrapped: _layers.Einsum
wrapped: nn.Module # Any Einsum variant (gm.nn, gemma4, gemma3n, nano)

# Do not use `nn.share_scope` here as the `wrapped` module inside
# `FeedForward` already uses `nn.share_scope`, so the two Einsum used in
Expand Down
Loading
Loading