Skip to content

Commit 65beb58

Browse files
Add parameter to DiceMetric and DiceHelper classes (#8774)
Fixes #8733 ### Description This PR adds support for connected component-based Dice metric calculation to the existing DiceMetric and DiceHelper classes. ### Changes * Added per_component: bool = False to both DiceMetric and DiceHelper constructors * Implemented compute_cc_dice method that calculates Dice scores for each connected component individually * Voronoi regions: Added compute_voronoi_regions_fast method for efficient connected component assignment without external cc3d dependency * Added input shape validation requiring 5D binary segmentation with 2 channels (background + foreground) when per_component=True * Updated first_ch calculation to properly exclude background channel when using per-component mode ### Reference * https://arxiv.org/abs/2410.18684 * https://github.com/alexanderjaus/CC-Metrics ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d0650f3 commit 65beb58

3 files changed

Lines changed: 213 additions & 2 deletions

File tree

monai/metrics/meandice.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@
1313

1414
import torch
1515

16-
from monai.metrics.utils import do_metric_reduction
16+
from monai.metrics.utils import compute_voronoi_regions_fast, do_metric_reduction
1717
from monai.utils import MetricReduction, deprecated_arg
18+
from monai.utils.module import optional_import
1819

1920
from .metric import CumulativeIterationMetric
2021

22+
scipy_ndimage, has_scipy_ndimage = optional_import("scipy.ndimage")
23+
cupy, has_cupy = optional_import("cupy")
24+
cupy_ndimage, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage")
25+
26+
2127
__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]
2228

2329

@@ -41,6 +47,18 @@ class DiceMetric(CumulativeIterationMetric):
4147
image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
4248
and ground truth is BCHW[D].
4349
50+
The `per_component=True` approach computes the Dice metric on a per-connected component basis in the ground truth segmentation,
51+
ensuring equal weighting for each component regardless of its size. This method eliminates biases in traditional metrics,
52+
providing a more balanced evaluation, particularly in scenarios where object size does not correlate with clinical relevance.
53+
This provides a more granular evaluation of segmentation quality, especially useful when dealing with fragmented or
54+
disconnected objects in the foreground.
55+
Note:
56+
- The input prediction (`y_pred`) and ground truth (`y`) must both have 2 channels (foreground/background),
57+
with binary segmentation (0 for background, 1 for foreground). That is, this assumes the shape of both prediction
58+
and ground truth is B2HW[D].
59+
- This method cannot be used with multiclass segmentation.
60+
For more information, refer to the original paper: https://arxiv.org/abs/2410.18684
61+
4462
The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
4563
4664
Further information can be found in the official
@@ -95,6 +113,9 @@ class DiceMetric(CumulativeIterationMetric):
95113
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
96114
the index begins at "0", otherwise at "1". It can also take a list of label names.
97115
The outcome will then be returned as a dictionary.
116+
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
117+
computed for each connected component in the ground truth, and then averaged. This requires binary
118+
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
98119
99120
"""
100121

@@ -106,6 +127,7 @@ def __init__(
106127
ignore_empty: bool = True,
107128
num_classes: int | None = None,
108129
return_with_label: bool | list[str] = False,
130+
per_component: bool = False,
109131
) -> None:
110132
super().__init__()
111133
self.include_background = include_background
@@ -114,13 +136,15 @@ def __init__(
114136
self.ignore_empty = ignore_empty
115137
self.num_classes = num_classes
116138
self.return_with_label = return_with_label
139+
self.per_component = per_component
117140
self.dice_helper = DiceHelper(
118141
include_background=self.include_background,
119142
reduction=MetricReduction.NONE,
120143
get_not_nans=False,
121144
apply_argmax=False,
122145
ignore_empty=self.ignore_empty,
123146
num_classes=self.num_classes,
147+
per_component=self.per_component,
124148
)
125149

126150
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
@@ -175,6 +199,7 @@ def compute_dice(
175199
include_background: bool = True,
176200
ignore_empty: bool = True,
177201
num_classes: int | None = None,
202+
per_component: bool = False,
178203
) -> torch.Tensor:
179204
"""
180205
Computes Dice score metric for a batch of predictions. This performs the same computation as
@@ -192,6 +217,9 @@ def compute_dice(
192217
num_classes: number of input channels (always including the background). When this is ``None``,
193218
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
194219
single-channel class indices and the number of classes is not automatically inferred from data.
220+
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
221+
computed for each connected component in the ground truth, and then averaged. This requires binary
222+
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
195223
196224
Returns:
197225
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
@@ -204,6 +232,7 @@ def compute_dice(
204232
apply_argmax=False,
205233
ignore_empty=ignore_empty,
206234
num_classes=num_classes,
235+
per_component=per_component,
207236
)(y_pred=y_pred, y=y)
208237

209238

@@ -246,6 +275,9 @@ class DiceHelper:
246275
num_classes: number of input channels (always including the background). When this is ``None``,
247276
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
248277
single-channel class indices and the number of classes is not automatically inferred from data.
278+
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
279+
computed for each connected component in the ground truth, and then averaged. This requires binary
280+
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
249281
"""
250282

251283
@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
@@ -262,6 +294,7 @@ def __init__(
262294
num_classes: int | None = None,
263295
sigmoid: bool | None = None,
264296
softmax: bool | None = None,
297+
per_component: bool = False,
265298
) -> None:
266299
# handling deprecated arguments
267300
if sigmoid is not None:
@@ -277,6 +310,50 @@ def __init__(
277310
self.activate = activate
278311
self.ignore_empty = ignore_empty
279312
self.num_classes = num_classes
313+
self.per_component = per_component
314+
315+
def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
316+
"""
317+
Compute per-component Dice for a single batch item.
318+
319+
Args:
320+
y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W) or (1, 2, H, W).
321+
y (torch.Tensor): Ground truth with shape (1, 2, D, H, W) or (1, 2, H, W).
322+
323+
Returns:
324+
torch.Tensor: Mean Dice over connected components.
325+
"""
326+
if y_pred.ndim == y.ndim:
327+
y_pred_idx = torch.argmax(y_pred, dim=1)
328+
y_idx = torch.argmax(y, dim=1)
329+
else:
330+
y_pred_idx = y_pred
331+
y_idx = y
332+
if y_idx[0].sum() == 0:
333+
if self.ignore_empty:
334+
data = torch.tensor(float("nan"), device=y_idx.device)
335+
elif y_pred_idx.sum() == 0:
336+
data = torch.tensor(1.0, device=y_idx.device)
337+
else:
338+
data = torch.tensor(0.0, device=y_idx.device)
339+
else:
340+
cc_assignment = compute_voronoi_regions_fast(y_idx[0])
341+
if cc_assignment.device != y_idx.device:
342+
cc_assignment = cc_assignment.to(y_idx.device)
343+
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
344+
nof_components = uniq.numel()
345+
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
346+
idx = (inv << 2) | code
347+
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
348+
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
349+
denom = 2 * tp + fp + fn
350+
dice_scores = torch.where(
351+
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
352+
)
353+
data = dice_scores.unsqueeze(-1)
354+
data = torch.nan_to_num(data)
355+
data = data.reshape(-1, 1)
356+
return torch.stack([data.mean()])
280357

281358
def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
282359
"""
@@ -305,6 +382,9 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
305382
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
306383
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
307384
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
385+
386+
Raises:
387+
ValueError: when the shapes of `y_pred` and `y` are not compatible for the per-component computation.
308388
"""
309389
_apply_argmax, _threshold = self.apply_argmax, self.threshold
310390
if self.num_classes is None:
@@ -322,15 +402,31 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
322402
y_pred = torch.sigmoid(y_pred)
323403
y_pred = y_pred > 0.5
324404

325-
first_ch = 0 if self.include_background else 1
405+
if self.per_component:
406+
if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2:
407+
same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5)
408+
binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2
409+
same_shape = y_pred.shape == y.shape
410+
if not (same_rank and binary_channels and same_shape):
411+
raise ValueError(
412+
"per_component requires matching 4D/5D binary tensors "
413+
"(B, 2, H, W) or (B, 2, D, H, W). "
414+
f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
415+
)
416+
417+
first_ch = 0 if self.include_background and not self.per_component else 1
326418
data = []
327419
for b in range(y_pred.shape[0]):
420+
if self.per_component:
421+
data.append(self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0)).reshape(-1))
422+
continue
328423
c_list = []
329424
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
330425
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
331426
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
332427
c_list.append(self.compute_channel(x_pred, x))
333428
data.append(torch.stack(c_list))
429+
334430
data = torch.stack(data, dim=0).contiguous() # type: ignore
335431

336432
f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore

monai/metrics/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
distance_transform_edt, _ = optional_import("scipy.ndimage", name="distance_transform_edt")
4040
distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt")
4141

42+
scipy_ndimage, has_scipy_ndimage = optional_import("scipy.ndimage")
43+
cupy, has_cupy = optional_import("cupy")
44+
cupy_ndimage, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage")
45+
4246
__all__ = [
4347
"ignore_background",
4448
"do_metric_reduction",
@@ -462,6 +466,59 @@ def prepare_spacing(
462466
)
463467

464468

469+
def compute_voronoi_regions_fast(labels: np.ndarray | torch.Tensor) -> torch.Tensor:
470+
"""
471+
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
472+
Returns the ID of the nearest component for each voxel.
473+
474+
Args:
475+
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
476+
477+
Raises:
478+
RuntimeError: when `scipy.ndimage` is not available.
479+
ValueError: when `labels` has fewer than two dimensions.
480+
481+
Returns:
482+
torch.Tensor: Voronoi region IDs (int32) on CPU.
483+
"""
484+
if isinstance(labels, torch.Tensor) and labels.is_cuda and has_cupy and has_cupy_ndimage:
485+
xp = cupy
486+
nd_distance_transform_edt = cupy_ndimage.distance_transform_edt
487+
nd_generate_binary_structure = cupy_ndimage.generate_binary_structure
488+
nd_label = cupy_ndimage.label
489+
x = cupy.asarray(labels.detach())
490+
else:
491+
xp = np
492+
nd_distance_transform_edt = scipy_ndimage.distance_transform_edt
493+
nd_generate_binary_structure = scipy_ndimage.generate_binary_structure
494+
nd_label = scipy_ndimage.label
495+
496+
if not has_scipy_ndimage:
497+
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
498+
499+
if isinstance(labels, torch.Tensor):
500+
warnings.warn(
501+
"Voronoi computation is running on CPU. "
502+
"To accelerate, move the input tensor to GPU and ensure 'cupy' with 'cupyx.scipy.ndimage' is installed."
503+
)
504+
x = labels.cpu().numpy()
505+
else:
506+
x = np.asarray(labels)
507+
rank = conn_rank = x.ndim
508+
structure = nd_generate_binary_structure(rank=rank, connectivity=conn_rank)
509+
cc, num = nd_label(x > 0, structure=structure)
510+
if num == 0:
511+
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
512+
edt_input = xp.ones(cc.shape, dtype=xp.uint8)
513+
edt_input[cc > 0] = 0
514+
indices = nd_distance_transform_edt(edt_input, sampling=None, return_distances=False, return_indices=True)
515+
voronoi = cc[tuple(indices)]
516+
if xp is cupy:
517+
return torch.as_tensor(cupy.asnumpy(voronoi), dtype=torch.int32)
518+
else:
519+
return torch.as_tensor(voronoi, dtype=torch.int32)
520+
521+
465522
ENCODING_KERNEL = {2: [[8, 4], [2, 1]], 3: [[[128, 64], [32, 16]], [[8, 4], [2, 1]]]}
466523

467524

tests/metrics/test_compute_meandice.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from parameterized import parameterized
1919

2020
from monai.metrics import DiceHelper, DiceMetric, compute_dice
21+
from monai.utils.module import optional_import
22+
23+
_, has_ndimage = optional_import("scipy.ndimage")
24+
_, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage")
2125

2226
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
2327
# keep background
@@ -250,6 +254,42 @@
250254
{"label_1": 0.4000, "label_2": 0.6667},
251255
]
252256

257+
# Testcase for per_component DiceMetric - 3D input
258+
y = torch.zeros((5, 2, 64, 64, 64), device=_device)
259+
y_hat = torch.zeros((5, 2, 64, 64, 64), device=_device)
260+
261+
y[0, 1, 20:25, 20:25, 20:25] = 1
262+
y[0, 1, 40:45, 40:45, 40:45] = 1
263+
y[0, 0] = 1 - y[0, 1]
264+
265+
y_hat[0, 1, 21:26, 21:26, 21:26] = 1
266+
y_hat[0, 1, 41:46, 39:44, 41:46] = 1
267+
y_hat[0, 0] = 1 - y_hat[0, 1]
268+
269+
TEST_CASE_16 = [
270+
{"per_component": True, "ignore_empty": False},
271+
{"y": y, "y_pred": y_hat},
272+
[[0.5120], [1.0], [1.0], [1.0], [1.0]],
273+
]
274+
275+
# Testcase for per_component DiceMetric - 2D input
276+
y = torch.zeros((5, 2, 64, 64), device=_device)
277+
y_hat = torch.zeros((5, 2, 64, 64), device=_device)
278+
279+
y[0, 1, 20:25, 20:25] = 1
280+
y[0, 1, 40:45, 40:45] = 1
281+
y[0, 0] = 1 - y[0, 1]
282+
283+
y_hat[0, 1, 21:26, 21:26] = 1
284+
y_hat[0, 1, 41:46, 39:44] = 1
285+
y_hat[0, 0] = 1 - y_hat[0, 1]
286+
287+
TEST_CASE_17 = [
288+
{"per_component": True, "ignore_empty": False},
289+
{"y": y, "y_pred": y_hat},
290+
[[0.6400], [1.0], [1.0], [1.0], [1.0]],
291+
]
292+
253293

254294
class TestComputeMeanDice(unittest.TestCase):
255295

@@ -301,6 +341,24 @@ def test_nans_class(self, params, input_data, expected_value):
301341
else:
302342
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
303343

344+
# CC DiceMetric tests
345+
@parameterized.expand([TEST_CASE_16, TEST_CASE_17])
346+
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
347+
def test_cc_dice_value_nogpu(self, params, input_data, expected_value):
348+
dice_metric = DiceMetric(**params)
349+
if not has_cupy_ndimage:
350+
cpu_inputs = {"y": input_data["y"].cpu(), "y_pred": input_data["y_pred"].cpu()}
351+
dice_metric(**cpu_inputs)
352+
else:
353+
dice_metric(**input_data)
354+
result = dice_metric.aggregate(reduction="none")
355+
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
356+
357+
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
358+
def test_channel_dimensions(self):
359+
with self.assertRaises(ValueError):
360+
DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 144, 144]))
361+
304362

305363
if __name__ == "__main__":
306364
unittest.main()

0 commit comments

Comments
 (0)