Skip to content

Commit 9481ec3

Browse files
committed
add DMSPS
add DMSPS for weakly supervised segmentation add adaptive region specific Tverskyloss
1 parent a279271 commit 9481ec3

File tree

4 files changed

+208
-4
lines changed

4 files changed

+208
-4
lines changed

pymic/loss/seg/ars_tversky.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
2+
# -*- coding: utf-8 -*-
3+
from __future__ import print_function, division
4+
5+
import torch.nn as nn
6+
from pymic.loss.seg.abstract import AbstractSegLoss
7+
8+
class ARSTverskyLoss(AbstractSegLoss):
9+
"""
10+
The Adaptive Region-Specific Loss in this paper:
11+
12+
* Y. Chen et al.: Adaptive Region-Specific Loss for Improved Medical Image Segmentation.
13+
`IEEE TPAMI 2023. <https://ieeexplore.ieee.org/document/10163830>`_
14+
15+
The arguments should be written in the `params` dictionary, and it has the
16+
following fields:
17+
18+
:param `ARSTversky_patch_size`: (list) the patch size.
19+
:param `A`: the lowest weight for FP or FN (default 0.3)
20+
:param `B`: the gap between lowest and highest weight (default 0.4)
21+
"""
22+
def __init__(self, params):
23+
super(ARSTverskyLoss, self).__init__(params)
24+
self.patch_size = params['ARSTversky_patch_size'.lower()]
25+
self.a = params.get('ARSTversky_a'.lower(), 0.3)
26+
self.b = params.get('ARSTversky_b'.lower(), 0.4)
27+
28+
self.dim = len(self.patch_size)
29+
assert self.dim in [2, 3], "The num of dim must be 2 or 3."
30+
if self.dim == 3:
31+
self.pool = nn.AvgPool3d(kernel_size=self.patch_size, stride=self.patch_size)
32+
elif self.dim == 2:
33+
self.pool = nn.AvgPool2d(kernel_size=self.patch_size, stride=self.patch_size)
34+
35+
def forward(self, loss_input_dict):
36+
predict = loss_input_dict['prediction']
37+
soft_y = loss_input_dict['ground_truth']
38+
39+
if(isinstance(predict, (list, tuple))):
40+
predict = predict[0]
41+
if(self.acti_func is not None):
42+
predict = self.get_activated_prediction(predict, self.acti_func)
43+
44+
smooth = 1e-5
45+
if self.dim == 2:
46+
assert predict.shape[-2] % self.patch_size[0] == 0, "image size % patch size must be 0 in dimension y"
47+
assert predict.shape[-1] % self.patch_size[1] == 0, "image size % patch size must be 0 in dimension x"
48+
elif self.dim == 3:
49+
assert predict.shape[-3] % self.patch_size[0] == 0, "image size % patch size must be 0 in dimension z"
50+
assert predict.shape[-2] % self.patch_size[1] == 0, "image size % patch size must be 0 in dimension y"
51+
assert predict.shape[-1] % self.patch_size[2] == 0, "image size % patch size must be 0 in dimension x"
52+
53+
tp = predict * soft_y
54+
fp = predict * (1 - soft_y)
55+
fn = (1 - predict) * soft_y
56+
57+
region_tp = self.pool(tp)
58+
region_fp = self.pool(fp)
59+
region_fn = self.pool(fn)
60+
61+
alpha = self.a + self.b * (region_fp + smooth) / (region_fp + region_fn + smooth)
62+
beta = self.a + self.b * (region_fn + smooth) / (region_fp + region_fn + smooth)
63+
64+
region_tversky = (region_tp + smooth) / (region_tp + alpha * region_fp + beta * region_fn + smooth)
65+
region_tversky = 1 - region_tversky
66+
loss = region_tversky.mean()
67+
return loss

pymic/net_run/weak_sup/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from pymic.net_run.weak_sup.wsl_tv import WSLTotalVariation
77
from pymic.net_run.weak_sup.wsl_ustm import WSLUSTM
88
from pymic.net_run.weak_sup.wsl_dmpls import WSLDMPLS
9+
from pymic.net_run.weak_sup.wsl_dmsps import WSLDMSPS
910

1011
WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization,
1112
'GatedCRF': WSLGatedCRF,
1213
'MumfordShah': WSLMumfordShah,
1314
'TotalVariation': WSLTotalVariation,
1415
'USTM': WSLUSTM,
15-
'DMPLS': WSLDMPLS}
16+
'DMPLS': WSLDMPLS,
17+
'DMSPS': WSLDMSPS}

pymic/net_run/weak_sup/wsl_dmpls.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import random
66
import time
77
import torch
8-
from torch.optim import lr_scheduler
98
from pymic.loss.seg.util import get_soft_label
109
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
1110
from pymic.loss.seg.util import get_classwise_dice
1211
from pymic.loss.seg.dice import DiceLoss
12+
from pymic.loss.seg.ce import CrossEntropyLoss
1313
from pymic.net_run.weak_sup import WSLSegAgent
1414
from pymic.util.ramps import get_rampup_ratio
1515

@@ -42,9 +42,13 @@ def training(self):
4242
class_num = self.config['network']['class_num']
4343
iter_valid = self.config['training']['iter_valid']
4444
wsl_cfg = self.config['weakly_supervised_learning']
45-
iter_max = self.config['training']['iter_max']
45+
iter_max = self.config['training']['iter_max']
4646
rampup_start = wsl_cfg.get('rampup_start', 0)
4747
rampup_end = wsl_cfg.get('rampup_end', iter_max)
48+
pseudo_loss_type = wsl_cfg.get('pseudo_sup_loss', 'dice_loss')
49+
if (pseudo_loss_type not in ('dice_loss', 'ce_loss')):
50+
raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \
51+
are supported.""")
4852
train_loss, train_loss_sup, train_loss_reg = 0, 0, 0
4953
train_dice_list = []
5054
data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0
@@ -83,7 +87,7 @@ def training(self):
8387
pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type)
8488

8589
# calculate the pseudo label supervision loss
86-
loss_calculator = DiceLoss()
90+
loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss()
8791
loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab}
8892
loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab}
8993
loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2))
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import print_function, division
3+
import logging
4+
import numpy as np
5+
import random
6+
import time
7+
import torch
8+
from pymic.loss.seg.util import get_soft_label
9+
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
10+
from pymic.loss.seg.util import get_classwise_dice
11+
from pymic.loss.seg.dice import DiceLoss
12+
from pymic.loss.seg.ce import CrossEntropyLoss
13+
from pymic.net_run.weak_sup import WSLSegAgent
14+
from pymic.util.ramps import get_rampup_ratio
15+
16+
class WSLDMSPS(WSLSegAgent):
17+
"""
18+
Weakly supervised segmentation based on Dynamically Mixed Pseudo Labels Supervision.
19+
20+
* Reference: Meng Han, Xiangde Luo, Xiangjiang Xie, Wenjun Liao, Shichuan Zhang, Tao Song,
21+
Guotai Wang, Shaoting Zhang. DMSPS: Dynamically mixed soft pseudo-label supervision for
22+
scribble-supervised medical image segmentation.
23+
`Medical Image Analysis 2024. <https://www.sciencedirect.com/science/article/pii/S1361841524001993>`_
24+
25+
:param config: (dict) A dictionary containing the configuration.
26+
:param stage: (str) One of the stage in `train` (default), `inference` or `test`.
27+
28+
.. note::
29+
30+
In the configuration dictionary, in addition to the four sections (`dataset`,
31+
`network`, `training` and `inference`) used in fully supervised learning, an
32+
extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details.
33+
"""
34+
def __init__(self, config, stage = 'train'):
35+
net_type = config['network']['net_type']
36+
if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']:
37+
raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \
38+
It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""")
39+
super(WSLDMSPS, self).__init__(config, stage)
40+
41+
def training(self):
42+
class_num = self.config['network']['class_num']
43+
iter_valid = self.config['training']['iter_valid']
44+
wsl_cfg = self.config['weakly_supervised_learning']
45+
iter_max = self.config['training']['iter_max']
46+
rampup_start = wsl_cfg.get('rampup_start', 0)
47+
rampup_end = wsl_cfg.get('rampup_end', iter_max)
48+
pseudo_loss_type = wsl_cfg.get('pseudo_sup_loss', 'ce_loss')
49+
if (pseudo_loss_type not in ('dice_loss', 'ce_loss')):
50+
raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \
51+
are supported.""")
52+
train_loss, train_loss_sup, train_loss_reg = 0, 0, 0
53+
train_dice_list = []
54+
data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0
55+
self.net.train()
56+
for it in range(iter_valid):
57+
t0 = time.time()
58+
try:
59+
data = next(self.trainIter)
60+
except StopIteration:
61+
self.trainIter = iter(self.train_loader)
62+
data = next(self.trainIter)
63+
t1 = time.time()
64+
# get the inputs
65+
inputs = self.convert_tensor_type(data['image'])
66+
y = self.convert_tensor_type(data['label_prob'])
67+
68+
inputs, y = inputs.to(self.device), y.to(self.device)
69+
70+
# zero the parameter gradients
71+
self.optimizer.zero_grad()
72+
73+
# forward + backward + optimize
74+
outputs1, outputs2 = self.net(inputs)
75+
t2 = time.time()
76+
77+
loss_sup1 = self.get_loss_value(data, outputs1, y)
78+
loss_sup2 = self.get_loss_value(data, outputs2, y)
79+
loss_sup = 0.5 * (loss_sup1 + loss_sup2)
80+
81+
# get pseudo label with dynamical mix
82+
outputs_soft1 = torch.softmax(outputs1, dim=1)
83+
outputs_soft2 = torch.softmax(outputs2, dim=1)
84+
beta = random.random()
85+
pseudo_lab = beta*outputs_soft1.detach() + (1.0-beta)*outputs_soft2.detach()
86+
# pseudo_lab = torch.argmax(pseudo_lab, dim = 1, keepdim = True)
87+
# pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type)
88+
89+
# calculate the pseudo label supervision loss
90+
loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss()
91+
loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab}
92+
loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab}
93+
loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2))
94+
95+
rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid")
96+
regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio
97+
loss = loss_sup + regular_w*loss_reg
98+
t3 = time.time()
99+
loss.backward()
100+
t4 = time.time()
101+
self.optimizer.step()
102+
103+
train_loss = train_loss + loss.item()
104+
train_loss_sup = train_loss_sup + loss_sup.item()
105+
train_loss_reg = train_loss_reg + loss_reg.item()
106+
# get dice evaluation for each class in annotated images
107+
if(isinstance(outputs1, tuple) or isinstance(outputs1, list)):
108+
outputs1 = outputs1[0]
109+
p_argmax = torch.argmax(outputs1, dim = 1, keepdim = True)
110+
p_soft = get_soft_label(p_argmax, class_num, self.tensor_type)
111+
p_soft, y = reshape_prediction_and_ground_truth(p_soft, y)
112+
dice_list = get_classwise_dice(p_soft, y)
113+
train_dice_list.append(dice_list.cpu().numpy())
114+
115+
data_time = data_time + t1 - t0
116+
gpu_time = gpu_time + t2 - t1
117+
loss_time = loss_time + t3 - t2
118+
back_time = back_time + t4 - t3
119+
train_avg_loss = train_loss / iter_valid
120+
train_avg_loss_sup = train_loss_sup / iter_valid
121+
train_avg_loss_reg = train_loss_reg / iter_valid
122+
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
123+
train_avg_dice = train_cls_dice[1:].mean()
124+
125+
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
126+
'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
127+
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice,
128+
'data_time': data_time, 'forward_time':gpu_time,
129+
'loss_time':loss_time, 'backward_time':back_time}
130+
return train_scalers
131+

0 commit comments

Comments
 (0)