Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gemma/gm/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from gemma.gm.text._tool_sampler import ToolSampler

# Sampling methods
# TODO(mblondel): Add nucleus sampling
from gemma.gm.text._sampling import SamplingMethod
from gemma.gm.text._sampling import Greedy
from gemma.gm.text._sampling import RandomSampling
from gemma.gm.text._sampling import TopkSampling
from gemma.gm.text._sampling import TopPSampling
from gemma.gm.text._sampling import MinPSampling

# Other utils
# from gemma.gm.text import _template as template
38 changes: 38 additions & 0 deletions gemma/gm/text/_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,41 @@ def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']:

return jax.random.categorical(rng, logits, axis=-1)


@dataclasses.dataclass(frozen=True, kw_only=True)
class MinPSampling(SamplingMethod):
"""Min-p sampling.

Min-p sampling keeps all tokens whose probability is at least
`min_p` times the probability of the most likely token. Unlike top-p
which uses an absolute cumulative threshold, min-p adapts the number
of candidate tokens based on model confidence: when the model is
confident (one token dominates), fewer alternatives are kept; when
uncertain, more tokens remain in the candidate set.

Reference: https://arxiv.org/abs/2407.01082
"""

min_p: float = 0.1
temperature: float = 1.0

@typechecked
def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']:
# temperature scaling
logits = logits / self.temperature

if self.min_p > 0.0:
probs = jax.nn.softmax(logits, axis=-1)

# Compute the threshold: min_p * max probability in each batch.
max_prob = jnp.max(probs, axis=-1, keepdims=True)
threshold = self.min_p * max_prob

# Mask out tokens below the threshold.
logits = jnp.where(
probs < threshold,
jnp.finfo(logits.dtype).min,
logits,
)

return jax.random.categorical(rng, logits, axis=-1)
36 changes: 36 additions & 0 deletions gemma/gm/text/_sampling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,39 @@ def test_top1_sampling_matches_greedy_sampling():
tokens_top1 = top1_sampling.get_next_tokens(logits, rng)
np.testing.assert_array_equal(tokens_greedy, tokens_top1)


def test_minp_sampling():
sampling = gm.text.MinPSampling(min_p=0.1)
rng = jax.random.PRNGKey(0)
batch_size = 2
vocab_size = 5
logits = jax.random.normal(rng, shape=(batch_size, vocab_size))
tokens = sampling.get_next_tokens(logits, rng)
assert tokens.shape == (batch_size,)


def test_minp_sampling_with_skewed_logits():
"""With highly skewed logits, min-p should only keep the dominant token."""
sampling = gm.text.MinPSampling(min_p=0.5)
rng = jax.random.PRNGKey(0)
# Token 0 has probability ~0.99, others are negligible.
logits = jax.numpy.array([
[10.0, 0.0, 0.0, 0.0, 0.0],
])
tokens = sampling.get_next_tokens(logits, rng)
# With min_p=0.5, threshold = 0.5 * 0.99 ≈ 0.49.
# Only token 0 (prob ~0.99) exceeds the threshold.
np.testing.assert_array_equal(tokens, [0])


def test_minp_zero_disables_filtering():
"""With min_p=0, all tokens should be candidates (no filtering)."""
sampling_off = gm.text.MinPSampling(min_p=0.0, temperature=1.0)
sampling_rand = gm.text.RandomSampling(temperature=1.0)
rng = jax.random.PRNGKey(42)
logits = jax.random.normal(rng, shape=(100, 5))
# With min_p=0, MinPSampling should behave like RandomSampling.
tokens_minp = sampling_off.get_next_tokens(logits, rng)
tokens_rand = sampling_rand.get_next_tokens(logits, rng)
np.testing.assert_array_equal(tokens_minp, tokens_rand)