From c5be2f27171c38ca9fa0ffe5b3f08e838ee0f443 Mon Sep 17 00:00:00 2001 From: rclough Date: Tue, 31 Mar 2026 17:48:24 +0000 Subject: [PATCH 1/4] perf(retrieval): fix NDCG GPU performance by replacing torch.unique in tie averaging torch.unique is ~15x slower on GPU than CPU, causing nDCG to run up to 2.65x slower on GPU than CPU. Replace the torch.unique-based tie-averaging approach in _tie_average_dcg with a diff + scatter_add_ strategy that is efficient on both CPU and GPU. The refactored _dcg_sample_scores also uses gather so that both 1-D (single query) and 2-D (batched queries) inputs are handled correctly, making retrieval_normalized_dcg usable with batched inputs directly. Fixes: https://github.com/Lightning-AI/torchmetrics/issues/2287 Co-Authored-By: Claude Sonnet 4.6 --- src/torchmetrics/functional/retrieval/ndcg.py | 93 +++++++++++++------ tests/unittests/retrieval/test_ndcg.py | 39 ++++++++ 2 files changed, 103 insertions(+), 29 deletions(-) diff --git a/src/torchmetrics/functional/retrieval/ndcg.py b/src/torchmetrics/functional/retrieval/ndcg.py index d381718c793..5d90e889498 100644 --- a/src/torchmetrics/functional/retrieval/ndcg.py +++ b/src/torchmetrics/functional/retrieval/ndcg.py @@ -19,53 +19,88 @@ from torchmetrics.utilities.checks import _check_retrieval_functional_inputs -def _tie_average_dcg(target: Tensor, preds: Tensor, discount_cumsum: Tensor) -> Tensor: - """Translated version of sklearns `_tie_average_dcg` function. +def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor: + """Compute DCG for tied predictions using scatter operations. + + Replaces the ``torch.unique`` approach with ``diff`` + ``scatter_add_``, which is + significantly faster on GPU (``torch.unique`` is ~15x slower on GPU than CPU). Args: - target: ground truth about each document relevance. - preds: estimated probabilities of each document to be relevant. - discount_cumsum: cumulative sum of the discount. + target: ground truth relevances, shape ``(L,)`` or ``(B, L)``. + preds: predicted scores, shape ``(L,)`` or ``(B, L)``. + discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(L,)``. Returns: - The cumulative gain of the tied elements. + DCG value(s): scalar for 1-D input, shape ``(B,)`` for batched input. """ - _, inv, counts = torch.unique(-preds, return_inverse=True, return_counts=True) - ranked = torch.zeros_like(counts, dtype=torch.float32) - ranked.scatter_add_(0, inv, target.to(dtype=ranked.dtype)) - ranked = ranked / counts - groups = counts.cumsum(dim=0) - 1 - discount_sums = torch.zeros_like(counts, dtype=torch.float32) - discount_sums[0] = discount_cumsum[groups[0]] - discount_sums[1:] = discount_cumsum[groups].diff() - return (ranked * discount_sums).sum() + batched = preds.dim() > 1 + B = preds.shape[0] if batched else 1 + L = preds.shape[-1] + + if not batched: + preds = preds.unsqueeze(0) + target = target.unsqueeze(0) + + # Sort each row by descending predicted score + order = preds.argsort(dim=-1, descending=True, stable=True) + p_sorted = preds.gather(-1, order) + g_sorted = target.float().gather(-1, order) + + # Detect tie-group boundaries: True at the first element of each new group + new_grp = torch.cat( + [ + torch.ones(B, 1, dtype=torch.bool, device=preds.device), + p_sorted.diff(dim=-1) != 0, + ], + dim=-1, + ) # (B, L) + + # Per-element group id, made unique across the batch + gid = new_grp.long().cumsum(-1) - 1 # 0-based within each row + gid = gid + torch.arange(B, device=preds.device).unsqueeze(-1) * L + + # Scatter: accumulate gains, discounts, and counts per group + flat_gid = gid.flatten() + flat_gain = g_sorted.flatten().float() + flat_disc = discount.unsqueeze(0).expand(B, -1).flatten().float() + + grp_gain = torch.zeros(B * L, dtype=torch.float32, device=preds.device) + grp_disc = torch.zeros(B * L, dtype=torch.float32, device=preds.device) + grp_cnt = torch.zeros(B * L, dtype=torch.long, device=preds.device) + + grp_gain.scatter_add_(0, flat_gid, flat_gain) + grp_disc.scatter_add_(0, flat_gid, flat_disc) + grp_cnt.scatter_add_(0, flat_gid, torch.ones_like(flat_gid)) + + contrib = grp_gain * grp_disc / grp_cnt.float().clamp(min=1) + dcg = contrib.view(B, L).sum(-1) # (B,) + return dcg if batched else dcg.squeeze(0) def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: bool) -> Tensor: - """Translated version of sklearns `_dcg_sample_scores` function. + """Compute DCG sample scores. Args: - target: ground truth about each document relevance. - preds: estimated probabilities of each document to be relevant. - top_k: consider only the top k elements - ignore_ties: If True, ties are ignored. If False, ties are averaged. + target: ground truth relevances, shape ``(L,)`` or ``(B, L)``. + preds: predicted scores, shape ``(L,)`` or ``(B, L)``. + top_k: consider only the top k elements. + ignore_ties: If ``True``, ties are broken by order. If ``False``, ties are averaged. Returns: - The cumulative gain + DCG value(s): scalar for 1-D input, shape ``(B,)`` for batched input. """ - discount = 1.0 / (torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0)) + L = target.shape[-1] + discount = 1.0 / torch.log2(torch.arange(L, device=target.device) + 2.0) discount[top_k:] = 0.0 if ignore_ties: - ranking = preds.argsort(descending=True) - ranked = target[ranking] - cumulative_gain = (discount * ranked).sum() - else: - discount_cumsum = discount.cumsum(dim=-1) - cumulative_gain = _tie_average_dcg(target, preds, discount_cumsum) - return cumulative_gain + ranking = preds.argsort(dim=-1, descending=True) + ranked = target.float().gather(-1, ranking) + return (discount * ranked).sum(-1) + + return _tie_average_dcg(target, preds, discount) def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor: diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index 5c8fe20a2f7..32f02787811 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -212,3 +212,42 @@ def test_corner_case_with_tied_scores(): retrieval_normalized_dcg(preds, target, top_k=k), torch.tensor([ndcg_score(target, preds, k=k)], dtype=torch.float32), ) + + +def test_batched_input_matches_per_query(): + """Batched 2-D input must give the same mean nDCG as averaging per-query results. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + """ + preds = torch.tensor([ + [0.1, 0.2, 0.3, 4.0, 70.0], + [0.5, 0.5, 0.1, 0.9, 0.2], + [1.0, 0.0, 0.5, 0.5, 0.3], + ]) + target = torch.tensor([ + [10, 0, 0, 1, 5], + [0, 1, 2, 3, 4], + [5, 0, 1, 2, 3], + ]) + + # Per-query average (existing 1-D API) + per_query = torch.stack([ + retrieval_normalized_dcg(preds[i], target[i]) for i in range(preds.shape[0]) + ]) + expected_mean = per_query.mean() + + # Batched 2-D call + batched_result = retrieval_normalized_dcg(preds, target) + + assert torch.allclose(batched_result, expected_mean, atol=1e-5), ( + f"Batched result {batched_result} differs from per-query mean {expected_mean}" + ) + + # Also verify against sklearn for each query + for i in range(preds.shape[0]): + p = preds[i].unsqueeze(0).numpy() + t = target[i].unsqueeze(0).numpy() + sklearn_val = torch.tensor(ndcg_score(t, p), dtype=torch.float32) + assert torch.allclose(per_query[i], sklearn_val, atol=1e-5), ( + f"Query {i}: got {per_query[i]}, expected {sklearn_val}" + ) From 1ec556cccc4ef149b13f5fdc72894a4bdfab61de Mon Sep 17 00:00:00 2001 From: rclough Date: Tue, 31 Mar 2026 17:57:29 +0000 Subject: [PATCH 2/4] perf(retrieval): fix NDCG GPU performance by replacing torch.unique in tie averaging torch.unique is ~15x slower on GPU than CPU, causing nDCG to run up to 2.65x slower on GPU than CPU. Replace with a diff + scatter_add_ strategy that is efficient on both CPU and GPU. Key changes to the algorithm (based on the optimized implementation proposed in #2287): - _tie_average_dcg: takes pre-sorted inputs, uses diff + scatter_add_ instead of torch.unique; float64 accumulation for numerical parity with sklearn; int32 group counts; valid-group masking before scatter - _dcg_sample_scores: handles sorting (with topk fast-path when k < L), gather, and discount creation; delegates tie averaging to the above - retrieval_normalized_dcg: unchanged public API; now correctly handles both 1-D (single query) and 2-D (batched) inputs Tests added: - test_accuracy_vs_sklearn: parametrized across 8 (batch, length, top_k) configs, tolerance 1e-4 matching reference implementation parity - test_batched_input_matches_per_query: 2-D result == mean of 1-D calls - test_tie_handling_explicit: explicit tie configurations vs sklearn - test_all_zeros_target: all-irrelevant queries return 0.0, not NaN - test_perfect_ranking: ideal predictions return nDCG == 1.0 - test_top_k_valid_range: results in [0, 1] for all top_k values Fixes: https://github.com/Lightning-AI/torchmetrics/issues/2287 Co-Authored-By: Claude Sonnet 4.6 --- src/torchmetrics/functional/retrieval/ndcg.py | 86 +++++++------ tests/unittests/retrieval/test_ndcg.py | 116 ++++++++++++++---- 2 files changed, 138 insertions(+), 64 deletions(-) diff --git a/src/torchmetrics/functional/retrieval/ndcg.py b/src/torchmetrics/functional/retrieval/ndcg.py index 5d90e889498..7baee989096 100644 --- a/src/torchmetrics/functional/retrieval/ndcg.py +++ b/src/torchmetrics/functional/retrieval/ndcg.py @@ -24,58 +24,55 @@ def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor: Replaces the ``torch.unique`` approach with ``diff`` + ``scatter_add_``, which is significantly faster on GPU (``torch.unique`` is ~15x slower on GPU than CPU). + Float64 is used for accumulation to preserve numerical accuracy. Args: - target: ground truth relevances, shape ``(L,)`` or ``(B, L)``. - preds: predicted scores, shape ``(L,)`` or ``(B, L)``. + target: ground truth relevances in **predicted** rank order, shape ``(B, L)``. + preds: predicted scores in **predicted** rank order, shape ``(B, L)``. discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(L,)``. Returns: - DCG value(s): scalar for 1-D input, shape ``(B,)`` for batched input. + DCG values, shape ``(B,)``, dtype float32. """ - batched = preds.dim() > 1 - B = preds.shape[0] if batched else 1 - L = preds.shape[-1] - - if not batched: - preds = preds.unsqueeze(0) - target = target.unsqueeze(0) - - # Sort each row by descending predicted score - order = preds.argsort(dim=-1, descending=True, stable=True) - p_sorted = preds.gather(-1, order) - g_sorted = target.float().gather(-1, order) + B, L = target.shape + device = target.device # Detect tie-group boundaries: True at the first element of each new group new_grp = torch.cat( [ - torch.ones(B, 1, dtype=torch.bool, device=preds.device), - p_sorted.diff(dim=-1) != 0, + torch.ones(B, 1, dtype=torch.bool, device=device), + preds.diff(dim=-1).abs() > 0, ], dim=-1, ) # (B, L) - # Per-element group id, made unique across the batch + # Per-element group id, unique across the batch gid = new_grp.long().cumsum(-1) - 1 # 0-based within each row - gid = gid + torch.arange(B, device=preds.device).unsqueeze(-1) * L + gid = gid + torch.arange(B, device=device).unsqueeze(-1) * L # Scatter: accumulate gains, discounts, and counts per group - flat_gid = gid.flatten() - flat_gain = g_sorted.flatten().float() + flat_id = gid.flatten() + flat_gain = target.flatten().float() flat_disc = discount.unsqueeze(0).expand(B, -1).flatten().float() - grp_gain = torch.zeros(B * L, dtype=torch.float32, device=preds.device) - grp_disc = torch.zeros(B * L, dtype=torch.float32, device=preds.device) - grp_cnt = torch.zeros(B * L, dtype=torch.long, device=preds.device) + grp_gain = torch.zeros(B * L, dtype=torch.float32, device=device) + grp_disc = torch.zeros(B * L, dtype=torch.float32, device=device) + grp_cnt = torch.zeros(B * L, dtype=torch.int32, device=device) - grp_gain.scatter_add_(0, flat_gid, flat_gain) - grp_disc.scatter_add_(0, flat_gid, flat_disc) - grp_cnt.scatter_add_(0, flat_gid, torch.ones_like(flat_gid)) + grp_gain.scatter_add_(0, flat_id, flat_gain) + grp_disc.scatter_add_(0, flat_id, flat_disc) + grp_cnt.scatter_add_(0, flat_id, torch.ones_like(flat_id, dtype=torch.int32)) - contrib = grp_gain * grp_disc / grp_cnt.float().clamp(min=1) - dcg = contrib.view(B, L).sum(-1) # (B,) - return dcg if batched else dcg.squeeze(0) + # Float64 accumulation for numerical parity with sklearn / reference implementations + contrib = grp_gain.double() * (grp_disc.double() / grp_cnt.clamp(min=1).double()) + + # Scatter only non-empty groups back to the batch dimension + valid = grp_cnt > 0 + batch_idx = flat_id[valid] // L + dcg = torch.zeros(B, dtype=torch.float64, device=device) + dcg.scatter_add_(0, batch_idx, contrib[valid]) + return dcg.float() def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: bool) -> Tensor: @@ -91,16 +88,31 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b DCG value(s): scalar for 1-D input, shape ``(B,)`` for batched input. """ - L = target.shape[-1] - discount = 1.0 / torch.log2(torch.arange(L, device=target.device) + 2.0) - discount[top_k:] = 0.0 + batched = preds.dim() > 1 + if not batched: + preds = preds.unsqueeze(0) + target = target.unsqueeze(0) + + L = preds.shape[-1] + + # Use topk when k < L to avoid sorting the full list + if top_k < L: + order = preds.topk(top_k, dim=-1, sorted=True).indices + L_eff = top_k + else: + order = preds.argsort(dim=-1, descending=True, stable=True) + L_eff = L + + discount = 1.0 / torch.log2(torch.arange(L_eff, device=preds.device) + 2.0) + p_sorted = preds.gather(-1, order) + g_sorted = target.float().gather(-1, order) if ignore_ties: - ranking = preds.argsort(dim=-1, descending=True) - ranked = target.float().gather(-1, ranking) - return (discount * ranked).sum(-1) + dcg = (discount * g_sorted).sum(-1, dtype=torch.float64).float() + else: + dcg = _tie_average_dcg(g_sorted, p_sorted, discount) - return _tie_average_dcg(target, preds, discount) + return dcg if batched else dcg.squeeze(0) def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor: diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index 32f02787811..dcfc159a6cb 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -214,40 +214,102 @@ def test_corner_case_with_tied_scores(): ) +# ---- Tests for vectorized GPU-efficient implementation (issue #2287) ---- + + +@pytest.mark.parametrize( + ("batch_size", "list_length", "top_k"), + [ + (1, 50, None), + (1, 100, 10), + (8, 50, None), + (8, 100, 50), + (32, 100, None), + (32, 500, 200), + (128, 100, 10), + (128, 500, None), + ], +) +def test_accuracy_vs_sklearn(batch_size: int, list_length: int, top_k: Optional[int]): + """Batched nDCG must stay within 1e-4 of sklearn across configs. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + """ + torch.manual_seed(42) + scores = torch.randn(batch_size, list_length) + labels = (torch.randint(0, 2, (batch_size, list_length)) * 2 - 1).float() + 1.0 + + fast_result = retrieval_normalized_dcg(scores, labels, top_k=top_k).item() + sklearn_result = float( + np.mean([ndcg_score([t], [p], k=top_k) for t, p in zip(labels.numpy(), scores.numpy())]) + ) + + assert abs(fast_result - sklearn_result) <= 1e-4, ( + f"nDCG differs from sklearn by {abs(fast_result - sklearn_result):.2e} " + f"(B={batch_size}, L={list_length}, k={top_k})" + ) + + def test_batched_input_matches_per_query(): - """Batched 2-D input must give the same mean nDCG as averaging per-query results. + """Batched 2-D input must give the same mean nDCG as averaging per-query 1-D results. See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. """ - preds = torch.tensor([ - [0.1, 0.2, 0.3, 4.0, 70.0], - [0.5, 0.5, 0.1, 0.9, 0.2], - [1.0, 0.0, 0.5, 0.5, 0.3], + torch.manual_seed(42) + preds = torch.randn(16, 50) + target = (torch.randint(0, 2, (16, 50)) * 2 - 1).float() + 1.0 + + per_query = torch.stack([retrieval_normalized_dcg(preds[i], target[i]) for i in range(preds.shape[0])]) + batched = retrieval_normalized_dcg(preds, target) + + assert torch.allclose(batched, per_query.mean(), atol=1e-5) + + +def test_tie_handling_explicit(): + """Tie-averaged DCG must match sklearn on inputs with explicit score ties. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + """ + scores = torch.tensor([ + [1.0, 1.0, 0.5, 0.5, 0.1], # two pairs of ties + [0.8, 0.8, 0.8, 0.2, 0.1], # three-way tie ]) - target = torch.tensor([ - [10, 0, 0, 1, 5], - [0, 1, 2, 3, 4], - [5, 0, 1, 2, 3], + labels = torch.tensor([ + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], ]) - # Per-query average (existing 1-D API) - per_query = torch.stack([ - retrieval_normalized_dcg(preds[i], target[i]) for i in range(preds.shape[0]) - ]) - expected_mean = per_query.mean() + result = retrieval_normalized_dcg(scores, labels) + sklearn_result = float( + np.mean([ndcg_score([t], [p]) for t, p in zip(labels.numpy(), scores.numpy())]) + ) - # Batched 2-D call - batched_result = retrieval_normalized_dcg(preds, target) + assert isinstance(result, torch.Tensor) + assert 0.0 <= result.item() <= 1.0 + assert abs(result.item() - sklearn_result) <= 1e-4 - assert torch.allclose(batched_result, expected_mean, atol=1e-5), ( - f"Batched result {batched_result} differs from per-query mean {expected_mean}" - ) - # Also verify against sklearn for each query - for i in range(preds.shape[0]): - p = preds[i].unsqueeze(0).numpy() - t = target[i].unsqueeze(0).numpy() - sklearn_val = torch.tensor(ndcg_score(t, p), dtype=torch.float32) - assert torch.allclose(per_query[i], sklearn_val, atol=1e-5), ( - f"Query {i}: got {per_query[i]}, expected {sklearn_val}" - ) +def test_all_zeros_target(): + """All-irrelevant queries (target all zero) must return 0, not NaN.""" + scores = torch.randn(4, 20) + labels = torch.zeros(4, 20) + result = retrieval_normalized_dcg(scores, labels) + assert result.item() == 0.0 + + +def test_perfect_ranking(): + """A perfectly-ranked list must return nDCG == 1.0.""" + labels = torch.tensor([[3.0, 2.0, 1.0, 0.0, 0.0]] * 4) + scores = labels.clone() # predictions match ideal order + result = retrieval_normalized_dcg(scores, labels) + assert torch.allclose(result, torch.tensor(1.0), atol=1e-5) + + +@pytest.mark.parametrize("top_k", [1, 10, 50, None]) +def test_top_k_valid_range(top_k: Optional[int]): + """Results must be in [0, 1] for all top_k values.""" + torch.manual_seed(0) + scores = torch.randn(8, 100) + labels = torch.randint(0, 3, (8, 100)).float() + result = retrieval_normalized_dcg(scores, labels, top_k=top_k) + assert 0.0 <= result.item() <= 1.0 From 41c14b2d82d23d086d1d378643839d1bf726bbcf Mon Sep 17 00:00:00 2001 From: rclough Date: Tue, 31 Mar 2026 18:14:16 +0000 Subject: [PATCH 3/4] style: fix ruff N806 (uppercase vars) and formatting in ndcg Co-Authored-By: Claude Sonnet 4.6 --- src/torchmetrics/functional/retrieval/ndcg.py | 46 +++++++++---------- tests/unittests/retrieval/test_ndcg.py | 8 +--- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/torchmetrics/functional/retrieval/ndcg.py b/src/torchmetrics/functional/retrieval/ndcg.py index 7baee989096..a8bf8241887 100644 --- a/src/torchmetrics/functional/retrieval/ndcg.py +++ b/src/torchmetrics/functional/retrieval/ndcg.py @@ -27,38 +27,38 @@ def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor: Float64 is used for accumulation to preserve numerical accuracy. Args: - target: ground truth relevances in **predicted** rank order, shape ``(B, L)``. - preds: predicted scores in **predicted** rank order, shape ``(B, L)``. - discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(L,)``. + target: ground truth relevances in **predicted** rank order, shape ``(n_queries, n_docs)``. + preds: predicted scores in **predicted** rank order, shape ``(n_queries, n_docs)``. + discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(n_docs,)``. Returns: - DCG values, shape ``(B,)``, dtype float32. + DCG values, shape ``(n_queries,)``, dtype float32. """ - B, L = target.shape + n_queries, n_docs = target.shape device = target.device # Detect tie-group boundaries: True at the first element of each new group new_grp = torch.cat( [ - torch.ones(B, 1, dtype=torch.bool, device=device), + torch.ones(n_queries, 1, dtype=torch.bool, device=device), preds.diff(dim=-1).abs() > 0, ], dim=-1, - ) # (B, L) + ) # (n_queries, n_docs) # Per-element group id, unique across the batch gid = new_grp.long().cumsum(-1) - 1 # 0-based within each row - gid = gid + torch.arange(B, device=device).unsqueeze(-1) * L + gid = gid + torch.arange(n_queries, device=device).unsqueeze(-1) * n_docs # Scatter: accumulate gains, discounts, and counts per group flat_id = gid.flatten() flat_gain = target.flatten().float() - flat_disc = discount.unsqueeze(0).expand(B, -1).flatten().float() + flat_disc = discount.unsqueeze(0).expand(n_queries, -1).flatten().float() - grp_gain = torch.zeros(B * L, dtype=torch.float32, device=device) - grp_disc = torch.zeros(B * L, dtype=torch.float32, device=device) - grp_cnt = torch.zeros(B * L, dtype=torch.int32, device=device) + grp_gain = torch.zeros(n_queries * n_docs, dtype=torch.float32, device=device) + grp_disc = torch.zeros(n_queries * n_docs, dtype=torch.float32, device=device) + grp_cnt = torch.zeros(n_queries * n_docs, dtype=torch.int32, device=device) grp_gain.scatter_add_(0, flat_id, flat_gain) grp_disc.scatter_add_(0, flat_id, flat_disc) @@ -69,8 +69,8 @@ def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor: # Scatter only non-empty groups back to the batch dimension valid = grp_cnt > 0 - batch_idx = flat_id[valid] // L - dcg = torch.zeros(B, dtype=torch.float64, device=device) + batch_idx = flat_id[valid] // n_docs + dcg = torch.zeros(n_queries, dtype=torch.float64, device=device) dcg.scatter_add_(0, batch_idx, contrib[valid]) return dcg.float() @@ -79,13 +79,13 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b """Compute DCG sample scores. Args: - target: ground truth relevances, shape ``(L,)`` or ``(B, L)``. - preds: predicted scores, shape ``(L,)`` or ``(B, L)``. + target: ground truth relevances, shape ``(n_docs,)`` or ``(n_queries, n_docs)``. + preds: predicted scores, shape ``(n_docs,)`` or ``(n_queries, n_docs)``. top_k: consider only the top k elements. ignore_ties: If ``True``, ties are broken by order. If ``False``, ties are averaged. Returns: - DCG value(s): scalar for 1-D input, shape ``(B,)`` for batched input. + DCG value(s): scalar for 1-D input, shape ``(n_queries,)`` for batched input. """ batched = preds.dim() > 1 @@ -93,17 +93,17 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b preds = preds.unsqueeze(0) target = target.unsqueeze(0) - L = preds.shape[-1] + n_docs = preds.shape[-1] - # Use topk when k < L to avoid sorting the full list - if top_k < L: + # Use topk when k < n_docs to avoid sorting the full list + if top_k < n_docs: order = preds.topk(top_k, dim=-1, sorted=True).indices - L_eff = top_k + n_docs_eff = top_k else: order = preds.argsort(dim=-1, descending=True, stable=True) - L_eff = L + n_docs_eff = n_docs - discount = 1.0 / torch.log2(torch.arange(L_eff, device=preds.device) + 2.0) + discount = 1.0 / torch.log2(torch.arange(n_docs_eff, device=preds.device) + 2.0) p_sorted = preds.gather(-1, order) g_sorted = target.float().gather(-1, order) diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index dcfc159a6cb..d3225aaeb8a 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -240,9 +240,7 @@ def test_accuracy_vs_sklearn(batch_size: int, list_length: int, top_k: Optional[ labels = (torch.randint(0, 2, (batch_size, list_length)) * 2 - 1).float() + 1.0 fast_result = retrieval_normalized_dcg(scores, labels, top_k=top_k).item() - sklearn_result = float( - np.mean([ndcg_score([t], [p], k=top_k) for t, p in zip(labels.numpy(), scores.numpy())]) - ) + sklearn_result = float(np.mean([ndcg_score([t], [p], k=top_k) for t, p in zip(labels.numpy(), scores.numpy())])) assert abs(fast_result - sklearn_result) <= 1e-4, ( f"nDCG differs from sklearn by {abs(fast_result - sklearn_result):.2e} " @@ -280,9 +278,7 @@ def test_tie_handling_explicit(): ]) result = retrieval_normalized_dcg(scores, labels) - sklearn_result = float( - np.mean([ndcg_score([t], [p]) for t, p in zip(labels.numpy(), scores.numpy())]) - ) + sklearn_result = float(np.mean([ndcg_score([t], [p]) for t, p in zip(labels.numpy(), scores.numpy())])) assert isinstance(result, torch.Tensor) assert 0.0 <= result.item() <= 1.0 From 0cc4c4212581bf11ef271498ad068adb4b191303 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Mar 2026 18:14:58 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/retrieval/test_ndcg.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index d3225aaeb8a..e4cf6e743f3 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -234,6 +234,7 @@ def test_accuracy_vs_sklearn(batch_size: int, list_length: int, top_k: Optional[ """Batched nDCG must stay within 1e-4 of sklearn across configs. See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + """ torch.manual_seed(42) scores = torch.randn(batch_size, list_length) @@ -252,6 +253,7 @@ def test_batched_input_matches_per_query(): """Batched 2-D input must give the same mean nDCG as averaging per-query 1-D results. See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + """ torch.manual_seed(42) preds = torch.randn(16, 50) @@ -267,6 +269,7 @@ def test_tie_handling_explicit(): """Tie-averaged DCG must match sklearn on inputs with explicit score ties. See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287. + """ scores = torch.tensor([ [1.0, 1.0, 0.5, 0.5, 0.1], # two pairs of ties