|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +from unittest import mock |
15 | 16 | from gemma import gm |
| 17 | +from gemma.gm.text import _sampler |
16 | 18 | from gemma.gm.text import _sampler_loop |
17 | 19 | import jax |
18 | 20 | import jax.numpy as jnp |
@@ -54,3 +56,92 @@ def test_sampler(): |
54 | 56 | ) |
55 | 57 | sampler.sample('Hello world') |
56 | 58 |
|
| 59 | + |
| 60 | +def test_chat_sampler_gemma4_dispatch(): |
| 61 | + """Tests that _sample() dispatches to gemma4_sampler when _is_gemma4 is True. |
| 62 | +
|
| 63 | + Uses mocks to verify the dispatch logic without requiring a full Gemma4 |
| 64 | + model. This catches regressions in the _sample() method that could break |
| 65 | + the Gemma4 path. |
| 66 | + """ |
| 67 | + model = gm.testing.DummyGemma() |
| 68 | + params = model.init( |
| 69 | + jax.random.PRNGKey(0), |
| 70 | + jnp.zeros((5,), dtype=jnp.int32), |
| 71 | + ) |
| 72 | + params = params['params'] |
| 73 | + tokenizer = gm.testing.DummyTokenizer() |
| 74 | + chat_sampler = gm.text.ChatSampler( |
| 75 | + model=model, |
| 76 | + params=params, |
| 77 | + tokenizer=tokenizer, |
| 78 | + cache_length=128, |
| 79 | + max_out_length=128, |
| 80 | + ) |
| 81 | + |
| 82 | + # Force the Gemma4 dispatch path. |
| 83 | + mock_sample = mock.MagicMock( |
| 84 | + return_value=_sampler.SamplerOutput( |
| 85 | + text='mock output', |
| 86 | + state=mock.MagicMock(), |
| 87 | + ) |
| 88 | + ) |
| 89 | + with mock.patch.object( |
| 90 | + type(chat_sampler), |
| 91 | + '_is_gemma4', |
| 92 | + new_callable=lambda: property(lambda self: True), |
| 93 | + ): |
| 94 | + with mock.patch.object( |
| 95 | + type(chat_sampler), |
| 96 | + 'gemma4_sampler', |
| 97 | + new_callable=lambda: property( |
| 98 | + lambda self: mock.MagicMock(sample=mock_sample) |
| 99 | + ), |
| 100 | + ): |
| 101 | + output = chat_sampler.chat('Hello world') |
| 102 | + assert isinstance(output, str) |
| 103 | + # Verify gemma4_sampler.sample was called (not sampler.sample). |
| 104 | + mock_sample.assert_called_once() |
| 105 | + |
| 106 | + |
| 107 | +def test_chat_sampler_non_gemma4_dispatch(): |
| 108 | + """Tests that _sample() dispatches to sampler when _is_gemma4 is False. |
| 109 | +
|
| 110 | + Uses mocks to verify the dispatch logic without exercising the full sampling |
| 111 | + pipeline (which is already covered by test_sampler). This catches regressions |
| 112 | + in _sample() that could break the non-Gemma4 dispatch path. |
| 113 | + """ |
| 114 | + model = gm.testing.DummyGemma() |
| 115 | + params = model.init( |
| 116 | + jax.random.PRNGKey(0), |
| 117 | + jnp.zeros((5,), dtype=jnp.int32), |
| 118 | + ) |
| 119 | + params = params['params'] |
| 120 | + tokenizer = gm.testing.DummyTokenizer() |
| 121 | + chat_sampler = gm.text.ChatSampler( |
| 122 | + model=model, |
| 123 | + params=params, |
| 124 | + tokenizer=tokenizer, |
| 125 | + cache_length=128, |
| 126 | + max_out_length=128, |
| 127 | + ) |
| 128 | + |
| 129 | + assert not chat_sampler._is_gemma4 # Confirm non-Gemma4 dispatch path. |
| 130 | + |
| 131 | + mock_sample = mock.MagicMock( |
| 132 | + return_value=_sampler.SamplerOutput( |
| 133 | + text='mock output', |
| 134 | + state=mock.MagicMock(), |
| 135 | + ) |
| 136 | + ) |
| 137 | + with mock.patch.object( |
| 138 | + type(chat_sampler), |
| 139 | + 'sampler', |
| 140 | + new_callable=lambda: property( |
| 141 | + lambda self: mock.MagicMock(sample=mock_sample) |
| 142 | + ), |
| 143 | + ): |
| 144 | + output = chat_sampler.chat('Hello world') |
| 145 | + assert isinstance(output, str) |
| 146 | + # Verify sampler.sample was called (not gemma4_sampler.sample). |
| 147 | + mock_sample.assert_called_once() |
0 commit comments