Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
b5107da
Add parameter to DiceMetric and DiceHelper classes
VijayVignesh1 Mar 12, 2026
ccca77a
Adding per_component information to inline docstring
VijayVignesh1 Mar 13, 2026
c110e2a
fixing indentation and formatting
VijayVignesh1 Mar 13, 2026
41e52c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2026
34a6817
Adding optional import for scipy and fixing issues raised by coderabb…
VijayVignesh1 Mar 13, 2026
8d412a1
fixing indentation and formatting
VijayVignesh1 Mar 13, 2026
cb433a8
Adding optional import for scipy and fixing issues raised by coderabb…
VijayVignesh1 Mar 13, 2026
ba2e0b3
Adding unittest skipUnless for scipy.ndimage and resolving mypy bug
VijayVignesh1 Mar 13, 2026
d9bfb5d
Adding unittest skip only to test cc functions and resolving shape ch…
VijayVignesh1 Mar 13, 2026
6f2155c
Extending per_component calculations to both 2D and 3D
VijayVignesh1 Mar 16, 2026
ba05438
Adding cupy support to per_component calculations
VijayVignesh1 Mar 16, 2026
4e6def7
Debugging extra axis bug and testcase bugs
VijayVignesh1 Mar 16, 2026
925e431
moving compute_voronoi_regions_fast to utils, removing hardcoded conn…
VijayVignesh1 Mar 31, 2026
28a2944
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
24a17e9
fixing linting issues
VijayVignesh1 Mar 31, 2026
a74f2cf
Merge branch 'dev' into 8733-per-component-dice-metric
VijayVignesh1 Apr 21, 2026
b3bcba4
Improving the docstring explanation
VijayVignesh1 Apr 29, 2026
3d19dd1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
a169c08
Merge branch 'dev' into 8733-per-component-dice-metric
VijayVignesh1 Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion monai/metrics/meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from __future__ import annotations

import numpy as np
import torch
from scipy.ndimage import distance_transform_edt, generate_binary_structure
from scipy.ndimage import label as sn_label
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction, deprecated_arg
Expand Down Expand Up @@ -95,6 +98,9 @@ class DiceMetric(CumulativeIterationMetric):
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
the index begins at "0", otherwise at "1". It can also take a list of label names.
The outcome will then be returned as a dictionary.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

"""

Expand All @@ -106,6 +112,7 @@ def __init__(
ignore_empty: bool = True,
num_classes: int | None = None,
return_with_label: bool | list[str] = False,
per_component: bool = False,
) -> None:
super().__init__()
self.include_background = include_background
Expand All @@ -114,13 +121,15 @@ def __init__(
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.return_with_label = return_with_label
self.per_component = per_component
self.dice_helper = DiceHelper(
include_background=self.include_background,
reduction=MetricReduction.NONE,
get_not_nans=False,
apply_argmax=False,
ignore_empty=self.ignore_empty,
num_classes=self.num_classes,
per_component=self.per_component,
)

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
Expand Down Expand Up @@ -175,6 +184,7 @@ def compute_dice(
include_background: bool = True,
ignore_empty: bool = True,
num_classes: int | None = None,
per_component: bool = False,
) -> torch.Tensor:
"""
Computes Dice score metric for a batch of predictions. This performs the same computation as
Expand All @@ -192,6 +202,9 @@ def compute_dice(
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.

Returns:
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
Expand All @@ -204,6 +217,7 @@ def compute_dice(
apply_argmax=False,
ignore_empty=ignore_empty,
num_classes=num_classes,
per_component=per_component,
)(y_pred=y_pred, y=y)


Expand Down Expand Up @@ -246,6 +260,9 @@ class DiceHelper:
num_classes: number of input channels (always including the background). When this is ``None``,
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
single-channel class indices and the number of classes is not automatically inferred from data.
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
computed for each connected component in the ground truth, and then averaged. This requires 5D binary
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
"""

@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
Expand All @@ -262,6 +279,7 @@ def __init__(
num_classes: int | None = None,
sigmoid: bool | None = None,
softmax: bool | None = None,
per_component: bool = False,
) -> None:
# handling deprecated arguments
if sigmoid is not None:
Expand All @@ -277,6 +295,73 @@ def __init__(
self.activate = activate
self.ignore_empty = ignore_empty
self.num_classes = num_classes
self.per_component = per_component

def compute_voronoi_regions_fast(self, labels, connectivity=26, sampling=None):
"""
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
Returns the ID of the nearest component for each voxel.

Args:
labels: input label map as a numpy array, where values > 0 are considered seeds for connected components.
connectivity: 6/18/26 (3D)
sampling: voxel spacing for anisotropic distances (scipy.ndimage.distance_transform_edt)
"""

x = np.asarray(labels)
conn_rank = {6: 1, 18: 2, 26: 3}.get(connectivity, 3)
structure = generate_binary_structure(rank=3, connectivity=conn_rank)
cc, num = sn_label(x > 0, structure=structure)
if num == 0:
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
edt_input = np.ones(cc.shape, dtype=np.uint8)
edt_input[cc > 0] = 0
indices = distance_transform_edt(edt_input, sampling=sampling, return_distances=False, return_indices=True)
voronoi = cc[tuple(indices)]
return torch.from_numpy(voronoi)
Comment thread
VijayVignesh1 marked this conversation as resolved.
Outdated

def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Comment thread
ericspod marked this conversation as resolved.
"""
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
for each batch item and for each channel of those items.

Args:
y_pred: input predictions with shape HW[D].
y: ground truth with shape HW[D].
"""
data = []
Comment thread
VijayVignesh1 marked this conversation as resolved.
Outdated
if y_pred.ndim == y.ndim:
y_pred_idx = torch.argmax(y_pred, dim=1)
y_idx = torch.argmax(y, dim=1)
else:
y_pred_idx = y_pred
y_idx = y
if y_idx[0].sum() == 0:
if y_pred_idx.sum() == 0:
data.append(torch.tensor(1.0, device=y_idx.device))
else:
data.append(torch.tensor(0.0, device=y_idx.device))
else:
cc_assignment = self.compute_voronoi_regions_fast(y_idx[0])
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
nof_components = uniq.numel()
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
idx = (inv << 2) | code
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
denom = 2 * tp + fp + fn
dice_scores = torch.where(
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
)
data.append(dice_scores.unsqueeze(-1))
data = [
torch.where(torch.isinf(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
data = [
torch.where(torch.isnan(x), torch.tensor(0.0, dtype=torch.float32, device=x.device), x) for x in data
]
Comment thread
VijayVignesh1 marked this conversation as resolved.
Outdated
data = [x.reshape(-1, 1) for x in data]
return torch.stack([x.mean() for x in data])
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -322,15 +407,24 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred > 0.5

first_ch = 0 if self.include_background else 1
if self.per_component and (len(y_pred.shape) != 5 or y_pred.shape[1] != 2):
raise ValueError(
f"per_component requires 5D binary segmentation with 2 channels (background + foreground). "
f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

first_ch = 0 if self.include_background and not self.per_component else 1
Comment thread
VijayVignesh1 marked this conversation as resolved.
data = []
for b in range(y_pred.shape[0]):
c_list = []
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
c_list.append(self.compute_channel(x_pred, x))
if self.per_component:
c_list = [self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0))]
data.append(torch.stack(c_list))
Comment thread
coderabbitai[bot] marked this conversation as resolved.

data = torch.stack(data, dim=0).contiguous() # type: ignore

f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
Expand Down
37 changes: 37 additions & 0 deletions tests/metrics/test_compute_meandice.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,31 @@
{"label_1": 0.4000, "label_2": 0.6667},
]

TEST_CASE_16 = [
{"per_component": True},
{
"y": (
lambda: (
y := torch.zeros((5, 2, 64, 64, 64)),
y.__setitem__((0, 1, slice(20, 25), slice(20, 25), slice(20, 25)), 1),
y.__setitem__((0, 1, slice(40, 45), slice(40, 45), slice(40, 45)), 1),
y.__setitem__((0, 0), 1 - y[0, 1]),
y,
)[-1]
)(),
"y_pred": (
lambda: (
y_hat := torch.zeros((5, 2, 64, 64, 64)),
y_hat.__setitem__((0, 1, slice(21, 26), slice(21, 26), slice(21, 26)), 1),
y_hat.__setitem__((0, 1, slice(41, 46), slice(39, 44), slice(41, 46)), 1),
y_hat.__setitem__((0, 0), 1 - y_hat[0, 1]),
y_hat,
)[-1]
)(),
},
[[[0.5120]], [[1.0]], [[1.0]], [[1.0]], [[1.0]]],
]


class TestComputeMeanDice(unittest.TestCase):

Expand Down Expand Up @@ -301,6 +326,18 @@ def test_nans_class(self, params, input_data, expected_value):
else:
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

# CC DiceMetric tests
@parameterized.expand([TEST_CASE_16])
def test_cc_dice_value(self, params, input_data, expected_value):
dice_metric = DiceMetric(**params)
dice_metric(**input_data)
result = dice_metric.aggregate(reduction="none")
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

def test_input_dimensions(self):
with self.assertRaises(ValueError):
DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))


if __name__ == "__main__":
unittest.main()
Loading