Skip to content

Commit 4baccae

Browse files
committed
update files for DMSPS
1 parent 4590b30 commit 4baccae

File tree

8 files changed

+137
-33
lines changed

8 files changed

+137
-33
lines changed

pymic/io/image_read_write.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None, spacing
9999
:param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided.
100100
"""
101101
img = sitk.GetImageFromArray(data)
102-
if(reference_name is not None):
102+
if((reference_name is not None) and (not reference_name.endswith(".h5"))):
103103
img_ref = sitk.ReadImage(reference_name)
104104
#img.CopyInformation(img_ref)
105105
img.SetSpacing(img_ref.GetSpacing())
@@ -141,11 +141,15 @@ def save_nd_array_as_image(data, image_name, reference_name = None, spacing = [1
141141
"""
142142
data_dim = len(data.shape)
143143
assert(data_dim == 2 or data_dim == 3)
144+
if(image_name.endswith(".h5")):
145+
if(data_dim == 3):
146+
image_name = image_name.replace(".h5", ".nii.gz")
147+
else:
148+
image_name = image_name.replace(".h5", ".png")
144149
if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or
145150
image_name.endswith(".mha")):
146151
assert(data_dim == 3)
147152
save_array_as_nifty_volume(data, image_name, reference_name, spacing)
148-
149153
elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or
150154
image_name.endswith(".tif") or image_name.endswith(".png")):
151155
assert(data_dim == 2)

pymic/loss/loss_dict_seg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, \
2727
NoiseRobustDiceLoss, BinaryDiceLoss, GroupDiceLoss
2828
from pymic.loss.seg.exp_log import ExpLogLoss
29+
from pymic.loss.seg.ars_tversky import ARSTverskyLoss
2930
from pymic.loss.seg.mse import MSELoss, MAELoss
3031
from pymic.loss.seg.slsr import SLSRLoss
3132

@@ -35,6 +36,7 @@
3536
'DiceLoss': DiceLoss,
3637
'BinaryDiceLoss': BinaryDiceLoss,
3738
'FocalDiceLoss': FocalDiceLoss,
39+
'ARSTverskyLoss': ARSTverskyLoss,
3840
'NoiseRobustDiceLoss': NoiseRobustDiceLoss,
3941
'GroupDiceLoss': GroupDiceLoss,
4042
'ExpLogLoss': ExpLogLoss,

pymic/loss/seg/ce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, loss_input_dict):
3636
soft_y = reshape_tensor_to_2D(soft_y)
3737

3838
# for numeric stability
39-
predict = predict * 0.999 + 5e-4
39+
# predict = predict * (1-1e-10) + 0.5e-10
4040
ce = - soft_y* torch.log(predict)
4141
if(cls_w is not None):
4242
ce = torch.sum(ce*cls_w, dim = 1)
@@ -46,7 +46,7 @@ def forward(self, loss_input_dict):
4646
ce = torch.mean(ce)
4747
else:
4848
pix_w = torch.squeeze(reshape_tensor_to_2D(pix_w))
49-
ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-5)
49+
ce = torch.sum(pix_w * ce) / (pix_w.sum() + 1e-10)
5050
return ce
5151

5252
class GeneralizedCELoss(AbstractSegLoss):

pymic/net/net2d/unet2d_multi_decoder.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,16 @@ def forward(self, x):
6363
output2 = torch.reshape(output2, new_shape)
6464
output2 = torch.transpose(output2, 1, 2)
6565

66-
if(self.training):
67-
return output1, output2
68-
else:
69-
if(self.output_mode == "average"):
70-
return (output1 + output2)/2
71-
elif(self.output_mode == "first"):
72-
return output1
73-
else:
74-
return output2
66+
return output1, output2
67+
# if(self.training):
68+
# return output1, output2
69+
# else:
70+
# if(self.output_mode == "average"):
71+
# return (output1 + output2)/2
72+
# elif(self.output_mode == "first"):
73+
# return output1
74+
# else:
75+
# return output2
7576

7677
class UNet2D_TriBranch(nn.Module):
7778
"""

pymic/net_run/weak_sup/wsl_dmsps.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import print_function, division
33
import logging
4+
import os
45
import numpy as np
56
import random
67
import time
78
import torch
8-
from PIL import Image
9+
import scipy
10+
from pymic.io.image_read_write import save_nd_array_as_image
911
from pymic.loss.seg.util import get_soft_label
1012
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
1113
from pymic.loss.seg.util import get_classwise_dice
@@ -133,4 +135,68 @@ def training(self):
133135
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice,
134136
'data_time': data_time, 'forward_time':gpu_time,
135137
'loss_time':loss_time, 'backward_time':back_time}
136-
return train_scalers
138+
return train_scalers
139+
140+
def save_outputs(self, data):
141+
"""
142+
Save prediction output.
143+
144+
:param data: (dictionary) A data dictionary with prediciton result and other
145+
information such as input image name.
146+
"""
147+
output_dir = self.config['testing']['output_dir']
148+
test_mode = self.config['testing'].get('dmsps_test_mode', 0)
149+
uct_threshold = self.config['testing'].get('dmsps_uncertainty_threshold', 0.1)
150+
# DMSPS_test_mode == 0: only save the segmentation label for the main decoder
151+
# DMSPS_test_mode == 1: save all the results, including the the probability map of each decoder,
152+
# the uncertainty map, and the confident predictions
153+
if(not os.path.exists(output_dir)):
154+
os.makedirs(output_dir, exist_ok=True)
155+
156+
names, pred = data['names'], data['predict']
157+
pred0, pred1 = pred
158+
prob0 = scipy.special.softmax(pred0, axis = 1)
159+
prob1 = scipy.special.softmax(pred1, axis = 1)
160+
prob_mean = (prob0 + prob1) / 2
161+
lab0 = np.asarray(np.argmax(prob0, axis = 1), np.uint8)
162+
lab1 = np.asarray(np.argmax(prob1, axis = 1), np.uint8)
163+
lab_mean = np.asarray(np.argmax(prob_mean, axis = 1), np.uint8)
164+
165+
# save the output and (optionally) probability predictions
166+
test_dir = self.config['dataset'].get('test_dir', None)
167+
if(test_dir is None):
168+
test_dir = self.config['dataset']['train_dir']
169+
img_name = names[0][0].split('/')[-1]
170+
print(img_name)
171+
lab0_name = img_name
172+
if(".h5" in lab0_name):
173+
lab0_name = lab0_name.replace(".h5", ".nii.gz")
174+
save_nd_array_as_image(lab0[0], output_dir + "/" + lab0_name, test_dir + '/' + names[0][0])
175+
if(test_mode == 1):
176+
lab1_name = lab0_name.replace(".nii.gz", "_predaux.nii.gz")
177+
save_nd_array_as_image(lab1[0], output_dir + "/" + lab1_name, test_dir + '/' + names[0][0])
178+
C = pred0.shape[1]
179+
uct = -1.0 * np.sum(prob_mean * np.log(prob_mean), axis=1, keepdims=False)/ np.log(C)
180+
uct_name = lab0_name.replace(".nii.gz", "_uncertainty.nii.gz")
181+
save_nd_array_as_image(uct[0], output_dir + "/" + uct_name, test_dir + '/' + names[0][0])
182+
conf_mask = uct < uct_threshold
183+
conf_lab = conf_mask * lab_mean + (1 - conf_mask)*4
184+
conf_lab_name = lab0_name.replace(".nii.gz", "_seeds_expand.nii.gz")
185+
186+
# get the largest connected component in each slice for each class
187+
D, H, W = conf_lab[0].shape
188+
from pymic.util.image_process import get_largest_k_components
189+
for d in range(D):
190+
lab2d = conf_lab[0][d]
191+
for c in range(C):
192+
lab2d_c = lab2d == c
193+
mask_c = get_largest_k_components(lab2d_c, k = 1)
194+
diff = lab2d_c != mask_c
195+
if(np.sum(diff) > 0):
196+
lab2d[diff] = C
197+
conf_lab[0][d] = lab2d
198+
save_nd_array_as_image(conf_lab[0], output_dir + "/" + conf_lab_name, test_dir + '/' + img_name)
199+
200+
201+
202+

pymic/transform/rescale.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,22 @@ def inverse_transform_for_prediction(self, sample):
7272
origin_shape = json.loads(sample['Rescale_origin_shape'])
7373
origin_dim = len(origin_shape) - 1
7474
predict = sample['predict']
75-
input_shape = predict.shape
76-
scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \
77-
i in range(origin_dim)]
78-
scale = [1.0, 1.0] + scale
7975

80-
output_predict = ndimage.interpolation.zoom(predict, scale, order = 1)
76+
if(isinstance(predict, tuple) or isinstance(predict, list)):
77+
output_predict = []
78+
for predict_i in predict:
79+
input_shape = predict_i.shape
80+
scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \
81+
i in range(origin_dim)]
82+
scale = [1.0, 1.0] + scale
83+
output_predict_i = ndimage.interpolation.zoom(predict_i, scale, order = 1)
84+
output_predict.append(output_predict_i)
85+
else:
86+
input_shape = predict.shape
87+
scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \
88+
i in range(origin_dim)]
89+
scale = [1.0, 1.0] + scale
90+
output_predict = ndimage.interpolation.zoom(predict, scale, order = 1)
8191
sample['predict'] = output_predict
8292
return sample
8393

pymic/util/evaluation_seg.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,10 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric):
212212
score = binary_iou(s_volume,g_volume)
213213
elif(metric_lower == 'assd'):
214214
score = binary_assd(s_volume, g_volume, spacing)
215+
score = min(score, 20) # to reject outliers
215216
elif(metric_lower == "hd95"):
216217
score = binary_hd95(s_volume, g_volume, spacing)
218+
score = min(score, 50) # to reject outliers
217219
elif(metric_lower == "rve"):
218220
score = binary_relative_volume_error(s_volume, g_volume)
219221
elif(metric_lower == "volume"):
@@ -269,8 +271,8 @@ def evaluation(config):
269271
:param label_fuse: (option, bool) If true, fuse the labels in the `label_list`
270272
as the foreground, and other labels as the background. Default is False.
271273
:param organ_name: (str) The name of the organ for segmentation.
272-
:param ground_truth_folder_root: (str) The root dir of ground truth images.
273-
:param segmentation_folder_root: (str or list) The root dir of segmentation images.
274+
:param ground_truth_folder: (str) The root dir of ground truth images.
275+
:param segmentation_folder: (str or list) The root dir of segmentation images.
274276
When a list is given, each list element should be the root dir of the results of one method.
275277
:param evaluation_image_pair: (str) The csv file that provide the segmentation
276278
images and the corresponding ground truth images.
@@ -366,23 +368,23 @@ def main():
366368
367369
"""
368370
parser = argparse.ArgumentParser()
369-
parser.add_argument("-cfg", help="configuration file for evaluation",
371+
parser.add_argument("--cfg", help="configuration file for evaluation",
370372
required=False, default=None)
371-
parser.add_argument("-metric", help="evaluation metrics, e.g., dice, or [dice, assd]",
373+
parser.add_argument("--metric", help="evaluation metrics, e.g., dice, or [dice, assd]",
372374
required=False, default=None)
373-
parser.add_argument("-cls_num", help="number of classes",
375+
parser.add_argument("--cls_num", help="number of classes",
374376
required=False, default=None)
375-
parser.add_argument("-cls_index", help="The class index for evaluation, e.g., 255, [1, 2]",
377+
parser.add_argument("--cls_index", help="The class index for evaluation, e.g., 255, [1, 2]",
376378
required=False, default=None)
377-
parser.add_argument("-gt_dir", help="path of folder for ground truth",
379+
parser.add_argument("--gt_dir", help="path of folder for ground truth",
378380
required=False, default=None)
379-
parser.add_argument("-seg_dir", help="path of folder for segmentation",
381+
parser.add_argument("--seg_dir", help="path of folder for segmentation",
380382
required=False, default=None)
381-
parser.add_argument("-name_pair", help="the .csv file for name mapping in case"
383+
parser.add_argument("--name_pair", help="the .csv file for name mapping in case"
382384
" the names of one case are different in the gt_dir "
383385
" and seg_dir",
384386
required=False, default=None)
385-
parser.add_argument("-out", help="the output .csv file name",
387+
parser.add_argument("--out", help="the output .csv file name",
386388
required=False, default=None)
387389
args = parser.parse_args()
388390
print(args)
@@ -402,5 +404,3 @@ def main():
402404

403405
if __name__ == '__main__':
404406
main()
405-
406-
main()

pymic/util/parse_config.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,22 @@ def parse_config(args):
9696
args_key = getattr(args, key)
9797
if(args_key is not None):
9898
val_str = args_key
99+
print(section, key, val_str)
99100
if(len(val_str)>0):
100101
val = parse_value_from_string(val_str)
101102
output[section][key] = val
102103
else:
103104
val = None
104-
print(section, key, val)
105+
106+
for key in ["train_dir", "train_csv", "valid_csv", "test_dir", "test_csv"]:
107+
if key in args and getattr(args, key) is not None:
108+
output["dataset"][key] = parse_value_from_string(getattr(args, key))
109+
for key in ["ckpt_dir", "iter_max", "gpus"]:
110+
if key in args and getattr(args, key) is not None:
111+
output["training"][key] = parse_value_from_string(getattr(args, key))
112+
for key in ["output_dir", "ckpt_mode", "ckpt_name"]:
113+
if key in args and getattr(args, key) is not None:
114+
output["testing"][key] = parse_value_from_string(getattr(args, key))
105115
return output
106116

107117
def synchronize_config(config):
@@ -133,6 +143,17 @@ def synchronize_config(config):
133143
if('RandomResizedCrop' in transform and \
134144
'RandomResizedCrop_output_size'.lower() not in data_cfg):
135145
data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size
146+
if('testing' in config):
147+
test_cfg = config['testing']
148+
sliding_window_enable = test_cfg.get("sliding_window_enable", False)
149+
if(sliding_window_enable):
150+
sliding_window_size = test_cfg.get("sliding_window_size", None)
151+
if(sliding_window_size is None):
152+
test_cfg["sliding_window_size"] = patch_size
153+
sliding_window_stride = test_cfg.get("sliding_window_stride", None)
154+
if(sliding_window_stride is None):
155+
test_cfg["sliding_window_stride"] = [item // 2 for item in patch_size]
156+
config['testing'] = test_cfg
136157
config['dataset'] = data_cfg
137158
# config['network'] = net_cfg
138159
return config

0 commit comments

Comments
 (0)