From 65d755775c32c5b920d7879ed1eb01527b381c69 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Tue, 3 Mar 2026 17:53:50 +0530 Subject: [PATCH 01/14] Add NDCG metric to rec_sys --- docs/source/metrics.rst | 2 +- ignite/metrics/__init__.py | 2 + ignite/metrics/rec_sys/__init__.py | 3 + ignite/metrics/rec_sys/ndcg.py | 216 ++++++++++++++++++++ tests/ignite/metrics/rec_sys/test_ndcg.py | 232 ++++++++++++++++++++++ 5 files changed, 454 insertions(+), 1 deletion(-) create mode 100644 ignite/metrics/rec_sys/ndcg.py create mode 100644 tests/ignite/metrics/rec_sys/test_ndcg.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 3941f39345b7..e7cb2d9c136d 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -391,7 +391,7 @@ Complete list of metrics clustering.DaviesBouldinScore clustering.CalinskiHarabaszScore rec_sys.HitRate - + rec_sys.NDGC .. note:: diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index c26660d58fe3..c6d973177203 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -37,6 +37,7 @@ from ignite.metrics.psnr import PSNR from ignite.metrics.recall import Recall from ignite.metrics.rec_sys.hitrate import HitRate +from ignite.metrics.rec_sys.ndcg import NDCG from ignite.metrics.roc_auc import ROC_AUC, RocCurve from ignite.metrics.root_mean_squared_error import RootMeanSquaredError from ignite.metrics.running_average import RunningAverage @@ -104,4 +105,5 @@ "CommonObjectDetectionMetrics", "coco_tensor_list_to_dict_list", "HitRate", + "NDGC", ] diff --git a/ignite/metrics/rec_sys/__init__.py b/ignite/metrics/rec_sys/__init__.py index f6f37785cb4e..6fa17fdca928 100644 --- a/ignite/metrics/rec_sys/__init__.py +++ b/ignite/metrics/rec_sys/__init__.py @@ -1 +1,4 @@ from ignite.metrics.rec_sys.hitrate import HitRate +from ignite.metrics.rec_sys.ndcg import NDCG + +__all__ = ["HitRate", "NDCG"] diff --git a/ignite/metrics/rec_sys/ndcg.py b/ignite/metrics/rec_sys/ndcg.py new file mode 100644 index 000000000000..a5f38acd1a31 --- /dev/null +++ b/ignite/metrics/rec_sys/ndcg.py @@ -0,0 +1,216 @@ +from typing import Callable + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["NDCG"] + + +class NDCG(Metric): + r"""Calculates the Normalized Discounted Cumulative Gain (NDCG) at `k` for Recommendation Systems. + + NDCG measures the quality of ranking by considering both the relevance of items and their + positions in the ranked list. It compares the achieved DCG against the ideal DCG (IDCG) + obtained by sorting items by their true relevance. + + .. math:: + \text{NDCG}@K = \frac{\text{DCG}@K}{\text{IDCG}@K} + + where: + + .. math:: + \text{DCG}@K = \sum_{i=1}^{K} \frac{2^{\text{rel}_i} - 1}{\log_2(i + 1)} + + and :math:`\text{rel}_i` is the relevance score of the item at position :math:`i` in the + ranked list (1-indexed). + + - ``update`` must receive output of the form ``(y_pred, y)``. + - ``y_pred`` is expected to be raw logits or probability score for each item in the catalog. + - ``y`` is expected to contain relevance scores (can be binary or graded). + - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. + - returns a list of NDCG ordered by the sorted values of ``top_k``. + + Args: + top_k: a list of sorted positive integers that specifies `k` for calculating NDCG@top-k. + ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros) + are ignored in computation of NDCG. If set False, such users are counted with NDCG of 0. + By default, True. + relevance_threshold: minimum label value to be considered relevant. Defaults to ``1``, + which handles standard binary labels and graded relevance scales (e.g. TREC-style + 0-4) by treating any label >= 1 as relevant. Items below this threshold contribute + 0 to DCG/IDCG calculations. + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. + The output is expected to be a tuple `(prediction, target)` + where `prediction` and `target` are tensors + of shape ``(batch, num_items)``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + skip_unrolling: specifies whether input should be unrolled or not before being + processed. Should be true for multi-output models.. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + ignore_zero_hits=True case + + .. testcode:: 1 + + metric = NDCG(top_k=[1, 2, 3, 4]) + metric.attach(default_evaluator, "ndcg") + y_pred=torch.Tensor([ + [4.0, 2.0, 3.0, 1.0], + [1.0, 2.0, 3.0, 4.0] + ]) + y_true=torch.Tensor([ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0] + ]) + state = default_evaluator.run([(y_pred, y_true)]) + print(state.metrics["ndcg"]) + + .. testoutput:: 1 + + [0.0, 0.63..., 0.63..., 0.63...] + + ignore_zero_hits=False case + + .. testcode:: 2 + + metric = NDCG(top_k=[1, 2, 3, 4], ignore_zero_hits=False) + metric.attach(default_evaluator, "ndcg") + y_pred=torch.Tensor([ + [4.0, 2.0, 3.0, 1.0], + [1.0, 2.0, 3.0, 4.0] + ]) + y_true=torch.Tensor([ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0] + ]) + state = default_evaluator.run([(y_pred, y_true)]) + print(state.metrics["ndcg"]) + + .. testoutput:: 2 + + [0.0, 0.31..., 0.31..., 0.31...] + + .. versionadded:: 0.6.0 + """ + + required_output_keys = ("y_pred", "y") + _state_dict_all_req_keys = ("_sum_ndcg_per_k", "_num_examples") + + def __init__( + self, + top_k: list[int], + ignore_zero_hits: bool = True, + relevance_threshold: float = 1.0, + output_transform: Callable = lambda x: x, + device: str | torch.device = torch.device("cpu"), + skip_unrolling: bool = False, + ): + if any(k <= 0 for k in top_k): + raise ValueError(" top_k must be list of positive integers only.") + + self.top_k = sorted(top_k) + self.ignore_zero_hits = ignore_zero_hits + self.relevance_threshold = relevance_threshold + super(NDCG, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_ndcg_per_k = torch.zeros(len(self.top_k), device=self._device) + self._num_examples = 0 + + def _compute_dcg(self, relevance_scores: torch.Tensor, k: int) -> torch.Tensor: + """Compute DCG@k for a batch of relevance scores. + + Args: + relevance_scores: Tensor of shape (batch, num_items) with relevance scores at ranked positions + k: Number of positions to consider + + Returns: + DCG scores of shape (batch,) + """ + # Handle case where k > actual number of items + actual_k = min(k, relevance_scores.shape[1]) + + # Create position weights: 1/log2(position + 1) for position in [1, actual_k] + # Positions are 1-indexed in the DCG formula + positions = torch.arange(1, actual_k + 1, dtype=torch.float32, device=relevance_scores.device) + discounts = 1.0 / torch.log2(positions + 1) # log2(i+1) for i in [1, actual_k] + + # Compute gains: 2^rel - 1 + gains = torch.pow(2.0, relevance_scores[:, :actual_k]) - 1.0 + + # DCG = sum of (gain / discount) + dcg = (gains * discounts).sum(dim=-1) + return dcg + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + if len(output) != 2: + raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.") + + y_pred, y = output + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") + + # Filter out examples with no relevant items if ignore_zero_hits is True + if self.ignore_zero_hits: + valid_mask = torch.any(y >= self.relevance_threshold, dim=-1) + y_pred = y_pred[valid_mask] + y = y[valid_mask] + + if y.shape[0] == 0: + return + + # Zero out items below relevance threshold for DCG computation + y_for_dcg = torch.where(y >= self.relevance_threshold, y, torch.zeros_like(y)) + + max_k = self.top_k[-1] + + # Get ranked indices based on predictions (stable=True for deterministic tie-breaking) + ranked_indices = torch.argsort(y_pred, dim=-1, descending=True, stable=True)[:, :max_k] + + # Get relevance scores in the predicted ranking order + ranked_relevance = torch.gather(y_for_dcg, dim=-1, index=ranked_indices) + + # Compute ideal ranking by sorting true relevance scores + ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k] + + for i, k in enumerate(self.top_k): + # Compute DCG@k and IDCG@k + dcg_k = self._compute_dcg(ranked_relevance, k) + idcg_k = self._compute_dcg(ideal_relevance, k) + + # NDCG = DCG / IDCG, handle division by zero (when IDCG = 0, NDCG = 0) + ndcg_k = torch.where( + idcg_k > 0, + dcg_k / idcg_k, + torch.zeros_like(dcg_k) + ) + + self._sum_ndcg_per_k[i] += ndcg_k.sum().to(self._device) + + self._num_examples += y.shape[0] + + @sync_all_reduce("_sum_ndcg_per_k", "_num_examples") + def compute(self) -> list[float]: + if self._num_examples == 0: + raise NotComputableError("NDCG must have at least one example.") + + ndcg_scores = (self._sum_ndcg_per_k / self._num_examples).tolist() + return ndcg_scores diff --git a/tests/ignite/metrics/rec_sys/test_ndcg.py b/tests/ignite/metrics/rec_sys/test_ndcg.py new file mode 100644 index 000000000000..347f00e07f30 --- /dev/null +++ b/tests/ignite/metrics/rec_sys/test_ndcg.py @@ -0,0 +1,232 @@ +import numpy as np +import pytest +import torch + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.rec_sys.ndcg import NDCG + + +def ranx_ndcg( + y_pred: np.ndarray, + y: np.ndarray, + top_k: list[int], + ignore_zero_hits: bool = True, +) -> list[float]: + """Reference NDCG implementation using ranx for verification. https://github.com/AmenRa/ranx """ + from ranx import Qrels, Run, evaluate + + sorted_top_k = sorted(top_k) + results = [] + + for k in sorted_top_k: + qrels_dict = {} + run_dict = {} + + for i, (scores, labels) in enumerate(zip(y_pred, y)): + qid = f"q{i}" + relevant = {f"d{j}": int(label) for j, label in enumerate(labels) if label > 0} + + if ignore_zero_hits and not relevant: + continue + + qrels_dict[qid] = relevant if relevant else {f"d0": 0} + run_dict[qid] = {f"d{j}": float(s) for j, s in enumerate(scores)} + + if not qrels_dict: + results.append(0.0) + continue + + run_dict = {q: run_dict[q] for q in qrels_dict} + results.append(float(evaluate(Qrels(qrels_dict), Run(run_dict), f"ndcg@{k}"))) + + return results + + +def test_zero_sample(): + metric = NDCG(top_k=[1, 5]) + with pytest.raises(NotComputableError, match=r"NDCG must have at least one example"): + metric.compute() + + +def test_shape_mismatch(): + metric = NDCG(top_k=[1]) + y_pred = torch.randn(4, 10) + y = torch.ones(4, 5) # Mismatched items count + with pytest.raises(ValueError, match="y_pred and y must be in the same shape"): + metric.update((y_pred, y)) + + +def test_invalid_top_k(): + with pytest.raises(ValueError, match="positive integers"): + NDCG(top_k=[0]) + with pytest.raises(ValueError, match="positive integers"): + NDCG(top_k=[-1, 5]) + + +@pytest.mark.parametrize("top_k", [[1], [1, 2, 4]]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute(top_k, ignore_zero_hits, available_device): + metric = NDCG( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + device=available_device, + ) + + y_pred = torch.tensor([[4.0, 2.0, 3.0, 1.0], [1.0, 2.0, 3.0, 4.0]]) + y_true = torch.tensor([[0, 0, 1.0, 1.0], [0, 0, 0.0, 0.0]]) + + metric.update((y_pred, y_true)) + res = metric.compute() + + expected = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + + assert isinstance(res, list) + assert len(res) == len(top_k) + np.testing.assert_allclose(res, expected, rtol=1e-5) + + +@pytest.mark.parametrize("num_queries", [10, 100]) +@pytest.mark.parametrize("num_items", [20, 100]) +@pytest.mark.parametrize("k", [1, 5]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute_vs_ranx(num_queries, num_items, k, ignore_zero_hits, available_device): + """Verify NDCG matches ranx across a wide range of input shapes and k values.""" + torch.manual_seed(42) + y_pred = torch.randn(num_queries, num_items) + y_true = torch.randint(0, 2, (num_queries, num_items)).float() + + metric = NDCG( + top_k=[k], + ignore_zero_hits=ignore_zero_hits, + device=available_device, + ) + metric.update((y_pred, y_true)) + + try: + res = metric.compute() + except NotComputableError: + res = [0.0] + + expected = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k=[k], + ignore_zero_hits=ignore_zero_hits, + ) + + np.testing.assert_allclose(res, expected, rtol=1e-5) + + +def test_perfect_prediction(): + """Perfect ranking -> NDCG = 1.0.""" + metric = NDCG(top_k=[1, 3]) + y_pred = torch.tensor([[5.0, 3.0, 4.0, 1.0]]) + y_true = torch.tensor([[3.0, 1.0, 2.0, 0.0]]) # Matches ranking order + metric.update((y_pred, y_true)) + assert metric.compute() == pytest.approx([1.0, 1.0]) + + +def test_all_zeros_relevance(): + """When all relevance is 0, IDCG=0, so NDCG should be 0 (or ignored if ignore_zero_hits=True).""" + metric = NDCG(top_k=[2], ignore_zero_hits=False) + y_pred = torch.tensor([[5.0, 3.0, 4.0]]) + y_true = torch.tensor([[0.0, 0.0, 0.0]]) + metric.update((y_pred, y_true)) + # NDCG should be 0 when there are no relevant items + assert metric.compute() == pytest.approx([0.0]) + + +def test_graded_relevance_threshold(): + """Labels >= relevance_threshold are considered, but contribute their full value to DCG.""" + # relevance_threshold=2: labels < 2 are zeroed out + metric = NDCG(top_k=[3], relevance_threshold=2.0) + + # Predictions rank: doc0, doc2, doc1 + # True relevance: [3, 1, 2] -> After threshold: [3, 0, 2] + # Ranked relevance (after threshold): [3, 2, 0] + y_pred = torch.tensor([[0.9, 0.3, 0.7]]) + y_true = torch.tensor([[3.0, 1.0, 2.0]]) + metric.update((y_pred, y_true)) + + # DCG: (2^3-1)/log2(2) + (2^2-1)/log2(3) + (2^0-1)/log2(4) + # = 7/1 + 3/1.585 + 0/2 = 7 + 1.893 = 8.893 + # IDCG: (2^3-1)/log2(2) + (2^2-1)/log2(3) + (2^0-1)/log2(4) + # = 7/1 + 3/1.585 + 0/2 = 8.893 + # NDCG = 1.0 (perfect ranking of items that meet threshold) + + result = metric.compute() + assert result[0] == pytest.approx(1.0, rel=1e-5) + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + n_iters = 10 + batch_size = 4 + num_items = 20 + top_k = [1, 5, 10] + + rank = idist.get_rank() + torch.manual_seed(42 + rank) + device = idist.device() + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + all_y_true = torch.randint(0, 4, (n_iters * batch_size, num_items)).float().to(device) + all_y_pred = torch.randn((n_iters * batch_size, num_items)).to(device) + + for ignore_zero_hits in [True, False]: + engine = Engine( + lambda e, i: ( + all_y_pred[i * batch_size : (i + 1) * batch_size], + all_y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + m = NDCG( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + device=metric_device, + ) + m.attach(engine, "ndcg") + + engine.run(range(n_iters), max_epochs=1) + + global_y_true = idist.all_gather(all_y_true).cpu().numpy() + global_y_pred = idist.all_gather(all_y_pred).cpu().numpy() + + res = engine.state.metrics["ndcg"] + + true_res = ranx_ndcg( + global_y_pred, + global_y_true, + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + + assert isinstance(res, list) + np.testing.assert_allclose(res, true_res, rtol=1e-5) + + engine.state.metrics.clear() + + def test_accumulator_device(self): + device = idist.device() + metric = NDCG(top_k=[1, 5], device=device) + + assert metric._device == device + assert metric._sum_ndcg_per_k.device == device + + y_pred = torch.randn(2, 10) + y = torch.zeros(2, 10) + metric.update((y_pred, y)) + + assert metric._sum_ndcg_per_k.device == device From d0515d40270e491a1749e77b0022b472100d28e5 Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 6 Mar 2026 13:44:08 +0530 Subject: [PATCH 02/14] Make ranx optional in NDCG tests Added ranx = pytest.importorskip(...) to make ranx optional Wrapped ranx validation checks with if ranx is not None Added pytest.skip() for tests that specifically require ranx Tests now pass without ranx installed while still validating against it when available. The NDCG metric implementation itself has no ranx dependency. --- tests/ignite/metrics/rec_sys/test_ndcg.py | 24 +++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/ignite/metrics/rec_sys/test_ndcg.py b/tests/ignite/metrics/rec_sys/test_ndcg.py index 347f00e07f30..c7a92e68d3bd 100644 --- a/tests/ignite/metrics/rec_sys/test_ndcg.py +++ b/tests/ignite/metrics/rec_sys/test_ndcg.py @@ -7,6 +7,7 @@ from ignite.exceptions import NotComputableError from ignite.metrics.rec_sys.ndcg import NDCG +ranx = pytest.importorskip("ranx", reason="ranx is required for reference validation tests") def ranx_ndcg( y_pred: np.ndarray, @@ -80,16 +81,18 @@ def test_compute(top_k, ignore_zero_hits, available_device): metric.update((y_pred, y_true)) res = metric.compute() - expected = ranx_ndcg( - y_pred.numpy(), - y_true.numpy(), - top_k, - ignore_zero_hits=ignore_zero_hits, - ) - assert isinstance(res, list) assert len(res) == len(top_k) - np.testing.assert_allclose(res, expected, rtol=1e-5) + + # Validate against ranx if available + if ranx is not None: + expected = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + np.testing.assert_allclose(res, expected, rtol=1e-5) @pytest.mark.parametrize("num_queries", [10, 100]) @@ -98,6 +101,9 @@ def test_compute(top_k, ignore_zero_hits, available_device): @pytest.mark.parametrize("ignore_zero_hits", [True, False]) def test_compute_vs_ranx(num_queries, num_items, k, ignore_zero_hits, available_device): """Verify NDCG matches ranx across a wide range of input shapes and k values.""" + if ranx is None: + pytest.skip("ranx not installed") + torch.manual_seed(42) y_pred = torch.randn(num_queries, num_items) y_true = torch.randint(0, 2, (num_queries, num_items)).float() @@ -168,6 +174,8 @@ def test_graded_relevance_threshold(): @pytest.mark.usefixtures("distributed") class TestDistributed: def test_integration(self): + if ranx is None: + pytest.skip("ranx not installed") n_iters = 10 batch_size = 4 num_items = 20 From 3a032d0f7b4efcdef497b16ba5e416b5acea4001 Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 6 Mar 2026 13:50:24 +0530 Subject: [PATCH 03/14] Add try-except for ranx import in test_ndcg.py Use try/except to handle missing ranx dependency instead of pytest.importorskip --- tests/ignite/metrics/rec_sys/test_ndcg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/ignite/metrics/rec_sys/test_ndcg.py b/tests/ignite/metrics/rec_sys/test_ndcg.py index c7a92e68d3bd..c51b9526accc 100644 --- a/tests/ignite/metrics/rec_sys/test_ndcg.py +++ b/tests/ignite/metrics/rec_sys/test_ndcg.py @@ -7,7 +7,10 @@ from ignite.exceptions import NotComputableError from ignite.metrics.rec_sys.ndcg import NDCG -ranx = pytest.importorskip("ranx", reason="ranx is required for reference validation tests") +try: + import ranx +except ImportError: + ranx = None def ranx_ndcg( y_pred: np.ndarray, From 9e7f2a1381a586c6b7f8be24d1a8e8d6eaf25759 Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 6 Mar 2026 14:51:10 +0530 Subject: [PATCH 04/14] Add ranx to development requirements --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index a9584e587b7f..1dfb307127c3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -36,3 +36,4 @@ pandas gymnasium # temporary fix: E AttributeError: module 'mpmath' has no attribute 'rational' mpmath<1.4 +ranx From db0e6a04f939bce4a8c0207eee670c4ddb49ca89 Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 6 Mar 2026 14:58:15 +0530 Subject: [PATCH 05/14] Remove optional ranx import - ranx now in requirements-dev.txt Removed try/except import and conditional checks for ranx. ranx is now a required dev dependency in requirements-dev.txt, so tests can use it directly without fallback logic. --- tests/ignite/metrics/rec_sys/test_ndcg.py | 38 +++++------------------ 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/tests/ignite/metrics/rec_sys/test_ndcg.py b/tests/ignite/metrics/rec_sys/test_ndcg.py index c51b9526accc..8fd829a479c1 100644 --- a/tests/ignite/metrics/rec_sys/test_ndcg.py +++ b/tests/ignite/metrics/rec_sys/test_ndcg.py @@ -7,10 +7,6 @@ from ignite.exceptions import NotComputableError from ignite.metrics.rec_sys.ndcg import NDCG -try: - import ranx -except ImportError: - ranx = None def ranx_ndcg( y_pred: np.ndarray, @@ -84,18 +80,16 @@ def test_compute(top_k, ignore_zero_hits, available_device): metric.update((y_pred, y_true)) res = metric.compute() + expected = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + assert isinstance(res, list) assert len(res) == len(top_k) - - # Validate against ranx if available - if ranx is not None: - expected = ranx_ndcg( - y_pred.numpy(), - y_true.numpy(), - top_k, - ignore_zero_hits=ignore_zero_hits, - ) - np.testing.assert_allclose(res, expected, rtol=1e-5) + np.testing.assert_allclose(res, expected, rtol=1e-5) @pytest.mark.parametrize("num_queries", [10, 100]) @@ -104,9 +98,6 @@ def test_compute(top_k, ignore_zero_hits, available_device): @pytest.mark.parametrize("ignore_zero_hits", [True, False]) def test_compute_vs_ranx(num_queries, num_items, k, ignore_zero_hits, available_device): """Verify NDCG matches ranx across a wide range of input shapes and k values.""" - if ranx is None: - pytest.skip("ranx not installed") - torch.manual_seed(42) y_pred = torch.randn(num_queries, num_items) y_true = torch.randint(0, 2, (num_queries, num_items)).float() @@ -148,28 +139,17 @@ def test_all_zeros_relevance(): y_pred = torch.tensor([[5.0, 3.0, 4.0]]) y_true = torch.tensor([[0.0, 0.0, 0.0]]) metric.update((y_pred, y_true)) - # NDCG should be 0 when there are no relevant items assert metric.compute() == pytest.approx([0.0]) def test_graded_relevance_threshold(): """Labels >= relevance_threshold are considered, but contribute their full value to DCG.""" - # relevance_threshold=2: labels < 2 are zeroed out metric = NDCG(top_k=[3], relevance_threshold=2.0) - # Predictions rank: doc0, doc2, doc1 - # True relevance: [3, 1, 2] -> After threshold: [3, 0, 2] - # Ranked relevance (after threshold): [3, 2, 0] y_pred = torch.tensor([[0.9, 0.3, 0.7]]) y_true = torch.tensor([[3.0, 1.0, 2.0]]) metric.update((y_pred, y_true)) - # DCG: (2^3-1)/log2(2) + (2^2-1)/log2(3) + (2^0-1)/log2(4) - # = 7/1 + 3/1.585 + 0/2 = 7 + 1.893 = 8.893 - # IDCG: (2^3-1)/log2(2) + (2^2-1)/log2(3) + (2^0-1)/log2(4) - # = 7/1 + 3/1.585 + 0/2 = 8.893 - # NDCG = 1.0 (perfect ranking of items that meet threshold) - result = metric.compute() assert result[0] == pytest.approx(1.0, rel=1e-5) @@ -177,8 +157,6 @@ def test_graded_relevance_threshold(): @pytest.mark.usefixtures("distributed") class TestDistributed: def test_integration(self): - if ranx is None: - pytest.skip("ranx not installed") n_iters = 10 batch_size = 4 num_items = 20 From 7afeccf5f2f820b3fa0ec83fbef99a2c956abe34 Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 6 Mar 2026 15:14:24 +0530 Subject: [PATCH 06/14] Update ignite/metrics/rec_sys/ndcg.py Co-authored-by: vfdev --- ignite/metrics/rec_sys/ndcg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/rec_sys/ndcg.py b/ignite/metrics/rec_sys/ndcg.py index a5f38acd1a31..fe74b7d5abee 100644 --- a/ignite/metrics/rec_sys/ndcg.py +++ b/ignite/metrics/rec_sys/ndcg.py @@ -178,7 +178,7 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: return # Zero out items below relevance threshold for DCG computation - y_for_dcg = torch.where(y >= self.relevance_threshold, y, torch.zeros_like(y)) + y_for_dcg = torch.where(y >= self.relevance_threshold, y, 0) max_k = self.top_k[-1] From d8f69e6ef5c3afbb283267d74f7d7692fd94d54e Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 6 Mar 2026 18:47:40 +0530 Subject: [PATCH 07/14] Added the graded relevance explanation to the docstring. Add detailed explanation of relevance types for NDCG. --- ignite/metrics/rec_sys/ndcg.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ignite/metrics/rec_sys/ndcg.py b/ignite/metrics/rec_sys/ndcg.py index fe74b7d5abee..33e22f1dff70 100644 --- a/ignite/metrics/rec_sys/ndcg.py +++ b/ignite/metrics/rec_sys/ndcg.py @@ -29,6 +29,21 @@ class NDCG(Metric): - ``update`` must receive output of the form ``(y_pred, y)``. - ``y_pred`` is expected to be raw logits or probability score for each item in the catalog. - ``y`` is expected to contain relevance scores (can be binary or graded). + ``` +Relevance Types: + - **Binary relevance**: Labels are either 0 (not relevant) or 1 (relevant) + - **Graded relevance**: Labels can have multiple levels (e.g., 0-4 scale) + + Common graded scales: + - 0: Not relevant + - 1: Marginally relevant + - 2: Relevant + - 3: Highly relevant + - 4: Perfectly relevant + + The NDCG formula handles both types through the gain function: 2^relevance - 1. + Higher relevance scores contribute more to the metric. + - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. - returns a list of NDCG ordered by the sorted values of ``top_k``. From a5b06974c198c3339941d74240d8c1249d3ae723 Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 6 Mar 2026 19:29:07 +0530 Subject: [PATCH 08/14] Add helper function --- ignite/metrics/rec_sys/ndcg.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ignite/metrics/rec_sys/ndcg.py b/ignite/metrics/rec_sys/ndcg.py index 33e22f1dff70..1115ef157d14 100644 --- a/ignite/metrics/rec_sys/ndcg.py +++ b/ignite/metrics/rec_sys/ndcg.py @@ -8,6 +8,12 @@ __all__ = ["NDCG"] +def _get_ranked_items(scores: torch.Tensor, items: torch.Tensor, k: int) -> torch.Tensor: + """Get top-k items ranked by scores.""" + indices = torch.argsort(scores, dim=-1, descending=True, stable=True)[:, :k] + return torch.gather(items, dim=-1, index=indices) + + class NDCG(Metric): r"""Calculates the Normalized Discounted Cumulative Gain (NDCG) at `k` for Recommendation Systems. @@ -197,11 +203,7 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: max_k = self.top_k[-1] - # Get ranked indices based on predictions (stable=True for deterministic tie-breaking) - ranked_indices = torch.argsort(y_pred, dim=-1, descending=True, stable=True)[:, :max_k] - - # Get relevance scores in the predicted ranking order - ranked_relevance = torch.gather(y_for_dcg, dim=-1, index=ranked_indices) + ranked_relevance = _get_ranked_items(y_pred, y_for_dcg, max_k) # Compute ideal ranking by sorting true relevance scores ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k] From 694d947b750afcbcdc1146c3a9373ce9c8a140ae Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 6 Mar 2026 19:54:38 +0530 Subject: [PATCH 09/14] Add gain_function parameter to NDCG calculation --- ignite/metrics/rec_sys/ndcg.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/rec_sys/ndcg.py b/ignite/metrics/rec_sys/ndcg.py index 1115ef157d14..5f5ff358b3d6 100644 --- a/ignite/metrics/rec_sys/ndcg.py +++ b/ignite/metrics/rec_sys/ndcg.py @@ -62,6 +62,12 @@ class NDCG(Metric): which handles standard binary labels and graded relevance scales (e.g. TREC-style 0-4) by treating any label >= 1 as relevant. Items below this threshold contribute 0 to DCG/IDCG calculations. + + gain_function (str): Gain function for relevance scores.Options: + - ``'exp_rank'``: 2^relevance - 1 (emphasizes high relevance, default) + - ``'linear_rank'``: relevance (simpler, linear scale) + Defaults to ``'exp_rank'``. + output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. @@ -138,12 +144,16 @@ def __init__( top_k: list[int], ignore_zero_hits: bool = True, relevance_threshold: float = 1.0, + gain_function: str = "exp_rank", output_transform: Callable = lambda x: x, device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ): if any(k <= 0 for k in top_k): raise ValueError(" top_k must be list of positive integers only.") + + if gain_function not in ["exp_rank", "linear_rank"]: + raise ValueError("gain_function must be either 'exp_rank' or 'linear_rank'") self.top_k = sorted(top_k) self.ignore_zero_hits = ignore_zero_hits @@ -173,8 +183,10 @@ def _compute_dcg(self, relevance_scores: torch.Tensor, k: int) -> torch.Tensor: positions = torch.arange(1, actual_k + 1, dtype=torch.float32, device=relevance_scores.device) discounts = 1.0 / torch.log2(positions + 1) # log2(i+1) for i in [1, actual_k] - # Compute gains: 2^rel - 1 - gains = torch.pow(2.0, relevance_scores[:, :actual_k]) - 1.0 + if self.gain_function == "exp_rank": + gains = torch.pow(2.0, relevance_scores) - 1 + else: # linear_rank + gains = relevance_scores # DCG = sum of (gain / discount) dcg = (gains * discounts).sum(dim=-1) From a35acb33ba88314b637707f2a2bb79fb7b2abc93 Mon Sep 17 00:00:00 2001 From: steaphenai Date: Mon, 9 Mar 2026 03:41:30 +0530 Subject: [PATCH 10/14] Add gain_function support to NDCG with comprehensive ranx/Catalyst validation --- ignite/metrics/__init__.py | 2 +- ignite/metrics/rec_sys/ndcg.py | 97 +++++------- tests/ignite/metrics/rec_sys/test_ndcg.py | 183 ++++++++++++++++++++-- 3 files changed, 210 insertions(+), 72 deletions(-) diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index c6d973177203..acbbce8325e4 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -105,5 +105,5 @@ "CommonObjectDetectionMetrics", "coco_tensor_list_to_dict_list", "HitRate", - "NDGC", + "NDCG", ] diff --git a/ignite/metrics/rec_sys/ndcg.py b/ignite/metrics/rec_sys/ndcg.py index 5f5ff358b3d6..c617fab7f981 100644 --- a/ignite/metrics/rec_sys/ndcg.py +++ b/ignite/metrics/rec_sys/ndcg.py @@ -35,21 +35,21 @@ class NDCG(Metric): - ``update`` must receive output of the form ``(y_pred, y)``. - ``y_pred`` is expected to be raw logits or probability score for each item in the catalog. - ``y`` is expected to contain relevance scores (can be binary or graded). - ``` -Relevance Types: - - **Binary relevance**: Labels are either 0 (not relevant) or 1 (relevant) - - **Graded relevance**: Labels can have multiple levels (e.g., 0-4 scale) - - Common graded scales: - - 0: Not relevant - - 1: Marginally relevant - - 2: Relevant - - 3: Highly relevant - - 4: Perfectly relevant - - The NDCG formula handles both types through the gain function: 2^relevance - 1. - Higher relevance scores contribute more to the metric. - + + Relevance Types: + - **Binary relevance**: Labels are either 0 (not relevant) or 1 (relevant) + - **Graded relevance**: Labels can have multiple levels (e.g., 0-4 scale) + + Common graded scales: + - 0: Not relevant + - 1: Marginally relevant + - 2: Relevant + - 3: Highly relevant + - 4: Perfectly relevant + + The NDCG formula handles both types through the gain function: 2^relevance - 1. + Higher relevance scores contribute more to the metric. + - ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`. - returns a list of NDCG ordered by the sorted values of ``top_k``. @@ -62,12 +62,10 @@ class NDCG(Metric): which handles standard binary labels and graded relevance scales (e.g. TREC-style 0-4) by treating any label >= 1 as relevant. Items below this threshold contribute 0 to DCG/IDCG calculations. - - gain_function (str): Gain function for relevance scores.Options: - - ``'exp_rank'``: 2^relevance - 1 (emphasizes high relevance, default) - - ``'linear_rank'``: relevance (simpler, linear scale) - Defaults to ``'exp_rank'``. - + gain_function (str): Gain function for relevance scores. Options: + - ``'exp_rank'``: 2^relevance - 1 (emphasizes high relevance, default) + - ``'linear_rank'``: relevance (simpler, linear scale) + Defaults to ``'exp_rank'``. output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. @@ -151,13 +149,14 @@ def __init__( ): if any(k <= 0 for k in top_k): raise ValueError(" top_k must be list of positive integers only.") - + if gain_function not in ["exp_rank", "linear_rank"]: - raise ValueError("gain_function must be either 'exp_rank' or 'linear_rank'") + raise ValueError("gain_function must be either 'exp_rank' or 'linear_rank'") self.top_k = sorted(top_k) self.ignore_zero_hits = ignore_zero_hits self.relevance_threshold = relevance_threshold + self.gain_function = gain_function super(NDCG, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling) @reinit__is_reduced @@ -166,29 +165,18 @@ def reset(self) -> None: self._num_examples = 0 def _compute_dcg(self, relevance_scores: torch.Tensor, k: int) -> torch.Tensor: - """Compute DCG@k for a batch of relevance scores. - - Args: - relevance_scores: Tensor of shape (batch, num_items) with relevance scores at ranked positions - k: Number of positions to consider - - Returns: - DCG scores of shape (batch,) - """ - # Handle case where k > actual number of items + """Compute DCG@k for a batch of relevance scores.""" actual_k = min(k, relevance_scores.shape[1]) - - # Create position weights: 1/log2(position + 1) for position in [1, actual_k] - # Positions are 1-indexed in the DCG formula + positions = torch.arange(1, actual_k + 1, dtype=torch.float32, device=relevance_scores.device) - discounts = 1.0 / torch.log2(positions + 1) # log2(i+1) for i in [1, actual_k] - - if self.gain_function == "exp_rank": - gains = torch.pow(2.0, relevance_scores) - 1 - else: # linear_rank - gains = relevance_scores - - # DCG = sum of (gain / discount) + discounts = 1.0 / torch.log2(positions + 1) + + topk_relevance = relevance_scores[:, :actual_k] + if self.gain_function == "exp_rank": + gains = torch.pow(2.0, topk_relevance) - 1 + else: + gains = topk_relevance + dcg = (gains * discounts).sum(dim=-1) return dcg @@ -201,7 +189,6 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: if y_pred.shape != y.shape: raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") - # Filter out examples with no relevant items if ignore_zero_hits is True if self.ignore_zero_hits: valid_mask = torch.any(y >= self.relevance_threshold, dim=-1) y_pred = y_pred[valid_mask] @@ -210,28 +197,22 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: if y.shape[0] == 0: return - # Zero out items below relevance threshold for DCG computation - y_for_dcg = torch.where(y >= self.relevance_threshold, y, 0) + y_for_dcg = torch.where(y >= self.relevance_threshold, y, torch.zeros_like(y)) max_k = self.top_k[-1] - - ranked_relevance = _get_ranked_items(y_pred, y_for_dcg, max_k) - - # Compute ideal ranking by sorting true relevance scores + ranked_relevance = _get_ranked_items(y_pred, y_for_dcg, max_k) ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k] for i, k in enumerate(self.top_k): - # Compute DCG@k and IDCG@k dcg_k = self._compute_dcg(ranked_relevance, k) idcg_k = self._compute_dcg(ideal_relevance, k) - - # NDCG = DCG / IDCG, handle division by zero (when IDCG = 0, NDCG = 0) + ndcg_k = torch.where( idcg_k > 0, dcg_k / idcg_k, - torch.zeros_like(dcg_k) + torch.zeros_like(dcg_k), ) - + self._sum_ndcg_per_k[i] += ndcg_k.sum().to(self._device) self._num_examples += y.shape[0] @@ -241,5 +222,5 @@ def compute(self) -> list[float]: if self._num_examples == 0: raise NotComputableError("NDCG must have at least one example.") - ndcg_scores = (self._sum_ndcg_per_k / self._num_examples).tolist() - return ndcg_scores + rates = (self._sum_ndcg_per_k / self._num_examples).tolist() + return rates diff --git a/tests/ignite/metrics/rec_sys/test_ndcg.py b/tests/ignite/metrics/rec_sys/test_ndcg.py index 8fd829a479c1..5a6da4593536 100644 --- a/tests/ignite/metrics/rec_sys/test_ndcg.py +++ b/tests/ignite/metrics/rec_sys/test_ndcg.py @@ -13,6 +13,8 @@ def ranx_ndcg( y: np.ndarray, top_k: list[int], ignore_zero_hits: bool = True, + relevance_threshold: float = 1.0, + gain_function: str = "exp_rank", ) -> list[float]: """Reference NDCG implementation using ranx for verification. https://github.com/AmenRa/ranx """ from ranx import Qrels, Run, evaluate @@ -26,12 +28,12 @@ def ranx_ndcg( for i, (scores, labels) in enumerate(zip(y_pred, y)): qid = f"q{i}" - relevant = {f"d{j}": int(label) for j, label in enumerate(labels) if label > 0} + relevant = {f"d{j}": float(label) for j, label in enumerate(labels) if label >= relevance_threshold} if ignore_zero_hits and not relevant: continue - qrels_dict[qid] = relevant if relevant else {f"d0": 0} + qrels_dict[qid] = relevant if relevant else {f"d0": 0.0} run_dict[qid] = {f"d{j}": float(s) for j, s in enumerate(scores)} if not qrels_dict: @@ -39,11 +41,47 @@ def ranx_ndcg( continue run_dict = {q: run_dict[q] for q in qrels_dict} - results.append(float(evaluate(Qrels(qrels_dict), Run(run_dict), f"ndcg@{k}"))) - + metric_name = f"{'ndcg_burges' if gain_function == 'exp_rank' else 'ndcg'}@{k}" + results.append(float(evaluate(Qrels(qrels_dict), Run(run_dict), metric_name))) return results +def catalyst_ndcg( + y_pred: np.ndarray, + y: np.ndarray, + top_k: list[int], + ignore_zero_hits: bool = True, + relevance_threshold: float = 1.0, + gain_function: str = "exp_rank", +) -> list[float]: + """Reference NDCG implementation using catalyst for verification.""" + pytest.importorskip("catalyst", reason="catalyst is required for catalyst parity checks") + from catalyst.metrics.functional import ndcg as catalyst_ndcg_fn + + sorted_top_k = sorted(top_k) + + outputs = torch.from_numpy(y_pred).float() + targets = torch.from_numpy(y).float() + + if ignore_zero_hits: + valid_mask = torch.any(targets >= relevance_threshold, dim=-1) + outputs = outputs[valid_mask] + targets = targets[valid_mask] + + if targets.shape[0] == 0: + return [0.0] * len(sorted_top_k) + + targets_for_dcg = torch.where(targets >= relevance_threshold, targets, torch.zeros_like(targets)) + values = catalyst_ndcg_fn( + outputs=outputs, + targets=targets_for_dcg, + topk=sorted_top_k, + gain_function=gain_function, + ) + + return [float(v) for v in values] + + def test_zero_sample(): metric = NDCG(top_k=[1, 5]) with pytest.raises(NotComputableError, match=r"NDCG must have at least one example"): @@ -65,6 +103,11 @@ def test_invalid_top_k(): NDCG(top_k=[-1, 5]) +def test_invalid_gain_function(): + with pytest.raises(ValueError, match="gain_function must be either"): + NDCG(top_k=[1], gain_function="invalid") + + @pytest.mark.parametrize("top_k", [[1], [1, 2, 4]]) @pytest.mark.parametrize("ignore_zero_hits", [True, False]) def test_compute(top_k, ignore_zero_hits, available_device): @@ -92,9 +135,9 @@ def test_compute(top_k, ignore_zero_hits, available_device): np.testing.assert_allclose(res, expected, rtol=1e-5) -@pytest.mark.parametrize("num_queries", [10, 100]) -@pytest.mark.parametrize("num_items", [20, 100]) -@pytest.mark.parametrize("k", [1, 5]) +@pytest.mark.parametrize("num_queries", [1, 10, 100]) +@pytest.mark.parametrize("num_items", [5, 20, 100]) +@pytest.mark.parametrize("k", [1, 5, 10]) @pytest.mark.parametrize("ignore_zero_hits", [True, False]) def test_compute_vs_ranx(num_queries, num_items, k, ignore_zero_hits, available_device): """Verify NDCG matches ranx across a wide range of input shapes and k values.""" @@ -124,17 +167,112 @@ def test_compute_vs_ranx(num_queries, num_items, k, ignore_zero_hits, available_ np.testing.assert_allclose(res, expected, rtol=1e-5) +@pytest.mark.parametrize("top_k", [[3], [2, 5]]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute_vs_ranx_and_catalyst_with_ties(top_k, ignore_zero_hits, available_device): + """Validate tie handling against ranx and catalyst with non-trivial tie cases.""" + y_pred = torch.tensor([ + [0.7, 0.7, 0.7, 0.5, 0.5], + [0.9, 0.9, 0.6, 0.6, 0.3], + [0.8, 0.8, 0.5, 0.1, 0.1], + ]) + y_true = torch.tensor([ + [3.0, 2.0, 1.0, 1.0, 0.0], + [4.0, 3.0, 2.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ]) + + metric = NDCG( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + gain_function="exp_rank", + device=available_device, + ) + metric.update((y_pred, y_true)) + res = metric.compute() + + expected_ranx = ranx_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + ) + expected_catalyst = catalyst_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + gain_function="exp_rank", + ) + + np.testing.assert_allclose(res, expected_ranx, rtol=1e-5) + np.testing.assert_allclose(res, expected_catalyst, rtol=1e-5) + + +@pytest.mark.parametrize("top_k", [[3], [2, 5]]) +@pytest.mark.parametrize("ignore_zero_hits", [True, False]) +def test_compute_vs_catalyst_linear_gain(top_k, ignore_zero_hits, available_device): + """Validate linear gain mode against catalyst.""" + y_pred = torch.tensor([ + [0.7, 0.7, 0.7, 0.5, 0.5], + [0.9, 0.9, 0.6, 0.6, 0.3], + ]) + y_true = torch.tensor([ + [3.0, 2.0, 1.0, 1.0, 0.0], + [4.0, 3.0, 2.0, 1.0, 0.0], + ]) + + metric = NDCG( + top_k=top_k, + ignore_zero_hits=ignore_zero_hits, + gain_function="linear_rank", + device=available_device, + ) + metric.update((y_pred, y_true)) + res = metric.compute() + + expected_catalyst = catalyst_ndcg( + y_pred.numpy(), + y_true.numpy(), + top_k, + ignore_zero_hits=ignore_zero_hits, + gain_function="linear_rank", + ) + + np.testing.assert_allclose(res, expected_catalyst, rtol=1e-5) + + def test_perfect_prediction(): """Perfect ranking -> NDCG = 1.0.""" metric = NDCG(top_k=[1, 3]) y_pred = torch.tensor([[5.0, 3.0, 4.0, 1.0]]) - y_true = torch.tensor([[3.0, 1.0, 2.0, 0.0]]) # Matches ranking order + y_true = torch.tensor([[3.0, 1.0, 2.0, 0.0]]) metric.update((y_pred, y_true)) assert metric.compute() == pytest.approx([1.0, 1.0]) +def test_multiple_batches(): + """NDCG accumulates correctly across multiple update() calls.""" + metric = NDCG(top_k=[3], ignore_zero_hits=False) + + y_pred_1 = torch.tensor([[3.0, 2.0, 1.0]]) + y_true_1 = torch.tensor([[3.0, 2.0, 1.0]]) + + y_pred_2 = torch.tensor([[1.0, 2.0, 3.0]]) + y_true_2 = torch.tensor([[3.0, 2.0, 1.0]]) + + metric.update((y_pred_1, y_true_1)) + metric.update((y_pred_2, y_true_2)) + + expected_1 = ranx_ndcg(y_pred_1.numpy(), y_true_1.numpy(), [3], ignore_zero_hits=False, gain_function="exp_rank")[0] + expected_2 = ranx_ndcg(y_pred_2.numpy(), y_true_2.numpy(), [3], ignore_zero_hits=False, gain_function="exp_rank")[0] + expected = (expected_1 + expected_2) / 2.0 + + assert metric.compute() == pytest.approx([expected], rel=1e-5) + + def test_all_zeros_relevance(): - """When all relevance is 0, IDCG=0, so NDCG should be 0 (or ignored if ignore_zero_hits=True).""" + """When all relevance is 0, IDCG=0, so NDCG should be 0 if ignore_zero_hits=False.""" metric = NDCG(top_k=[2], ignore_zero_hits=False) y_pred = torch.tensor([[5.0, 3.0, 4.0]]) y_true = torch.tensor([[0.0, 0.0, 0.0]]) @@ -145,22 +283,41 @@ def test_all_zeros_relevance(): def test_graded_relevance_threshold(): """Labels >= relevance_threshold are considered, but contribute their full value to DCG.""" metric = NDCG(top_k=[3], relevance_threshold=2.0) - + y_pred = torch.tensor([[0.9, 0.3, 0.7]]) y_true = torch.tensor([[3.0, 1.0, 2.0]]) metric.update((y_pred, y_true)) - + result = metric.compute() assert result[0] == pytest.approx(1.0, rel=1e-5) +def test_accumulator_detached(available_device): + metric = NDCG(top_k=[1], device=available_device) + y_pred = torch.randn(4, 5, requires_grad=True) + y = torch.randint(0, 2, (4, 5)).float() + metric.update((y_pred, y)) + + assert metric._sum_ndcg_per_k.requires_grad is False + assert metric._sum_ndcg_per_k.is_leaf is True + + +def test_all_zero_targets_ignore(): + metric = NDCG(top_k=[1, 3], ignore_zero_hits=True) + y_pred = torch.randn(4, 5) + y = torch.zeros(4, 5) + metric.update((y_pred, y)) + with pytest.raises(NotComputableError): + metric.compute() + + @pytest.mark.usefixtures("distributed") class TestDistributed: def test_integration(self): n_iters = 10 batch_size = 4 num_items = 20 - top_k = [1, 5, 10] + top_k = [1, 5] rank = idist.get_rank() torch.manual_seed(42 + rank) @@ -203,7 +360,7 @@ def test_integration(self): ) assert isinstance(res, list) - np.testing.assert_allclose(res, true_res, rtol=1e-5) + assert res == pytest.approx(true_res, rel=1e-5) engine.state.metrics.clear() From c97dfff09c3e14cc429020400415e3a60d41b0da Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 13 Mar 2026 16:06:33 +0530 Subject: [PATCH 11/14] Refactor comments and improve code readability --- ignite/metrics/rec_sys/ndcg.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ignite/metrics/rec_sys/ndcg.py b/ignite/metrics/rec_sys/ndcg.py index c617fab7f981..7bc82065bc60 100644 --- a/ignite/metrics/rec_sys/ndcg.py +++ b/ignite/metrics/rec_sys/ndcg.py @@ -34,7 +34,7 @@ class NDCG(Metric): - ``update`` must receive output of the form ``(y_pred, y)``. - ``y_pred`` is expected to be raw logits or probability score for each item in the catalog. - - ``y`` is expected to contain relevance scores (can be binary or graded). + - ``y`` is expected to contain relevance scores (can be binary or graded).Expand commentComment on line R37Resolved Relevance Types: - **Binary relevance**: Labels are either 0 (not relevant) or 1 (relevant) @@ -102,7 +102,7 @@ class NDCG(Metric): y_true=torch.Tensor([ [0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0] - ]) + ])Expand commentComment on line R105Resolved state = default_evaluator.run([(y_pred, y_true)]) print(state.metrics["ndcg"]) @@ -137,7 +137,7 @@ class NDCG(Metric): required_output_keys = ("y_pred", "y") _state_dict_all_req_keys = ("_sum_ndcg_per_k", "_num_examples") - def __init__( + def __init__(Expand commentComment on line R140Resolved self, top_k: list[int], ignore_zero_hits: bool = True, @@ -197,11 +197,16 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: if y.shape[0] == 0: return - y_for_dcg = torch.where(y >= self.relevance_threshold, y, torch.zeros_like(y)) + y_for_dcg = torch.where(y >= self.relevance_threshold, y, 0) max_k = self.top_k[-1] ranked_relevance = _get_ranked_items(y_pred, y_for_dcg, max_k) - ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k] + # Compute ideal ranking by sorting true relevance scores (descending). + # This aligns with standard IDCG computation in reference libraries: + # ranx: https://github.com/AmenRa/ranx/blob/master/ranx/metrics/ndcg.py#L52 + # catalyst: https://github.com/catalyst-team/catalyst/blob/master/catalyst/metrics/functional/_ndcg.py#L197 + + ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k]Expand commentComment on line R209Resolved for i, k in enumerate(self.top_k): dcg_k = self._compute_dcg(ranked_relevance, k) From 8df6fcfd642fedc27feaa6ad9dc8345064720e57 Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 13 Mar 2026 16:08:06 +0530 Subject: [PATCH 12/14] Refactor test_ndcg.py for import and y_true changes Updated import statements and modified y_true generation in tests. --- tests/ignite/metrics/rec_sys/test_ndcg.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/ignite/metrics/rec_sys/test_ndcg.py b/tests/ignite/metrics/rec_sys/test_ndcg.py index 5a6da4593536..c07efc3b17e8 100644 --- a/tests/ignite/metrics/rec_sys/test_ndcg.py +++ b/tests/ignite/metrics/rec_sys/test_ndcg.py @@ -17,7 +17,7 @@ def ranx_ndcg( gain_function: str = "exp_rank", ) -> list[float]: """Reference NDCG implementation using ranx for verification. https://github.com/AmenRa/ranx """ - from ranx import Qrels, Run, evaluate + from ranx import Qrels, Run, evaluateExpand commentComment on line R20Resolved sorted_top_k = sorted(top_k) results = [] @@ -56,7 +56,7 @@ def catalyst_ndcg( ) -> list[float]: """Reference NDCG implementation using catalyst for verification.""" pytest.importorskip("catalyst", reason="catalyst is required for catalyst parity checks") - from catalyst.metrics.functional import ndcg as catalyst_ndcg_fn + from catalyst.metrics.functional import ndcg as catalyst_ndcg_fnExpand commentComment on lines R58 to R59Resolved sorted_top_k = sorted(top_k) @@ -143,7 +143,7 @@ def test_compute_vs_ranx(num_queries, num_items, k, ignore_zero_hits, available_ """Verify NDCG matches ranx across a wide range of input shapes and k values.""" torch.manual_seed(42) y_pred = torch.randn(num_queries, num_items) - y_true = torch.randint(0, 2, (num_queries, num_items)).float() + y_true = torch.randint(0, 5, (num_queries, num_items)).float() metric = NDCG( top_k=[k], @@ -376,3 +376,4 @@ def test_accumulator_device(self): metric.update((y_pred, y)) assert metric._sum_ndcg_per_k.device == device + From bb12d23d7e3b0ce25d0af78dac60ab4b24cd90ed Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 13 Mar 2026 16:08:42 +0530 Subject: [PATCH 13/14] Add catalyst to development requirements --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index c29d027cab07..6a0686d133f7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -35,3 +35,4 @@ gymnasium # temporary fix: E AttributeError: module 'mpmath' has no attribute 'rational' mpmath<1.4 ranx +catalyst From 3a88ce51a0e1fc89bd841fcad7696366f67de6fd Mon Sep 17 00:00:00 2001 From: Steaphen Date: Fri, 13 Mar 2026 16:27:32 +0530 Subject: [PATCH 14/14] Update metrics.rst --- docs/source/metrics.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index e7cb2d9c136d..7f0b03df52e1 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -391,7 +391,7 @@ Complete list of metrics clustering.DaviesBouldinScore clustering.CalinskiHarabaszScore rec_sys.HitRate - rec_sys.NDGC + rec_sys.NDCG .. note::