Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ Complete list of metrics
clustering.DaviesBouldinScore
clustering.CalinskiHarabaszScore
rec_sys.HitRate

rec_sys.NDCG

.. note::

Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ignite.metrics.psnr import PSNR
from ignite.metrics.recall import Recall
from ignite.metrics.rec_sys.hitrate import HitRate
from ignite.metrics.rec_sys.ndcg import NDCG
from ignite.metrics.roc_auc import ROC_AUC, RocCurve
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
from ignite.metrics.running_average import RunningAverage
Expand Down Expand Up @@ -104,4 +105,5 @@
"CommonObjectDetectionMetrics",
"coco_tensor_list_to_dict_list",
"HitRate",
"NDCG",
]
3 changes: 3 additions & 0 deletions ignite/metrics/rec_sys/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from ignite.metrics.rec_sys.hitrate import HitRate
from ignite.metrics.rec_sys.ndcg import NDCG

__all__ = ["HitRate", "NDCG"]
231 changes: 231 additions & 0 deletions ignite/metrics/rec_sys/ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from typing import Callable

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["NDCG"]


def _get_ranked_items(scores: torch.Tensor, items: torch.Tensor, k: int) -> torch.Tensor:
"""Get top-k items ranked by scores."""
indices = torch.argsort(scores, dim=-1, descending=True, stable=True)[:, :k]
return torch.gather(items, dim=-1, index=indices)


class NDCG(Metric):
r"""Calculates the Normalized Discounted Cumulative Gain (NDCG) at `k` for Recommendation Systems.

NDCG measures the quality of ranking by considering both the relevance of items and their
positions in the ranked list. It compares the achieved DCG against the ideal DCG (IDCG)
obtained by sorting items by their true relevance.

.. math::
\text{NDCG}@K = \frac{\text{DCG}@K}{\text{IDCG}@K}

where:

.. math::
\text{DCG}@K = \sum_{i=1}^{K} \frac{2^{\text{rel}_i} - 1}{\log_2(i + 1)}

and :math:`\text{rel}_i` is the relevance score of the item at position :math:`i` in the
ranked list (1-indexed).

- ``update`` must receive output of the form ``(y_pred, y)``.
- ``y_pred`` is expected to be raw logits or probability score for each item in the catalog.
- ``y`` is expected to contain relevance scores (can be binary or graded).Expand commentComment on line R37Resolved

Relevance Types:
- **Binary relevance**: Labels are either 0 (not relevant) or 1 (relevant)
- **Graded relevance**: Labels can have multiple levels (e.g., 0-4 scale)

Common graded scales:
- 0: Not relevant
- 1: Marginally relevant
- 2: Relevant
- 3: Highly relevant
- 4: Perfectly relevant

The NDCG formula handles both types through the gain function: 2^relevance - 1.
Higher relevance scores contribute more to the metric.

- ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`.
- returns a list of NDCG ordered by the sorted values of ``top_k``.

Args:
top_k: a list of sorted positive integers that specifies `k` for calculating NDCG@top-k.
ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros)
are ignored in computation of NDCG. If set False, such users are counted with NDCG of 0.
By default, True.
relevance_threshold: minimum label value to be considered relevant. Defaults to ``1``,
which handles standard binary labels and graded relevance scales (e.g. TREC-style
0-4) by treating any label >= 1 as relevant. Items below this threshold contribute
0 to DCG/IDCG calculations.
gain_function (str): Gain function for relevance scores. Options:
- ``'exp_rank'``: 2^relevance - 1 (emphasizes high relevance, default)
- ``'linear_rank'``: relevance (simpler, linear scale)
Defaults to ``'exp_rank'``.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric.
The output is expected to be a tuple `(prediction, target)`
where `prediction` and `target` are tensors
of shape ``(batch, num_items)``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether input should be unrolled or not before being
processed. Should be true for multi-output models..

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. include:: defaults.rst
:start-after: :orphan:

ignore_zero_hits=True case

.. testcode:: 1

metric = NDCG(top_k=[1, 2, 3, 4])
metric.attach(default_evaluator, "ndcg")
y_pred=torch.Tensor([
[4.0, 2.0, 3.0, 1.0],
[1.0, 2.0, 3.0, 4.0]
])
y_true=torch.Tensor([
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]
])Expand commentComment on line R105Resolved
state = default_evaluator.run([(y_pred, y_true)])
print(state.metrics["ndcg"])

.. testoutput:: 1

[0.0, 0.63..., 0.63..., 0.63...]

ignore_zero_hits=False case

.. testcode:: 2

metric = NDCG(top_k=[1, 2, 3, 4], ignore_zero_hits=False)
metric.attach(default_evaluator, "ndcg")
y_pred=torch.Tensor([
[4.0, 2.0, 3.0, 1.0],
[1.0, 2.0, 3.0, 4.0]
])
y_true=torch.Tensor([
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]
])
state = default_evaluator.run([(y_pred, y_true)])
print(state.metrics["ndcg"])

.. testoutput:: 2

[0.0, 0.31..., 0.31..., 0.31...]

.. versionadded:: 0.6.0
"""

required_output_keys = ("y_pred", "y")
_state_dict_all_req_keys = ("_sum_ndcg_per_k", "_num_examples")

def __init__(Expand commentComment on line R140Resolved
self,
top_k: list[int],
ignore_zero_hits: bool = True,
relevance_threshold: float = 1.0,
gain_function: str = "exp_rank",
output_transform: Callable = lambda x: x,
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
):
if any(k <= 0 for k in top_k):
raise ValueError(" top_k must be list of positive integers only.")

if gain_function not in ["exp_rank", "linear_rank"]:
raise ValueError("gain_function must be either 'exp_rank' or 'linear_rank'")

self.top_k = sorted(top_k)
self.ignore_zero_hits = ignore_zero_hits
self.relevance_threshold = relevance_threshold
self.gain_function = gain_function
super(NDCG, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)

@reinit__is_reduced
def reset(self) -> None:
self._sum_ndcg_per_k = torch.zeros(len(self.top_k), device=self._device)
self._num_examples = 0

def _compute_dcg(self, relevance_scores: torch.Tensor, k: int) -> torch.Tensor:
"""Compute DCG@k for a batch of relevance scores."""
actual_k = min(k, relevance_scores.shape[1])

positions = torch.arange(1, actual_k + 1, dtype=torch.float32, device=relevance_scores.device)
discounts = 1.0 / torch.log2(positions + 1)

topk_relevance = relevance_scores[:, :actual_k]
if self.gain_function == "exp_rank":
gains = torch.pow(2.0, topk_relevance) - 1
else:
gains = topk_relevance

dcg = (gains * discounts).sum(dim=-1)
return dcg

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
if len(output) != 2:
raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.")

y_pred, y = output
if y_pred.shape != y.shape:
raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.")

if self.ignore_zero_hits:
valid_mask = torch.any(y >= self.relevance_threshold, dim=-1)
y_pred = y_pred[valid_mask]
y = y[valid_mask]

if y.shape[0] == 0:
return

y_for_dcg = torch.where(y >= self.relevance_threshold, y, 0)

max_k = self.top_k[-1]
ranked_relevance = _get_ranked_items(y_pred, y_for_dcg, max_k)
# Compute ideal ranking by sorting true relevance scores (descending).
# This aligns with standard IDCG computation in reference libraries:
# ranx: https://github.com/AmenRa/ranx/blob/master/ranx/metrics/ndcg.py#L52
# catalyst: https://github.com/catalyst-team/catalyst/blob/master/catalyst/metrics/functional/_ndcg.py#L197

ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k]Expand commentComment on line R209Resolved

for i, k in enumerate(self.top_k):
dcg_k = self._compute_dcg(ranked_relevance, k)
idcg_k = self._compute_dcg(ideal_relevance, k)

ndcg_k = torch.where(
idcg_k > 0,
dcg_k / idcg_k,
torch.zeros_like(dcg_k),
)

self._sum_ndcg_per_k[i] += ndcg_k.sum().to(self._device)

self._num_examples += y.shape[0]

@sync_all_reduce("_sum_ndcg_per_k", "_num_examples")
def compute(self) -> list[float]:
if self._num_examples == 0:
raise NotComputableError("NDCG must have at least one example.")

rates = (self._sum_ndcg_per_k / self._num_examples).tolist()
return rates
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ pandas
gymnasium
# temporary fix: E AttributeError: module 'mpmath' has no attribute 'rational'
mpmath<1.4
ranx
catalyst
Loading