1+ # -*- coding: utf-8 -*-
2+ from __future__ import print_function , division
3+ import logging
4+ import numpy as np
5+ import random
6+ import time
7+ import torch
8+ from pymic .loss .seg .util import get_soft_label
9+ from pymic .loss .seg .util import reshape_prediction_and_ground_truth
10+ from pymic .loss .seg .util import get_classwise_dice
11+ from pymic .loss .seg .dice import DiceLoss
12+ from pymic .loss .seg .ce import CrossEntropyLoss
13+ from pymic .net_run .weak_sup import WSLSegAgent
14+ from pymic .util .ramps import get_rampup_ratio
15+
16+ class WSLDMSPS (WSLSegAgent ):
17+ """
18+ Weakly supervised segmentation based on Dynamically Mixed Pseudo Labels Supervision.
19+
20+ * Reference: Meng Han, Xiangde Luo, Xiangjiang Xie, Wenjun Liao, Shichuan Zhang, Tao Song,
21+ Guotai Wang, Shaoting Zhang. DMSPS: Dynamically mixed soft pseudo-label supervision for
22+ scribble-supervised medical image segmentation.
23+ `Medical Image Analysis 2024. <https://www.sciencedirect.com/science/article/pii/S1361841524001993>`_
24+
25+ :param config: (dict) A dictionary containing the configuration.
26+ :param stage: (str) One of the stage in `train` (default), `inference` or `test`.
27+
28+ .. note::
29+
30+ In the configuration dictionary, in addition to the four sections (`dataset`,
31+ `network`, `training` and `inference`) used in fully supervised learning, an
32+ extra section `weakly_supervised_learning` is needed. See :doc:`usage.wsl` for details.
33+ """
34+ def __init__ (self , config , stage = 'train' ):
35+ net_type = config ['network' ]['net_type' ]
36+ if net_type not in ['UNet2D_DualBranch' , 'UNet3D_DualBranch' ]:
37+ raise ValueError ("""For WSL_DMPLS, a dual branch network is expected. \
38+ It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""" )
39+ super (WSLDMSPS , self ).__init__ (config , stage )
40+
41+ def training (self ):
42+ class_num = self .config ['network' ]['class_num' ]
43+ iter_valid = self .config ['training' ]['iter_valid' ]
44+ wsl_cfg = self .config ['weakly_supervised_learning' ]
45+ iter_max = self .config ['training' ]['iter_max' ]
46+ rampup_start = wsl_cfg .get ('rampup_start' , 0 )
47+ rampup_end = wsl_cfg .get ('rampup_end' , iter_max )
48+ pseudo_loss_type = wsl_cfg .get ('pseudo_sup_loss' , 'ce_loss' )
49+ if (pseudo_loss_type not in ('dice_loss' , 'ce_loss' )):
50+ raise ValueError ("""For pseudo supervision loss, only dice_loss and ce_loss \
51+ are supported.""" )
52+ train_loss , train_loss_sup , train_loss_reg = 0 , 0 , 0
53+ train_dice_list = []
54+ data_time , gpu_time , loss_time , back_time = 0 , 0 , 0 , 0
55+ self .net .train ()
56+ for it in range (iter_valid ):
57+ t0 = time .time ()
58+ try :
59+ data = next (self .trainIter )
60+ except StopIteration :
61+ self .trainIter = iter (self .train_loader )
62+ data = next (self .trainIter )
63+ t1 = time .time ()
64+ # get the inputs
65+ inputs = self .convert_tensor_type (data ['image' ])
66+ y = self .convert_tensor_type (data ['label_prob' ])
67+
68+ inputs , y = inputs .to (self .device ), y .to (self .device )
69+
70+ # zero the parameter gradients
71+ self .optimizer .zero_grad ()
72+
73+ # forward + backward + optimize
74+ outputs1 , outputs2 = self .net (inputs )
75+ t2 = time .time ()
76+
77+ loss_sup1 = self .get_loss_value (data , outputs1 , y )
78+ loss_sup2 = self .get_loss_value (data , outputs2 , y )
79+ loss_sup = 0.5 * (loss_sup1 + loss_sup2 )
80+
81+ # get pseudo label with dynamical mix
82+ outputs_soft1 = torch .softmax (outputs1 , dim = 1 )
83+ outputs_soft2 = torch .softmax (outputs2 , dim = 1 )
84+ beta = random .random ()
85+ pseudo_lab = beta * outputs_soft1 .detach () + (1.0 - beta )* outputs_soft2 .detach ()
86+ # pseudo_lab = torch.argmax(pseudo_lab, dim = 1, keepdim = True)
87+ # pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type)
88+
89+ # calculate the pseudo label supervision loss
90+ loss_calculator = DiceLoss () if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss ()
91+ loss_dict1 = {"prediction" :outputs1 , 'ground_truth' :pseudo_lab }
92+ loss_dict2 = {"prediction" :outputs2 , 'ground_truth' :pseudo_lab }
93+ loss_reg = 0.5 * (loss_calculator (loss_dict1 ) + loss_calculator (loss_dict2 ))
94+
95+ rampup_ratio = get_rampup_ratio (self .glob_it , rampup_start , rampup_end , "sigmoid" )
96+ regular_w = wsl_cfg .get ('regularize_w' , 0.1 ) * rampup_ratio
97+ loss = loss_sup + regular_w * loss_reg
98+ t3 = time .time ()
99+ loss .backward ()
100+ t4 = time .time ()
101+ self .optimizer .step ()
102+
103+ train_loss = train_loss + loss .item ()
104+ train_loss_sup = train_loss_sup + loss_sup .item ()
105+ train_loss_reg = train_loss_reg + loss_reg .item ()
106+ # get dice evaluation for each class in annotated images
107+ if (isinstance (outputs1 , tuple ) or isinstance (outputs1 , list )):
108+ outputs1 = outputs1 [0 ]
109+ p_argmax = torch .argmax (outputs1 , dim = 1 , keepdim = True )
110+ p_soft = get_soft_label (p_argmax , class_num , self .tensor_type )
111+ p_soft , y = reshape_prediction_and_ground_truth (p_soft , y )
112+ dice_list = get_classwise_dice (p_soft , y )
113+ train_dice_list .append (dice_list .cpu ().numpy ())
114+
115+ data_time = data_time + t1 - t0
116+ gpu_time = gpu_time + t2 - t1
117+ loss_time = loss_time + t3 - t2
118+ back_time = back_time + t4 - t3
119+ train_avg_loss = train_loss / iter_valid
120+ train_avg_loss_sup = train_loss_sup / iter_valid
121+ train_avg_loss_reg = train_loss_reg / iter_valid
122+ train_cls_dice = np .asarray (train_dice_list ).mean (axis = 0 )
123+ train_avg_dice = train_cls_dice [1 :].mean ()
124+
125+ train_scalers = {'loss' : train_avg_loss , 'loss_sup' :train_avg_loss_sup ,
126+ 'loss_reg' :train_avg_loss_reg , 'regular_w' :regular_w ,
127+ 'avg_fg_dice' :train_avg_dice , 'class_dice' : train_cls_dice ,
128+ 'data_time' : data_time , 'forward_time' :gpu_time ,
129+ 'loss_time' :loss_time , 'backward_time' :back_time }
130+ return train_scalers
131+
0 commit comments