Skip to content

Commit cc21cef

Browse files
Add MCC loss
Add implementation of Matthews Correlation Coefficient (MCC)-based loss Add tests for MCC loss Add entry for MCC loss in documentation Signed-off-by: Kumar Abhishek <7644965+kakumarabhishek@users.noreply.github.com>
1 parent daaedaa commit cc21cef

4 files changed

Lines changed: 348 additions & 0 deletions

File tree

docs/source/losses.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ Segmentation Losses
9898
.. autoclass:: NACLLoss
9999
:members:
100100

101+
`MCCLoss`
102+
~~~~~~~~~
103+
.. autoclass:: MCCLoss
104+
:members:
105+
101106
Registration Losses
102107
-------------------
103108

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .focal_loss import FocalLoss
3636
from .giou_loss import BoxGIoULoss, giou
3737
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
38+
from .mcc_loss import MCCLoss
3839
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
3940
from .multi_scale import MultiScaleLoss
4041
from .nacl_loss import NACLLoss

monai/losses/mcc_loss.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import warnings
15+
from collections.abc import Callable
16+
17+
import torch
18+
from torch.nn.modules.loss import _Loss
19+
20+
from monai.networks import one_hot
21+
from monai.utils import LossReduction
22+
23+
24+
class MCCLoss(_Loss):
25+
"""
26+
Compute the Matthews Correlation Coefficient (MCC) loss between two tensors.
27+
28+
Unlike Dice and Tversky losses which only use TP, FP, and FN, the MCC loss considers all four entries
29+
of the confusion matrix (TP, TN, FP, FN), making it effective for class-imbalanced segmentation tasks
30+
where background dominates the image. The loss is computed as ``1 - MCC`` where
31+
``MCC = (TP * TN - FP * FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))``.
32+
33+
The soft confusion matrix entries are computed as:
34+
35+
- ``TP = sum(input * target)``
36+
- ``TN = sum((1 - input) * (1 - target))``
37+
- ``FP = sum(input * (1 - target))``
38+
- ``FN = sum((1 - input) * target)``
39+
40+
The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
41+
42+
Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,
43+
must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`
44+
can be 1 or N (one-hot format).
45+
46+
The original paper:
47+
48+
Abhishek, K. and Hamarneh, G. (2021) Matthews Correlation Coefficient Loss for Deep Convolutional
49+
Networks: Application to Skin Lesion Segmentation. IEEE ISBI, pp. 225-229.
50+
(https://doi.org/10.1109/ISBI48211.2021.9433782)
51+
52+
"""
53+
54+
def __init__(
55+
self,
56+
include_background: bool = True,
57+
to_onehot_y: bool = False,
58+
sigmoid: bool = False,
59+
softmax: bool = False,
60+
other_act: Callable | None = None,
61+
reduction: LossReduction | str = LossReduction.MEAN,
62+
smooth_nr: float = 0.0,
63+
smooth_dr: float = 1e-5,
64+
batch: bool = False,
65+
) -> None:
66+
"""
67+
Args:
68+
include_background: if False, channel index 0 (background category) is excluded from the calculation.
69+
if the non-background segmentations are small compared to the total image size they can get
70+
overwhelmed by the signal from the background so excluding it in such cases helps convergence.
71+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
72+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
73+
sigmoid: if True, apply a sigmoid function to the prediction.
74+
softmax: if True, apply a softmax function to the prediction.
75+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
76+
``other_act = torch.tanh``.
77+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
78+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
79+
80+
- ``"none"``: no reduction will be applied.
81+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
82+
- ``"sum"``: the output will be summed.
83+
84+
smooth_nr: a small constant added to the numerator to avoid zero.
85+
smooth_dr: a small constant added to the denominator to avoid nan.
86+
batch: whether to sum the confusion matrix entries over the batch dimension before computing MCC.
87+
Defaults to False, MCC is computed independently for each item in the batch
88+
before any `reduction`.
89+
90+
Raises:
91+
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
92+
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
93+
Incompatible values.
94+
95+
"""
96+
super().__init__(reduction=LossReduction(reduction).value)
97+
if other_act is not None and not callable(other_act):
98+
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
99+
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
100+
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
101+
self.include_background = include_background
102+
self.to_onehot_y = to_onehot_y
103+
self.sigmoid = sigmoid
104+
self.softmax = softmax
105+
self.other_act = other_act
106+
self.smooth_nr = float(smooth_nr)
107+
self.smooth_dr = float(smooth_dr)
108+
self.batch = batch
109+
110+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
111+
"""
112+
Args:
113+
input: the shape should be BNH[WD], where N is the number of classes.
114+
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
115+
116+
Raises:
117+
AssertionError: When input and target (after one hot transform if set)
118+
have different shapes.
119+
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
120+
121+
Example:
122+
>>> from monai.losses.mcc_loss import MCCLoss
123+
>>> import torch
124+
>>> B, C, H, W = 7, 1, 3, 2
125+
>>> input = torch.rand(B, C, H, W)
126+
>>> target = torch.randint(low=0, high=2, size=(B, C, H, W)).float()
127+
>>> self = MCCLoss(reduction='none')
128+
>>> loss = self(input, target)
129+
"""
130+
if self.sigmoid:
131+
input = torch.sigmoid(input)
132+
133+
n_pred_ch = input.shape[1]
134+
if self.softmax:
135+
if n_pred_ch == 1:
136+
warnings.warn("single channel prediction, `softmax=True` ignored.")
137+
else:
138+
input = torch.softmax(input, 1)
139+
140+
if self.other_act is not None:
141+
input = self.other_act(input)
142+
143+
if self.to_onehot_y:
144+
if n_pred_ch == 1:
145+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
146+
else:
147+
target = one_hot(target, num_classes=n_pred_ch)
148+
149+
if not self.include_background:
150+
if n_pred_ch == 1:
151+
warnings.warn("single channel prediction, `include_background=False` ignored.")
152+
else:
153+
target = target[:, 1:]
154+
input = input[:, 1:]
155+
156+
if target.shape != input.shape:
157+
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
158+
159+
# reducing only spatial dimensions (not batch nor channels)
160+
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
161+
if self.batch:
162+
reduce_axis = [0] + reduce_axis
163+
164+
# Soft confusion matrix entries (Eq. 5 in the paper).
165+
tp = torch.sum(input * target, dim=reduce_axis)
166+
tn = torch.sum((1.0 - input) * (1.0 - target), dim=reduce_axis)
167+
fp = torch.sum(input * (1.0 - target), dim=reduce_axis)
168+
fn = torch.sum((1.0 - input) * target, dim=reduce_axis)
169+
170+
# MCC (Eq. 3) and loss (Eq. 4).
171+
numerator = tp * tn - fp * fn + self.smooth_nr
172+
denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + self.smooth_dr)
173+
174+
mcc = numerator / denominator
175+
score: torch.Tensor = 1.0 - mcc
176+
177+
# When fp = fn = 0, prediction is perfect but the denominator product
178+
# tends to 0 when tp = 0 or tn = 0, giving mcc ~ 0 instead of 1.
179+
perfect = (fp == 0) & (fn == 0)
180+
score = torch.where(perfect, torch.zeros_like(score), score)
181+
182+
if self.reduction == LossReduction.SUM.value:
183+
return torch.sum(score)
184+
if self.reduction == LossReduction.NONE.value:
185+
return score
186+
if self.reduction == LossReduction.MEAN.value:
187+
return torch.mean(score)
188+
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

tests/losses/test_mcc_loss.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.losses import MCCLoss
21+
from tests.test_utils import test_script_save
22+
23+
TEST_CASES = [
24+
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2), sigmoid
25+
{"include_background": True, "sigmoid": True},
26+
{"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])},
27+
0.733197,
28+
],
29+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2), sigmoid
30+
{"include_background": True, "sigmoid": True},
31+
{
32+
"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),
33+
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
34+
},
35+
1.0,
36+
],
37+
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2), perfect prediction
38+
{"include_background": True},
39+
{"input": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])},
40+
0.0,
41+
],
42+
[ # shape: (1, 1, 2, 2), (1, 1, 2, 2), worst case (inverted)
43+
{"include_background": True},
44+
{"input": torch.tensor([[[[0.0, 1.0], [1.0, 0.0]]]]), "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])},
45+
2.0,
46+
],
47+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, exclude background, one-hot
48+
{"include_background": False, "to_onehot_y": True},
49+
{
50+
"input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),
51+
"target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),
52+
},
53+
0.0,
54+
],
55+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot
56+
{"include_background": True, "to_onehot_y": True, "sigmoid": True},
57+
{
58+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
59+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
60+
},
61+
0.836617,
62+
],
63+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot, batch=True
64+
{"include_background": True, "to_onehot_y": True, "sigmoid": True, "batch": True},
65+
{
66+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
67+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
68+
},
69+
0.845961,
70+
],
71+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, sigmoid, one-hot, reduction=sum
72+
{"include_background": True, "to_onehot_y": True, "sigmoid": True, "reduction": "sum"},
73+
{
74+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
75+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
76+
},
77+
3.346468,
78+
],
79+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, softmax, one-hot
80+
{"include_background": True, "to_onehot_y": True, "softmax": True},
81+
{
82+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
83+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
84+
},
85+
0.730736,
86+
],
87+
[ # shape: (2, 2, 3), (2, 1, 3), multi-class, softmax, one-hot, reduction=none
88+
{"include_background": True, "to_onehot_y": True, "softmax": True, "reduction": "none"},
89+
{
90+
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
91+
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
92+
},
93+
[[0.461472, 0.461472], [1.0, 1.0]],
94+
],
95+
[ # shape: (1, 1, 3, 3), (1, 1, 3, 3), all-ones perfect prediction
96+
{"include_background": True},
97+
{"input": torch.ones(1, 1, 3, 3), "target": torch.ones(1, 1, 3, 3)},
98+
0.0,
99+
],
100+
[ # shape: (1, 1, 3, 3), (1, 1, 3, 3), all-zeros perfect prediction
101+
{"include_background": True},
102+
{"input": torch.zeros(1, 1, 3, 3), "target": torch.zeros(1, 1, 3, 3)},
103+
0.0,
104+
],
105+
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2), other_act=torch.tanh
106+
{"include_background": True, "other_act": torch.tanh},
107+
{
108+
"input": torch.tensor([[[[1.0, -1.0], [-1.0, 1.0]]], [[[1.0, -1.0], [-1.0, 1.0]]]]),
109+
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
110+
},
111+
1.0,
112+
],
113+
]
114+
115+
116+
class TestMCCLoss(unittest.TestCase):
117+
@parameterized.expand(TEST_CASES)
118+
def test_shape(self, input_param, input_data, expected_val):
119+
result = MCCLoss(**input_param).forward(**input_data)
120+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)
121+
122+
def test_ill_shape(self):
123+
loss = MCCLoss()
124+
with self.assertRaisesRegex(AssertionError, ""):
125+
loss.forward(torch.ones((2, 2, 3)), torch.ones((4, 5, 6)))
126+
chn_input = torch.ones((1, 1, 3))
127+
chn_target = torch.ones((1, 1, 3))
128+
with self.assertRaisesRegex(ValueError, ""):
129+
MCCLoss(reduction="unknown")(chn_input, chn_target)
130+
with self.assertRaisesRegex(ValueError, ""):
131+
MCCLoss(reduction=None)(chn_input, chn_target)
132+
133+
def test_ill_opts(self):
134+
with self.assertRaisesRegex(ValueError, ""):
135+
MCCLoss(sigmoid=True, softmax=True)
136+
with self.assertRaisesRegex(TypeError, ""):
137+
MCCLoss(other_act="tanh")
138+
139+
@parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
140+
def test_input_warnings(self, include_background, softmax, to_onehot_y):
141+
chn_input = torch.ones((1, 1, 3))
142+
chn_target = torch.ones((1, 1, 3))
143+
with self.assertWarns(Warning):
144+
loss = MCCLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
145+
loss.forward(chn_input, chn_target)
146+
147+
def test_script(self):
148+
loss = MCCLoss()
149+
test_input = torch.ones(2, 1, 8, 8)
150+
test_script_save(loss, test_input, test_input)
151+
152+
153+
if __name__ == "__main__":
154+
unittest.main()

0 commit comments

Comments
 (0)