Skip to content

Commit cdfab1a

Browse files
Clamp top_k to vocab size in sample_logits and add edge-case tests (#1347)
sample_logits crashed with 'selected index k out of range' when top_k exceeded the vocabulary size (reachable via model.generate(top_k=...)). Clamp top_k to the vocab size, matching HuggingFace's TopKLogitsWarper. Also add the first unit tests for sample_logits (top_k clamping, greedy temperature=0, top_p, frequency penalty, repetition penalty), resolving the standing in-code TODO.
1 parent 08c9ef9 commit cdfab1a

2 files changed

Lines changed: 103 additions & 5 deletions

File tree

tests/unit/utilities/test_logits_utils.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import pytest
55
import torch
66

7-
from transformer_lens.utilities.logits_utils import logits_to_df
7+
from transformer_lens.utilities.logits_utils import (
8+
_apply_repetition_penalty,
9+
logits_to_df,
10+
sample_logits,
11+
)
812

913

1014
class _StubTokenizer:
@@ -85,3 +89,96 @@ def test_rejects_non_1d_input(self):
8589

8690
with pytest.raises(BeartypeCallHintParamViolation):
8791
logits_to_df(torch.zeros(3, 4))
92+
93+
94+
class TestSampleLogitsTopK:
95+
def test_top_k_larger_than_vocab_does_not_crash(self):
96+
# Regression test: before clamping top_k, final_logits.topk(top_k)
97+
# raised "selected index k out of range" when top_k > vocab size.
98+
out = sample_logits(torch.randn(1, 3), top_k=10)
99+
assert out.shape == (1,)
100+
assert 0 <= out.item() < 3
101+
102+
def test_top_k_larger_than_vocab_batched(self):
103+
out = sample_logits(torch.randn(4, 5), top_k=8)
104+
assert out.shape == (4,)
105+
assert torch.all((out >= 0) & (out < 5))
106+
107+
def test_top_k_equal_to_vocab(self):
108+
out = sample_logits(torch.randn(1, 4), top_k=4)
109+
assert out.shape == (1,)
110+
assert 0 <= out.item() < 4
111+
112+
def test_top_k_restricts_to_dominant_token(self):
113+
# With top_k=1 only the argmax token is ever sampled.
114+
logits = torch.tensor([[0.0, 100.0, 0.0, 0.0]])
115+
outs = [sample_logits(logits, top_k=1).item() for _ in range(20)]
116+
assert set(outs) == {1}
117+
118+
def test_top_k_rejects_non_positive(self):
119+
with pytest.raises(AssertionError):
120+
sample_logits(torch.randn(1, 4), top_k=0)
121+
122+
123+
class TestSampleLogitsTemperature:
124+
def test_temperature_zero_is_greedy_argmax(self):
125+
logits = torch.tensor([[1.0, 3.0, 2.0, 0.5]])
126+
out = sample_logits(logits, temperature=0.0)
127+
assert out.tolist() == [1]
128+
129+
def test_temperature_zero_batched_argmax(self):
130+
logits = torch.tensor([[1.0, 3.0, 2.0], [5.0, 0.0, 1.0]])
131+
out = sample_logits(logits, temperature=0.0)
132+
assert out.tolist() == [1, 0]
133+
134+
def test_temperature_zero_applies_repetition_penalty(self):
135+
# Token 1 is the argmax but has appeared, so the penalty should push
136+
# the greedy choice onto the next-best unseen token (token 2).
137+
logits = torch.tensor([[0.0, 10.0, 9.0, 0.0]])
138+
tokens = torch.tensor([[1]])
139+
out = sample_logits(logits, temperature=0.0, repetition_penalty=100.0, tokens=tokens)
140+
assert out.tolist() == [2]
141+
142+
143+
class TestSampleLogitsTopP:
144+
def test_top_p_keeps_dominant_token(self):
145+
# One token holds essentially all the probability mass, so even a small
146+
# top_p must keep it and it is the only token ever sampled.
147+
logits = torch.tensor([[0.0, 50.0, 0.0, 0.0]])
148+
outs = [sample_logits(logits, top_p=0.5).item() for _ in range(20)]
149+
assert set(outs) == {1}
150+
151+
def test_top_p_rejects_out_of_range(self):
152+
with pytest.raises(AssertionError):
153+
sample_logits(torch.randn(1, 4), top_p=0.0)
154+
with pytest.raises(AssertionError):
155+
sample_logits(torch.randn(1, 4), top_p=1.5)
156+
157+
158+
class TestSampleLogitsFreqPenalty:
159+
def test_freq_penalty_suppresses_repeated_token(self):
160+
# Token 0 starts as the clear favourite, but appears many times in the
161+
# context; a large frequency penalty should make it never get sampled.
162+
logits = torch.tensor([[5.0, 4.0, 4.0, 4.0]])
163+
tokens = torch.zeros((1, 50), dtype=torch.long) # token 0 repeated 50x
164+
outs = [sample_logits(logits, freq_penalty=10.0, tokens=tokens).item() for _ in range(50)]
165+
assert 0 not in outs
166+
167+
def test_freq_penalty_requires_tokens(self):
168+
with pytest.raises(AssertionError):
169+
sample_logits(torch.randn(1, 4), freq_penalty=1.0)
170+
171+
172+
class TestApplyRepetitionPenalty:
173+
def test_positive_logits_divided_negative_multiplied(self):
174+
logits = torch.tensor([[2.0, -2.0, 0.0]])
175+
tokens = torch.tensor([[0, 1]])
176+
out = _apply_repetition_penalty(logits, tokens, penalty=2.0)
177+
# token 0 positive -> divided; token 1 negative -> multiplied; token 2 untouched
178+
assert out.tolist() == [[1.0, -4.0, 0.0]]
179+
180+
def test_does_not_mutate_input(self):
181+
logits = torch.tensor([[2.0, -2.0, 0.0]])
182+
original = logits.clone()
183+
_apply_repetition_penalty(logits, torch.tensor([[0, 1]]), penalty=2.0)
184+
assert torch.equal(logits, original)

transformer_lens/utilities/logits_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,7 @@ def sample_logits(
9696
9797
Repetition penalty (HuggingFace-style) divides positive logits by the penalty value and multiplies negative logits by it for any token that has appeared in the sequence. A value of 1.0 means no penalty. Values > 1.0 discourage repetition. This is applied before temperature scaling.
9898
99-
#! TODO: Finish testing all the edge cases here. Useful testing code:
100-
logits = torch.randn(4)
101-
print(logits)
102-
np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)
99+
When ``top_k`` exceeds the vocabulary size it is clamped to the vocabulary size (matching HuggingFace), rather than raising an error.
103100
"""
104101
if temperature == 0.0:
105102
# Greedy sampling - still apply repetition penalty before argmax
@@ -128,6 +125,10 @@ def sample_logits(
128125
)
129126
if top_k is not None:
130127
assert top_k > 0, "top_k has to be greater than 0"
128+
# Clamp top_k to the vocab size so a large value does not raise
129+
# "selected index k out of range" (matches HuggingFace's
130+
# TopKLogitsWarper, which does top_k = min(top_k, logits.size(-1))).
131+
top_k = min(top_k, final_logits.shape[-1])
131132
top_logits, top_idx = final_logits.topk(top_k, dim=-1)
132133
indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
133134
final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))

0 commit comments

Comments
 (0)