Skip to content

MetricCollection not working with MeanAveragePrecision #3335

@Inkorak

Description

@Inkorak

🐛 Bug

MetricCollection does not work with MeanAveragePrecision.

To Reproduce

Sample code and steps to reproduce the behavior with expected result...

from torch import tensor
import torch
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.collections import MetricCollection

mask_pred = [
    [0, 0, 0, 0, 0],
    [0, 0, 1, 1, 0],
    [0, 0, 1, 1, 0],
    [0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0],
]
mask_tgt = [
    [0, 0, 0, 0, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 1, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 0, 0, 0, 0],
]
preds = [
    dict(
        masks=tensor([mask_pred, mask_pred], dtype=torch.bool),
        boxes=tensor([[258.0, 41.0, 606.0, 285.0], [258.0, 41.0, 606.0, 285.0]]),
        scores=tensor([0.536, 0.4]),
        labels=tensor([0, 0]),
    )
]
target = [
    dict(
        masks=tensor([mask_tgt, mask_tgt], dtype=torch.bool),
        boxes=tensor([[214.0, 41.0, 562.0, 285.0], [214.0, 41.0, 562.0, 285.0]]),
        labels=tensor([0, 1]),
    )
]

metric_map = MeanAveragePrecision(iou_type=["segm", "bbox"], class_metrics=True)
metric_map_50 = MeanAveragePrecision(
    iou_type=["segm", "bbox"], iou_thresholds=[50], class_metrics=True
)
metric = MetricCollection({"iou_all": metric_map, "iou_50": metric_map_50})
metric.update(preds, target)

res = metric.compute()
  • TorchMetrics version: 1.8.2

Additional context

Error message

    metric.update(preds, target)
  File "/home/user/miniconda3/envs/cv-classification2/lib/python3.10/site-packages/torchmetrics/collections.py", line 263, in update
    self._merge_compute_groups()
  File "/home/user/miniconda3/envs/cv-classification2/lib/python3.10/site-packages/torchmetrics/collections.py", line 286, in _merge_compute_groups
    if self._equal_metric_states(metric1, metric2):
  File "/home/user/miniconda3/envs/cv-classification2/lib/python3.10/site-packages/torchmetrics/collections.py", line 332, in _equal_metric_states
    and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2)))
  File "/home/user/miniconda3/envs/cv-classification2/lib/python3.10/site-packages/torchmetrics/collections.py", line 332, in <genexpr>
    and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2)))
AttributeError: 'tuple' object has no attribute 'shape'

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions