Skip to content

Commit d30de96

Browse files
Add MCC loss (#8785)
Add Matthews Correlation Coefficient (MCC) Lossa Fixes #8784 . ### Description Add the Matthews Correlation Coefficient (MCC) loss function to monai.losses. Unlike Dice and Tversky losses which only use TP, FP, and FN, the MCC loss considers all four entries of the confusion matrix (TP, TN, FP, FN), making it effective for class-imbalanced segmentation tasks. The loss was proposed in Abhishek & Hamarneh (IEEE ISBI 2021), has been cited 75 times, and has been adopted by Segmentation Models PyTorch (smp). The implementation follows MONAI conventions, supporting sigmoid, softmax, to_onehot_y, include_background, batch, and reduction parameters. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Kumar Abhishek <7644965+kakumarabhishek@users.noreply.github.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Kumar Abhishek <7644965+kakumarabhishek@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 46b2c0b commit d30de96

File tree

4 files changed

+348
-0
lines changed

4 files changed

+348
-0
lines changed

docs/source/losses.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ Segmentation Losses
9898
.. autoclass:: NACLLoss
9999
:members:
100100

101+
`MCCLoss`
102+
~~~~~~~~~
103+
.. autoclass:: MCCLoss
104+
:members:
105+
101106
Registration Losses
102107
-------------------
103108

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .giou_loss import BoxGIoULoss, giou
3737
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
3838
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
39+
from .mcc_loss import MCCLoss
3940
from .multi_scale import MultiScaleLoss
4041
from .nacl_loss import NACLLoss
4142
from .perceptual import PerceptualLoss

monai/losses/mcc_loss.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import warnings
15+
from collections.abc import Callable
16+
17+
import torch
18+
from torch.nn.modules.loss import _Loss
19+
20+
from monai.networks import one_hot
21+
from monai.utils import LossReduction
22+
23+
24+
class MCCLoss(_Loss):
25+
"""
26+
Compute the Matthews Correlation Coefficient (MCC) loss between two tensors.
27+
28+
Unlike Dice and Tversky losses which only use TP, FP, and FN, the MCC loss considers all four entries
29+
of the confusion matrix (TP, TN, FP, FN), making it effective for class-imbalanced segmentation tasks
30+
where background dominates the image. The loss is computed as ``1 - MCC`` where
31+
``MCC = (TP * TN - FP * FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))``.
32+
33+
The soft confusion matrix entries are computed as:
34+
35+
- ``TP = sum(input * target)``
36+
- ``TN = sum((1 - input) * (1 - target))``
37+
- ``FP = sum(input * (1 - target))``
38+
- ``FN = sum((1 - input) * target)``
39+
40+
The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
41+
42+
Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,
43+
must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`
44+
can be 1 or N (one-hot format).
45+
46+
The original paper:
47+
48+
Abhishek, K. and Hamarneh, G. (2021) Matthews Correlation Coefficient Loss for Deep Convolutional
49+
Networks: Application to Skin Lesion Segmentation. IEEE ISBI, pp. 225-229.
50+
(https://doi.org/10.1109/ISBI48211.2021.9433782)
51+
52+
"""
53+
54+
def __init__(
55+
self,
56+
include_background: bool = True,
57+
to_onehot_y: bool = False,
58+
sigmoid: bool = False,
59+
softmax: bool = False,
60+
other_act: Callable | None = None,
61+
reduction: LossReduction | str = LossReduction.MEAN,
62+
smooth_nr: float = 0.0,
63+
smooth_dr: float = 1e-5,
64+
batch: bool = False,
65+
) -> None:
66+
"""
67+
Args:
68+
include_background: if False, channel index 0 (background category) is excluded from the calculation.
69+
if the non-background segmentations are small compared to the total image size they can get
70+
overwhelmed by the signal from the background so excluding it in such cases helps convergence.
71+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
72+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
73+
sigmoid: if True, apply a sigmoid function to the prediction.
74+
softmax: if True, apply a softmax function to the prediction.
75+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
76+
``other_act = torch.tanh``.
77+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
78+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
79+
80+
- ``"none"``: no reduction will be applied.
81+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
82+
- ``"sum"``: the output will be summed.
83+
84+
smooth_nr: a small constant added to the numerator to avoid zero.
85+
smooth_dr: a small constant added to the denominator to avoid nan.
86+
batch: whether to sum the confusion matrix entries over the batch dimension before computing MCC.
87+
Defaults to False, MCC is computed independently for each item in the batch
88+
before any `reduction`.
89+
90+
Raises:
91+
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
92+
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
93+
Incompatible values.
94+
95+
"""
96+
super().__init__(reduction=LossReduction(reduction).value)
97+
if other_act is not None and not callable(other_act):
98+
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
99+
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
100+
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
101+
self.include_background = include_background
102+
self.to_onehot_y = to_onehot_y
103+
self.sigmoid = sigmoid
104+
self.softmax = softmax
105+
self.other_act = other_act
106+
self.smooth_nr = float(smooth_nr)
107+
self.smooth_dr = float(smooth_dr)
108+
self.batch = batch
109+
110+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
111+
"""
112+
Args:
113+
input: the shape should be BNH[WD], where N is the number of classes.
114+
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
115+
116+
Raises:
117+
AssertionError: When input and target (after one hot transform if set)
118+
have different shapes.
119+
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
120+
121+
Example:
122+
>>> from monai.losses.mcc_loss import MCCLoss
123+
>>> import torch
124+
>>> B, C, H, W = 7, 1, 3, 2
125+
>>> input = torch.rand(B, C, H, W)
126+
>>> target = torch.randint(low=0, high=2, size=(B, C, H, W)).float()
127+
>>> self = MCCLoss(reduction='none')
128+
>>> loss = self(input, target)
129+
"""
130+
if self.sigmoid:
131+
input = torch.sigmoid(input)
132+
133+
n_pred_ch = input.shape[1]
134+
if self.softmax:
135+
if n_pred_ch == 1:
136+
warnings.warn("single channel prediction, `softmax=True` ignored.")
137+
else:
138+
input = torch.softmax(input, 1)
139+
140+
if self.other_act is not None:
141+
input = self.other_act(input)
142+
143+
if self.to_onehot_y:
144+
if n_pred_ch == 1:
145+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
146+
else:
147+
target = one_hot(target, num_classes=n_pred_ch)
148+
149+
if not self.include_background:
150+
if n_pred_ch == 1:
151+
warnings.warn("single channel prediction, `include_background=False` ignored.")
152+
else:
153+
target = target[:, 1:]
154+
input = input[:, 1:]
155+
156+
if target.shape != input.shape:
157+
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
158+
159+
# reducing only spatial dimensions (not batch nor channels)
160+
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
161+
if self.batch:
162+
reduce_axis = [0] + reduce_axis
163+
164+
# Soft confusion matrix entries (Eq. 5 in the paper).
165+
tp = torch.sum(input * target, dim=reduce_axis)
166+
tn = torch.sum((1.0 - input) * (1.0 - target), dim=reduce_axis)
167+
fp = torch.sum(input * (1.0 - target), dim=reduce_axis)
168+
fn = torch.sum((1.0 - input) * target, dim=reduce_axis)
169+
170+
# MCC (Eq. 3) and loss (Eq. 4).
171+
numerator = tp * tn - fp * fn + self.smooth_nr
172+
denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + self.smooth_dr)
173+
174+
mcc = numerator / denominator
175+
score: torch.Tensor = 1.0 - mcc
176+
177+
# When fp = fn = 0, prediction is perfect but the denominator product
178+
# tends to 0 when tp = 0 or tn = 0, giving mcc ~ 0 instead of 1.
179+
perfect = (fp == 0) & (fn == 0)
180+
score = torch.where(perfect, torch.zeros_like(score), score)
181+
182+
if self.reduction == LossReduction.SUM.value:
183+
return torch.sum(score)
184+
if self.reduction == LossReduction.NONE.value:
185+
return score
186+
if self.reduction == LossReduction.MEAN.value:
187+
return torch.mean(score)
188+
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

tests/losses/test_mcc_loss.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.losses import MCCLoss
21+
from tests.test_utils import test_script_save
22+
23+
TEST_CASES = [
24+
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2), sigmoid
25+
{"include_background": True, "sigmoid": True},
26+
{"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},
27+
0.733197,
28+
],
29+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2), sigmoid
30+
{"include_background": True, "sigmoid": True},
31+
{
32+
"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),
33+
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
34+
},
35+
1.0,
36+
],
37+
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2), perfect prediction
38+
{"include_background": True},
39+
{"input": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])},
40+
0.0,
41+
],
42+
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2), worst case (inverted)
43+
{"include_background": True},
44+
{"input": torch.tensor([[[[0.0, 1.0], [1.0, 0.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])},
45+
2.0,
46+
],
47+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, exclude background, one-hot
48+
{"include_background": False, "to_onehot_y": True},
49+
{
50+
"input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),
51+
"target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),
52+
},
53+
0.0,
54+
],
55+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot
56+
{"include_background": True, "to_onehot_y": True, "sigmoid": True},
57+
{
58+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
59+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
60+
},
61+
0.836617,
62+
],
63+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot, batch=True
64+
{"include_background": True, "to_onehot_y": True, "sigmoid": True, "batch": True},
65+
{
66+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
67+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
68+
},
69+
0.845961,
70+
],
71+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot, reduction=sum
72+
{"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "sum"},
73+
{
74+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
75+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
76+
},
77+
3.346468,
78+
],
79+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, softmax, one-hot
80+
{"include_background": True, "to_onehot_y": True, "softmax": True},
81+
{
82+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
83+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
84+
},
85+
0.730736,
86+
],
87+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, softmax, one-hot, reduction=none
88+
{"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "none"},
89+
{
90+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
91+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
92+
},
93+
[[0.461472, 0.461472], [1.0, 1.0]],
94+
],
95+
[ # shape: (1, 1, 3, 3), (1, 1, 3, 3), all-ones perfect prediction
96+
{"include_background": True},
97+
{"input": torch.ones(1, 1, 3, 3), "target": torch.ones(1, 1, 3, 3)},
98+
0.0,
99+
],
100+
[ # shape: (1, 1, 3, 3), (1, 1, 3, 3), all-zeros perfect prediction
101+
{"include_background": True},
102+
{"input": torch.zeros(1, 1, 3, 3), "target": torch.zeros(1, 1, 3, 3)},
103+
0.0,
104+
],
105+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2), other_act=torch.tanh
106+
{"include_background": True, "other_act": torch.tanh},
107+
{
108+
"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),
109+
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
110+
},
111+
1.0,
112+
],
113+
]
114+
115+
116+
class TestMCCLoss(unittest.TestCase):
117+
@parameterized.expand(TEST_CASES)
118+
def test_shape(self, input_param, input_data, expected_val):
119+
result = MCCLoss(**input_param).forward(**input_data)
120+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)
121+
122+
def test_ill_shape(self):
123+
loss = MCCLoss()
124+
with self.assertRaisesRegex(AssertionError, ""):
125+
loss.forward(torch.ones((2, 2, 3)), torch.ones((4, 5, 6)))
126+
chn_input = torch.ones((1, 1, 3))
127+
chn_target = torch.ones((1, 1, 3))
128+
with self.assertRaisesRegex(ValueError, ""):
129+
MCCLoss(reduction="unknown")(chn_input, chn_target)
130+
with self.assertRaisesRegex(ValueError, ""):
131+
MCCLoss(reduction=None)(chn_input, chn_target)
132+
133+
def test_ill_opts(self):
134+
with self.assertRaisesRegex(ValueError, ""):
135+
MCCLoss(sigmoid=True, softmax=True)
136+
with self.assertRaisesRegex(TypeError, ""):
137+
MCCLoss(other_act="tanh")
138+
139+
@parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
140+
def test_input_warnings(self, include_background, softmax, to_onehot_y):
141+
chn_input = torch.ones((1, 1, 3))
142+
chn_target = torch.ones((1, 1, 3))
143+
with self.assertWarns(Warning):
144+
loss = MCCLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
145+
loss.forward(chn_input, chn_target)
146+
147+
def test_script(self):
148+
loss = MCCLoss()
149+
test_input = torch.ones(2, 1, 8, 8)
150+
test_script_save(loss, test_input, test_input)
151+
152+
153+
if __name__ == "__main__":
154+
unittest.main()

0 commit comments

Comments
 (0)