Skip to content

Commit c610c46

Browse files
committed
Use linear space in multiclass stats computation when n_classes>1000
1 parent 5e8b01d commit c610c46

2 files changed

Lines changed: 34 additions & 7 deletions

File tree

src/torchmetrics/functional/classification/stat_scores.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -441,13 +441,21 @@ def _multiclass_stat_scores_update(
441441
idx = target != ignore_index
442442
preds = preds[idx]
443443
target = target[idx]
444-
unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
445-
bins = _bincount(unique_mapping, minlength=num_classes**2)
446-
confmat = bins.reshape(num_classes, num_classes)
447-
tp = confmat.diag()
448-
fp = confmat.sum(0) - tp
449-
fn = confmat.sum(1) - tp
450-
tn = confmat.sum() - (fp + fn + tp)
444+
if num_classes < 1000:
445+
unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
446+
bins = _bincount(unique_mapping, minlength=num_classes**2)
447+
confmat = bins.reshape(num_classes, num_classes)
448+
tp = confmat.diag()
449+
fp = confmat.sum(0) - tp
450+
fn = confmat.sum(1) - tp
451+
tn = confmat.sum() - (fp + fn + tp)
452+
else:
453+
# The above approach requires num_classes**2 memory. For large num_classes, we can calculate the
454+
# statistics separately using linear memory.
455+
tp = _bincount(preds[target == preds], minlength=num_classes)
456+
fp = _bincount(preds, minlength=num_classes) - tp
457+
fn = _bincount(target, minlength=num_classes) - tp
458+
tn = target.numel() - (tp + fp + fn)
451459
return tp, fp, tn, fn
452460

453461

tests/unittests/classification/test_accuracy.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,25 @@ def test_multiclass_accuracy_gpu_sync_points_uptodate(
408408
)
409409

410410

411+
def test_multiclass_accuracy_large_num_classes():
412+
"""Test that accuracy is correct when num_classes>=1000, exercising the linear-space code path."""
413+
num_classes = 1_000_000
414+
n = 500
415+
generator = torch.Generator().manual_seed(42)
416+
target = torch.randint(0, num_classes, (n,), generator=generator)
417+
preds = torch.randint(0, num_classes, (n,), generator=generator)
418+
419+
# We have so many classes that its most likely the accurary is 0 in this test, so we artificially
420+
# set 20% of the predictions to be correct.
421+
artificially_correct = torch.randperm(n, generator=generator)[: n // 5]
422+
preds[artificially_correct] = target[artificially_correct]
423+
424+
# Expected: fraction of exactly correct predictions
425+
expected = (preds == target).float().mean()
426+
result = multiclass_accuracy(preds, target, num_classes=num_classes, average="micro")
427+
assert torch.isclose(result, expected), f"Expected {expected}, got {result}"
428+
429+
411430
_mc_k_target = torch.tensor([0, 1, 2])
412431
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
413432

0 commit comments

Comments
 (0)