@@ -408,6 +408,25 @@ def test_multiclass_accuracy_gpu_sync_points_uptodate(
408408 )
409409
410410
411+ def test_multiclass_accuracy_large_num_classes ():
412+ """Test that accuracy is correct when num_classes>=1000, exercising the linear-space code path."""
413+ num_classes = 1_000_000
414+ n = 500
415+ generator = torch .Generator ().manual_seed (42 )
416+ target = torch .randint (0 , num_classes , (n ,), generator = generator )
417+ preds = torch .randint (0 , num_classes , (n ,), generator = generator )
418+
419+ # We have so many classes that its most likely the accurary is 0 in this test, so we artificially
420+ # set 20% of the predictions to be correct.
421+ artificially_correct = torch .randperm (n , generator = generator )[: n // 5 ]
422+ preds [artificially_correct ] = target [artificially_correct ]
423+
424+ # Expected: fraction of exactly correct predictions
425+ expected = (preds == target ).float ().mean ()
426+ result = multiclass_accuracy (preds , target , num_classes = num_classes , average = "micro" )
427+ assert torch .isclose (result , expected ), f"Expected { expected } , got { result } "
428+
429+
411430_mc_k_target = torch .tensor ([0 , 1 , 2 ])
412431_mc_k_preds = torch .tensor ([[0.35 , 0.4 , 0.25 ], [0.1 , 0.5 , 0.4 ], [0.2 , 0.1 , 0.7 ]])
413432
0 commit comments