Commit 1ec556c
perf(retrieval): fix NDCG GPU performance by replacing torch.unique in tie averaging
torch.unique is ~15x slower on GPU than CPU, causing nDCG to run up to
2.65x slower on GPU than CPU. Replace with a diff + scatter_add_ strategy
that is efficient on both CPU and GPU.
Key changes to the algorithm (based on the optimized implementation
proposed in #2287):
- _tie_average_dcg: takes pre-sorted inputs, uses diff + scatter_add_
instead of torch.unique; float64 accumulation for numerical parity
with sklearn; int32 group counts; valid-group masking before scatter
- _dcg_sample_scores: handles sorting (with topk fast-path when k < L),
gather, and discount creation; delegates tie averaging to the above
- retrieval_normalized_dcg: unchanged public API; now correctly handles
both 1-D (single query) and 2-D (batched) inputs
Tests added:
- test_accuracy_vs_sklearn: parametrized across 8 (batch, length, top_k)
configs, tolerance 1e-4 matching reference implementation parity
- test_batched_input_matches_per_query: 2-D result == mean of 1-D calls
- test_tie_handling_explicit: explicit tie configurations vs sklearn
- test_all_zeros_target: all-irrelevant queries return 0.0, not NaN
- test_perfect_ranking: ideal predictions return nDCG == 1.0
- test_top_k_valid_range: results in [0, 1] for all top_k values
Fixes: #2287
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>1 parent c5be2f2 commit 1ec556c
2 files changed
Lines changed: 138 additions & 64 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| 27 | + | |
27 | 28 | | |
28 | 29 | | |
29 | | - | |
30 | | - | |
| 30 | + | |
| 31 | + | |
31 | 32 | | |
32 | 33 | | |
33 | 34 | | |
34 | | - | |
| 35 | + | |
35 | 36 | | |
36 | 37 | | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
46 | | - | |
47 | | - | |
48 | | - | |
| 38 | + | |
| 39 | + | |
49 | 40 | | |
50 | 41 | | |
51 | 42 | | |
52 | 43 | | |
53 | | - | |
54 | | - | |
| 44 | + | |
| 45 | + | |
55 | 46 | | |
56 | 47 | | |
57 | 48 | | |
58 | 49 | | |
59 | | - | |
| 50 | + | |
60 | 51 | | |
61 | | - | |
| 52 | + | |
62 | 53 | | |
63 | 54 | | |
64 | | - | |
65 | | - | |
| 55 | + | |
| 56 | + | |
66 | 57 | | |
67 | 58 | | |
68 | | - | |
69 | | - | |
70 | | - | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
71 | 62 | | |
72 | | - | |
73 | | - | |
74 | | - | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
75 | 66 | | |
76 | | - | |
77 | | - | |
78 | | - | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
79 | 76 | | |
80 | 77 | | |
81 | 78 | | |
| |||
91 | 88 | | |
92 | 89 | | |
93 | 90 | | |
94 | | - | |
95 | | - | |
96 | | - | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
97 | 109 | | |
98 | 110 | | |
99 | | - | |
100 | | - | |
101 | | - | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
102 | 114 | | |
103 | | - | |
| 115 | + | |
104 | 116 | | |
105 | 117 | | |
106 | 118 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
214 | 214 | | |
215 | 215 | | |
216 | 216 | | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
217 | 253 | | |
218 | | - | |
| 254 | + | |
219 | 255 | | |
220 | 256 | | |
221 | 257 | | |
222 | | - | |
223 | | - | |
224 | | - | |
225 | | - | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
226 | 276 | | |
227 | | - | |
228 | | - | |
229 | | - | |
230 | | - | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
231 | 280 | | |
232 | 281 | | |
233 | | - | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
238 | 286 | | |
239 | | - | |
240 | | - | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
241 | 290 | | |
242 | | - | |
243 | | - | |
244 | | - | |
245 | 291 | | |
246 | | - | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | | - | |
251 | | - | |
252 | | - | |
253 | | - | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
0 commit comments