Skip to content

Commit 6c3dfd9

Browse files
steaphenaivfdev-5
andauthored
Add NDCG metric to rec_sys (#3608)
Fixes #3581 Refs #3568 #3569 #3610 Description: Implements NDCG (Normalized Discounted Cumulative Gain) metric for recommendation systems and ranking evaluation. ## Implementation Details - Follows the same pattern as `HitRate` and `MRR` metrics - Uses `ranx` library for verification (addressing #3569 feedback) - Supports both binary (0/1) and graded (0-4) labels - `relevance_threshold` defaults to 1.0 for binary labels - Uses `stable=True` for reproducible tie-breaking - Handles k > num_items - Will align with top_k API changes from #3568 when finalized ## Mathematical Background - **DCG**: Discounted Cumulative Gain - rewards relevant items, penalizes low positions - **IDCG**: Ideal DCG - perfect ranking baseline - **NDCG**: DCG / IDCG - normalized score from 0.0 to 1.0 Formula: `NDCG@K = Σ(2^rel_i - 1) / log2(i + 1)` ## Testing - 26 tests passing locally (verified against ranx) - Parametrized tests across multiple configurations - Edge cases covered (perfect prediction, zero relevance, graded labels) - Distributed training testedcs Check list: - [x] New tests are added (if a new feature is added) - [x] New doc strings: description and/or example code are in RST format - [ ] Documentation is updated (if required) --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 696cba6 commit 6c3dfd9

6 files changed

Lines changed: 582 additions & 1 deletion

File tree

docs/source/metrics.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ Complete list of metrics
396396
fairness.SubgroupDifference
397397
fairness.SubgroupMetric
398398
rec_sys.HitRate
399-
399+
rec_sys.NDCG
400400

401401
.. note::
402402

ignite/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ignite.metrics.psnr import PSNR
3939
from ignite.metrics.recall import Recall
4040
from ignite.metrics.rec_sys.hitrate import HitRate
41+
from ignite.metrics.rec_sys.ndcg import NDCG
4142
from ignite.metrics.roc_auc import ROC_AUC, RocCurve
4243
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
4344
from ignite.metrics.running_average import RunningAverage
@@ -106,4 +107,5 @@
106107
"CommonObjectDetectionMetrics",
107108
"coco_tensor_list_to_dict_list",
108109
"HitRate",
110+
"NDCG",
109111
]

ignite/metrics/rec_sys/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
from ignite.metrics.rec_sys.hitrate import HitRate
2+
from ignite.metrics.rec_sys.ndcg import NDCG
3+
4+
__all__ = ["HitRate", "NDCG"]

ignite/metrics/rec_sys/ndcg.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
from typing import Callable
2+
3+
import torch
4+
5+
from ignite.exceptions import NotComputableError
6+
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
7+
8+
__all__ = ["NDCG"]
9+
10+
11+
def _get_ranked_items(scores: torch.Tensor, items: torch.Tensor, k: int) -> torch.Tensor:
12+
"""Get top-k items ranked by scores."""
13+
indices = torch.argsort(scores, dim=-1, descending=True, stable=True)[:, :k]
14+
return torch.gather(items, dim=-1, index=indices)
15+
16+
17+
class NDCG(Metric):
18+
r"""Calculates the Normalized Discounted Cumulative Gain (NDCG) at `k` for Recommendation Systems.
19+
20+
For a step-by-step guide on how to use this metric, please refer to the
21+
`NDCG Tutorial <https://github.com/pytorch-ignite/examples/tree/main/tutorials/intermediate/ndcg-metric-tutorial.ipynb>`
22+
23+
NDCG measures the quality of ranking by considering both the relevance of items and their
24+
positions in the ranked list. It compares the achieved DCG against the ideal DCG (IDCG)
25+
obtained by sorting items by their true relevance.
26+
27+
.. math::
28+
\text{NDCG}@K = \frac{\text{DCG}@K}{\text{IDCG}@K}
29+
30+
where:
31+
32+
.. math::
33+
\text{DCG}@K = \sum_{i=1}^{K} \frac{2^{\text{rel}_i} - 1}{\log_2(i + 1)}
34+
35+
and :math:`\text{rel}_i` is the relevance score of the item at position :math:`i` in the
36+
ranked list (1-indexed).
37+
38+
- ``update`` must receive output of the form ``(y_pred, y)``.
39+
- ``y_pred`` is expected to be raw logits or probability score for each item in the catalog.
40+
- ``y`` is expected to contain relevance scores (can be binary or graded).
41+
42+
Relevance Types:
43+
- **Binary relevance**: Labels are either 0 (not relevant) or 1 (relevant)
44+
- **Graded relevance**: Labels can have multiple levels (e.g., 0-4 scale)
45+
46+
Common graded scales:
47+
- 0: Not relevant
48+
- 1: Marginally relevant
49+
- 2: Relevant
50+
- 3: Highly relevant
51+
- 4: Perfectly relevant
52+
53+
The NDCG formula handles both types through the gain function: 2^relevance - 1.
54+
Higher relevance scores contribute more to the metric.
55+
56+
- ``y_pred`` and ``y`` are only allowed shape :math:`(batch, num\_items)`.
57+
- returns a list of NDCG ordered by the sorted values of ``top_k``.
58+
59+
Args:
60+
top_k: a list of sorted positive integers that specifies `k` for calculating NDCG@top-k.
61+
ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros)
62+
are ignored in computation of NDCG. If set False, such users are counted with NDCG of 0.
63+
By default, True.
64+
relevance_threshold: minimum label value to be considered relevant. Defaults to ``1``,
65+
which handles standard binary labels and graded relevance scales (e.g. TREC-style
66+
0-4) by treating any label >= 1 as relevant. Items below this threshold contribute
67+
0 to DCG/IDCG calculations.
68+
gain_function (str): Gain function for relevance scores. Options:
69+
- ``'exp_rank'``: 2^relevance - 1 (emphasizes high relevance, default)
70+
- ``'linear_rank'``: relevance (simpler, linear scale)
71+
Defaults to ``'exp_rank'``.
72+
output_transform: a callable that is used to transform the
73+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
74+
form expected by the metric.
75+
The output is expected to be a tuple `(prediction, target)`
76+
where `prediction` and `target` are tensors
77+
of shape ``(batch, num_items)``.
78+
device: specifies which device updates are accumulated on. Setting the
79+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
80+
non-blocking. By default, CPU.
81+
skip_unrolling: specifies whether input should be unrolled or not before being
82+
processed. Should be true for multi-output models..
83+
84+
Examples:
85+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
86+
The output of the engine's ``process_function`` needs to be in the format of
87+
``(y_pred, y)``. If not, ``output_tranform`` can be added
88+
to the metric to transform the output into the form expected by the metric.
89+
90+
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
91+
92+
.. include:: defaults.rst
93+
:start-after: :orphan:
94+
95+
ignore_zero_hits=True case
96+
97+
.. testcode:: 1
98+
99+
metric = NDCG(top_k=[1, 2, 3, 4])
100+
metric.attach(default_evaluator, "ndcg")
101+
y_pred=torch.Tensor([
102+
[4.0, 2.0, 3.0, 1.0],
103+
[1.0, 2.0, 3.0, 4.0]
104+
])
105+
y_true=torch.Tensor([
106+
[0.0, 0.0, 1.0, 1.0],
107+
[0.0, 0.0, 0.0, 0.0]
108+
])
109+
state = default_evaluator.run([(y_pred, y_true)])
110+
print(state.metrics["ndcg"])
111+
112+
.. testoutput:: 1
113+
114+
[0.0, 0.38..., 0.38..., 0.65...]
115+
116+
ignore_zero_hits=False case
117+
118+
.. testcode:: 2
119+
120+
metric = NDCG(top_k=[1, 2, 3, 4], ignore_zero_hits=False)
121+
metric.attach(default_evaluator, "ndcg")
122+
y_pred=torch.Tensor([
123+
[4.0, 2.0, 3.0, 1.0],
124+
[1.0, 2.0, 3.0, 4.0]
125+
])
126+
y_true=torch.Tensor([
127+
[0.0, 0.0, 1.0, 1.0],
128+
[0.0, 0.0, 0.0, 0.0]
129+
])
130+
state = default_evaluator.run([(y_pred, y_true)])
131+
print(state.metrics["ndcg"])
132+
133+
.. testoutput:: 2
134+
135+
[0.0, 0.19..., 0.19..., 0.32...]
136+
137+
.. versionadded:: 0.6.0
138+
"""
139+
140+
required_output_keys = ("y_pred", "y")
141+
_state_dict_all_req_keys = ("_sum_ndcg_per_k", "_num_examples")
142+
143+
def __init__(
144+
self,
145+
top_k: list[int],
146+
ignore_zero_hits: bool = True,
147+
relevance_threshold: float = 1.0,
148+
gain_function: str = "exp_rank",
149+
output_transform: Callable = lambda x: x,
150+
device: str | torch.device = torch.device("cpu"),
151+
skip_unrolling: bool = False,
152+
):
153+
if any(k <= 0 for k in top_k):
154+
raise ValueError(f"top_k must be list of positive integers only, but given {top_k}")
155+
156+
if gain_function not in ["exp_rank", "linear_rank"]:
157+
raise ValueError(f"gain_function must be either 'exp_rank' or 'linear_rank', but given {gain_function}")
158+
159+
self.top_k = sorted(top_k)
160+
self.ignore_zero_hits = ignore_zero_hits
161+
self.relevance_threshold = relevance_threshold
162+
self.gain_function = gain_function
163+
super(NDCG, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
164+
165+
@reinit__is_reduced
166+
def reset(self) -> None:
167+
self._sum_ndcg_per_k = torch.zeros(len(self.top_k), device=self._device)
168+
self._num_examples = 0
169+
170+
def _compute_dcg(self, relevance_scores: torch.Tensor, k: int) -> torch.Tensor:
171+
"""Compute DCG@k for a batch of relevance scores."""
172+
actual_k = min(k, relevance_scores.shape[1])
173+
174+
positions = torch.arange(1, actual_k + 1, dtype=torch.float32, device=relevance_scores.device)
175+
discounts = 1.0 / torch.log2(positions + 1)
176+
177+
topk_relevance = relevance_scores[:, :actual_k]
178+
if self.gain_function == "exp_rank":
179+
gains = torch.pow(2.0, topk_relevance) - 1
180+
else:
181+
gains = topk_relevance
182+
183+
dcg = (gains * discounts).sum(dim=-1)
184+
return dcg
185+
186+
@reinit__is_reduced
187+
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
188+
if len(output) != 2:
189+
raise ValueError(f"output should be in format `(y_pred,y)` but got tuple of {len(output)} tensors.")
190+
191+
y_pred, y = output
192+
if y_pred.shape != y.shape:
193+
raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.")
194+
195+
if self.ignore_zero_hits:
196+
valid_mask = torch.any(y >= self.relevance_threshold, dim=-1)
197+
y_pred = y_pred[valid_mask]
198+
y = y[valid_mask]
199+
200+
if y.shape[0] == 0:
201+
return
202+
203+
y_for_dcg = torch.where(y >= self.relevance_threshold, y, 0)
204+
205+
max_k = self.top_k[-1]
206+
ranked_relevance = _get_ranked_items(y_pred, y_for_dcg, max_k)
207+
# Compute ideal ranking by sorting true relevance scores (descending).
208+
# This aligns with standard IDCG computation in reference libraries:
209+
# ranx: https://github.com/AmenRa/ranx/blob/master/ranx/metrics/ndcg.py#L52
210+
# catalyst: https://github.com/catalyst-team/catalyst/blob/master/catalyst/metrics/functional/_ndcg.py#L197
211+
212+
ideal_relevance = torch.sort(y_for_dcg, dim=-1, descending=True, stable=True)[0][:, :max_k]
213+
214+
for i, k in enumerate(self.top_k):
215+
dcg_k = self._compute_dcg(ranked_relevance, k)
216+
idcg_k = self._compute_dcg(ideal_relevance, k)
217+
218+
ndcg_k = torch.where(
219+
idcg_k > 0,
220+
dcg_k / idcg_k,
221+
torch.zeros_like(dcg_k),
222+
)
223+
224+
self._sum_ndcg_per_k[i] += ndcg_k.sum().to(self._device)
225+
226+
self._num_examples += y.shape[0]
227+
228+
@sync_all_reduce("_sum_ndcg_per_k", "_num_examples")
229+
def compute(self) -> list[float]:
230+
if self._num_examples == 0:
231+
raise NotComputableError("NDCG must have at least one example.")
232+
233+
rates = (self._sum_ndcg_per_k / self._num_examples).tolist()
234+
return rates

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ pandas
3535
gymnasium
3636
# temporary fix: E AttributeError: module 'mpmath' has no attribute 'rational'
3737
mpmath<1.4
38+
ranx

0 commit comments

Comments
 (0)