|
251 | 251 | ] |
252 | 252 |
|
253 | 253 |
|
| 254 | +# single-channel y (class indices) with multi-channel y_pred (one-hot) |
| 255 | +TEST_CASE_MIXED_1 = [ |
| 256 | + { |
| 257 | + "y_pred": torch.tensor( |
| 258 | + [[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]]]] |
| 259 | + ), # (1, 3, 2, 2) one-hot |
| 260 | + "y": torch.tensor([[[[0.0, 1.0], [2.0, 1.0]]]]), # (1, 1, 2, 2) class indices |
| 261 | + "include_background": True, |
| 262 | + }, |
| 263 | + # class 0: y_gt=[[1,0],[0,0]], y_pred=[[0,1],[0,0]] -> dice=0.0 |
| 264 | + # class 1: y_gt=[[0,1],[0,1]], y_pred=[[0,0],[0,1]] -> dice=2/3 |
| 265 | + # class 2: y_gt=[[0,0],[1,0]], y_pred=[[1,0],[1,0]] -> dice=2/3 |
| 266 | + [[0.0000, 0.6667, 0.6667]], |
| 267 | +] |
| 268 | + |
| 269 | +# single-channel y_pred (argmaxed, with num_classes) with multi-channel y (one-hot) |
| 270 | +TEST_CASE_MIXED_2 = [ |
| 271 | + { |
| 272 | + "y_pred": torch.tensor([[[[2.0, 2.0], [2.0, 2.0]]]]), # (1, 1, 2, 2) all class 2 |
| 273 | + "y": torch.tensor( |
| 274 | + [[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]] |
| 275 | + ), # (1, 3, 2, 2) one-hot, all background |
| 276 | + "include_background": True, |
| 277 | + "num_classes": 3, |
| 278 | + }, |
| 279 | + # class 0: y_gt=[1,1,1,1](4), y_pred=[0,0,0,0](0) -> dice=0.0 |
| 280 | + # class 1: y_gt=[0,0,0,0](0), y_pred=[0,0,0,0](0) -> dice=nan (ignore_empty default) |
| 281 | + # class 2: y_gt=[0,0,0,0](0), y_pred=[1,1,1,1](4) -> dice=nan (ignore_empty default) |
| 282 | + [[False, True, True]], # False=not-nan, True=nan |
| 283 | +] |
| 284 | + |
| 285 | +# single-channel y (class indices) with multi-channel y_pred, exclude background |
| 286 | +TEST_CASE_MIXED_3 = [ |
| 287 | + { |
| 288 | + "y_pred": torch.tensor( |
| 289 | + [ |
| 290 | + [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]], |
| 291 | + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]], |
| 292 | + ] |
| 293 | + ), # (2, 3, 2, 2) one-hot |
| 294 | + "y": torch.tensor( |
| 295 | + [ |
| 296 | + [[[0.0, 0.0], [0.0, 1.0]]], |
| 297 | + [[[0.0, 0.0], [0.0, 1.0]]], |
| 298 | + ] |
| 299 | + ), # (2, 1, 2, 2) class indices |
| 300 | + "include_background": False, |
| 301 | + }, |
| 302 | + # batch 0: class 1 y_gt=[[0,0],[0,1]], y_pred=[[0,0],[1,1]] -> dice=2/3 |
| 303 | + # class 2 y_gt=[[0,0],[0,0]], y_pred=[[1,0],[0,0]] -> dice=nan |
| 304 | + # batch 1: class 1 y_gt=[[0,0],[0,1]], y_pred=[[1,0],[0,0]] -> dice=0.0 |
| 305 | + # class 2 y_gt=[[0,0],[0,0]], y_pred=[[0,1],[1,0]] -> dice=nan |
| 306 | + [[False, True], [False, True]], # nan pattern |
| 307 | +] |
| 308 | + |
| 309 | + |
254 | 310 | class TestComputeMeanDice(unittest.TestCase): |
255 | 311 |
|
256 | | - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12]) |
| 312 | + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12, TEST_CASE_MIXED_1]) |
257 | 313 | def test_value(self, input_data, expected_value): |
258 | 314 | result = compute_dice(**input_data) |
259 | 315 | np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) |
260 | 316 | np.testing.assert_equal(result.device, input_data["y_pred"].device) |
261 | 317 |
|
262 | | - @parameterized.expand([TEST_CASE_3]) |
| 318 | + @parameterized.expand([TEST_CASE_3, TEST_CASE_MIXED_2, TEST_CASE_MIXED_3]) |
263 | 319 | def test_nans(self, input_data, expected_value): |
264 | 320 | result = compute_dice(**input_data) |
265 | 321 | self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value)) |
|
0 commit comments