2121
2222linear_sum_assignment , _ = optional_import ("scipy.optimize" , name = "linear_sum_assignment" )
2323
24- __all__ = ["PanopticQualityMetric" , "compute_panoptic_quality" ]
24+ __all__ = ["PanopticQualityMetric" , "compute_panoptic_quality" , "compute_mean_iou" ]
2525
2626
2727class PanopticQualityMetric (CumulativeIterationMetric ):
@@ -55,6 +55,8 @@ class PanopticQualityMetric(CumulativeIterationMetric):
5555 If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
5656 maximal amount of unique pairing.
5757 smooth_numerator: a small constant added to the numerator to avoid zero.
58+ return_confusion_matrix: if True, returns raw confusion matrix values (tp, fp, fn, iou_sum)
59+ instead of computed metrics. Default is False.
5860
5961 """
6062
@@ -65,19 +67,22 @@ def __init__(
6567 reduction : MetricReduction | str = MetricReduction .MEAN_BATCH ,
6668 match_iou_threshold : float = 0.5 ,
6769 smooth_numerator : float = 1e-6 ,
70+ return_confusion_matrix : bool = False ,
6871 ) -> None :
6972 super ().__init__ ()
7073 self .num_classes = num_classes
7174 self .reduction = reduction
7275 self .match_iou_threshold = match_iou_threshold
7376 self .smooth_numerator = smooth_numerator
7477 self .metric_name = ensure_tuple (metric_name )
78+ self .return_confusion_matrix = return_confusion_matrix
7579
7680 def _compute_tensor (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
7781 """
7882 Args:
79- y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the
80- second channel represent the instance predictions and classification predictions respectively.
83+ y_pred: Predictions. It must be in the form of B2HW (2D) or B2HWD (3D) and have integer type.
84+ The first channel and the second channel represent the instance predictions and classification
85+ predictions respectively.
8186 y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the
8287 second channel represent the instance labels and classification labels respectively.
8388 Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`,
@@ -86,7 +91,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
8691 Raises:
8792 ValueError: when `y_pred` and `y` have different shapes.
8893 ValueError: when `y_pred` and `y` have != 2 channels.
89- ValueError: when `y_pred` and `y` have != 4 dimensions.
94+ ValueError: when `y_pred` and `y` have != 4 or 5 dimensions.
9095
9196 """
9297 if y_pred .shape != y .shape :
@@ -98,8 +103,10 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
98103 )
99104
100105 dims = y_pred .ndimension ()
101- if dims != 4 :
102- raise ValueError (f"y_pred should have 4 dimensions (batch, 2, h, w), got { dims } ." )
106+ if dims not in (4 , 5 ):
107+ raise ValueError (
108+ f"y_pred should have 4 dimensions (batch, 2, h, w) or 5 dimensions (batch, 2, h, w, d), got { dims } ."
109+ )
103110
104111 batch_size = y_pred .shape [0 ]
105112
@@ -131,13 +138,22 @@ def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Ten
131138 available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
132139 ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
133140
141+ Returns:
142+ If `return_confusion_matrix` is True, returns the raw confusion matrix [tp, fp, fn, iou_sum].
143+ Otherwise, returns the computed metric(s) based on `metric_name`.
144+
134145 """
135146 data = self .get_buffer ()
136147 if not isinstance (data , torch .Tensor ):
137148 raise ValueError ("the data to aggregate must be PyTorch Tensor." )
138149
139150 # do metric reduction
140151 f , _ = do_metric_reduction (data , reduction or self .reduction )
152+
153+ if self .return_confusion_matrix :
154+ # Return raw confusion matrix values
155+ return f
156+
141157 tp , fp , fn , iou_sum = f [..., 0 ], f [..., 1 ], f [..., 2 ], f [..., 3 ]
142158 results = []
143159 for metric_name in self .metric_name :
@@ -169,7 +185,7 @@ def compute_panoptic_quality(
169185 calculate PQ, and returning them directly enables further calculation over all images.
170186
171187 Args:
172- pred: input data to compute, it must be in the form of HW and have integer type.
188+ pred: input data to compute, it must be in the form of HW (2D) or HWD (3D) and have integer type.
173189 gt: ground truth. It must have the same shape as `pred` and have integer type.
174190 metric_name: output metric. The value can be "pq", "sq" or "rq".
175191 remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id.
@@ -294,3 +310,24 @@ def _check_panoptic_metric_name(metric_name: str) -> str:
294310 if metric_name in ["recognition_quality" , "rq" ]:
295311 return "rq"
296312 raise ValueError (f"metric name: { metric_name } is wrong, please use 'pq', 'sq' or 'rq'." )
313+
314+
315+ def compute_mean_iou (confusion_matrix : torch .Tensor , smooth_numerator : float = 1e-6 ) -> torch .Tensor :
316+ """Compute mean IoU from confusion matrix values.
317+
318+ Args:
319+ confusion_matrix: tensor with shape (..., 4) where the last dimension contains
320+ [tp, fp, fn, iou_sum] as returned by `compute_panoptic_quality` with `output_confusion_matrix=True`.
321+ smooth_numerator: a small constant added to the numerator to avoid zero.
322+
323+ Returns:
324+ Mean IoU computed as iou_sum / (tp + smooth_numerator).
325+
326+ """
327+ if confusion_matrix .shape [- 1 ] != 4 :
328+ raise ValueError (
329+ f"confusion_matrix should have shape (..., 4) with [tp, fp, fn, iou_sum], "
330+ f"got shape { confusion_matrix .shape } ."
331+ )
332+ tp , iou_sum = confusion_matrix [..., 0 ], confusion_matrix [..., 3 ]
333+ return iou_sum / (tp + smooth_numerator )
0 commit comments