|
16 | 16 | import torch |
17 | 17 | from torch.nn.modules.loss import _Loss |
18 | 18 |
|
| 19 | +from monai.metrics.utils import create_ignore_mask |
19 | 20 | from monai.networks import one_hot |
20 | 21 | from monai.utils import LossReduction |
21 | | -from monai.metrics.utils import create_ignore_mask |
22 | 22 |
|
23 | 23 |
|
24 | 24 | class AsymmetricFocalTverskyLoss(_Loss): |
@@ -78,16 +78,19 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
78 | 78 | mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index) |
79 | 79 | if mask is not None: |
80 | 80 | mask = mask.expand_as(y_true) |
81 | | - y_pred = y_pred * mask |
82 | | - y_true = y_true * mask |
83 | 81 |
|
84 | 82 | y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) |
85 | 83 | axis = list(range(2, len(y_pred.shape))) |
86 | 84 |
|
87 | 85 | # Calculate true positives (tp), false negatives (fn) and false positives (fp) |
88 | | - tp = torch.sum(y_true * y_pred, dim=axis) |
89 | | - fn = torch.sum(y_true * (1 - y_pred), dim=axis) |
90 | | - fp = torch.sum((1 - y_true) * y_pred, dim=axis) |
| 86 | + if mask is not None: |
| 87 | + tp = torch.sum(y_true * y_pred * mask, dim=axis) |
| 88 | + fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis) |
| 89 | + fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis) |
| 90 | + else: |
| 91 | + tp = torch.sum(y_true * y_pred, dim=axis) |
| 92 | + fn = torch.sum(y_true * (1 - y_pred), dim=axis) |
| 93 | + fp = torch.sum((1 - y_true) * y_pred, dim=axis) |
91 | 94 | dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) |
92 | 95 |
|
93 | 96 | # Calculate losses separately for each class, enhancing both classes |
|
0 commit comments