Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 76 additions & 29 deletions src/torchmetrics/functional/retrieval/ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,100 @@
from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def _tie_average_dcg(target: Tensor, preds: Tensor, discount_cumsum: Tensor) -> Tensor:
"""Translated version of sklearns `_tie_average_dcg` function.
def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor:
"""Compute DCG for tied predictions using scatter operations.

Replaces the ``torch.unique`` approach with ``diff`` + ``scatter_add_``, which is
significantly faster on GPU (``torch.unique`` is ~15x slower on GPU than CPU).
Float64 is used for accumulation to preserve numerical accuracy.

Args:
target: ground truth about each document relevance.
preds: estimated probabilities of each document to be relevant.
discount_cumsum: cumulative sum of the discount.
target: ground truth relevances in **predicted** rank order, shape ``(n_queries, n_docs)``.
preds: predicted scores in **predicted** rank order, shape ``(n_queries, n_docs)``.
discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(n_docs,)``.

Returns:
The cumulative gain of the tied elements.
DCG values, shape ``(n_queries,)``, dtype float32.

"""
_, inv, counts = torch.unique(-preds, return_inverse=True, return_counts=True)
ranked = torch.zeros_like(counts, dtype=torch.float32)
ranked.scatter_add_(0, inv, target.to(dtype=ranked.dtype))
ranked = ranked / counts
groups = counts.cumsum(dim=0) - 1
discount_sums = torch.zeros_like(counts, dtype=torch.float32)
discount_sums[0] = discount_cumsum[groups[0]]
discount_sums[1:] = discount_cumsum[groups].diff()
return (ranked * discount_sums).sum()
n_queries, n_docs = target.shape
device = target.device

# Detect tie-group boundaries: True at the first element of each new group
new_grp = torch.cat(
[
torch.ones(n_queries, 1, dtype=torch.bool, device=device),
preds.diff(dim=-1).abs() > 0,
],
dim=-1,
) # (n_queries, n_docs)

# Per-element group id, unique across the batch
gid = new_grp.long().cumsum(-1) - 1 # 0-based within each row
gid = gid + torch.arange(n_queries, device=device).unsqueeze(-1) * n_docs

# Scatter: accumulate gains, discounts, and counts per group
flat_id = gid.flatten()
flat_gain = target.flatten().float()
flat_disc = discount.unsqueeze(0).expand(n_queries, -1).flatten().float()

grp_gain = torch.zeros(n_queries * n_docs, dtype=torch.float32, device=device)
grp_disc = torch.zeros(n_queries * n_docs, dtype=torch.float32, device=device)
grp_cnt = torch.zeros(n_queries * n_docs, dtype=torch.int32, device=device)

grp_gain.scatter_add_(0, flat_id, flat_gain)
grp_disc.scatter_add_(0, flat_id, flat_disc)
grp_cnt.scatter_add_(0, flat_id, torch.ones_like(flat_id, dtype=torch.int32))

# Float64 accumulation for numerical parity with sklearn / reference implementations
contrib = grp_gain.double() * (grp_disc.double() / grp_cnt.clamp(min=1).double())

# Scatter only non-empty groups back to the batch dimension
valid = grp_cnt > 0
batch_idx = flat_id[valid] // n_docs
dcg = torch.zeros(n_queries, dtype=torch.float64, device=device)
dcg.scatter_add_(0, batch_idx, contrib[valid])
return dcg.float()


def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: bool) -> Tensor:
"""Translated version of sklearns `_dcg_sample_scores` function.
"""Compute DCG sample scores.

Args:
target: ground truth about each document relevance.
preds: estimated probabilities of each document to be relevant.
top_k: consider only the top k elements
ignore_ties: If True, ties are ignored. If False, ties are averaged.
target: ground truth relevances, shape ``(n_docs,)`` or ``(n_queries, n_docs)``.
preds: predicted scores, shape ``(n_docs,)`` or ``(n_queries, n_docs)``.
top_k: consider only the top k elements.
ignore_ties: If ``True``, ties are broken by order. If ``False``, ties are averaged.

Returns:
The cumulative gain
DCG value(s): scalar for 1-D input, shape ``(n_queries,)`` for batched input.

"""
discount = 1.0 / (torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0))
discount[top_k:] = 0.0
batched = preds.dim() > 1
if not batched:
preds = preds.unsqueeze(0)
target = target.unsqueeze(0)

n_docs = preds.shape[-1]

# Use topk when k < n_docs to avoid sorting the full list
if top_k < n_docs:
order = preds.topk(top_k, dim=-1, sorted=True).indices
n_docs_eff = top_k
else:
order = preds.argsort(dim=-1, descending=True, stable=True)
n_docs_eff = n_docs

discount = 1.0 / torch.log2(torch.arange(n_docs_eff, device=preds.device) + 2.0)
p_sorted = preds.gather(-1, order)
g_sorted = target.float().gather(-1, order)

if ignore_ties:
ranking = preds.argsort(descending=True)
ranked = target[ranking]
cumulative_gain = (discount * ranked).sum()
dcg = (discount * g_sorted).sum(-1, dtype=torch.float64).float()
else:
discount_cumsum = discount.cumsum(dim=-1)
cumulative_gain = _tie_average_dcg(target, preds, discount_cumsum)
return cumulative_gain
dcg = _tie_average_dcg(g_sorted, p_sorted, discount)

return dcg if batched else dcg.squeeze(0)


def retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
Expand Down
100 changes: 100 additions & 0 deletions tests/unittests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,103 @@ def test_corner_case_with_tied_scores():
retrieval_normalized_dcg(preds, target, top_k=k),
torch.tensor([ndcg_score(target, preds, k=k)], dtype=torch.float32),
)


# ---- Tests for vectorized GPU-efficient implementation (issue #2287) ----


@pytest.mark.parametrize(
("batch_size", "list_length", "top_k"),
[
(1, 50, None),
(1, 100, 10),
(8, 50, None),
(8, 100, 50),
(32, 100, None),
(32, 500, 200),
(128, 100, 10),
(128, 500, None),
],
)
def test_accuracy_vs_sklearn(batch_size: int, list_length: int, top_k: Optional[int]):
"""Batched nDCG must stay within 1e-4 of sklearn across configs.

See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.

"""
torch.manual_seed(42)
scores = torch.randn(batch_size, list_length)
labels = (torch.randint(0, 2, (batch_size, list_length)) * 2 - 1).float() + 1.0

fast_result = retrieval_normalized_dcg(scores, labels, top_k=top_k).item()
sklearn_result = float(np.mean([ndcg_score([t], [p], k=top_k) for t, p in zip(labels.numpy(), scores.numpy())]))

assert abs(fast_result - sklearn_result) <= 1e-4, (
f"nDCG differs from sklearn by {abs(fast_result - sklearn_result):.2e} "
f"(B={batch_size}, L={list_length}, k={top_k})"
)


def test_batched_input_matches_per_query():
"""Batched 2-D input must give the same mean nDCG as averaging per-query 1-D results.

See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.

"""
torch.manual_seed(42)
preds = torch.randn(16, 50)
target = (torch.randint(0, 2, (16, 50)) * 2 - 1).float() + 1.0

per_query = torch.stack([retrieval_normalized_dcg(preds[i], target[i]) for i in range(preds.shape[0])])
batched = retrieval_normalized_dcg(preds, target)

assert torch.allclose(batched, per_query.mean(), atol=1e-5)


def test_tie_handling_explicit():
"""Tie-averaged DCG must match sklearn on inputs with explicit score ties.

See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.

"""
scores = torch.tensor([
[1.0, 1.0, 0.5, 0.5, 0.1], # two pairs of ties
[0.8, 0.8, 0.8, 0.2, 0.1], # three-way tie
])
labels = torch.tensor([
[1.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 1.0, 0.0],
])

result = retrieval_normalized_dcg(scores, labels)
sklearn_result = float(np.mean([ndcg_score([t], [p]) for t, p in zip(labels.numpy(), scores.numpy())]))

assert isinstance(result, torch.Tensor)
assert 0.0 <= result.item() <= 1.0
assert abs(result.item() - sklearn_result) <= 1e-4


def test_all_zeros_target():
"""All-irrelevant queries (target all zero) must return 0, not NaN."""
scores = torch.randn(4, 20)
labels = torch.zeros(4, 20)
result = retrieval_normalized_dcg(scores, labels)
assert result.item() == 0.0


def test_perfect_ranking():
"""A perfectly-ranked list must return nDCG == 1.0."""
labels = torch.tensor([[3.0, 2.0, 1.0, 0.0, 0.0]] * 4)
scores = labels.clone() # predictions match ideal order
result = retrieval_normalized_dcg(scores, labels)
assert torch.allclose(result, torch.tensor(1.0), atol=1e-5)


@pytest.mark.parametrize("top_k", [1, 10, 50, None])
def test_top_k_valid_range(top_k: Optional[int]):
"""Results must be in [0, 1] for all top_k values."""
torch.manual_seed(0)
scores = torch.randn(8, 100)
labels = torch.randint(0, 3, (8, 100)).float()
result = retrieval_normalized_dcg(scores, labels, top_k=top_k)
assert 0.0 <= result.item() <= 1.0
Loading