Skip to content

Commit 41c14b2

Browse files
rcloughclaude
andcommitted
style: fix ruff N806 (uppercase vars) and formatting in ndcg
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 1ec556c commit 41c14b2

2 files changed

Lines changed: 25 additions & 29 deletions

File tree

src/torchmetrics/functional/retrieval/ndcg.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,38 +27,38 @@ def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor:
2727
Float64 is used for accumulation to preserve numerical accuracy.
2828
2929
Args:
30-
target: ground truth relevances in **predicted** rank order, shape ``(B, L)``.
31-
preds: predicted scores in **predicted** rank order, shape ``(B, L)``.
32-
discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(L,)``.
30+
target: ground truth relevances in **predicted** rank order, shape ``(n_queries, n_docs)``.
31+
preds: predicted scores in **predicted** rank order, shape ``(n_queries, n_docs)``.
32+
discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(n_docs,)``.
3333
3434
Returns:
35-
DCG values, shape ``(B,)``, dtype float32.
35+
DCG values, shape ``(n_queries,)``, dtype float32.
3636
3737
"""
38-
B, L = target.shape
38+
n_queries, n_docs = target.shape
3939
device = target.device
4040

4141
# Detect tie-group boundaries: True at the first element of each new group
4242
new_grp = torch.cat(
4343
[
44-
torch.ones(B, 1, dtype=torch.bool, device=device),
44+
torch.ones(n_queries, 1, dtype=torch.bool, device=device),
4545
preds.diff(dim=-1).abs() > 0,
4646
],
4747
dim=-1,
48-
) # (B, L)
48+
) # (n_queries, n_docs)
4949

5050
# Per-element group id, unique across the batch
5151
gid = new_grp.long().cumsum(-1) - 1 # 0-based within each row
52-
gid = gid + torch.arange(B, device=device).unsqueeze(-1) * L
52+
gid = gid + torch.arange(n_queries, device=device).unsqueeze(-1) * n_docs
5353

5454
# Scatter: accumulate gains, discounts, and counts per group
5555
flat_id = gid.flatten()
5656
flat_gain = target.flatten().float()
57-
flat_disc = discount.unsqueeze(0).expand(B, -1).flatten().float()
57+
flat_disc = discount.unsqueeze(0).expand(n_queries, -1).flatten().float()
5858

59-
grp_gain = torch.zeros(B * L, dtype=torch.float32, device=device)
60-
grp_disc = torch.zeros(B * L, dtype=torch.float32, device=device)
61-
grp_cnt = torch.zeros(B * L, dtype=torch.int32, device=device)
59+
grp_gain = torch.zeros(n_queries * n_docs, dtype=torch.float32, device=device)
60+
grp_disc = torch.zeros(n_queries * n_docs, dtype=torch.float32, device=device)
61+
grp_cnt = torch.zeros(n_queries * n_docs, dtype=torch.int32, device=device)
6262

6363
grp_gain.scatter_add_(0, flat_id, flat_gain)
6464
grp_disc.scatter_add_(0, flat_id, flat_disc)
@@ -69,8 +69,8 @@ def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor:
6969

7070
# Scatter only non-empty groups back to the batch dimension
7171
valid = grp_cnt > 0
72-
batch_idx = flat_id[valid] // L
73-
dcg = torch.zeros(B, dtype=torch.float64, device=device)
72+
batch_idx = flat_id[valid] // n_docs
73+
dcg = torch.zeros(n_queries, dtype=torch.float64, device=device)
7474
dcg.scatter_add_(0, batch_idx, contrib[valid])
7575
return dcg.float()
7676

@@ -79,31 +79,31 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b
7979
"""Compute DCG sample scores.
8080
8181
Args:
82-
target: ground truth relevances, shape ``(L,)`` or ``(B, L)``.
83-
preds: predicted scores, shape ``(L,)`` or ``(B, L)``.
82+
target: ground truth relevances, shape ``(n_docs,)`` or ``(n_queries, n_docs)``.
83+
preds: predicted scores, shape ``(n_docs,)`` or ``(n_queries, n_docs)``.
8484
top_k: consider only the top k elements.
8585
ignore_ties: If ``True``, ties are broken by order. If ``False``, ties are averaged.
8686
8787
Returns:
88-
DCG value(s): scalar for 1-D input, shape ``(B,)`` for batched input.
88+
DCG value(s): scalar for 1-D input, shape ``(n_queries,)`` for batched input.
8989
9090
"""
9191
batched = preds.dim() > 1
9292
if not batched:
9393
preds = preds.unsqueeze(0)
9494
target = target.unsqueeze(0)
9595

96-
L = preds.shape[-1]
96+
n_docs = preds.shape[-1]
9797

98-
# Use topk when k < L to avoid sorting the full list
99-
if top_k < L:
98+
# Use topk when k < n_docs to avoid sorting the full list
99+
if top_k < n_docs:
100100
order = preds.topk(top_k, dim=-1, sorted=True).indices
101-
L_eff = top_k
101+
n_docs_eff = top_k
102102
else:
103103
order = preds.argsort(dim=-1, descending=True, stable=True)
104-
L_eff = L
104+
n_docs_eff = n_docs
105105

106-
discount = 1.0 / torch.log2(torch.arange(L_eff, device=preds.device) + 2.0)
106+
discount = 1.0 / torch.log2(torch.arange(n_docs_eff, device=preds.device) + 2.0)
107107
p_sorted = preds.gather(-1, order)
108108
g_sorted = target.float().gather(-1, order)
109109

tests/unittests/retrieval/test_ndcg.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,7 @@ def test_accuracy_vs_sklearn(batch_size: int, list_length: int, top_k: Optional[
240240
labels = (torch.randint(0, 2, (batch_size, list_length)) * 2 - 1).float() + 1.0
241241

242242
fast_result = retrieval_normalized_dcg(scores, labels, top_k=top_k).item()
243-
sklearn_result = float(
244-
np.mean([ndcg_score([t], [p], k=top_k) for t, p in zip(labels.numpy(), scores.numpy())])
245-
)
243+
sklearn_result = float(np.mean([ndcg_score([t], [p], k=top_k) for t, p in zip(labels.numpy(), scores.numpy())]))
246244

247245
assert abs(fast_result - sklearn_result) <= 1e-4, (
248246
f"nDCG differs from sklearn by {abs(fast_result - sklearn_result):.2e} "
@@ -280,9 +278,7 @@ def test_tie_handling_explicit():
280278
])
281279

282280
result = retrieval_normalized_dcg(scores, labels)
283-
sklearn_result = float(
284-
np.mean([ndcg_score([t], [p]) for t, p in zip(labels.numpy(), scores.numpy())])
285-
)
281+
sklearn_result = float(np.mean([ndcg_score([t], [p]) for t, p in zip(labels.numpy(), scores.numpy())]))
286282

287283
assert isinstance(result, torch.Tensor)
288284
assert 0.0 <= result.item() <= 1.0

0 commit comments

Comments
 (0)