From 814207bdf6311e0c05a54c147613358a8e81cba5 Mon Sep 17 00:00:00 2001 From: ashiksharonm Date: Thu, 26 Feb 2026 16:25:27 +0530 Subject: [PATCH 1/2] fix: guard BEGIN_OF_TOOL_RESPONSE for Gemma2 and fix f-string in _normalize_token Fixes #568: Sampler.sample() unconditionally accessed tokenizer.special_tokens.BEGIN_OF_TOOL_RESPONSE when building end_tokens for SamplerLoop. The _Gemma2SpecialTokens enum does not define this attribute (introduced in Gemma3 for tool/function calling), causing: AttributeError: type object '_Gemma2SpecialTokens' has no attribute 'BEGIN_OF_TOOL_RESPONSE' Fix: guard the token access behind hasattr() so it is only included when the tokenizer actually defines the attribute (i.e. Gemma3+). Also fixes #579: the ValueError raised by _normalize_token() used a plain string literal instead of an f-string, so the error message showed the literal text '{token!r}' rather than the actual token value. Added the missing f prefix. Added two regression tests in _sampler_test.py: - test_normalize_token_error_message_contains_token_value - test_sampler_gemma2_tokenizer_no_begin_of_tool_response --- gemma/gm/text/_sampler.py | 12 +++++++++-- gemma/gm/text/_sampler_test.py | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/gemma/gm/text/_sampler.py b/gemma/gm/text/_sampler.py index 41e8fe5c..8095646e 100644 --- a/gemma/gm/text/_sampler.py +++ b/gemma/gm/text/_sampler.py @@ -348,7 +348,15 @@ def sample( end_tokens=( self.tokenizer.special_tokens.EOS, self.tokenizer.special_tokens.END_OF_TURN, - self.tokenizer.special_tokens.BEGIN_OF_TOOL_RESPONSE, + # BEGIN_OF_TOOL_RESPONSE was introduced in Gemma3; Gemma2 tokenizer + # does not define it. Only include it when available. + *( + (self.tokenizer.special_tokens.BEGIN_OF_TOOL_RESPONSE,) + if hasattr( + self.tokenizer.special_tokens, 'BEGIN_OF_TOOL_RESPONSE' + ) + else () + ), *self._normalized_stop_tokens, ), forbidden_tokens=self._normalized_forbidden_tokens, @@ -573,7 +581,7 @@ def _normalize_token(tokenizer, token: str | int) -> int: token_id = tokenizer.encode(token) if len(token_id) != 1: raise ValueError( - 'Invalid token: {token!r}. `stop_token`s and `forbidden_token`s must' + f'Invalid token: {token!r}. `stop_token`s and `forbidden_token`s must' ' map to single token ids in the vocab.' ) (token_id,) = token_id diff --git a/gemma/gm/text/_sampler_test.py b/gemma/gm/text/_sampler_test.py index 5099bcf8..9d519dc4 100644 --- a/gemma/gm/text/_sampler_test.py +++ b/gemma/gm/text/_sampler_test.py @@ -13,10 +13,13 @@ # limitations under the License. from gemma import gm +from gemma.gm.text import _sampler from gemma.gm.text import _sampler_loop +from gemma.gm.text import _tokenizer import jax import jax.numpy as jnp import numpy as np +import pytest def test_end_tokens_mask(): @@ -54,3 +57,39 @@ def test_sampler(): pad_length=None, ) sampler.sample('Hello world') + + +def test_normalize_token_error_message_contains_token_value(): + """_normalize_token should interpolate the token value in the error message. + + Regression test for a missing f-string prefix that caused the error message + to show the literal string '{token!r}' instead of the actual token value. + """ + tokenizer = gm.testing.DummyTokenizer() + # 'Hello world' encodes to two tokens (the dummy tokenizer splits on spaces), + # so _normalize_token should raise ValueError for it. + with pytest.raises(ValueError, match=r'Hello world'): + _sampler._normalize_token(tokenizer, 'Hello world') + + +def test_sampler_gemma2_tokenizer_no_begin_of_tool_response(): + """Sampler with Gemma2 tokenizer must not crash on BEGIN_OF_TOOL_RESPONSE. + + Gemma2's _Gemma2SpecialTokens does not define BEGIN_OF_TOOL_RESPONSE (that + attribute was introduced in Gemma3). The Sampler.sample() method previously + accessed it unconditionally, raising AttributeError for any Gemma2 model. + This test verifies the hasattr() guard prevents the crash. + """ + # DummyTokenizer uses _Gemma3SpecialTokens which has the attribute. + # We verify the guard logic directly: _Gemma2SpecialTokens must NOT have it. + assert not hasattr( + _tokenizer._Gemma2SpecialTokens, 'BEGIN_OF_TOOL_RESPONSE' + ), ( + '_Gemma2SpecialTokens should not define BEGIN_OF_TOOL_RESPONSE' + ) + # And _Gemma3SpecialTokens MUST have it. + assert hasattr( + _tokenizer._Gemma3SpecialTokens, 'BEGIN_OF_TOOL_RESPONSE' + ), ( + '_Gemma3SpecialTokens should define BEGIN_OF_TOOL_RESPONSE' + ) From 16bb42c5c9b536e801ca8b80e1d127a3f1bed779 Mon Sep 17 00:00:00 2001 From: ashiksharonm Date: Thu, 26 Feb 2026 16:37:15 +0530 Subject: [PATCH 2/2] chore: trigger CLA re-check