Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions gemma/gm/nn/_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,30 @@ def __post_init__(self):
@nn.compact
def __call__(self, *args, **kwargs):
"""Calls the model."""
with self._lora_interceptor():
return self.model(*args, **kwargs)

def _lora_interceptor(self):
"""Returns the LoRA ModuleInterceptor context manager."""
replace_module_fn = functools.partial(
_replace_by_lora,
rank=self.rank,
dtype=self.dtype,
verbose=self.verbose,
)
with peft.ModuleInterceptor(replace_module_fn):
return self.model(*args, **kwargs)
return peft.ModuleInterceptor(replace_module_fn)

@nn.compact
def encoder_call(self, *args, **kwargs):
"""Calls the model's encoder_call with LoRA adapters active."""
with self._lora_interceptor():
return self.model.encoder_call(*args, **kwargs)

@nn.compact
def init_cache(self, *args, **kwargs):
"""Calls the model's init_cache with LoRA adapters active."""
with self._lora_interceptor():
return self.model.init_cache(*args, **kwargs)

def __kontext_keys__(self) -> dict[str, str]:
"""Kauldron keys when calling `kontext.get_from_keys_obj`."""
Expand Down
76 changes: 76 additions & 0 deletions gemma/gm/nn/_lora_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for LoRA wrapper delegation."""

from absl.testing import absltest
from flax import linen as nn
from gemma.gm.nn import _lora
import jax
import jax.numpy as jnp


class DummyModel(nn.Module):
some_attr: int = 42

@nn.compact
def __call__(self, x):
return nn.Dense(features=4, name='dense')(x)

@nn.compact
def encoder_call(self, x):
return nn.Dense(features=4, name='dense')(x)

@nn.compact
def init_cache(self, x):
return nn.Dense(features=4, name='dense')(x)


class LoRATest(absltest.TestCase):

def test_call_intercepts_dense(self):
model = _lora.LoRA(rank=2, model=DummyModel())
params = model.init(jax.random.key(0), jnp.zeros((1, 4)))['params']
self.assertIn('dense', params)
self.assertIn('lora', params['dense'])
self.assertIn('a', params['dense']['lora'])
self.assertIn('b', params['dense']['lora'])

def test_encoder_call_intercepts_dense(self):
model = _lora.LoRA(rank=2, model=DummyModel())
params = model.init(
jax.random.key(0), jnp.zeros((1, 4)), method=model.encoder_call
)['params']
self.assertIn('dense', params)
self.assertIn('lora', params['dense'])
self.assertIn('a', params['dense']['lora'])
self.assertIn('b', params['dense']['lora'])

def test_init_cache_intercepts_dense(self):
model = _lora.LoRA(rank=2, model=DummyModel())
params = model.init(
jax.random.key(0), jnp.zeros((1, 4)), method=model.init_cache
)['params']
self.assertIn('dense', params)
self.assertIn('lora', params['dense'])
self.assertIn('a', params['dense']['lora'])
self.assertIn('b', params['dense']['lora'])

def test_getattr_forwarding(self):
model = _lora.LoRA(rank=2, model=DummyModel())
self.assertEqual(model.some_attr, 42)


if __name__ == '__main__':
absltest.main()
Loading