Skip to content

Commit 1ec556c

Browse files
rcloughclaude
andcommitted
perf(retrieval): fix NDCG GPU performance by replacing torch.unique in tie averaging
torch.unique is ~15x slower on GPU than CPU, causing nDCG to run up to 2.65x slower on GPU than CPU. Replace with a diff + scatter_add_ strategy that is efficient on both CPU and GPU. Key changes to the algorithm (based on the optimized implementation proposed in #2287): - _tie_average_dcg: takes pre-sorted inputs, uses diff + scatter_add_ instead of torch.unique; float64 accumulation for numerical parity with sklearn; int32 group counts; valid-group masking before scatter - _dcg_sample_scores: handles sorting (with topk fast-path when k < L), gather, and discount creation; delegates tie averaging to the above - retrieval_normalized_dcg: unchanged public API; now correctly handles both 1-D (single query) and 2-D (batched) inputs Tests added: - test_accuracy_vs_sklearn: parametrized across 8 (batch, length, top_k) configs, tolerance 1e-4 matching reference implementation parity - test_batched_input_matches_per_query: 2-D result == mean of 1-D calls - test_tie_handling_explicit: explicit tie configurations vs sklearn - test_all_zeros_target: all-irrelevant queries return 0.0, not NaN - test_perfect_ranking: ideal predictions return nDCG == 1.0 - test_top_k_valid_range: results in [0, 1] for all top_k values Fixes: #2287 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent c5be2f2 commit 1ec556c

2 files changed

Lines changed: 138 additions & 64 deletions

File tree

src/torchmetrics/functional/retrieval/ndcg.py

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,58 +24,55 @@ def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor:
2424
2525
Replaces the ``torch.unique`` approach with ``diff`` + ``scatter_add_``, which is
2626
significantly faster on GPU (``torch.unique`` is ~15x slower on GPU than CPU).
27+
Float64 is used for accumulation to preserve numerical accuracy.
2728
2829
Args:
29-
target: ground truth relevances, shape ``(L,)`` or ``(B, L)``.
30-
preds: predicted scores, shape ``(L,)`` or ``(B, L)``.
30+
target: ground truth relevances in **predicted** rank order, shape ``(B, L)``.
31+
preds: predicted scores in **predicted** rank order, shape ``(B, L)``.
3132
discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(L,)``.
3233
3334
Returns:
34-
DCG value(s): scalar for 1-D input, shape ``(B,)`` for batched input.
35+
DCG values, shape ``(B,)``, dtype float32.
3536
3637
"""
37-
batched = preds.dim() > 1
38-
B = preds.shape[0] if batched else 1
39-
L = preds.shape[-1]
40-
41-
if not batched:
42-
preds = preds.unsqueeze(0)
43-
target = target.unsqueeze(0)
44-
45-
# Sort each row by descending predicted score
46-
order = preds.argsort(dim=-1, descending=True, stable=True)
47-
p_sorted = preds.gather(-1, order)
48-
g_sorted = target.float().gather(-1, order)
38+
B, L = target.shape
39+
device = target.device
4940

5041
# Detect tie-group boundaries: True at the first element of each new group
5142
new_grp = torch.cat(
5243
[
53-
torch.ones(B, 1, dtype=torch.bool, device=preds.device),
54-
p_sorted.diff(dim=-1) != 0,
44+
torch.ones(B, 1, dtype=torch.bool, device=device),
45+
preds.diff(dim=-1).abs() > 0,
5546
],
5647
dim=-1,
5748
) # (B, L)
5849

59-
# Per-element group id, made unique across the batch
50+
# Per-element group id, unique across the batch
6051
gid = new_grp.long().cumsum(-1) - 1 # 0-based within each row
61-
gid = gid + torch.arange(B, device=preds.device).unsqueeze(-1) * L
52+
gid = gid + torch.arange(B, device=device).unsqueeze(-1) * L
6253

6354
# Scatter: accumulate gains, discounts, and counts per group
64-
flat_gid = gid.flatten()
65-
flat_gain = g_sorted.flatten().float()
55+
flat_id = gid.flatten()
56+
flat_gain = target.flatten().float()
6657
flat_disc = discount.unsqueeze(0).expand(B, -1).flatten().float()
6758

68-
grp_gain = torch.zeros(B * L, dtype=torch.float32, device=preds.device)
69-
grp_disc = torch.zeros(B * L, dtype=torch.float32, device=preds.device)
70-
grp_cnt = torch.zeros(B * L, dtype=torch.long, device=preds.device)
59+
grp_gain = torch.zeros(B * L, dtype=torch.float32, device=device)
60+
grp_disc = torch.zeros(B * L, dtype=torch.float32, device=device)
61+
grp_cnt = torch.zeros(B * L, dtype=torch.int32, device=device)
7162

72-
grp_gain.scatter_add_(0, flat_gid, flat_gain)
73-
grp_disc.scatter_add_(0, flat_gid, flat_disc)
74-
grp_cnt.scatter_add_(0, flat_gid, torch.ones_like(flat_gid))
63+
grp_gain.scatter_add_(0, flat_id, flat_gain)
64+
grp_disc.scatter_add_(0, flat_id, flat_disc)
65+
grp_cnt.scatter_add_(0, flat_id, torch.ones_like(flat_id, dtype=torch.int32))
7566

76-
contrib = grp_gain * grp_disc / grp_cnt.float().clamp(min=1)
77-
dcg = contrib.view(B, L).sum(-1) # (B,)
78-
return dcg if batched else dcg.squeeze(0)
67+
# Float64 accumulation for numerical parity with sklearn / reference implementations
68+
contrib = grp_gain.double() * (grp_disc.double() / grp_cnt.clamp(min=1).double())
69+
70+
# Scatter only non-empty groups back to the batch dimension
71+
valid = grp_cnt > 0
72+
batch_idx = flat_id[valid] // L
73+
dcg = torch.zeros(B, dtype=torch.float64, device=device)
74+
dcg.scatter_add_(0, batch_idx, contrib[valid])
75+
return dcg.float()
7976

8077

8178
def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: bool) -> Tensor:
@@ -91,16 +88,31 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b
9188
DCG value(s): scalar for 1-D input, shape ``(B,)`` for batched input.
9289
9390
"""
94-
L = target.shape[-1]
95-
discount = 1.0 / torch.log2(torch.arange(L, device=target.device) + 2.0)
96-
discount[top_k:] = 0.0
91+
batched = preds.dim() > 1
92+
if not batched:
93+
preds = preds.unsqueeze(0)
94+
target = target.unsqueeze(0)
95+
96+
L = preds.shape[-1]
97+
98+
# Use topk when k < L to avoid sorting the full list
99+
if top_k < L:
100+
order = preds.topk(top_k, dim=-1, sorted=True).indices
101+
L_eff = top_k
102+
else:
103+
order = preds.argsort(dim=-1, descending=True, stable=True)
104+
L_eff = L
105+
106+
discount = 1.0 / torch.log2(torch.arange(L_eff, device=preds.device) + 2.0)
107+
p_sorted = preds.gather(-1, order)
108+
g_sorted = target.float().gather(-1, order)
97109

98110
if ignore_ties:
99-
ranking = preds.argsort(dim=-1, descending=True)
100-
ranked = target.float().gather(-1, ranking)
101-
return (discount * ranked).sum(-1)
111+
dcg = (discount * g_sorted).sum(-1, dtype=torch.float64).float()
112+
else:
113+
dcg = _tie_average_dcg(g_sorted, p_sorted, discount)
102114

103-
return _tie_average_dcg(target, preds, discount)
115+
return dcg if batched else dcg.squeeze(0)
104116

105117

106118
def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:

tests/unittests/retrieval/test_ndcg.py

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -214,40 +214,102 @@ def test_corner_case_with_tied_scores():
214214
)
215215

216216

217+
# ---- Tests for vectorized GPU-efficient implementation (issue #2287) ----
218+
219+
220+
@pytest.mark.parametrize(
221+
("batch_size", "list_length", "top_k"),
222+
[
223+
(1, 50, None),
224+
(1, 100, 10),
225+
(8, 50, None),
226+
(8, 100, 50),
227+
(32, 100, None),
228+
(32, 500, 200),
229+
(128, 100, 10),
230+
(128, 500, None),
231+
],
232+
)
233+
def test_accuracy_vs_sklearn(batch_size: int, list_length: int, top_k: Optional[int]):
234+
"""Batched nDCG must stay within 1e-4 of sklearn across configs.
235+
236+
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.
237+
"""
238+
torch.manual_seed(42)
239+
scores = torch.randn(batch_size, list_length)
240+
labels = (torch.randint(0, 2, (batch_size, list_length)) * 2 - 1).float() + 1.0
241+
242+
fast_result = retrieval_normalized_dcg(scores, labels, top_k=top_k).item()
243+
sklearn_result = float(
244+
np.mean([ndcg_score([t], [p], k=top_k) for t, p in zip(labels.numpy(), scores.numpy())])
245+
)
246+
247+
assert abs(fast_result - sklearn_result) <= 1e-4, (
248+
f"nDCG differs from sklearn by {abs(fast_result - sklearn_result):.2e} "
249+
f"(B={batch_size}, L={list_length}, k={top_k})"
250+
)
251+
252+
217253
def test_batched_input_matches_per_query():
218-
"""Batched 2-D input must give the same mean nDCG as averaging per-query results.
254+
"""Batched 2-D input must give the same mean nDCG as averaging per-query 1-D results.
219255
220256
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.
221257
"""
222-
preds = torch.tensor([
223-
[0.1, 0.2, 0.3, 4.0, 70.0],
224-
[0.5, 0.5, 0.1, 0.9, 0.2],
225-
[1.0, 0.0, 0.5, 0.5, 0.3],
258+
torch.manual_seed(42)
259+
preds = torch.randn(16, 50)
260+
target = (torch.randint(0, 2, (16, 50)) * 2 - 1).float() + 1.0
261+
262+
per_query = torch.stack([retrieval_normalized_dcg(preds[i], target[i]) for i in range(preds.shape[0])])
263+
batched = retrieval_normalized_dcg(preds, target)
264+
265+
assert torch.allclose(batched, per_query.mean(), atol=1e-5)
266+
267+
268+
def test_tie_handling_explicit():
269+
"""Tie-averaged DCG must match sklearn on inputs with explicit score ties.
270+
271+
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.
272+
"""
273+
scores = torch.tensor([
274+
[1.0, 1.0, 0.5, 0.5, 0.1], # two pairs of ties
275+
[0.8, 0.8, 0.8, 0.2, 0.1], # three-way tie
226276
])
227-
target = torch.tensor([
228-
[10, 0, 0, 1, 5],
229-
[0, 1, 2, 3, 4],
230-
[5, 0, 1, 2, 3],
277+
labels = torch.tensor([
278+
[1.0, 0.0, 1.0, 0.0, 0.0],
279+
[1.0, 0.0, 0.0, 1.0, 0.0],
231280
])
232281

233-
# Per-query average (existing 1-D API)
234-
per_query = torch.stack([
235-
retrieval_normalized_dcg(preds[i], target[i]) for i in range(preds.shape[0])
236-
])
237-
expected_mean = per_query.mean()
282+
result = retrieval_normalized_dcg(scores, labels)
283+
sklearn_result = float(
284+
np.mean([ndcg_score([t], [p]) for t, p in zip(labels.numpy(), scores.numpy())])
285+
)
238286

239-
# Batched 2-D call
240-
batched_result = retrieval_normalized_dcg(preds, target)
287+
assert isinstance(result, torch.Tensor)
288+
assert 0.0 <= result.item() <= 1.0
289+
assert abs(result.item() - sklearn_result) <= 1e-4
241290

242-
assert torch.allclose(batched_result, expected_mean, atol=1e-5), (
243-
f"Batched result {batched_result} differs from per-query mean {expected_mean}"
244-
)
245291

246-
# Also verify against sklearn for each query
247-
for i in range(preds.shape[0]):
248-
p = preds[i].unsqueeze(0).numpy()
249-
t = target[i].unsqueeze(0).numpy()
250-
sklearn_val = torch.tensor(ndcg_score(t, p), dtype=torch.float32)
251-
assert torch.allclose(per_query[i], sklearn_val, atol=1e-5), (
252-
f"Query {i}: got {per_query[i]}, expected {sklearn_val}"
253-
)
292+
def test_all_zeros_target():
293+
"""All-irrelevant queries (target all zero) must return 0, not NaN."""
294+
scores = torch.randn(4, 20)
295+
labels = torch.zeros(4, 20)
296+
result = retrieval_normalized_dcg(scores, labels)
297+
assert result.item() == 0.0
298+
299+
300+
def test_perfect_ranking():
301+
"""A perfectly-ranked list must return nDCG == 1.0."""
302+
labels = torch.tensor([[3.0, 2.0, 1.0, 0.0, 0.0]] * 4)
303+
scores = labels.clone() # predictions match ideal order
304+
result = retrieval_normalized_dcg(scores, labels)
305+
assert torch.allclose(result, torch.tensor(1.0), atol=1e-5)
306+
307+
308+
@pytest.mark.parametrize("top_k", [1, 10, 50, None])
309+
def test_top_k_valid_range(top_k: Optional[int]):
310+
"""Results must be in [0, 1] for all top_k values."""
311+
torch.manual_seed(0)
312+
scores = torch.randn(8, 100)
313+
labels = torch.randint(0, 3, (8, 100)).float()
314+
result = retrieval_normalized_dcg(scores, labels, top_k=top_k)
315+
assert 0.0 <= result.item() <= 1.0

0 commit comments

Comments
 (0)