|
4 | 4 | import pytest |
5 | 5 | import torch |
6 | 6 |
|
7 | | -from transformer_lens.utilities.logits_utils import logits_to_df |
| 7 | +from transformer_lens.utilities.logits_utils import ( |
| 8 | + _apply_repetition_penalty, |
| 9 | + logits_to_df, |
| 10 | + sample_logits, |
| 11 | +) |
8 | 12 |
|
9 | 13 |
|
10 | 14 | class _StubTokenizer: |
@@ -85,3 +89,96 @@ def test_rejects_non_1d_input(self): |
85 | 89 |
|
86 | 90 | with pytest.raises(BeartypeCallHintParamViolation): |
87 | 91 | logits_to_df(torch.zeros(3, 4)) |
| 92 | + |
| 93 | + |
| 94 | +class TestSampleLogitsTopK: |
| 95 | + def test_top_k_larger_than_vocab_does_not_crash(self): |
| 96 | + # Regression test: before clamping top_k, final_logits.topk(top_k) |
| 97 | + # raised "selected index k out of range" when top_k > vocab size. |
| 98 | + out = sample_logits(torch.randn(1, 3), top_k=10) |
| 99 | + assert out.shape == (1,) |
| 100 | + assert 0 <= out.item() < 3 |
| 101 | + |
| 102 | + def test_top_k_larger_than_vocab_batched(self): |
| 103 | + out = sample_logits(torch.randn(4, 5), top_k=8) |
| 104 | + assert out.shape == (4,) |
| 105 | + assert torch.all((out >= 0) & (out < 5)) |
| 106 | + |
| 107 | + def test_top_k_equal_to_vocab(self): |
| 108 | + out = sample_logits(torch.randn(1, 4), top_k=4) |
| 109 | + assert out.shape == (1,) |
| 110 | + assert 0 <= out.item() < 4 |
| 111 | + |
| 112 | + def test_top_k_restricts_to_dominant_token(self): |
| 113 | + # With top_k=1 only the argmax token is ever sampled. |
| 114 | + logits = torch.tensor([[0.0, 100.0, 0.0, 0.0]]) |
| 115 | + outs = [sample_logits(logits, top_k=1).item() for _ in range(20)] |
| 116 | + assert set(outs) == {1} |
| 117 | + |
| 118 | + def test_top_k_rejects_non_positive(self): |
| 119 | + with pytest.raises(AssertionError): |
| 120 | + sample_logits(torch.randn(1, 4), top_k=0) |
| 121 | + |
| 122 | + |
| 123 | +class TestSampleLogitsTemperature: |
| 124 | + def test_temperature_zero_is_greedy_argmax(self): |
| 125 | + logits = torch.tensor([[1.0, 3.0, 2.0, 0.5]]) |
| 126 | + out = sample_logits(logits, temperature=0.0) |
| 127 | + assert out.tolist() == [1] |
| 128 | + |
| 129 | + def test_temperature_zero_batched_argmax(self): |
| 130 | + logits = torch.tensor([[1.0, 3.0, 2.0], [5.0, 0.0, 1.0]]) |
| 131 | + out = sample_logits(logits, temperature=0.0) |
| 132 | + assert out.tolist() == [1, 0] |
| 133 | + |
| 134 | + def test_temperature_zero_applies_repetition_penalty(self): |
| 135 | + # Token 1 is the argmax but has appeared, so the penalty should push |
| 136 | + # the greedy choice onto the next-best unseen token (token 2). |
| 137 | + logits = torch.tensor([[0.0, 10.0, 9.0, 0.0]]) |
| 138 | + tokens = torch.tensor([[1]]) |
| 139 | + out = sample_logits(logits, temperature=0.0, repetition_penalty=100.0, tokens=tokens) |
| 140 | + assert out.tolist() == [2] |
| 141 | + |
| 142 | + |
| 143 | +class TestSampleLogitsTopP: |
| 144 | + def test_top_p_keeps_dominant_token(self): |
| 145 | + # One token holds essentially all the probability mass, so even a small |
| 146 | + # top_p must keep it and it is the only token ever sampled. |
| 147 | + logits = torch.tensor([[0.0, 50.0, 0.0, 0.0]]) |
| 148 | + outs = [sample_logits(logits, top_p=0.5).item() for _ in range(20)] |
| 149 | + assert set(outs) == {1} |
| 150 | + |
| 151 | + def test_top_p_rejects_out_of_range(self): |
| 152 | + with pytest.raises(AssertionError): |
| 153 | + sample_logits(torch.randn(1, 4), top_p=0.0) |
| 154 | + with pytest.raises(AssertionError): |
| 155 | + sample_logits(torch.randn(1, 4), top_p=1.5) |
| 156 | + |
| 157 | + |
| 158 | +class TestSampleLogitsFreqPenalty: |
| 159 | + def test_freq_penalty_suppresses_repeated_token(self): |
| 160 | + # Token 0 starts as the clear favourite, but appears many times in the |
| 161 | + # context; a large frequency penalty should make it never get sampled. |
| 162 | + logits = torch.tensor([[5.0, 4.0, 4.0, 4.0]]) |
| 163 | + tokens = torch.zeros((1, 50), dtype=torch.long) # token 0 repeated 50x |
| 164 | + outs = [sample_logits(logits, freq_penalty=10.0, tokens=tokens).item() for _ in range(50)] |
| 165 | + assert 0 not in outs |
| 166 | + |
| 167 | + def test_freq_penalty_requires_tokens(self): |
| 168 | + with pytest.raises(AssertionError): |
| 169 | + sample_logits(torch.randn(1, 4), freq_penalty=1.0) |
| 170 | + |
| 171 | + |
| 172 | +class TestApplyRepetitionPenalty: |
| 173 | + def test_positive_logits_divided_negative_multiplied(self): |
| 174 | + logits = torch.tensor([[2.0, -2.0, 0.0]]) |
| 175 | + tokens = torch.tensor([[0, 1]]) |
| 176 | + out = _apply_repetition_penalty(logits, tokens, penalty=2.0) |
| 177 | + # token 0 positive -> divided; token 1 negative -> multiplied; token 2 untouched |
| 178 | + assert out.tolist() == [[1.0, -4.0, 0.0]] |
| 179 | + |
| 180 | + def test_does_not_mutate_input(self): |
| 181 | + logits = torch.tensor([[2.0, -2.0, 0.0]]) |
| 182 | + original = logits.clone() |
| 183 | + _apply_repetition_penalty(logits, torch.tensor([[0, 1]]), penalty=2.0) |
| 184 | + assert torch.equal(logits, original) |
0 commit comments