From fa394ca81fa219c7f4e2a04675fca14715545388 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Mon, 20 Apr 2026 14:53:05 +0530 Subject: [PATCH 01/11] feat(metrics): add Perplexity metric to NLP metrics Expose a new token-level Perplexity metric in ignite.metrics.nlp and top-level ignite.metrics, with dedicated unit tests to validate correctness and behavior. --- ignite/metrics/__init__.py | 2 + ignite/metrics/nlp/__init__.py | 2 + ignite/metrics/nlp/perplexity.py | 103 ++++++++++++++++++++ tests/ignite/metrics/nlp/test_perplexity.py | 99 +++++++++++++++++++ 4 files changed, 206 insertions(+) create mode 100644 ignite/metrics/nlp/perplexity.py create mode 100644 tests/ignite/metrics/nlp/test_perplexity.py diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index b1813cc92935..dd81e6e41f09 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -33,6 +33,7 @@ from ignite.metrics.mutual_information import MutualInformation from ignite.metrics.nlp.bleu import Bleu from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN +from ignite.metrics.nlp.perplexity import Perplexity from ignite.metrics.precision import Precision from ignite.metrics.precision_recall_curve import PrecisionRecallCurve from ignite.metrics.psnr import PSNR @@ -93,6 +94,7 @@ "Rouge", "RougeN", "RougeL", + "Perplexity", "regression", "clustering", "fairness", diff --git a/ignite/metrics/nlp/__init__.py b/ignite/metrics/nlp/__init__.py index 506f0bab51e1..d0212882b78b 100644 --- a/ignite/metrics/nlp/__init__.py +++ b/ignite/metrics/nlp/__init__.py @@ -1,8 +1,10 @@ from ignite.metrics.nlp.bleu import Bleu +from ignite.metrics.nlp.perplexity import Perplexity from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN __all__ = [ "Bleu", + "Perplexity", "Rouge", "RougeN", "RougeL", diff --git a/ignite/metrics/nlp/perplexity.py b/ignite/metrics/nlp/perplexity.py new file mode 100644 index 000000000000..97f4dbc588fb --- /dev/null +++ b/ignite/metrics/nlp/perplexity.py @@ -0,0 +1,103 @@ +from collections.abc import Callable + +import torch +import torch.nn.functional as F + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["Perplexity"] + + +class Perplexity(Metric): + r"""Calculates the `Perplexity `_ of a language model. + + .. math:: + \text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right) + + where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the + conditional probability of token :math:`w_i` given the preceding tokens. + + Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood + over all tokens. Lower perplexity indicates a better language model. + + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + - `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)`` + containing the unnormalized log-probabilities (logits). + - `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices. + + Note: + Perplexity uses token-weighted accumulation rather than batch-average to avoid bias + towards shorter sequences. The total NLL and total token count are accumulated across + all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. testcode:: + + from ignite.metrics.nlp import Perplexity + import torch + + ppl = Perplexity() + + # batch_size=2, vocab_size=5, seq_len=3 + y_pred = torch.log_softmax(torch.randn(2, 5, 3), dim=1) + y = torch.randint(0, 5, (2, 3)) + + ppl.update((y_pred, y)) + + print(type(ppl.compute())) + + .. testoutput:: + + + + .. versionadded:: 0.5.2 + """ + + _state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens") + + def __init__( + self, + output_transform: Callable = lambda x: x, + device: str | torch.device = torch.device("cpu"), + ): + super().__init__(output_transform=output_transform, device=device) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_nll = torch.tensor(0.0, dtype=torch.double, device=self._device) + self._num_tokens = torch.tensor(0, dtype=torch.long, device=self._device) + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output + + if y_pred.ndim < 2: + raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})") + + if y.ndim < 1: + raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})") + + nll = F.cross_entropy(y_pred, y, reduction="sum") + self._sum_of_nll += nll.to(self._device, dtype=torch.double) + self._num_tokens += y.numel() + + @sync_all_reduce("_sum_of_nll", "_num_tokens") + def compute(self) -> float: + if self._num_tokens == 0: + raise NotComputableError("Perplexity must have at least one example before it can be computed.") + + return torch.exp(self._sum_of_nll / self._num_tokens).item() diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py new file mode 100644 index 000000000000..59b35467b9f7 --- /dev/null +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -0,0 +1,99 @@ +import pytest +import torch +import torch.nn.functional as F + +from ignite.exceptions import NotComputableError +from ignite.metrics.nlp import Perplexity + + +def test_zero_sample(): + ppl = Perplexity() + ppl.reset() + with pytest.raises(NotComputableError): + ppl.compute() + + +def test_compute_matches_manual(): + torch.manual_seed(42) + ppl = Perplexity() + ppl.reset() + + y_pred = torch.randn(4, 10, 5) + y = torch.randint(0, 10, (4, 5)) + + ppl.update((y_pred, y)) + + nll_manual = F.cross_entropy(y_pred, y, reduction="sum").item() + ppl_manual = torch.exp(torch.tensor(nll_manual / y.numel())).item() + + assert abs(ppl.compute() - ppl_manual) < 1e-4 + + +def test_token_weighted_accumulation(): + """Token-weighted accumulation must differ from naive batch average.""" + torch.manual_seed(0) + ppl = Perplexity() + ppl.reset() + + # Two batches with different sequence lengths + b1_pred = torch.randn(2, 5, 4) + b1_y = torch.randint(0, 5, (2, 4)) + b2_pred = torch.randn(3, 5, 10) + b2_y = torch.randint(0, 5, (3, 10)) + + ppl.update((b1_pred, b1_y)) + ppl.update((b2_pred, b2_y)) + + nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item() + nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item() + total_tokens = b1_y.numel() + b2_y.numel() + ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item() + + assert abs(ppl.compute() - ppl_ref) < 1e-4 + + +def test_returns_float(): + torch.manual_seed(1) + ppl = Perplexity() + ppl.reset() + + y_pred = torch.randn(2, 5, 3) + y = torch.randint(0, 5, (2, 3)) + ppl.update((y_pred, y)) + + result = ppl.compute() + assert isinstance(result, float) + + +def test_invalid_y_pred_shape(): + ppl = Perplexity() + ppl.reset() + + with pytest.raises(ValueError, match="y_pred must be at least 2-dimensional"): + ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0]))) + + +def test_reset_clears_state(): + torch.manual_seed(2) + ppl = Perplexity() + + y_pred = torch.randn(2, 5, 3) + y = torch.randint(0, 5, (2, 3)) + ppl.update((y_pred, y)) + + ppl.reset() + with pytest.raises(NotComputableError): + ppl.compute() + + +def test_single_token(): + ppl = Perplexity() + ppl.reset() + + y_pred = torch.randn(1, 5, 1) + y = torch.randint(0, 5, (1, 1)) + ppl.update((y_pred, y)) + + result = ppl.compute() + assert result > 0 + assert isinstance(result, float) From e1e1e5fb807beab3ebc9a88943a5b3dc41726863 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Tue, 21 Apr 2026 15:45:50 +0530 Subject: [PATCH 02/11] fix(metrics): detach Perplexity accumulators and refine tests --- ignite/metrics/nlp/perplexity.py | 2 +- tests/ignite/metrics/nlp/test_perplexity.py | 24 ++++++++++----------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/ignite/metrics/nlp/perplexity.py b/ignite/metrics/nlp/perplexity.py index 97f4dbc588fb..be523a90fac5 100644 --- a/ignite/metrics/nlp/perplexity.py +++ b/ignite/metrics/nlp/perplexity.py @@ -92,7 +92,7 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})") nll = F.cross_entropy(y_pred, y, reduction="sum") - self._sum_of_nll += nll.to(self._device, dtype=torch.double) + self._sum_of_nll += nll.detach().to(self._device, dtype=torch.double) self._num_tokens += y.numel() @sync_all_reduce("_sum_of_nll", "_num_tokens") diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index 59b35467b9f7..9a10bfd8ec97 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -52,19 +52,6 @@ def test_token_weighted_accumulation(): assert abs(ppl.compute() - ppl_ref) < 1e-4 -def test_returns_float(): - torch.manual_seed(1) - ppl = Perplexity() - ppl.reset() - - y_pred = torch.randn(2, 5, 3) - y = torch.randint(0, 5, (2, 3)) - ppl.update((y_pred, y)) - - result = ppl.compute() - assert isinstance(result, float) - - def test_invalid_y_pred_shape(): ppl = Perplexity() ppl.reset() @@ -97,3 +84,14 @@ def test_single_token(): result = ppl.compute() assert result > 0 assert isinstance(result, float) + + +def test_accumulator_detached(available_device): + ppl = Perplexity(device=available_device) + y_pred = torch.randn(4, 6, 3, device=available_device, requires_grad=True) + y = torch.randint(0, 6, (4, 3), device=available_device) + + ppl.update((y_pred, y)) + + assert ppl._sum_of_nll.requires_grad is False + assert ppl._sum_of_nll.is_leaf is True From 433535943e26530aec56f6f0cf53f1f536075a88 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Tue, 21 Apr 2026 16:56:04 +0530 Subject: [PATCH 03/11] test(metrics): align Perplexity tests with metric patterns --- tests/ignite/metrics/nlp/test_perplexity.py | 174 ++++++++++++++------ 1 file changed, 123 insertions(+), 51 deletions(-) diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index 9a10bfd8ec97..8253738cb1b9 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -2,45 +2,75 @@ import torch import torch.nn.functional as F +import ignite.distributed as idist +from ignite.engine import Engine from ignite.exceptions import NotComputableError from ignite.metrics.nlp import Perplexity +torch.manual_seed(12) + def test_zero_sample(): ppl = Perplexity() - ppl.reset() - with pytest.raises(NotComputableError): + with pytest.raises(NotComputableError, match=r"Perplexity must have at least one example before it can be computed"): ppl.compute() -def test_compute_matches_manual(): - torch.manual_seed(42) +def test_invalid_y_pred_shape(): ppl = Perplexity() + with pytest.raises(ValueError, match=r"y_pred must be at least 2-dimensional"): + ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0]))) + + +def test_reset_clears_state(): + torch.manual_seed(2) + ppl = Perplexity() + + y_pred = torch.randn(2, 5, 3) + y = torch.randint(0, 5, (2, 3)) + ppl.update((y_pred, y)) ppl.reset() + with pytest.raises(NotComputableError): + ppl.compute() + + +def _reference_perplexity(y_pred, y): + """Reference implementation: token-weighted NLL.""" + nll = F.cross_entropy(y_pred, y, reduction="sum").item() + return torch.exp(torch.tensor(nll / y.numel())).item() + + +@pytest.mark.parametrize("n_times", range(3)) +def test_compute_matches_reference(n_times, available_device): + ppl = Perplexity(device=available_device) + assert ppl._device == torch.device(available_device) + + torch.manual_seed(n_times) y_pred = torch.randn(4, 10, 5) y = torch.randint(0, 10, (4, 5)) + ppl.reset() ppl.update((y_pred, y)) - nll_manual = F.cross_entropy(y_pred, y, reduction="sum").item() - ppl_manual = torch.exp(torch.tensor(nll_manual / y.numel())).item() + ref = _reference_perplexity(y_pred, y) + assert pytest.approx(ppl.compute(), abs=1e-4) == ref - assert abs(ppl.compute() - ppl_manual) < 1e-4 +@pytest.mark.parametrize("n_times", range(3)) +def test_token_weighted_accumulation(n_times, available_device): + """Token-weighted accumulation across multiple batches.""" + ppl = Perplexity(device=available_device) + assert ppl._device == torch.device(available_device) -def test_token_weighted_accumulation(): - """Token-weighted accumulation must differ from naive batch average.""" - torch.manual_seed(0) - ppl = Perplexity() - ppl.reset() + torch.manual_seed(n_times) - # Two batches with different sequence lengths b1_pred = torch.randn(2, 5, 4) b1_y = torch.randint(0, 5, (2, 4)) b2_pred = torch.randn(3, 5, 10) b2_y = torch.randint(0, 5, (3, 10)) + ppl.reset() ppl.update((b1_pred, b1_y)) ppl.update((b2_pred, b2_y)) @@ -49,49 +79,91 @@ def test_token_weighted_accumulation(): total_tokens = b1_y.numel() + b2_y.numel() ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item() - assert abs(ppl.compute() - ppl_ref) < 1e-4 + assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref -def test_invalid_y_pred_shape(): +def test_accumulator_detached(): + """Metric state tensors must be detached from the computation graph.""" ppl = Perplexity() ppl.reset() - with pytest.raises(ValueError, match="y_pred must be at least 2-dimensional"): - ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0]))) - - -def test_reset_clears_state(): - torch.manual_seed(2) - ppl = Perplexity() - - y_pred = torch.randn(2, 5, 3) + y_pred = torch.randn(2, 5, 3, requires_grad=True) y = torch.randint(0, 5, (2, 3)) ppl.update((y_pred, y)) - ppl.reset() - with pytest.raises(NotComputableError): - ppl.compute() - - -def test_single_token(): - ppl = Perplexity() - ppl.reset() - - y_pred = torch.randn(1, 5, 1) - y = torch.randint(0, 5, (1, 1)) - ppl.update((y_pred, y)) - - result = ppl.compute() - assert result > 0 - assert isinstance(result, float) - - -def test_accumulator_detached(available_device): - ppl = Perplexity(device=available_device) - y_pred = torch.randn(4, 6, 3, device=available_device, requires_grad=True) - y = torch.randint(0, 6, (4, 3), device=available_device) - - ppl.update((y_pred, y)) - - assert ppl._sum_of_nll.requires_grad is False - assert ppl._sum_of_nll.is_leaf is True + assert not ppl._sum_of_nll.requires_grad + assert not ppl._num_tokens.requires_grad + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_accumulator_device(self): + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + ppl = Perplexity(device=metric_device) + assert ppl._device == metric_device + assert ppl._sum_of_nll.device == metric_device, ( + f"{ppl._sum_of_nll.device} vs {metric_device}" + ) + + y_pred = torch.randn(2, 5, 3, device=device) + y = torch.randint(0, 5, (2, 3), device=device) + ppl.update((y_pred, y)) + + assert ppl._sum_of_nll.device == metric_device, ( + f"{ppl._sum_of_nll.device} vs {metric_device}" + ) + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration(self, n_epochs): + rank = idist.get_rank() + torch.manual_seed(10 + rank) + + n_iters = 20 + batch_size = 4 + vocab_size = 10 + seq_len = 5 + + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + y_true = torch.randint(0, vocab_size, size=(n_iters * batch_size, seq_len)).to(device) + y_preds = torch.randn(n_iters * batch_size, vocab_size, seq_len).to(device) + + def update(engine, i): + return ( + y_preds[i * batch_size: (i + 1) * batch_size], + y_true[i * batch_size: (i + 1) * batch_size], + ) + + engine = Engine(update) + ppl = Perplexity(device=metric_device) + ppl.attach(engine, "ppl") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_true_gathered = idist.all_gather(y_true) + y_preds_gathered = idist.all_gather(y_preds) + + assert "ppl" in engine.state.metrics + res = engine.state.metrics["ppl"] + + # Reference + nll = F.cross_entropy( + y_preds_gathered, + y_true_gathered, + reduction="sum" + ).item() + ref = torch.exp(torch.tensor(nll / y_true_gathered.numel())).item() + + assert pytest.approx(res, abs=1e-4) == ref From d5ab43325f17fa72f489d2f6a38a35063c766a44 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Tue, 21 Apr 2026 19:51:28 +0530 Subject: [PATCH 04/11] fix(metrics): address Perplexity review follow-ups --- ignite/metrics/nlp/perplexity.py | 4 +++- tests/ignite/metrics/nlp/test_perplexity.py | 8 +------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/ignite/metrics/nlp/perplexity.py b/ignite/metrics/nlp/perplexity.py index be523a90fac5..32adfef873de 100644 --- a/ignite/metrics/nlp/perplexity.py +++ b/ignite/metrics/nlp/perplexity.py @@ -84,6 +84,8 @@ def reset(self) -> None: @reinit__is_reduced def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output + y_pred = y_pred.detach() + y = y.detach() if y_pred.ndim < 2: raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})") @@ -92,7 +94,7 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})") nll = F.cross_entropy(y_pred, y, reduction="sum") - self._sum_of_nll += nll.detach().to(self._device, dtype=torch.double) + self._sum_of_nll += nll.to(self._device, dtype=torch.double) self._num_tokens += y.numel() @sync_all_reduce("_sum_of_nll", "_num_tokens") diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index 8253738cb1b9..abbc3fcbc549 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -158,12 +158,6 @@ def update(engine, i): assert "ppl" in engine.state.metrics res = engine.state.metrics["ppl"] - # Reference - nll = F.cross_entropy( - y_preds_gathered, - y_true_gathered, - reduction="sum" - ).item() - ref = torch.exp(torch.tensor(nll / y_true_gathered.numel())).item() + ref = _reference_perplexity(y_preds_gathered, y_true_gathered) assert pytest.approx(res, abs=1e-4) == ref From dae37e995d653a6006991f6325480d4c01299c24 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Tue, 21 Apr 2026 21:28:36 +0530 Subject: [PATCH 05/11] test(metrics): use _reference_perplexity in token-weighted accumulation test --- tests/ignite/metrics/nlp/test_perplexity.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index abbc3fcbc549..da45e4d060bf 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -74,10 +74,9 @@ def test_token_weighted_accumulation(n_times, available_device): ppl.update((b1_pred, b1_y)) ppl.update((b2_pred, b2_y)) - nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item() - nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item() - total_tokens = b1_y.numel() + b2_y.numel() - ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item() + combined_pred = torch.cat([b1_pred, b2_pred], dim=0) + combined_y = torch.cat([b1_y, b2_y], dim=0) + ppl_ref = _reference_perplexity(combined_pred, combined_y) assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref From f9ecaa14cb9d8e24c3eae94be4ca26666bf51870 Mon Sep 17 00:00:00 2001 From: Steaphen Date: Thu, 23 Apr 2026 14:07:50 +0530 Subject: [PATCH 06/11] Update tests/ignite/metrics/nlp/test_perplexity.py Co-authored-by: vfdev --- tests/ignite/metrics/nlp/test_perplexity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index da45e4d060bf..fc982bf87677 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -37,8 +37,8 @@ def test_reset_clears_state(): def _reference_perplexity(y_pred, y): """Reference implementation: token-weighted NLL.""" - nll = F.cross_entropy(y_pred, y, reduction="sum").item() - return torch.exp(torch.tensor(nll / y.numel())).item() + nll = F.cross_entropy(y_pred, y, reduction="sum") + return torch.exp(nll / y.numel()).item() @pytest.mark.parametrize("n_times", range(3)) From e5c0cfd6a596876b2c7ee68c1969eed5e99493b5 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Thu, 23 Apr 2026 16:00:08 +0530 Subject: [PATCH 07/11] feat(metrics): add ignore_index to Perplexity, expose in docs, remove trivial test --- docs/source/metrics.rst | 1 + ignite/metrics/nlp/perplexity.py | 6 ++++-- tests/ignite/metrics/nlp/test_perplexity.py | 13 ------------- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 2261498c8be0..6517db169e19 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -351,6 +351,7 @@ Complete list of metrics SSIM TopKCategoricalAccuracy Bleu + Perplexity Rouge RougeL RougeN diff --git a/ignite/metrics/nlp/perplexity.py b/ignite/metrics/nlp/perplexity.py index 32adfef873de..35dd7b81b5ef 100644 --- a/ignite/metrics/nlp/perplexity.py +++ b/ignite/metrics/nlp/perplexity.py @@ -73,7 +73,9 @@ def __init__( self, output_transform: Callable = lambda x: x, device: str | torch.device = torch.device("cpu"), + ignore_index: int = -100, ): + self._ignore_index = ignore_index super().__init__(output_transform=output_transform, device=device) @reinit__is_reduced @@ -93,9 +95,9 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: if y.ndim < 1: raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})") - nll = F.cross_entropy(y_pred, y, reduction="sum") + nll = F.cross_entropy(y_pred, y, reduction="sum", ignore_index=self._ignore_index) self._sum_of_nll += nll.to(self._device, dtype=torch.double) - self._num_tokens += y.numel() + self._num_tokens += (y != self._ignore_index).sum() @sync_all_reduce("_sum_of_nll", "_num_tokens") def compute(self) -> float: diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index fc982bf87677..45b80fbe946d 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -81,19 +81,6 @@ def test_token_weighted_accumulation(n_times, available_device): assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref -def test_accumulator_detached(): - """Metric state tensors must be detached from the computation graph.""" - ppl = Perplexity() - ppl.reset() - - y_pred = torch.randn(2, 5, 3, requires_grad=True) - y = torch.randint(0, 5, (2, 3)) - ppl.update((y_pred, y)) - - assert not ppl._sum_of_nll.requires_grad - assert not ppl._num_tokens.requires_grad - - @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.usefixtures("distributed") From fa8fb7f64d23231495a2703ba059aa8e8e001926 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Thu, 23 Apr 2026 21:15:01 +0530 Subject: [PATCH 08/11] style: fix ruff formatting in test_perplexity.py --- tests/ignite/metrics/nlp/test_perplexity.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index 45b80fbe946d..39bdce101dc9 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -12,7 +12,9 @@ def test_zero_sample(): ppl = Perplexity() - with pytest.raises(NotComputableError, match=r"Perplexity must have at least one example before it can be computed"): + with pytest.raises( + NotComputableError, match=r"Perplexity must have at least one example before it can be computed" + ): ppl.compute() @@ -94,17 +96,13 @@ def test_accumulator_device(self): for metric_device in metric_devices: ppl = Perplexity(device=metric_device) assert ppl._device == metric_device - assert ppl._sum_of_nll.device == metric_device, ( - f"{ppl._sum_of_nll.device} vs {metric_device}" - ) + assert ppl._sum_of_nll.device == metric_device, f"{ppl._sum_of_nll.device} vs {metric_device}" y_pred = torch.randn(2, 5, 3, device=device) y = torch.randint(0, 5, (2, 3), device=device) ppl.update((y_pred, y)) - assert ppl._sum_of_nll.device == metric_device, ( - f"{ppl._sum_of_nll.device} vs {metric_device}" - ) + assert ppl._sum_of_nll.device == metric_device, f"{ppl._sum_of_nll.device} vs {metric_device}" @pytest.mark.parametrize("n_epochs", [1, 2]) def test_integration(self, n_epochs): @@ -127,8 +125,8 @@ def test_integration(self, n_epochs): def update(engine, i): return ( - y_preds[i * batch_size: (i + 1) * batch_size], - y_true[i * batch_size: (i + 1) * batch_size], + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], ) engine = Engine(update) From c7a3720b68fa8625919941c762e048aaabdd3091 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Thu, 23 Apr 2026 21:55:07 +0530 Subject: [PATCH 09/11] fix(tests): fix token weighted accumulation test with different seq lengths --- tests/ignite/metrics/nlp/test_perplexity.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index 39bdce101dc9..feff1752ee38 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -76,9 +76,10 @@ def test_token_weighted_accumulation(n_times, available_device): ppl.update((b1_pred, b1_y)) ppl.update((b2_pred, b2_y)) - combined_pred = torch.cat([b1_pred, b2_pred], dim=0) - combined_y = torch.cat([b1_y, b2_y], dim=0) - ppl_ref = _reference_perplexity(combined_pred, combined_y) + nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item() + nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item() + total_tokens = b1_y.numel() + b2_y.numel() + ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item() assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref From 3143650ef86f52e5f1f6d750ca43c19befb39aa7 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Thu, 23 Apr 2026 22:09:59 +0530 Subject: [PATCH 10/11] fix(tests): use _reference_perplexity and matching seq lengths in accumulation test --- tests/ignite/metrics/nlp/test_perplexity.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py index feff1752ee38..26c207c03815 100644 --- a/tests/ignite/metrics/nlp/test_perplexity.py +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -69,17 +69,16 @@ def test_token_weighted_accumulation(n_times, available_device): b1_pred = torch.randn(2, 5, 4) b1_y = torch.randint(0, 5, (2, 4)) - b2_pred = torch.randn(3, 5, 10) - b2_y = torch.randint(0, 5, (3, 10)) + b2_pred = torch.randn(3, 5, 4) + b2_y = torch.randint(0, 5, (3, 4)) ppl.reset() ppl.update((b1_pred, b1_y)) ppl.update((b2_pred, b2_y)) - nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item() - nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item() - total_tokens = b1_y.numel() + b2_y.numel() - ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item() + combined_pred = torch.cat([b1_pred, b2_pred], dim=0) + combined_y = torch.cat([b1_y, b2_y], dim=0) + ppl_ref = _reference_perplexity(combined_pred, combined_y) assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref From b514e1229ba76a37bee9d3cc9dfd29456e34fef8 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Sat, 25 Apr 2026 02:33:19 +0530 Subject: [PATCH 11/11] fix(metrics): remove explicit double dtype from Perplexity accumulators for MPS compatibility --- ignite/metrics/nlp/perplexity.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/nlp/perplexity.py b/ignite/metrics/nlp/perplexity.py index 35dd7b81b5ef..45c4ddf98138 100644 --- a/ignite/metrics/nlp/perplexity.py +++ b/ignite/metrics/nlp/perplexity.py @@ -80,8 +80,8 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self._sum_of_nll = torch.tensor(0.0, dtype=torch.double, device=self._device) - self._num_tokens = torch.tensor(0, dtype=torch.long, device=self._device) + self._sum_of_nll = torch.tensor(0.0, device=self._device) + self._num_tokens = torch.tensor(0, device=self._device) @reinit__is_reduced def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: @@ -96,7 +96,7 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})") nll = F.cross_entropy(y_pred, y, reduction="sum", ignore_index=self._ignore_index) - self._sum_of_nll += nll.to(self._device, dtype=torch.double) + self._sum_of_nll += nll.to(self._device) self._num_tokens += (y != self._ignore_index).sum() @sync_all_reduce("_sum_of_nll", "_num_tokens")