@@ -234,15 +234,14 @@ def test_accuracy_vs_sklearn(batch_size: int, list_length: int, top_k: Optional[
234234 """Batched nDCG must stay within 1e-4 of sklearn across configs.
235235
236236 See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.
237+
237238 """
238239 torch .manual_seed (42 )
239240 scores = torch .randn (batch_size , list_length )
240241 labels = (torch .randint (0 , 2 , (batch_size , list_length )) * 2 - 1 ).float () + 1.0
241242
242243 fast_result = retrieval_normalized_dcg (scores , labels , top_k = top_k ).item ()
243- sklearn_result = float (
244- np .mean ([ndcg_score ([t ], [p ], k = top_k ) for t , p in zip (labels .numpy (), scores .numpy ())])
245- )
244+ sklearn_result = float (np .mean ([ndcg_score ([t ], [p ], k = top_k ) for t , p in zip (labels .numpy (), scores .numpy ())]))
246245
247246 assert abs (fast_result - sklearn_result ) <= 1e-4 , (
248247 f"nDCG differs from sklearn by { abs (fast_result - sklearn_result ):.2e} "
@@ -254,6 +253,7 @@ def test_batched_input_matches_per_query():
254253 """Batched 2-D input must give the same mean nDCG as averaging per-query 1-D results.
255254
256255 See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.
256+
257257 """
258258 torch .manual_seed (42 )
259259 preds = torch .randn (16 , 50 )
@@ -269,6 +269,7 @@ def test_tie_handling_explicit():
269269 """Tie-averaged DCG must match sklearn on inputs with explicit score ties.
270270
271271 See issue: https://github.com/Lightning-AI/torchmetrics/issues/2287.
272+
272273 """
273274 scores = torch .tensor ([
274275 [1.0 , 1.0 , 0.5 , 0.5 , 0.1 ], # two pairs of ties
@@ -280,9 +281,7 @@ def test_tie_handling_explicit():
280281 ])
281282
282283 result = retrieval_normalized_dcg (scores , labels )
283- sklearn_result = float (
284- np .mean ([ndcg_score ([t ], [p ]) for t , p in zip (labels .numpy (), scores .numpy ())])
285- )
284+ sklearn_result = float (np .mean ([ndcg_score ([t ], [p ]) for t , p in zip (labels .numpy (), scores .numpy ())]))
286285
287286 assert isinstance (result , torch .Tensor )
288287 assert 0.0 <= result .item () <= 1.0
0 commit comments