🐛 Bug
Not really a bug, but a limitation that can be solved easily.
The current implementation for the MulticlassAccuracy uses quadratic memory relative to the number of classes.
In my specific case i tried to use it for hundreds of thousands of classes, and got CUDA OOM.
From looking at the code, looks like the full confusion matrix was created just for easier and cleaner code, but its easy to use only linear memory.
Thanks in advance for the support ❤️
To Reproduce
import torch
from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(average='macro', num_classes=1_000_000)
metric.update(torch.tensor([1]), torch.tensor([1])) # <-- OOM here
Additional context
Fix:
# _multiclass_stat_scores_update
- unique_mapping = target.to(torch.long) * num_classes + preds.to(torch.long)
- bins = _bincount(unique_mapping, minlength=num_classes**2)
- confmat = bins.reshape(num_classes, num_classes)
- tp = confmat.diag()
- fp = confmat.sum(0) - tp
- fn = confmat.sum(1) - tp
- tn = confmat.sum() - (fp + fn + tp)
+ tp = _bincount(preds[target == preds], minlength=num_classes)
+ fp = _bincount(preds, minlength=num_classes) - tp
+ fn = _bincount(target, minlength=num_classes) - tp
+ tn = target.numel() - (tp + fp + fn)
#3342
🐛 Bug
Not really a bug, but a limitation that can be solved easily.
The current implementation for the
MulticlassAccuracyuses quadratic memory relative to the number of classes.In my specific case i tried to use it for hundreds of thousands of classes, and got CUDA OOM.
From looking at the code, looks like the full confusion matrix was created just for easier and cleaner code, but its easy to use only linear memory.
Thanks in advance for the support ❤️
To Reproduce
Additional context
Fix:
#3342