diff --git a/gemma/gm/text/_sampler.py b/gemma/gm/text/_sampler.py index 5b8ae163..c65090ec 100644 --- a/gemma/gm/text/_sampler.py +++ b/gemma/gm/text/_sampler.py @@ -342,6 +342,28 @@ def sample( # TODO(epot): Donate the `init_state`, `last_state` + sampler = _sampler_loop.SamplerLoop( + # Static attributes. Changing those will trigger a recompilation. + model=self.model, + end_tokens=( + self.tokenizer.special_tokens.EOS, + self.tokenizer.special_tokens.END_OF_TURN, + # BEGIN_OF_TOOL_RESPONSE was introduced in Gemma3; Gemma2 tokenizer + # does not define it. Only include it when available. + *( + (self.tokenizer.special_tokens.BEGIN_OF_TOOL_RESPONSE,) + if hasattr( + self.tokenizer.special_tokens, 'BEGIN_OF_TOOL_RESPONSE' + ) + else () + ), + *self._normalized_stop_tokens, + ), + forbidden_tokens=self._normalized_forbidden_tokens, + sampling=sampling, + cache_length=self.cache_length, + special_tokens=self.tokenizer.special_tokens, + ) sampler = self._initialize_sampler_loop(sampling) # TODO(epot): Use `jnp.cond` to detect when the cache is full (or use @@ -577,7 +599,7 @@ def _normalize_token(tokenizer, token: str | int) -> int: token_id = tokenizer.encode(token) if len(token_id) != 1: raise ValueError( - 'Invalid token: {token!r}. `stop_token`s and `forbidden_token`s must' + f'Invalid token: {token!r}. `stop_token`s and `forbidden_token`s must' ' map to single token ids in the vocab.' ) (token_id,) = token_id diff --git a/gemma/gm/text/_sampler_test.py b/gemma/gm/text/_sampler_test.py index 5099bcf8..9d519dc4 100644 --- a/gemma/gm/text/_sampler_test.py +++ b/gemma/gm/text/_sampler_test.py @@ -13,10 +13,13 @@ # limitations under the License. from gemma import gm +from gemma.gm.text import _sampler from gemma.gm.text import _sampler_loop +from gemma.gm.text import _tokenizer import jax import jax.numpy as jnp import numpy as np +import pytest def test_end_tokens_mask(): @@ -54,3 +57,39 @@ def test_sampler(): pad_length=None, ) sampler.sample('Hello world') + + +def test_normalize_token_error_message_contains_token_value(): + """_normalize_token should interpolate the token value in the error message. + + Regression test for a missing f-string prefix that caused the error message + to show the literal string '{token!r}' instead of the actual token value. + """ + tokenizer = gm.testing.DummyTokenizer() + # 'Hello world' encodes to two tokens (the dummy tokenizer splits on spaces), + # so _normalize_token should raise ValueError for it. + with pytest.raises(ValueError, match=r'Hello world'): + _sampler._normalize_token(tokenizer, 'Hello world') + + +def test_sampler_gemma2_tokenizer_no_begin_of_tool_response(): + """Sampler with Gemma2 tokenizer must not crash on BEGIN_OF_TOOL_RESPONSE. + + Gemma2's _Gemma2SpecialTokens does not define BEGIN_OF_TOOL_RESPONSE (that + attribute was introduced in Gemma3). The Sampler.sample() method previously + accessed it unconditionally, raising AttributeError for any Gemma2 model. + This test verifies the hasattr() guard prevents the crash. + """ + # DummyTokenizer uses _Gemma3SpecialTokens which has the attribute. + # We verify the guard logic directly: _Gemma2SpecialTokens must NOT have it. + assert not hasattr( + _tokenizer._Gemma2SpecialTokens, 'BEGIN_OF_TOOL_RESPONSE' + ), ( + '_Gemma2SpecialTokens should not define BEGIN_OF_TOOL_RESPONSE' + ) + # And _Gemma3SpecialTokens MUST have it. + assert hasattr( + _tokenizer._Gemma3SpecialTokens, 'BEGIN_OF_TOOL_RESPONSE' + ), ( + '_Gemma3SpecialTokens should define BEGIN_OF_TOOL_RESPONSE' + )