Skip to content

Commit c2612ea

Browse files
committed
style: fix import sorting with isort
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent 4457e37 commit c2612ea

4 files changed

Lines changed: 12 additions & 9 deletions

File tree

monai/losses/dice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from monai.losses.focal_loss import FocalLoss
2424
from monai.losses.spatial_mask import MaskedLoss
2525
from monai.losses.utils import compute_tp_fp_fn
26+
from monai.metrics.utils import create_ignore_mask
2627
from monai.networks import one_hot
2728
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option
28-
from monai.metrics.utils import create_ignore_mask
2929

3030

3131
class DiceLoss(_Loss):

monai/losses/focal_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import torch.nn.functional as F
1919
from torch.nn.modules.loss import _Loss
2020

21+
from monai.metrics.utils import create_ignore_mask
2122
from monai.networks import one_hot
2223
from monai.utils import LossReduction
23-
from monai.metrics.utils import create_ignore_mask
2424

2525

2626
class FocalLoss(_Loss):

monai/losses/tversky.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from torch.nn.modules.loss import _Loss
1919

2020
from monai.losses.utils import compute_tp_fp_fn
21+
from monai.metrics.utils import create_ignore_mask
2122
from monai.networks import one_hot
2223
from monai.utils import LossReduction
23-
from monai.metrics.utils import create_ignore_mask
2424

2525

2626
class TverskyLoss(_Loss):

monai/losses/unified_focal_loss.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import torch
1717
from torch.nn.modules.loss import _Loss
1818

19+
from monai.metrics.utils import create_ignore_mask
1920
from monai.networks import one_hot
2021
from monai.utils import LossReduction
21-
from monai.metrics.utils import create_ignore_mask
2222

2323

2424
class AsymmetricFocalTverskyLoss(_Loss):
@@ -78,16 +78,19 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7878
mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index)
7979
if mask is not None:
8080
mask = mask.expand_as(y_true)
81-
y_pred = y_pred * mask
82-
y_true = y_true * mask
8381

8482
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
8583
axis = list(range(2, len(y_pred.shape)))
8684

8785
# Calculate true positives (tp), false negatives (fn) and false positives (fp)
88-
tp = torch.sum(y_true * y_pred, dim=axis)
89-
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
90-
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
86+
if mask is not None:
87+
tp = torch.sum(y_true * y_pred * mask, dim=axis)
88+
fn = torch.sum(y_true * (1 - y_pred) * mask, dim=axis)
89+
fp = torch.sum((1 - y_true) * y_pred * mask, dim=axis)
90+
else:
91+
tp = torch.sum(y_true * y_pred, dim=axis)
92+
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
93+
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
9194
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
9295

9396
# Calculate losses separately for each class, enhancing both classes

0 commit comments

Comments
 (0)