Skip to content

Commit 44a625e

Browse files
author
The gemma Authors
committed
Make sampler dispatch overridable in ChatSampler.
PiperOrigin-RevId: 906439952
1 parent ae84d95 commit 44a625e

2 files changed

Lines changed: 143 additions & 24 deletions

File tree

gemma/gm/text/_chat_sampler.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,46 @@ def gemma4_sampler(self) -> _gemma4_sampler.Gemma4Sampler:
218218
'Use `sampler` instead.'
219219
)
220220

221+
def _sample(
222+
self,
223+
prompt_text: str,
224+
*,
225+
images,
226+
audio,
227+
audio_lengths,
228+
sampling,
229+
max_new_tokens,
230+
rng,
231+
last_state,
232+
stream,
233+
sharding,
234+
):
235+
"""Dispatches to the correct underlying sampler."""
236+
if self._is_gemma4:
237+
return self.gemma4_sampler.sample(
238+
prompt_text,
239+
images=images,
240+
audio=audio,
241+
audio_lengths=audio_lengths,
242+
sampling=sampling,
243+
max_new_tokens=max_new_tokens,
244+
rng=rng,
245+
return_state=True,
246+
last_state=last_state,
247+
sharding=sharding,
248+
)
249+
else:
250+
return self.sampler.sample( # pytype: disable=wrong-arg-types
251+
prompt_text,
252+
images=images,
253+
sampling=sampling,
254+
max_new_tokens=max_new_tokens,
255+
rng=rng,
256+
return_state=True,
257+
last_state=last_state,
258+
stream=bool(stream),
259+
)
260+
221261
def chat(
222262
self,
223263
prompt: str | dialog.Conversation,
@@ -338,30 +378,18 @@ def chat(
338378
)
339379

340380
# --- Dispatch to the correct sampler ---
341-
if self._is_gemma4:
342-
out = self.gemma4_sampler.sample(
343-
prompt_text,
344-
images=images,
345-
audio=audio,
346-
audio_lengths=audio_lengths,
347-
sampling=sampling,
348-
max_new_tokens=max_new_tokens,
349-
rng=rng,
350-
return_state=True,
351-
last_state=last_state,
352-
sharding=sharding,
353-
)
354-
else:
355-
out = self.sampler.sample( # pytype: disable=wrong-arg-types
356-
prompt_text,
357-
images=images,
358-
sampling=sampling,
359-
max_new_tokens=max_new_tokens,
360-
rng=rng,
361-
return_state=True,
362-
last_state=last_state,
363-
stream=bool(stream),
364-
)
381+
out = self._sample(
382+
prompt_text,
383+
images=images,
384+
audio=audio,
385+
audio_lengths=audio_lengths,
386+
sampling=sampling,
387+
max_new_tokens=max_new_tokens,
388+
rng=rng,
389+
last_state=last_state,
390+
stream=stream,
391+
sharding=sharding,
392+
)
365393

366394
# In streaming mode, the output is an iterator, yielding tokens one at a
367395
# time.

gemma/gm/text/_sampler_test.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest import mock
1516
from gemma import gm
17+
from gemma.gm.text import _sampler
1618
from gemma.gm.text import _sampler_loop
1719
import jax
1820
import jax.numpy as jnp
@@ -54,3 +56,92 @@ def test_sampler():
5456
)
5557
sampler.sample('Hello world')
5658

59+
60+
def test_chat_sampler_gemma4_dispatch():
61+
"""Tests that _sample() dispatches to gemma4_sampler when _is_gemma4 is True.
62+
63+
Uses mocks to verify the dispatch logic without requiring a full Gemma4
64+
model. This catches regressions in the _sample() method that could break
65+
the Gemma4 path.
66+
"""
67+
model = gm.testing.DummyGemma()
68+
params = model.init(
69+
jax.random.PRNGKey(0),
70+
jnp.zeros((5,), dtype=jnp.int32),
71+
)
72+
params = params['params']
73+
tokenizer = gm.testing.DummyTokenizer()
74+
chat_sampler = gm.text.ChatSampler(
75+
model=model,
76+
params=params,
77+
tokenizer=tokenizer,
78+
cache_length=128,
79+
max_out_length=128,
80+
)
81+
82+
# Force the Gemma4 dispatch path.
83+
mock_sample = mock.MagicMock(
84+
return_value=_sampler.SamplerOutput(
85+
text='mock output',
86+
state=mock.MagicMock(),
87+
)
88+
)
89+
with mock.patch.object(
90+
type(chat_sampler),
91+
'_is_gemma4',
92+
new_callable=lambda: property(lambda self: True),
93+
):
94+
with mock.patch.object(
95+
type(chat_sampler),
96+
'gemma4_sampler',
97+
new_callable=lambda: property(
98+
lambda self: mock.MagicMock(sample=mock_sample)
99+
),
100+
):
101+
output = chat_sampler.chat('Hello world')
102+
assert isinstance(output, str)
103+
# Verify gemma4_sampler.sample was called (not sampler.sample).
104+
mock_sample.assert_called_once()
105+
106+
107+
def test_chat_sampler_non_gemma4_dispatch():
108+
"""Tests that _sample() dispatches to sampler when _is_gemma4 is False.
109+
110+
Uses mocks to verify the dispatch logic without exercising the full sampling
111+
pipeline (which is already covered by test_sampler). This catches regressions
112+
in _sample() that could break the non-Gemma4 dispatch path.
113+
"""
114+
model = gm.testing.DummyGemma()
115+
params = model.init(
116+
jax.random.PRNGKey(0),
117+
jnp.zeros((5,), dtype=jnp.int32),
118+
)
119+
params = params['params']
120+
tokenizer = gm.testing.DummyTokenizer()
121+
chat_sampler = gm.text.ChatSampler(
122+
model=model,
123+
params=params,
124+
tokenizer=tokenizer,
125+
cache_length=128,
126+
max_out_length=128,
127+
)
128+
129+
assert not chat_sampler._is_gemma4 # Confirm non-Gemma4 dispatch path.
130+
131+
mock_sample = mock.MagicMock(
132+
return_value=_sampler.SamplerOutput(
133+
text='mock output',
134+
state=mock.MagicMock(),
135+
)
136+
)
137+
with mock.patch.object(
138+
type(chat_sampler),
139+
'sampler',
140+
new_callable=lambda: property(
141+
lambda self: mock.MagicMock(sample=mock_sample)
142+
),
143+
):
144+
output = chat_sampler.chat('Hello world')
145+
assert isinstance(output, str)
146+
# Verify sampler.sample was called (not gemma4_sampler.sample).
147+
mock_sample.assert_called_once()

0 commit comments

Comments
 (0)