Skip to content

Commit c3835f2

Browse files
committed
Use linear space in multiclass stats computation when n_classes>1000
1 parent 5e8b01d commit c3835f2

3 files changed

Lines changed: 63 additions & 7 deletions

File tree

src/torchmetrics/functional/classification/stat_scores.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -441,13 +441,21 @@ def _multiclass_stat_scores_update(
441441
idx = target != ignore_index
442442
preds = preds[idx]
443443
target = target[idx]
444-
unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
445-
bins = _bincount(unique_mapping, minlength=num_classes**2)
446-
confmat = bins.reshape(num_classes, num_classes)
447-
tp = confmat.diag()
448-
fp = confmat.sum(0) - tp
449-
fn = confmat.sum(1) - tp
450-
tn = confmat.sum() - (fp + fn + tp)
444+
if num_classes < 1000:
445+
unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
446+
bins = _bincount(unique_mapping, minlength=num_classes**2)
447+
confmat = bins.reshape(num_classes, num_classes)
448+
tp = confmat.diag()
449+
fp = confmat.sum(0) - tp
450+
fn = confmat.sum(1) - tp
451+
tn = confmat.sum() - (fp + fn + tp)
452+
else:
453+
# The above approach requires num_classes**2 memory. For large num_classes, we can calculate the
454+
# statistics separately using linear memory.
455+
tp = _bincount(preds[target == preds], minlength=num_classes)
456+
fp = _bincount(preds, minlength=num_classes) - tp
457+
fn = _bincount(target, minlength=num_classes) - tp
458+
tn = target.numel() - (tp + fp + fn)
451459
return tp, fp, tn, fn
452460

453461

tests/unittests/classification/test_accuracy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,25 @@ def test_multiclass_accuracy_gpu_sync_points_uptodate(
408408
)
409409

410410

411+
def test_multiclass_accuracy_large_num_classes():
412+
"""Test that accuracy is correct when num_classes>=1000, exercising the linear-space code path."""
413+
num_classes = 1_000_000
414+
n = 500
415+
generator = torch.Generator().manual_seed(42)
416+
target = torch.randint(0, num_classes, (n,), generator=generator)
417+
preds = torch.randint(0, num_classes, (n,), generator=generator)
418+
419+
# We have so many classes that its most likely the accuracy is 0 in this test, so we artificially
420+
# set 20% of the predictions to be correct.
421+
artificially_correct = torch.randperm(n, generator=generator)[: n // 5]
422+
preds[artificially_correct] = target[artificially_correct]
423+
424+
# Expected: fraction of exactly correct predictions
425+
expected = (preds == target).float().mean()
426+
result = multiclass_accuracy(preds, target, num_classes=num_classes, average="micro")
427+
assert torch.isclose(result, expected), f"Expected {expected}, got {result}"
428+
429+
411430
_mc_k_target = torch.tensor([0, 1, 2])
412431
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
413432

tests/unittests/classification/test_stat_scores.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,35 @@ def test_multiclass_overflow():
476476
assert torch.allclose(res, torch.tensor(compare))
477477

478478

479+
# test both small and larger number of classes to test both linear and quadratic code paths
480+
@pytest.mark.parametrize("num_classes", [2, 6, 2000, 1_000_000])
481+
@pytest.mark.parametrize("average", ["micro", "macro", None])
482+
def test_multiclass_stat_scores_large_num_classes(num_classes, average):
483+
"""Test that stat scores are correct when num_classes>=1000, exercising the linear-space code path."""
484+
n = 500
485+
generator = torch.Generator().manual_seed(42)
486+
target = torch.randint(0, num_classes, (n,), generator=generator)
487+
preds = torch.randint(0, num_classes, (n,), generator=generator)
488+
489+
# We have so many classes that it's most likely tp=0 in this test, so we artificially
490+
# set 20% of the predictions to be correct.
491+
artificially_correct = torch.randperm(n, generator=generator)[: n // 5]
492+
preds[artificially_correct] = target[artificially_correct]
493+
494+
result = multiclass_stat_scores(preds, target, num_classes=num_classes, average=average)
495+
496+
tp = torch.bincount(target[preds == target], minlength=num_classes)
497+
fp = torch.bincount(preds, minlength=num_classes) - tp
498+
fn = torch.bincount(target, minlength=num_classes) - tp
499+
tn = n - tp - fp - fn
500+
expected = torch.stack([tp, fp, tn, fn, tp + fn], dim=1)
501+
if average == "micro":
502+
expected = expected.sum(0)
503+
elif average == "macro":
504+
expected = expected.float().mean(0)
505+
assert torch.allclose(result, expected, atol=1e-4, rtol=1e-4)
506+
507+
479508
def _reference_sklearn_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average):
480509
preds = preds.numpy()
481510
target = target.numpy()

0 commit comments

Comments
 (0)