Skip to content

Commit 3c27d3f

Browse files
committed
fix: Handle mixed single/multi-channel y and y_pred in DiceHelper
Convert y_pred and y to boolean independently based on each tensor's own channel count, fixing incorrect Dice values when formats differ (e.g. single-channel class indices paired with multi-channel one-hot). Add test cases covering both mixed-format combinations. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 5d1456e commit 3c27d3f

File tree

2 files changed

+67
-13
lines changed

2 files changed

+67
-13
lines changed

monai/metrics/meandice.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -307,23 +307,21 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
307307
batch_size = y_pred.shape[0]
308308
device = y_pred.device
309309

310-
# Convert to boolean for computation
310+
# Convert y_pred to boolean (handle single-channel class indices vs multi-channel one-hot independently)
311311
if y_pred.shape[1] == 1 and n_pred_ch > 1:
312-
# Single-channel class indices: convert to one-hot
313312
y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device)
314-
y_bool = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.bool, device=device)
315-
316313
for c in range(n_pred_ch):
317314
y_pred_bool[:, c] = (y_pred[:, 0] == c)
318-
y_bool[:, c] = (y[:, 0] == c)
319315
else:
320-
# One-hot format: cast to bool
321316
y_pred_bool = y_pred.bool()
322-
if y.shape[1] == 1 and y_pred.shape[1] > 1:
323-
# Expand y to match y_pred channels
324-
y_bool = (y == 1).expand(batch_size, n_pred_ch, *y.shape[2:])
325-
else:
326-
y_bool = y.bool()
317+
318+
# Convert y to boolean (independent of y_pred format)
319+
if y.shape[1] == 1 and n_pred_ch > 1:
320+
y_bool = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.bool, device=device)
321+
for c in range(n_pred_ch):
322+
y_bool[:, c] = (y[:, 0] == c)
323+
else:
324+
y_bool = y.bool()
327325

328326
# Flatten spatial dimensions for vectorized computation: (batch, channels, -1)
329327
y_pred_flat = y_pred_bool.reshape(batch_size, n_pred_ch, -1).float()

tests/metrics/test_compute_meandice.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,15 +251,71 @@
251251
]
252252

253253

254+
# single-channel y (class indices) with multi-channel y_pred (one-hot)
255+
TEST_CASE_MIXED_1 = [
256+
{
257+
"y_pred": torch.tensor(
258+
[[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]]]]
259+
), # (1, 3, 2, 2) one-hot
260+
"y": torch.tensor([[[[0.0, 1.0], [2.0, 1.0]]]]), # (1, 1, 2, 2) class indices
261+
"include_background": True,
262+
},
263+
# class 0: y_gt=[[1,0],[0,0]], y_pred=[[0,1],[0,0]] -> dice=0.0
264+
# class 1: y_gt=[[0,1],[0,1]], y_pred=[[0,0],[0,1]] -> dice=2/3
265+
# class 2: y_gt=[[0,0],[1,0]], y_pred=[[1,0],[1,0]] -> dice=2/3
266+
[[0.0000, 0.6667, 0.6667]],
267+
]
268+
269+
# single-channel y_pred (argmaxed, with num_classes) with multi-channel y (one-hot)
270+
TEST_CASE_MIXED_2 = [
271+
{
272+
"y_pred": torch.tensor([[[[2.0, 2.0], [2.0, 2.0]]]]), # (1, 1, 2, 2) all class 2
273+
"y": torch.tensor(
274+
[[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]
275+
), # (1, 3, 2, 2) one-hot, all background
276+
"include_background": True,
277+
"num_classes": 3,
278+
},
279+
# class 0: y_gt=[1,1,1,1](4), y_pred=[0,0,0,0](0) -> dice=0.0
280+
# class 1: y_gt=[0,0,0,0](0), y_pred=[0,0,0,0](0) -> dice=nan (ignore_empty default)
281+
# class 2: y_gt=[0,0,0,0](0), y_pred=[1,1,1,1](4) -> dice=nan (ignore_empty default)
282+
[[False, True, True]], # False=not-nan, True=nan
283+
]
284+
285+
# single-channel y (class indices) with multi-channel y_pred, exclude background
286+
TEST_CASE_MIXED_3 = [
287+
{
288+
"y_pred": torch.tensor(
289+
[
290+
[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],
291+
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],
292+
]
293+
), # (2, 3, 2, 2) one-hot
294+
"y": torch.tensor(
295+
[
296+
[[[0.0, 0.0], [0.0, 1.0]]],
297+
[[[0.0, 0.0], [0.0, 1.0]]],
298+
]
299+
), # (2, 1, 2, 2) class indices
300+
"include_background": False,
301+
},
302+
# batch 0: class 1 y_gt=[[0,0],[0,1]], y_pred=[[0,0],[1,1]] -> dice=2/3
303+
# class 2 y_gt=[[0,0],[0,0]], y_pred=[[1,0],[0,0]] -> dice=nan
304+
# batch 1: class 1 y_gt=[[0,0],[0,1]], y_pred=[[1,0],[0,0]] -> dice=0.0
305+
# class 2 y_gt=[[0,0],[0,0]], y_pred=[[0,1],[1,0]] -> dice=nan
306+
[[False, True], [False, True]], # nan pattern
307+
]
308+
309+
254310
class TestComputeMeanDice(unittest.TestCase):
255311

256-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
312+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12, TEST_CASE_MIXED_1])
257313
def test_value(self, input_data, expected_value):
258314
result = compute_dice(**input_data)
259315
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
260316
np.testing.assert_equal(result.device, input_data["y_pred"].device)
261317

262-
@parameterized.expand([TEST_CASE_3])
318+
@parameterized.expand([TEST_CASE_3, TEST_CASE_MIXED_2, TEST_CASE_MIXED_3])
263319
def test_nans(self, input_data, expected_value):
264320
result = compute_dice(**input_data)
265321
self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))

0 commit comments

Comments
 (0)