MetricCollection does not work with MeanAveragePrecision.
Sample code and steps to reproduce the behavior with expected result...
from torch import tensor
import torch
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.collections import MetricCollection
mask_pred = [
[0, 0, 0, 0, 0],
[0, 0, 1, 1, 0],
[0, 0, 1, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
mask_tgt = [
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0],
]
preds = [
dict(
masks=tensor([mask_pred, mask_pred], dtype=torch.bool),
boxes=tensor([[258.0, 41.0, 606.0, 285.0], [258.0, 41.0, 606.0, 285.0]]),
scores=tensor([0.536, 0.4]),
labels=tensor([0, 0]),
)
]
target = [
dict(
masks=tensor([mask_tgt, mask_tgt], dtype=torch.bool),
boxes=tensor([[214.0, 41.0, 562.0, 285.0], [214.0, 41.0, 562.0, 285.0]]),
labels=tensor([0, 1]),
)
]
metric_map = MeanAveragePrecision(iou_type=["segm", "bbox"], class_metrics=True)
metric_map_50 = MeanAveragePrecision(
iou_type=["segm", "bbox"], iou_thresholds=[50], class_metrics=True
)
metric = MetricCollection({"iou_all": metric_map, "iou_50": metric_map_50})
metric.update(preds, target)
res = metric.compute()
metric.update(preds, target)
File "/home/user/miniconda3/envs/cv-classification2/lib/python3.10/site-packages/torchmetrics/collections.py", line 263, in update
self._merge_compute_groups()
File "/home/user/miniconda3/envs/cv-classification2/lib/python3.10/site-packages/torchmetrics/collections.py", line 286, in _merge_compute_groups
if self._equal_metric_states(metric1, metric2):
File "/home/user/miniconda3/envs/cv-classification2/lib/python3.10/site-packages/torchmetrics/collections.py", line 332, in _equal_metric_states
and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2)))
File "/home/user/miniconda3/envs/cv-classification2/lib/python3.10/site-packages/torchmetrics/collections.py", line 332, in <genexpr>
and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2)))
AttributeError: 'tuple' object has no attribute 'shape'
🐛 Bug
MetricCollection does not work with MeanAveragePrecision.
To Reproduce
Sample code and steps to reproduce the behavior with expected result...
Additional context
Error message