diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index c85d49e606c..b571420db09 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -441,13 +441,21 @@ def _multiclass_stat_scores_update( idx = target != ignore_index preds = preds[idx] target = target[idx] - unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long) - bins = _bincount(unique_mapping, minlength=num_classes**2) - confmat = bins.reshape(num_classes, num_classes) - tp = confmat.diag() - fp = confmat.sum(0) - tp - fn = confmat.sum(1) - tp - tn = confmat.sum() - (fp + fn + tp) + if num_classes < 1000: + unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long) + bins = _bincount(unique_mapping, minlength=num_classes**2) + confmat = bins.reshape(num_classes, num_classes) + tp = confmat.diag() + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + else: + # The above approach requires num_classes**2 memory. For large num_classes, we can calculate the + # statistics separately using linear memory. + tp = _bincount(preds[target == preds], minlength=num_classes) + fp = _bincount(preds, minlength=num_classes) - tp + fn = _bincount(target, minlength=num_classes) - tp + tn = target.numel() - (tp + fp + fn) return tp, fp, tn, fn diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 423af64c1af..540507eebf8 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -408,6 +408,25 @@ def test_multiclass_accuracy_gpu_sync_points_uptodate( ) +def test_multiclass_accuracy_large_num_classes(): + """Test that accuracy is correct when num_classes>=1000, exercising the linear-space code path.""" + num_classes = 1_000_000 + n = 500 + generator = torch.Generator().manual_seed(42) + target = torch.randint(0, num_classes, (n,), generator=generator) + preds = torch.randint(0, num_classes, (n,), generator=generator) + + # We have so many classes that its most likely the accuracy is 0 in this test, so we artificially + # set 20% of the predictions to be correct. + artificially_correct = torch.randperm(n, generator=generator)[: n // 5] + preds[artificially_correct] = target[artificially_correct] + + # Expected: fraction of exactly correct predictions + expected = (preds == target).float().mean() + result = multiclass_accuracy(preds, target, num_classes=num_classes, average="micro") + assert torch.isclose(result, expected), f"Expected {expected}, got {result}" + + _mc_k_target = torch.tensor([0, 1, 2]) _mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 4e1f11bdc23..d7f8409cb48 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -476,6 +476,35 @@ def test_multiclass_overflow(): assert torch.allclose(res, torch.tensor(compare)) +# test both small and larger number of classes to test both linear and quadratic code paths +@pytest.mark.parametrize("num_classes", [2, 6, 2000, 1_000_000]) +@pytest.mark.parametrize("average", ["micro", "macro", None]) +def test_multiclass_stat_scores_large_num_classes(num_classes, average): + """Test that stat scores are correct when num_classes>=1000, exercising the linear-space code path.""" + n = 500 + generator = torch.Generator().manual_seed(42) + target = torch.randint(0, num_classes, (n,), generator=generator) + preds = torch.randint(0, num_classes, (n,), generator=generator) + + # We have so many classes that it's most likely tp=0 in this test, so we artificially + # set 20% of the predictions to be correct. + artificially_correct = torch.randperm(n, generator=generator)[: n // 5] + preds[artificially_correct] = target[artificially_correct] + + result = multiclass_stat_scores(preds, target, num_classes=num_classes, average=average) + + tp = torch.bincount(target[preds == target], minlength=num_classes) + fp = torch.bincount(preds, minlength=num_classes) - tp + fn = torch.bincount(target, minlength=num_classes) - tp + tn = n - tp - fp - fn + expected = torch.stack([tp, fp, tn, fn, tp + fn], dim=1) + if average == "micro": + expected = expected.sum(0) + elif average == "macro": + expected = expected.float().mean(0) + assert torch.allclose(result, expected, atol=1e-4, rtol=1e-4) + + def _reference_sklearn_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy()