diff --git a/src/torchmetrics/functional/retrieval/ndcg.py b/src/torchmetrics/functional/retrieval/ndcg.py index d381718c793..a8bf8241887 100644 --- a/src/torchmetrics/functional/retrieval/ndcg.py +++ b/src/torchmetrics/functional/retrieval/ndcg.py @@ -19,53 +19,100 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs -def _tie_average_dcg(target: Tensor, preds: Tensor, discount_cumsum: Tensor) -> Tensor: - """Translated version of sklearns `_tie_average_dcg` function. +def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor: + """Compute DCG for tied predictions using scatter operations. + + Replaces the ``torch.unique`` approach with ``diff`` + ``scatter_add_``, which is + significantly faster on GPU (``torch.unique`` is ~15x slower on GPU than CPU). + Float64 is used for accumulation to preserve numerical accuracy. Args: - target: ground truth about each document relevance. - preds: estimated probabilities of each document to be relevant. - discount_cumsum: cumulative sum of the discount. + target: ground truth relevances in **predicted** rank order, shape ``(n_queries, n_docs)``. + preds: predicted scores in **predicted** rank order, shape ``(n_queries, n_docs)``. + discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(n_docs,)``. Returns: - The cumulative gain of the tied elements. + DCG values, shape ``(n_queries,)``, dtype float32. """ - _, inv, counts = torch.unique(-preds, return_inverse=True, return_counts=True) - ranked = torch.zeros_like(counts, dtype=torch.float32) - ranked.scatter_add_(0, inv, target.to(dtype=ranked.dtype)) - ranked = ranked / counts - groups = counts.cumsum(dim=0) - 1 - discount_sums = torch.zeros_like(counts, dtype=torch.float32) - discount_sums[0] = discount_cumsum[groups[0]] - discount_sums[1:] = discount_cumsum[groups].diff() - return (ranked * discount_sums).sum() + n_queries, n_docs = target.shape + device = target.device + + # Detect tie-group boundaries: True at the first element of each new group + new_grp = torch.cat( + [ + torch.ones(n_queries, 1, dtype=torch.bool, device=device), + preds.diff(dim=-1).abs() > 0, + ], + dim=-1, + ) # (n_queries, n_docs) + + # Per-element group id, unique across the batch + gid = new_grp.long().cumsum(-1) - 1 # 0-based within each row + gid = gid + torch.arange(n_queries, device=device).unsqueeze(-1) * n_docs + + # Scatter: accumulate gains, discounts, and counts per group + flat_id = gid.flatten() + flat_gain = target.flatten().float() + flat_disc = discount.unsqueeze(0).expand(n_queries, -1).flatten().float() + + grp_gain = torch.zeros(n_queries * n_docs, dtype=torch.float32, device=device) + grp_disc = torch.zeros(n_queries * n_docs, dtype=torch.float32, device=device) + grp_cnt = torch.zeros(n_queries * n_docs, dtype=torch.int32, device=device) + + grp_gain.scatter_add_(0, flat_id, flat_gain) + grp_disc.scatter_add_(0, flat_id, flat_disc) + grp_cnt.scatter_add_(0, flat_id, torch.ones_like(flat_id, dtype=torch.int32)) + + # Float64 accumulation for numerical parity with sklearn / reference implementations + contrib = grp_gain.double() * (grp_disc.double() / grp_cnt.clamp(min=1).double()) + + # Scatter only non-empty groups back to the batch dimension + valid = grp_cnt > 0 + batch_idx = flat_id[valid] // n_docs + dcg = torch.zeros(n_queries, dtype=torch.float64, device=device) + dcg.scatter_add_(0, batch_idx, contrib[valid]) + return dcg.float() def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: bool) -> Tensor: - """Translated version of sklearns `_dcg_sample_scores` function. + """Compute DCG sample scores. Args: - target: ground truth about each document relevance. - preds: estimated probabilities of each document to be relevant. - top_k: consider only the top k elements - ignore_ties: If True, ties are ignored. If False, ties are averaged. + target: ground truth relevances, shape ``(n_docs,)`` or ``(n_queries, n_docs)``. + preds: predicted scores, shape ``(n_docs,)`` or ``(n_queries, n_docs)``. + top_k: consider only the top k elements. + ignore_ties: If ``True``, ties are broken by order. If ``False``, ties are averaged. Returns: - The cumulative gain + DCG value(s): scalar for 1-D input, shape ``(n_queries,)`` for batched input. """ - discount = 1.0 / (torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0)) - discount[top_k:] = 0.0 + batched = preds.dim() > 1 + if not batched: + preds = preds.unsqueeze(0) + target = target.unsqueeze(0) + + n_docs = preds.shape[-1] + + # Use topk when k < n_docs to avoid sorting the full list + if top_k < n_docs: + order = preds.topk(top_k, dim=-1, sorted=True).indices + n_docs_eff = top_k + else: + order = preds.argsort(dim=-1, descending=True, stable=True) + n_docs_eff = n_docs + + discount = 1.0 / torch.log2(torch.arange(n_docs_eff, device=preds.device) + 2.0) + p_sorted = preds.gather(-1, order) + g_sorted = target.float().gather(-1, order) if ignore_ties: - ranking = preds.argsort(descending=True) - ranked = target[ranking] - cumulative_gain = (discount * ranked).sum() + dcg = (discount * g_sorted).sum(-1, dtype=torch.float64).float() else: - discount_cumsum = discount.cumsum(dim=-1) - cumulative_gain = _tie_average_dcg(target, preds, discount_cumsum) - return cumulative_gain + dcg = _tie_average_dcg(g_sorted, p_sorted, discount) + + return dcg if batched else dcg.squeeze(0) def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor: diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index 5c8fe20a2f7..e4cf6e743f3 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -212,3 +212,103 @@ def test_corner_case_with_tied_scores(): retrieval_normalized_dcg(preds, target, top_k=k), torch.tensor([ndcg_score(target, preds, k=k)], dtype=torch.float32), ) + + +# ---- Tests for vectorized GPU-efficient implementation (issue #2287) ---- + + +@pytest.mark.parametrize( + ("batch_size", "list_length", "top_k"), + [ + (1, 50, None), + (1, 100, 10), + (8, 50, None), + (8, 100, 50), + (32, 100, None), + (32, 500, 200), + (128, 100, 10), + (128, 500, None), + ], +) +def test_accuracy_vs_sklearn(batch_size: int, list_length: int, top_k: Optional[int]): + """Batched nDCG must stay within 1e-4 of sklearn across configs. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + + """ + torch.manual_seed(42) + scores = torch.randn(batch_size, list_length) + labels = (torch.randint(0, 2, (batch_size, list_length)) * 2 - 1).float() + 1.0 + + fast_result = retrieval_normalized_dcg(scores, labels, top_k=top_k).item() + sklearn_result = float(np.mean([ndcg_score([t], [p], k=top_k) for t, p in zip(labels.numpy(), scores.numpy())])) + + assert abs(fast_result - sklearn_result) <= 1e-4, ( + f"nDCG differs from sklearn by {abs(fast_result - sklearn_result):.2e} " + f"(B={batch_size}, L={list_length}, k={top_k})" + ) + + +def test_batched_input_matches_per_query(): + """Batched 2-D input must give the same mean nDCG as averaging per-query 1-D results. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + + """ + torch.manual_seed(42) + preds = torch.randn(16, 50) + target = (torch.randint(0, 2, (16, 50)) * 2 - 1).float() + 1.0 + + per_query = torch.stack([retrieval_normalized_dcg(preds[i], target[i]) for i in range(preds.shape[0])]) + batched = retrieval_normalized_dcg(preds, target) + + assert torch.allclose(batched, per_query.mean(), atol=1e-5) + + +def test_tie_handling_explicit(): + """Tie-averaged DCG must match sklearn on inputs with explicit score ties. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + + """ + scores = torch.tensor([ + [1.0, 1.0, 0.5, 0.5, 0.1], # two pairs of ties + [0.8, 0.8, 0.8, 0.2, 0.1], # three-way tie + ]) + labels = torch.tensor([ + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + ]) + + result = retrieval_normalized_dcg(scores, labels) + sklearn_result = float(np.mean([ndcg_score([t], [p]) for t, p in zip(labels.numpy(), scores.numpy())])) + + assert isinstance(result, torch.Tensor) + assert 0.0 <= result.item() <= 1.0 + assert abs(result.item() - sklearn_result) <= 1e-4 + + +def test_all_zeros_target(): + """All-irrelevant queries (target all zero) must return 0, not NaN.""" + scores = torch.randn(4, 20) + labels = torch.zeros(4, 20) + result = retrieval_normalized_dcg(scores, labels) + assert result.item() == 0.0 + + +def test_perfect_ranking(): + """A perfectly-ranked list must return nDCG == 1.0.""" + labels = torch.tensor([[3.0, 2.0, 1.0, 0.0, 0.0]] * 4) + scores = labels.clone() # predictions match ideal order + result = retrieval_normalized_dcg(scores, labels) + assert torch.allclose(result, torch.tensor(1.0), atol=1e-5) + + +@pytest.mark.parametrize("top_k", [1, 10, 50, None]) +def test_top_k_valid_range(top_k: Optional[int]): + """Results must be in [0, 1] for all top_k values.""" + torch.manual_seed(0) + scores = torch.randn(8, 100) + labels = torch.randint(0, 3, (8, 100)).float() + result = retrieval_normalized_dcg(scores, labels, top_k=top_k) + assert 0.0 <= result.item() <= 1.0