Skip to content

Commit 8453d4e

Browse files
committed
feat(metrics): add Perplexity metric to NLP metrics
Expose a new token-level Perplexity metric in ignite.metrics.nlp and top-level ignite.metrics, with dedicated unit tests to validate correctness and behavior. Made-with: Cursor
1 parent c0ceca5 commit 8453d4e

4 files changed

Lines changed: 206 additions & 0 deletions

File tree

ignite/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ignite.metrics.mutual_information import MutualInformation
3434
from ignite.metrics.nlp.bleu import Bleu
3535
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
36+
from ignite.metrics.nlp.perplexity import Perplexity
3637
from ignite.metrics.precision import Precision
3738
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve
3839
from ignite.metrics.psnr import PSNR
@@ -93,6 +94,7 @@
9394
"Rouge",
9495
"RougeN",
9596
"RougeL",
97+
"Perplexity",
9698
"regression",
9799
"clustering",
98100
"fairness",

ignite/metrics/nlp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from ignite.metrics.nlp.bleu import Bleu
2+
from ignite.metrics.nlp.perplexity import Perplexity
23
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
34

45
__all__ = [
56
"Bleu",
7+
"Perplexity",
68
"Rouge",
79
"RougeN",
810
"RougeL",

ignite/metrics/nlp/perplexity.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from collections.abc import Callable
2+
3+
import torch
4+
import torch.nn.functional as F
5+
6+
from ignite.exceptions import NotComputableError
7+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
8+
9+
__all__ = ["Perplexity"]
10+
11+
12+
class Perplexity(Metric):
13+
r"""Calculates the `Perplexity <https://en.wikipedia.org/wiki/Perplexity>`_ of a language model.
14+
15+
.. math::
16+
\text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right)
17+
18+
where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the
19+
conditional probability of token :math:`w_i` given the preceding tokens.
20+
21+
Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood
22+
over all tokens. Lower perplexity indicates a better language model.
23+
24+
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
25+
- `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)``
26+
containing the unnormalized log-probabilities (logits).
27+
- `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices.
28+
29+
Note:
30+
Perplexity uses token-weighted accumulation rather than batch-average to avoid bias
31+
towards shorter sequences. The total NLL and total token count are accumulated across
32+
all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``.
33+
34+
Args:
35+
output_transform: a callable that is used to transform the
36+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
37+
form expected by the metric. This can be useful if, for example, you have a multi-output model and
38+
you want to compute the metric with respect to one of the outputs.
39+
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
40+
device: specifies which device updates are accumulated on. Setting the
41+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
42+
non-blocking. By default, CPU.
43+
44+
Examples:
45+
46+
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
47+
48+
.. testcode::
49+
50+
from ignite.metrics.nlp import Perplexity
51+
import torch
52+
53+
ppl = Perplexity()
54+
55+
# batch_size=2, vocab_size=5, seq_len=3
56+
y_pred = torch.log_softmax(torch.randn(2, 5, 3), dim=1)
57+
y = torch.randint(0, 5, (2, 3))
58+
59+
ppl.update((y_pred, y))
60+
61+
print(type(ppl.compute()))
62+
63+
.. testoutput::
64+
65+
<class 'float'>
66+
67+
.. versionadded:: 0.5.2
68+
"""
69+
70+
_state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens")
71+
72+
def __init__(
73+
self,
74+
output_transform: Callable = lambda x: x,
75+
device: str | torch.device = torch.device("cpu"),
76+
):
77+
super().__init__(output_transform=output_transform, device=device)
78+
79+
@reinit__is_reduced
80+
def reset(self) -> None:
81+
self._sum_of_nll = torch.tensor(0.0, dtype=torch.double, device=self._device)
82+
self._num_tokens = torch.tensor(0, dtype=torch.long, device=self._device)
83+
84+
@reinit__is_reduced
85+
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
86+
y_pred, y = output
87+
88+
if y_pred.ndim < 2:
89+
raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})")
90+
91+
if y.ndim < 1:
92+
raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})")
93+
94+
nll = F.cross_entropy(y_pred, y, reduction="sum")
95+
self._sum_of_nll += nll.to(self._device, dtype=torch.double)
96+
self._num_tokens += y.numel()
97+
98+
@sync_all_reduce("_sum_of_nll", "_num_tokens")
99+
def compute(self) -> float:
100+
if self._num_tokens == 0:
101+
raise NotComputableError("Perplexity must have at least one example before it can be computed.")
102+
103+
return torch.exp(self._sum_of_nll / self._num_tokens).item()
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
from ignite.exceptions import NotComputableError
6+
from ignite.metrics.nlp import Perplexity
7+
8+
9+
def test_zero_sample():
10+
ppl = Perplexity()
11+
ppl.reset()
12+
with pytest.raises(NotComputableError):
13+
ppl.compute()
14+
15+
16+
def test_compute_matches_manual():
17+
torch.manual_seed(42)
18+
ppl = Perplexity()
19+
ppl.reset()
20+
21+
y_pred = torch.randn(4, 10, 5)
22+
y = torch.randint(0, 10, (4, 5))
23+
24+
ppl.update((y_pred, y))
25+
26+
nll_manual = F.cross_entropy(y_pred, y, reduction="sum").item()
27+
ppl_manual = torch.exp(torch.tensor(nll_manual / y.numel())).item()
28+
29+
assert abs(ppl.compute() - ppl_manual) < 1e-4
30+
31+
32+
def test_token_weighted_accumulation():
33+
"""Token-weighted accumulation must differ from naive batch average."""
34+
torch.manual_seed(0)
35+
ppl = Perplexity()
36+
ppl.reset()
37+
38+
# Two batches with different sequence lengths
39+
b1_pred = torch.randn(2, 5, 4)
40+
b1_y = torch.randint(0, 5, (2, 4))
41+
b2_pred = torch.randn(3, 5, 10)
42+
b2_y = torch.randint(0, 5, (3, 10))
43+
44+
ppl.update((b1_pred, b1_y))
45+
ppl.update((b2_pred, b2_y))
46+
47+
nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item()
48+
nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item()
49+
total_tokens = b1_y.numel() + b2_y.numel()
50+
ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item()
51+
52+
assert abs(ppl.compute() - ppl_ref) < 1e-4
53+
54+
55+
def test_returns_float():
56+
torch.manual_seed(1)
57+
ppl = Perplexity()
58+
ppl.reset()
59+
60+
y_pred = torch.randn(2, 5, 3)
61+
y = torch.randint(0, 5, (2, 3))
62+
ppl.update((y_pred, y))
63+
64+
result = ppl.compute()
65+
assert isinstance(result, float)
66+
67+
68+
def test_invalid_y_pred_shape():
69+
ppl = Perplexity()
70+
ppl.reset()
71+
72+
with pytest.raises(ValueError, match="y_pred must be at least 2-dimensional"):
73+
ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0])))
74+
75+
76+
def test_reset_clears_state():
77+
torch.manual_seed(2)
78+
ppl = Perplexity()
79+
80+
y_pred = torch.randn(2, 5, 3)
81+
y = torch.randint(0, 5, (2, 3))
82+
ppl.update((y_pred, y))
83+
84+
ppl.reset()
85+
with pytest.raises(NotComputableError):
86+
ppl.compute()
87+
88+
89+
def test_single_token():
90+
ppl = Perplexity()
91+
ppl.reset()
92+
93+
y_pred = torch.randn(1, 5, 1)
94+
y = torch.randint(0, 5, (1, 1))
95+
ppl.update((y_pred, y))
96+
97+
result = ppl.compute()
98+
assert result > 0
99+
assert isinstance(result, float)

0 commit comments

Comments
 (0)