Skip to content

Commit afd5e47

Browse files
author
The gemma Authors
committed
Internal
PiperOrigin-RevId: 911763673
1 parent 8815f0c commit afd5e47

4 files changed

Lines changed: 455 additions & 25 deletions

File tree

gemma/gm/ckpts/_checkpoint.py

Lines changed: 145 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ def make_tree_for_params(
150150
if self.has_mm_params and not params.has_mm_params:
151151
ckpt_params = _add_skip_mm_params(ckpt_params, metadata)
152152

153+
# Reconcile known structural mismatches between model-init and
154+
# checkpoint (e.g. Gemma4 LoRA: empty wrapper stubs from split_params,
155+
# leaf-vs-dict format from nn.share_scope). Only triggers when
156+
# mismatches are detected; Gemma3/legacy paths are unchanged.
157+
if _needs_reconciliation(ckpt_params, self.nested_tree):
158+
ckpt_params = _reconcile_tree(ckpt_params, self.nested_tree)
159+
153160
# 2. Reformat the nested tree to match the checkpoint structure.
154161
if self.type == _CheckpointType.NESTED:
155162
target_params = ckpt_params # No need to reformat
@@ -170,7 +177,11 @@ def make_tree_for_params(
170177

171178
@functools.cached_property
172179
def has_mm_params(self) -> bool:
173-
return 'vision_encoder' in self.nested_tree
180+
# Check for any known multimodal encoder (vision or audio).
181+
return (
182+
'vision_encoder' in self.nested_tree
183+
or 'audio_encoder' in self.nested_tree
184+
)
174185

175186
@functools.cached_property
176187
def has_audio_input_embedding(self) -> bool:
@@ -395,10 +406,22 @@ def _remove_mm_params(params):
395406
# TODO(epot): Once orbax supports partial restore, we would not need to
396407
# load those extra params in the first place.
397408

398-
del params['vision_encoder']
399-
for k in ('mm_input_projection', 'mm_soft_embedding_norm'):
409+
# Vision params
410+
if 'vision_encoder' in params:
411+
del params['vision_encoder']
412+
for k in ('mm_input_projection', 'mm_soft_embedding_norm',
413+
'mm_pre_projection_norm', 'mm_input_embedding_extra'):
414+
if k in params.get('embedder', {}):
415+
del params['embedder'][k]
416+
417+
# Audio params (Gemma4)
418+
if 'audio_encoder' in params:
419+
del params['audio_encoder']
420+
for k in ('audio_input_projection', 'audio_soft_embedding_norm',
421+
'audio_input_embedding', 'audio_input_embedding_extra'):
400422
if k in params.get('embedder', {}):
401423
del params['embedder'][k]
424+
402425
return params
403426

404427

@@ -407,14 +430,129 @@ def _add_skip_mm_params(params: Params, metadata: _CheckpointTree) -> Params:
407430
params = etree.copy(params)
408431
params_with_mm = metadata.nested_tree
409432

410-
params['vision_encoder'] = params_with_mm['vision_encoder']
411-
for k in ('mm_input_projection', 'mm_soft_embedding_norm'):
412-
if k in params_with_mm.get('embedder', {}):
413-
params['embedder'][k] = params_with_mm['embedder'][k]
433+
# Known top-level multimodal encoder keys.
434+
_MM_TOP_LEVEL_KEYS = ('vision_encoder', 'audio_encoder')
435+
# Known embedder-level multimodal projection/norm keys.
436+
_MM_EMBEDDER_KEYS = (
437+
# Vision
438+
'mm_input_projection',
439+
'mm_soft_embedding_norm',
440+
'mm_pre_projection_norm',
441+
'mm_input_embedding_extra',
442+
# Audio
443+
'audio_input_projection',
444+
'audio_soft_embedding_norm',
445+
'audio_input_embedding',
446+
'audio_input_embedding_extra',
447+
)
448+
449+
for k in _MM_TOP_LEVEL_KEYS:
450+
if k in params_with_mm and k not in params:
451+
params[k] = params_with_mm[k]
452+
453+
embedder_mm = params_with_mm.get('embedder', {})
454+
for k in _MM_EMBEDDER_KEYS:
455+
if k in embedder_mm and k not in params.get('embedder', {}):
456+
params['embedder'][k] = embedder_mm[k]
414457

415458
return params
416459

417460

461+
def _needs_reconciliation(params: Params, metadata_tree: Params) -> bool:
462+
"""Returns True if the model params tree has known structural mismatches.
463+
464+
Detects two patterns that arise when LoRA interceptors interact with
465+
models using ``nn.share_scope`` (e.g. Gemma4 FeedForward):
466+
467+
1. **Empty ``{}`` stubs** left by ``peft.split_params`` at LoRA wrapper
468+
scopes (e.g. ``_LoRAEinsum_0``). These keys exist in the model tree
469+
but not in the checkpoint.
470+
2. **Leaf-vs-dict format**: ``nn.share_scope`` flattens ``{'w': array}``
471+
to bare ``ArrayImpl`` in the model-init tree, while the checkpoint
472+
keeps the dict format.
473+
474+
This check is intentionally conservative — it returns ``False`` for
475+
Gemma3 and legacy checkpoints, so their restore paths are unchanged.
476+
"""
477+
if not isinstance(params, dict) or not isinstance(metadata_tree, dict):
478+
return False
479+
480+
for k, p_val in params.items():
481+
if k not in metadata_tree:
482+
# Key in model but not in checkpoint (e.g. LoRA stub).
483+
if isinstance(p_val, dict) and not p_val:
484+
return True
485+
continue
486+
m_val = metadata_tree[k]
487+
# Leaf-vs-dict mismatch.
488+
if not isinstance(p_val, dict) and isinstance(m_val, dict):
489+
return True
490+
# Recurse into sub-dicts.
491+
if isinstance(p_val, dict) and isinstance(m_val, dict):
492+
if _needs_reconciliation(p_val, m_val):
493+
return True
494+
495+
return False
496+
497+
498+
def _reconcile_tree(params: Params, metadata_tree: Params) -> Params:
499+
"""Align model-init params tree to match checkpoint metadata structure.
500+
501+
Only called when :func:`_needs_reconciliation` returns ``True``.
502+
503+
Handles two known mismatches between ``model.init()`` and on-disk
504+
checkpoints:
505+
506+
1. **Empty stubs**: LoRA wrappers (or other interceptors) may leave
507+
empty dict scopes in the params tree that don't exist in the
508+
checkpoint. These are dropped.
509+
2. **Leaf-vs-dict format**: ``nn.share_scope`` in Gemma4 ``FeedForward``
510+
flattens ``{'w': array}`` to bare ``ArrayImpl`` during model init.
511+
When the checkpoint stores ``{'w': array}``, the leaf is wrapped to
512+
match.
513+
514+
Args:
515+
params: The model-init params tree (may contain stubs / format
516+
mismatches).
517+
metadata_tree: The checkpoint metadata tree (ground-truth structure).
518+
519+
Returns:
520+
A new params tree aligned to the checkpoint metadata structure.
521+
"""
522+
if not isinstance(params, dict) or not isinstance(metadata_tree, dict):
523+
return params
524+
525+
result = {}
526+
for k in metadata_tree:
527+
if k not in params:
528+
# Key in checkpoint but not in model (e.g. MM params handled by
529+
# _add_skip_mm_params separately) — skip.
530+
continue
531+
p_val = params[k]
532+
m_val = metadata_tree[k]
533+
534+
if isinstance(p_val, dict) and isinstance(m_val, dict):
535+
# Both dicts — recurse.
536+
reconciled = _reconcile_tree(p_val, m_val)
537+
if reconciled: # Drop if empty after reconciliation.
538+
result[k] = reconciled
539+
elif not isinstance(p_val, dict) and isinstance(m_val, dict):
540+
# Model has leaf (ArrayImpl), checkpoint has dict ({'w': ...}).
541+
# Wrap the leaf to match checkpoint format.
542+
if len(m_val) == 1:
543+
inner_key = next(iter(m_val))
544+
result[k] = {inner_key: p_val}
545+
else:
546+
result[k] = p_val # Fallback: keep as-is.
547+
else:
548+
# Both leaves, or model has dict but checkpoint has leaf.
549+
result[k] = p_val
550+
551+
# Keys in params but NOT in metadata are intentionally dropped.
552+
# This strips LoRA wrapper stubs (_LoRAEinsum_0, etc.).
553+
return result
554+
555+
418556
def _is_flat_layout(params: Params) -> bool:
419557
"""Returns True is the structure is the legacy one."""
420558
return (not _is_stacked_layout(params)) and all(

gemma/gm/data/_tasks.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,42 @@
2121
import einops
2222
from etils.etree import jax as etree # pylint: disable=g-importing-member
2323
from gemma.gm.data import _functional
24-
from gemma.gm.text import _template
2524
from gemma.gm.text import _tokenizer
2625
from grain import python as grain
2726
import jax
2827
from kauldron import kd
2928
import numpy as np
3029

30+
# Turn tag strings indexed by tokenizer FORMAT.
31+
# Gemma4 uses '<|turn>' / '<turn|>', all others use '<start_of_turn>' /
32+
# '<end_of_turn>'. Importing the `dialog` library is intentionally avoided
33+
# to keep the dep footprint small; the two known format strings are inlined.
34+
_TURN_TAGS: dict[str, tuple[str, str]] = {}
35+
36+
37+
def _get_turn_tags(
38+
tokenizer: _tokenizer.Tokenizer,
39+
) -> tuple[str, str]:
40+
"""Returns (start_of_turn, end_of_turn) tag strings for *tokenizer*."""
41+
fmt = getattr(tokenizer, 'FORMAT', None)
42+
# dialog.Format.GEMMA4 has value 'gemma4' (StrEnum).
43+
if fmt is not None and str(fmt).lower() == 'gemma4':
44+
return ('<|turn>', '<turn|>')
45+
# Default: Gemma3 / legacy format.
46+
return ('<start_of_turn>', '<end_of_turn>')
47+
48+
49+
def _format_prompt(prompt: str, tokenizer: _tokenizer.Tokenizer) -> str:
50+
"""Formats *prompt* with the correct turn tags for *tokenizer*."""
51+
sot, eot = _get_turn_tags(tokenizer)
52+
return f'{sot}user\n{prompt}{eot}\n{sot}model\n'
53+
54+
55+
def _format_answer(response: str, tokenizer: _tokenizer.Tokenizer) -> str:
56+
"""Formats *response* with the correct turn tags for *tokenizer*."""
57+
_, eot = _get_turn_tags(tokenizer)
58+
return f'{response}{eot}'
59+
3160

3261
@dataclasses.dataclass(kw_only=True, frozen=True)
3362
class Seq2SeqTask(grain.MapTransform):
@@ -115,10 +144,10 @@ def map(self, element):
115144
prompt = _decode_bytes(prompt)
116145
response = _decode_bytes(response)
117146

118-
# Format the input to match the expected dialog template.
119-
# TODO(epot): Add a `template` protocol to allow customizing this.
120-
prompt = _template.PROMPT.format(prompt)
121-
response = _template.ANSWER.format(response)
147+
# Format the input using tokenizer-aware turn tags.
148+
# TODO(epot): Add a `template` protocol for full customization.
149+
prompt = _format_prompt(prompt, self.tokenizer)
150+
response = _format_answer(response, self.tokenizer)
122151

123152
# For sampling, we don't need to tokenize the input.
124153
if self.sampling:
@@ -219,11 +248,11 @@ def map(self, element):
219248
chosen = _decode_bytes(chosen)
220249
rejected = _decode_bytes(rejected)
221250

222-
# Format the input to match the expected dialog template.
223-
# TODO(epot): Move this in a separate FormatDialog transform.
224-
prompt = _template.PROMPT.format(prompt)
225-
chosen = _template.ANSWER.format(chosen)
226-
rejected = _template.ANSWER.format(rejected)
251+
# Format the input using tokenizer-aware turn tags.
252+
# TODO(epot): Extract into a standalone FormatDialog transform.
253+
prompt = _format_prompt(prompt, self.tokenizer)
254+
chosen = _format_answer(chosen, self.tokenizer)
255+
rejected = _format_answer(rejected, self.tokenizer)
227256

228257
# Tokenize the input and the responses.
229258
# Note: Input should start with begin-of-sequence token.

gemma/gm/nn/_lora.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,26 @@
2121
from flax import linen as nn
2222
from gemma import peft
2323
from gemma.gm.nn import _layers
24+
from gemma.gm.nn.gemma3n import _layers as _gemma3n_layers
25+
from gemma.gm.nn.gemma4 import _layers as _gemma4_layers
2426
import jax
2527
import jax.numpy as jnp
2628
from kauldron import kontext
2729
import numpy as np
2830

2931

30-
_SUPPORTED_MODULES = (nn.Dense, nn.Einsum, nn.DenseGeneral, _layers.Einsum)
32+
_SUPPORTED_MODULES = (
33+
nn.Dense,
34+
nn.Einsum,
35+
nn.DenseGeneral,
36+
_layers.Einsum,
37+
_gemma4_layers.Einsum,
38+
_gemma4_layers.ClippedEinsum,
39+
_gemma3n_layers.Einsum,
40+
# NOTE: nano._layers.Einsum is excluded because nano:nano depends on
41+
# //third_party/py/gemma/gm, creating a circular BUILD dependency.
42+
# To add it, nano._layers needs its own fine-grained BUILD target.
43+
)
3144

3245

3346
class LoRA(nn.Module):
@@ -107,19 +120,19 @@ def _replace_by_lora(
107120
if debug_str:
108121
logging.info(debug_str)
109122

110-
# TODO(epot): Replace by generic LoRA wrapper ?
123+
# TODO(epot): Replace by generic LoRA wrapper ?
111124
match module:
112125
case nn.Dense():
113126
return peft.LoRADense(rank=rank, dtype=dtype, wrapped=module)
114127
case nn.Einsum():
115128
return peft.LoRAEinsum(rank=rank, dtype=dtype, wrapped=module)
116129
case nn.DenseGeneral():
117130
return peft.LoRADenseGeneral(rank=rank, dtype=dtype, wrapped=module)
118-
case _layers.Einsum():
119-
# This hack is required because the FeedForward layer call two different
120-
# Einsum with using `nn.share_scope`, so the two wrappers need a different
121-
# name.
122-
# This seems to be a bug in flax interceptor.
131+
case _ if isinstance(module, _SUPPORTED_MODULES):
132+
# All custom Einsum variants (gm.nn, gemma4, gemma3n, nano, etc.)
133+
# use `_LoRAEinsum` wrapper. The name hack is required because
134+
# FeedForward uses `nn.share_scope` to flatten two Einsum modules
135+
# into the same param scope — the two wrappers need distinct names.
123136
if module.weight_name != 'w':
124137
name = f'_LoRAEinsum_{module.weight_name}'
125138
else:
@@ -135,7 +148,7 @@ class _LoRAEinsum(nn.Module):
135148
_: dataclasses.KW_ONLY
136149
rank: int
137150
dtype: np.dtype
138-
wrapped: _layers.Einsum
151+
wrapped: nn.Module # Any Einsum variant (gm.nn, gemma4, gemma3n, nano)
139152

140153
# Do not use `nn.share_scope` here as the `wrapped` module inside
141154
# `FeedForward` already uses `nn.share_scope`, so the two Einsum used in

0 commit comments

Comments
 (0)