@@ -476,6 +476,35 @@ def test_multiclass_overflow():
476476 assert torch .allclose (res , torch .tensor (compare ))
477477
478478
479+ # test both small and larger number of classes to test both linear and quadratic code paths
480+ @pytest .mark .parametrize ("num_classes" , [2 , 6 , 2000 , 1_000_000 ])
481+ @pytest .mark .parametrize ("average" , ["micro" , "macro" , None ])
482+ def test_multiclass_stat_scores_large_num_classes (num_classes , average ):
483+ """Test that stat scores are correct when num_classes>=1000, exercising the linear-space code path."""
484+ n = 500
485+ generator = torch .Generator ().manual_seed (42 )
486+ target = torch .randint (0 , num_classes , (n ,), generator = generator )
487+ preds = torch .randint (0 , num_classes , (n ,), generator = generator )
488+
489+ # We have so many classes that it's most likely tp=0 in this test, so we artificially
490+ # set 20% of the predictions to be correct.
491+ artificially_correct = torch .randperm (n , generator = generator )[: n // 5 ]
492+ preds [artificially_correct ] = target [artificially_correct ]
493+
494+ result = multiclass_stat_scores (preds , target , num_classes = num_classes , average = average )
495+
496+ tp = torch .bincount (target [preds == target ], minlength = num_classes )
497+ fp = torch .bincount (preds , minlength = num_classes ) - tp
498+ fn = torch .bincount (target , minlength = num_classes ) - tp
499+ tn = n - tp - fp - fn
500+ expected = torch .stack ([tp , fp , tn , fn , tp + fn ], dim = 1 )
501+ if average == "micro" :
502+ expected = expected .sum (0 )
503+ elif average == "macro" :
504+ expected = expected .float ().mean (0 )
505+ assert torch .allclose (result , expected , atol = 1e-4 , rtol = 1e-4 )
506+
507+
479508def _reference_sklearn_stat_scores_multilabel (preds , target , ignore_index , multidim_average , average ):
480509 preds = preds .numpy ()
481510 target = target .numpy ()
0 commit comments