diff --git a/gemma/gm/text/__init__.py b/gemma/gm/text/__init__.py index c6f0df80..620b6cb3 100644 --- a/gemma/gm/text/__init__.py +++ b/gemma/gm/text/__init__.py @@ -33,12 +33,12 @@ from gemma.gm.text._tool_sampler import ToolSampler # Sampling methods - # TODO(mblondel): Add nucleus sampling from gemma.gm.text._sampling import SamplingMethod from gemma.gm.text._sampling import Greedy from gemma.gm.text._sampling import RandomSampling from gemma.gm.text._sampling import TopkSampling from gemma.gm.text._sampling import TopPSampling + from gemma.gm.text._sampling import MinPSampling # Other utils # from gemma.gm.text import _template as template diff --git a/gemma/gm/text/_sampling.py b/gemma/gm/text/_sampling.py index 9202d563..f1cc97e7 100644 --- a/gemma/gm/text/_sampling.py +++ b/gemma/gm/text/_sampling.py @@ -116,3 +116,41 @@ def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']: return jax.random.categorical(rng, logits, axis=-1) + +@dataclasses.dataclass(frozen=True, kw_only=True) +class MinPSampling(SamplingMethod): + """Min-p sampling. + + Min-p sampling keeps all tokens whose probability is at least + `min_p` times the probability of the most likely token. Unlike top-p + which uses an absolute cumulative threshold, min-p adapts the number + of candidate tokens based on model confidence: when the model is + confident (one token dominates), fewer alternatives are kept; when + uncertain, more tokens remain in the candidate set. + + Reference: https://arxiv.org/abs/2407.01082 + """ + + min_p: float = 0.1 + temperature: float = 1.0 + + @typechecked + def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']: + # temperature scaling + logits = logits / self.temperature + + if self.min_p > 0.0: + probs = jax.nn.softmax(logits, axis=-1) + + # Compute the threshold: min_p * max probability in each batch. + max_prob = jnp.max(probs, axis=-1, keepdims=True) + threshold = self.min_p * max_prob + + # Mask out tokens below the threshold. + logits = jnp.where( + probs < threshold, + jnp.finfo(logits.dtype).min, + logits, + ) + + return jax.random.categorical(rng, logits, axis=-1) diff --git a/gemma/gm/text/_sampling_test.py b/gemma/gm/text/_sampling_test.py index 286a49d6..1a9fc953 100644 --- a/gemma/gm/text/_sampling_test.py +++ b/gemma/gm/text/_sampling_test.py @@ -79,3 +79,39 @@ def test_top1_sampling_matches_greedy_sampling(): tokens_top1 = top1_sampling.get_next_tokens(logits, rng) np.testing.assert_array_equal(tokens_greedy, tokens_top1) + +def test_minp_sampling(): + sampling = gm.text.MinPSampling(min_p=0.1) + rng = jax.random.PRNGKey(0) + batch_size = 2 + vocab_size = 5 + logits = jax.random.normal(rng, shape=(batch_size, vocab_size)) + tokens = sampling.get_next_tokens(logits, rng) + assert tokens.shape == (batch_size,) + + +def test_minp_sampling_with_skewed_logits(): + """With highly skewed logits, min-p should only keep the dominant token.""" + sampling = gm.text.MinPSampling(min_p=0.5) + rng = jax.random.PRNGKey(0) + # Token 0 has probability ~0.99, others are negligible. + logits = jax.numpy.array([ + [10.0, 0.0, 0.0, 0.0, 0.0], + ]) + tokens = sampling.get_next_tokens(logits, rng) + # With min_p=0.5, threshold = 0.5 * 0.99 ≈ 0.49. + # Only token 0 (prob ~0.99) exceeds the threshold. + np.testing.assert_array_equal(tokens, [0]) + + +def test_minp_zero_disables_filtering(): + """With min_p=0, all tokens should be candidates (no filtering).""" + sampling_off = gm.text.MinPSampling(min_p=0.0, temperature=1.0) + sampling_rand = gm.text.RandomSampling(temperature=1.0) + rng = jax.random.PRNGKey(42) + logits = jax.random.normal(rng, shape=(100, 5)) + # With min_p=0, MinPSampling should behave like RandomSampling. + tokens_minp = sampling_off.get_next_tokens(logits, rng) + tokens_rand = sampling_rand.get_next_tokens(logits, rng) + np.testing.assert_array_equal(tokens_minp, tokens_rand) +