@@ -27,38 +27,38 @@ def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor:
2727 Float64 is used for accumulation to preserve numerical accuracy.
2828
2929 Args:
30- target: ground truth relevances in **predicted** rank order, shape ``(B, L )``.
31- preds: predicted scores in **predicted** rank order, shape ``(B, L )``.
32- discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(L ,)``.
30+ target: ground truth relevances in **predicted** rank order, shape ``(n_queries, n_docs )``.
31+ preds: predicted scores in **predicted** rank order, shape ``(n_queries, n_docs )``.
32+ discount: per-rank discount values ``1 / log2(rank + 2)``, shape ``(n_docs ,)``.
3333
3434 Returns:
35- DCG values, shape ``(B ,)``, dtype float32.
35+ DCG values, shape ``(n_queries ,)``, dtype float32.
3636
3737 """
38- B , L = target .shape
38+ n_queries , n_docs = target .shape
3939 device = target .device
4040
4141 # Detect tie-group boundaries: True at the first element of each new group
4242 new_grp = torch .cat (
4343 [
44- torch .ones (B , 1 , dtype = torch .bool , device = device ),
44+ torch .ones (n_queries , 1 , dtype = torch .bool , device = device ),
4545 preds .diff (dim = - 1 ).abs () > 0 ,
4646 ],
4747 dim = - 1 ,
48- ) # (B, L )
48+ ) # (n_queries, n_docs )
4949
5050 # Per-element group id, unique across the batch
5151 gid = new_grp .long ().cumsum (- 1 ) - 1 # 0-based within each row
52- gid = gid + torch .arange (B , device = device ).unsqueeze (- 1 ) * L
52+ gid = gid + torch .arange (n_queries , device = device ).unsqueeze (- 1 ) * n_docs
5353
5454 # Scatter: accumulate gains, discounts, and counts per group
5555 flat_id = gid .flatten ()
5656 flat_gain = target .flatten ().float ()
57- flat_disc = discount .unsqueeze (0 ).expand (B , - 1 ).flatten ().float ()
57+ flat_disc = discount .unsqueeze (0 ).expand (n_queries , - 1 ).flatten ().float ()
5858
59- grp_gain = torch .zeros (B * L , dtype = torch .float32 , device = device )
60- grp_disc = torch .zeros (B * L , dtype = torch .float32 , device = device )
61- grp_cnt = torch .zeros (B * L , dtype = torch .int32 , device = device )
59+ grp_gain = torch .zeros (n_queries * n_docs , dtype = torch .float32 , device = device )
60+ grp_disc = torch .zeros (n_queries * n_docs , dtype = torch .float32 , device = device )
61+ grp_cnt = torch .zeros (n_queries * n_docs , dtype = torch .int32 , device = device )
6262
6363 grp_gain .scatter_add_ (0 , flat_id , flat_gain )
6464 grp_disc .scatter_add_ (0 , flat_id , flat_disc )
@@ -69,8 +69,8 @@ def _tie_average_dcg(target: Tensor, preds: Tensor, discount: Tensor) -> Tensor:
6969
7070 # Scatter only non-empty groups back to the batch dimension
7171 valid = grp_cnt > 0
72- batch_idx = flat_id [valid ] // L
73- dcg = torch .zeros (B , dtype = torch .float64 , device = device )
72+ batch_idx = flat_id [valid ] // n_docs
73+ dcg = torch .zeros (n_queries , dtype = torch .float64 , device = device )
7474 dcg .scatter_add_ (0 , batch_idx , contrib [valid ])
7575 return dcg .float ()
7676
@@ -79,31 +79,31 @@ def _dcg_sample_scores(target: Tensor, preds: Tensor, top_k: int, ignore_ties: b
7979 """Compute DCG sample scores.
8080
8181 Args:
82- target: ground truth relevances, shape ``(L ,)`` or ``(B, L )``.
83- preds: predicted scores, shape ``(L ,)`` or ``(B, L )``.
82+ target: ground truth relevances, shape ``(n_docs ,)`` or ``(n_queries, n_docs )``.
83+ preds: predicted scores, shape ``(n_docs ,)`` or ``(n_queries, n_docs )``.
8484 top_k: consider only the top k elements.
8585 ignore_ties: If ``True``, ties are broken by order. If ``False``, ties are averaged.
8686
8787 Returns:
88- DCG value(s): scalar for 1-D input, shape ``(B ,)`` for batched input.
88+ DCG value(s): scalar for 1-D input, shape ``(n_queries ,)`` for batched input.
8989
9090 """
9191 batched = preds .dim () > 1
9292 if not batched :
9393 preds = preds .unsqueeze (0 )
9494 target = target .unsqueeze (0 )
9595
96- L = preds .shape [- 1 ]
96+ n_docs = preds .shape [- 1 ]
9797
98- # Use topk when k < L to avoid sorting the full list
99- if top_k < L :
98+ # Use topk when k < n_docs to avoid sorting the full list
99+ if top_k < n_docs :
100100 order = preds .topk (top_k , dim = - 1 , sorted = True ).indices
101- L_eff = top_k
101+ n_docs_eff = top_k
102102 else :
103103 order = preds .argsort (dim = - 1 , descending = True , stable = True )
104- L_eff = L
104+ n_docs_eff = n_docs
105105
106- discount = 1.0 / torch .log2 (torch .arange (L_eff , device = preds .device ) + 2.0 )
106+ discount = 1.0 / torch .log2 (torch .arange (n_docs_eff , device = preds .device ) + 2.0 )
107107 p_sorted = preds .gather (- 1 , order )
108108 g_sorted = target .float ().gather (- 1 , order )
109109
0 commit comments