Skip to content

Commit 4590b30

Browse files
committed
update dataset and transform
fix code for loading h5 images update transforms
1 parent a4634f2 commit 4590b30

File tree

5 files changed

+94
-61
lines changed

5 files changed

+94
-61
lines changed

pymic/io/nifty_dataset.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ class NiftyDataset(Dataset):
4141
:param transform: (list) List of transforms to be applied on a sample.
4242
The built-in transforms can listed in :mod:`pymic.transform.trans_dict`.
4343
"""
44-
def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missing_modal = False,
45-
with_label = True, transform=None, task = TaskType.SEGMENTATION):
44+
def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3,
45+
allow_missing_modal = False, label_key = "label",
46+
transform=None, task = TaskType.SEGMENTATION):
4647
self.root_dir = root_dir
4748
if(csv_file is not None):
4849
self.csv_items = pd.read_csv(csv_file)
@@ -56,10 +57,11 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi
5657
self.modal_num = modal_num
5758
self.image_dim = image_dim
5859
self.allow_emtpy= allow_missing_modal
59-
self.with_label = with_label
60+
self.label_key = label_key
6061
self.transform = transform
6162
self.task = task
6263
self.h5files = False
64+
self.with_label = True
6365
assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]
6466

6567
# check if the files are h5 images, and if the labels are provided.
@@ -69,11 +71,11 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi
6971
self.h5files = True
7072
temp_full_name = "{0:}/{1:}".format(self.root_dir, temp_name)
7173
h5f = h5py.File(temp_full_name, 'r')
72-
if('label' not in h5f):
74+
if(self.label_key not in h5f):
7375
self.with_label = False
7476
else:
7577
csv_keys = list(self.csv_items.keys())
76-
if('label' not in csv_keys):
78+
if(self.label_key not in csv_keys):
7779
self.with_label = False
7880

7981
self.image_weight_idx = None
@@ -84,7 +86,7 @@ def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missi
8486
self.pixel_weight_idx = csv_keys.index('pixel_weight')
8587
if(not self.with_label):
8688
logging.warning("`label` section is not found in the csv file {0:}".format(
87-
csv_file) + "or the corresponding h5 file." +
89+
csv_file) + " or the corresponding h5 file." +
8890
"\n -- This is only allowed for self-supervised learning" +
8991
"\n -- when `SelfSuperviseLabel` is used in the transform, or when" +
9092
"\n -- loading the unlabeled data for preprocessing.")
@@ -94,7 +96,7 @@ def __len__(self):
9496

9597
def __getlabel__(self, idx):
9698
csv_keys = list(self.csv_items.keys())
97-
label_idx = csv_keys.index('label')
99+
label_idx = csv_keys.index(self.label_key)
98100
label_name = self.csv_items.iloc[idx, label_idx]
99101
label_name_full = "{0:}/{1:}".format(self.root_dir, label_name)
100102
label = load_image_as_nd_array(label_name_full)['data_array']
@@ -139,8 +141,8 @@ def __getitem__(self, idx):
139141
img = check_and_expand_dim(h5f['image'][:], self.image_dim)
140142
sample = {'image':img}
141143
if(self.with_label):
142-
lab = check_and_expand_dim(h5f['label'][:], self.image_dim)
143-
sample['label'] = lab
144+
lab = check_and_expand_dim(h5f[self.label_key][:], self.image_dim)
145+
sample['label'] = np.asarray(lab, np.float32)
144146
sample['names'] = [sample_name]
145147
else:
146148
for i in range (self.modal_num):

pymic/net_run/agent_seg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,22 @@ def get_stage_dataset_from_config(self, stage):
7272
allow_miss = self.config['dataset'].get('allow_missing_modal', False)
7373
stage_dir = self.config['dataset'].get('train_dir', None)
7474
stage_dim = self.config['dataset'].get('train_dim', 3)
75+
stage_lab_key = self.config['dataset'].get('train_label_key', 'label')
7576
if(stage == 'valid'): # and "valid_dir" in self.config['dataset']):
7677
stage_dir = self.config['dataset'].get('valid_dir', stage_dir)
7778
stage_dim = self.config['dataset'].get('valid_dim', stage_dim)
79+
stage_lab_key = self.config['dataset'].get('valid_label_key', 'label')
7880
if(stage == 'test'): # and "test_dir" in self.config['dataset']):
7981
stage_dir = self.config['dataset'].get('test_dir', stage_dir)
8082
stage_dim = self.config['dataset'].get('test_dim', stage_dim)
83+
stage_lab_key = self.config['dataset'].get('test_label_key', 'label')
8184
logging.info("Creating dataset for {0:}".format(stage))
8285
dataset = NiftyDataset(root_dir = stage_dir,
8386
csv_file = csv_file,
8487
modal_num = modal_num,
8588
image_dim = stage_dim,
8689
allow_missing_modal = allow_miss,
87-
with_label= with_label,
90+
label_key = stage_lab_key,
8891
transform = data_transform,
8992
task = self.task_type)
9093
return dataset

pymic/net_run/weak_sup/wsl_dmsps.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import random
66
import time
77
import torch
8+
from PIL import Image
89
from pymic.loss.seg.util import get_soft_label
910
from pymic.loss.seg.util import reshape_prediction_and_ground_truth
1011
from pymic.loss.seg.util import get_classwise_dice
1112
from pymic.loss.seg.dice import DiceLoss
1213
from pymic.loss.seg.ce import CrossEntropyLoss
14+
# from torch.nn.modules.loss import CrossEntropyLoss as TorchCELoss
1315
from pymic.net_run.weak_sup import WSLSegAgent
1416
from pymic.util.ramps import get_rampup_ratio
1517

@@ -33,11 +35,11 @@ class WSLDMSPS(WSLSegAgent):
3335
"""
3436
def __init__(self, config, stage = 'train'):
3537
net_type = config['network']['net_type']
36-
if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']:
37-
raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \
38-
It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""")
38+
# if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']:
39+
# raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \
40+
# It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""")
3941
super(WSLDMSPS, self).__init__(config, stage)
40-
42+
4143
def training(self):
4244
class_num = self.config['network']['class_num']
4345
iter_valid = self.config['training']['iter_valid']
@@ -49,10 +51,12 @@ def training(self):
4951
if (pseudo_loss_type not in ('dice_loss', 'ce_loss')):
5052
raise ValueError("""For pseudo supervision loss, only dice_loss and ce_loss \
5153
are supported.""")
54+
pseudo_loss_func = CrossEntropyLoss() if pseudo_loss_type == 'ce_loss' else DiceLoss()
5255
train_loss, train_loss_sup, train_loss_reg = 0, 0, 0
5356
train_dice_list = []
5457
data_time, gpu_time, loss_time, back_time = 0, 0, 0, 0
5558
self.net.train()
59+
# ce_loss = CrossEntropyLoss()
5660
for it in range(iter_valid):
5761
t0 = time.time()
5862
try:
@@ -66,7 +70,6 @@ def training(self):
6670
y = self.convert_tensor_type(data['label_prob'])
6771

6872
inputs, y = inputs.to(self.device), y.to(self.device)
69-
7073
# zero the parameter gradients
7174
self.optimizer.zero_grad()
7275

@@ -78,23 +81,26 @@ def training(self):
7881
loss_sup2 = self.get_loss_value(data, outputs2, y)
7982
loss_sup = 0.5 * (loss_sup1 + loss_sup2)
8083

81-
# get pseudo label with dynamical mix
84+
# torch_ce_loss = TorchCELoss(ignore_index=class_num)
85+
# torch_ce_loss2 = TorchCELoss()
86+
# loss_ce1 = torch_ce_loss(outputs1, label[:].long())
87+
# loss_ce2 = torch_ce_loss(outputs2, label[:].long())
88+
# loss_sup = 0.5 * (loss_ce1 + loss_ce2)
89+
90+
# get pseudo label with dynamic mixture
8291
outputs_soft1 = torch.softmax(outputs1, dim=1)
8392
outputs_soft2 = torch.softmax(outputs2, dim=1)
84-
beta = random.random()
85-
pseudo_lab = beta*outputs_soft1.detach() + (1.0-beta)*outputs_soft2.detach()
86-
# pseudo_lab = torch.argmax(pseudo_lab, dim = 1, keepdim = True)
87-
# pseudo_lab = get_soft_label(pseudo_lab, class_num, self.tensor_type)
88-
89-
# calculate the pseudo label supervision loss
90-
loss_calculator = DiceLoss() if pseudo_loss_type == 'dice_loss' else CrossEntropyLoss()
91-
loss_dict1 = {"prediction":outputs1, 'ground_truth':pseudo_lab}
92-
loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab}
93-
loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2))
93+
alpha = random.random()
94+
soft_pseudo_label = alpha * outputs_soft1.detach() + (1.0-alpha) * outputs_soft2.detach()
95+
# loss_reg = 0.5*(torch_ce_loss2(outputs_soft1, soft_pseudo_label) +torch_ce_loss2(outputs_soft2, soft_pseudo_label) )
9496

97+
loss_dict1 = {"prediction":outputs_soft1, 'ground_truth':soft_pseudo_label}
98+
loss_dict2 = {"prediction":outputs_soft2, 'ground_truth':soft_pseudo_label}
99+
loss_reg = 0.5 * (pseudo_loss_func(loss_dict1) + pseudo_loss_func(loss_dict2))
100+
95101
rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid")
96-
regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio
97-
loss = loss_sup + regular_w*loss_reg
102+
regular_w = wsl_cfg.get('regularize_w', 8.0) * rampup_ratio
103+
loss = loss_sup + regular_w*loss_reg
98104
t3 = time.time()
99105
loss.backward()
100106
t4 = time.time()
@@ -127,5 +133,4 @@ def training(self):
127133
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice,
128134
'data_time': data_time, 'forward_time':gpu_time,
129135
'loss_time':loss_time, 'backward_time':back_time}
130-
return train_scalers
131-
136+
return train_scalers

pymic/transform/rotate.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,13 @@ class RandomRotate(AbstractTransform):
1919
2020
:param `RandomRotate_angle_range_d`: (list/tuple or None)
2121
Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90).
22-
The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True.
2322
If None, no rotation along this axis.
2423
:param `RandomRotate_angle_range_h`: (list/tuple or None)
2524
Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90).
26-
The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True.
2725
If None, no rotation along this axis. Only used for 3D images.
2826
:param `RandomRotate_angle_range_w`: (list/tuple or None)
2927
Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90).
30-
The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True.
3128
If None, no rotation along this axis. Only used for 3D images.
32-
:param `RandomRotate_discrete_mode`: (optional, bool) Whether the rotate angles
33-
are discrete values in rangle range. For example, if you only want to rotate
34-
the images with a fixed set of angles like (90, 180, 270), then set discrete_mode mode as True.
3529
:param `RandomRotate_probability`: (optional, float)
3630
The probability of applying RandomRotate. Default is 0.5.
3731
:param `RandomRotate_inverse`: (optional, bool)
@@ -42,11 +36,8 @@ def __init__(self, params):
4236
self.angle_range_d = params['RandomRotate_angle_range_d'.lower()]
4337
self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None)
4438
self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None)
45-
self.discrete_mode = params.get('RandomRotate_discrete_mode'.lower(), False)
4639
self.prob = params.get('RandomRotate_probability'.lower(), 0.5)
4740
self.inverse = params.get('RandomRotate_inverse'.lower(), True)
48-
if(len(self.angle_range_d) > 2):
49-
assert(self.discrete_mode)
5041

5142
def __apply_transformation(self, image, transform_param_list, order = 1):
5243
"""
@@ -61,38 +52,21 @@ def __apply_transformation(self, image, transform_param_list, order = 1):
6152
return image
6253

6354
def __call__(self, sample):
64-
# if(random.random() > self.prob):
65-
# sample['RandomRotate_triggered'] = False
66-
# return sample
67-
# else:
68-
# sample['RandomRotate_triggered'] = True
6955
image = sample['image']
7056
input_shape = image.shape
7157
input_dim = len(input_shape) - 1
7258

7359
transform_param_list = []
7460
if(self.angle_range_d is not None):
75-
if(self.discrete_mode):
76-
idx = random.randint(0, len(self.angle_range_d) - 1)
77-
angle_d = self.angle_range_d[idx]
78-
else:
79-
angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1])
61+
angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1])
8062
transform_param_list.append([angle_d, (-1, -2)])
8163
if(input_dim == 3):
8264
if(self.angle_range_h is not None):
83-
if(self.discrete_mode):
84-
idx = random.randint(0, len(self.angle_range_h) - 1)
85-
angle_h = self.angle_range_h[idx]
86-
else:
87-
angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1])
88-
transform_param_list.append([angle_h, (-1, -3)])
65+
angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1])
66+
transform_param_list.append([angle_h, (-1, -3)])
8967
if(self.angle_range_w is not None):
90-
if(self.discrete_mode):
91-
idx = random.randint(0, len(self.angle_range_w) - 1)
92-
angle_w = self.angle_range_w[idx]
93-
else:
94-
angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1])
95-
transform_param_list.append([angle_w, (-2, -3)])
68+
angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1])
69+
transform_param_list.append([angle_w, (-2, -3)])
9670
assert(len(transform_param_list) > 0)
9771
# select a random transform from the possible list rather than
9872
# use a combination for higher efficiency
@@ -123,4 +97,49 @@ def inverse_transform_for_prediction(self, sample):
12397
transform_param_list[i][0] = - transform_param_list[i][0]
12498
sample['predict'] = self.__apply_transformation(sample['predict'] ,
12599
transform_param_list, 1)
100+
return sample
101+
102+
class RandomRot90(AbstractTransform):
103+
"""
104+
Random rotate an image in x-y plane with angles in [90, 180, 270].
105+
106+
The arguments should be written in the `params` dictionary, and it has the
107+
following fields:
108+
109+
:param `RandomRot90_probability`: (optional, float)
110+
The probability of applying RandomRot90. Default is 0.75.
111+
:param `RandomRot90_inverse`: (optional, bool)
112+
Is inverse transform needed for inference. Default is `True`.
113+
"""
114+
def __init__(self, params):
115+
super(RandomRot90, self).__init__(params)
116+
self.prob = params.get('RandomRot90_probability'.lower(), 0.75)
117+
self.inverse = params.get('RandomRot90_inverse'.lower(), True)
118+
119+
def __call__(self, sample):
120+
if(random.random() > self.prob):
121+
sample['RandomRot90_triggered'] = False
122+
sample['RandomRot90_Param'] = 0
123+
return sample
124+
else:
125+
sample['RandomRot90_triggered'] = True
126+
image = sample['image']
127+
rote_k = random.randint(1, 3)
128+
sample['RandomRot90_Param'] = rote_k
129+
image_t = np.rot90(image, rote_k, (-2, -1))
130+
sample['image'] = image_t
131+
if('label' in sample and \
132+
self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]):
133+
sample['label'] = np.rot90(sample['label'], rote_k, (-2, -1))
134+
if('pixel_weight' in sample and \
135+
self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]):
136+
sample['pixel_weight'] = np.rot90(sample['pixel_weight'], rote_k, (-2, -1))
137+
return sample
138+
139+
def inverse_transform_for_prediction(self, sample):
140+
if(not sample['RandomRot90_triggered']):
141+
return sample
142+
rote_k = sample['RandomRot90_Param']
143+
rote_i = 4 - rote_k
144+
sample['predict'] = np.rot90(sample['predict'], rote_i, (-2, -1))
126145
return sample

pymic/transform/trans_dict.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
'RandomRescale': RandomRescale,
2626
'RandomFlip': RandomFlip,
2727
'RandomRotate': RandomRotate,
28+
'RandomRot90': RandomRot90,
2829
'ReduceLabelDim': ReduceLabelDim,
2930
'Rescale': Rescale,
3031
'SelfSuperviseLabel': SelfSuperviseLabel,
@@ -43,6 +44,7 @@
4344
from pymic.transform.normalize import *
4445
from pymic.transform.crop import *
4546
from pymic.transform.crop4dino import Crop4Dino
47+
from pymic.transform.crop4voco import Crop4VoCo
4648
from pymic.transform.crop4vox2vec import Crop4Vox2Vec
4749
from pymic.transform.crop4vf import Crop4VolumeFusion, VolumeFusion, VolumeFusionShuffle
4850
from pymic.transform.label_convert import *
@@ -57,6 +59,7 @@
5759
'CropHumanRegion': CropHumanRegion,
5860
'CenterCrop': CenterCrop,
5961
'Crop4Dino': Crop4Dino,
62+
'Crop4VoCo': Crop4VoCo,
6063
'Crop4Vox2Vec': Crop4Vox2Vec,
6164
'Crop4VolumeFusion': Crop4VolumeFusion,
6265
'GrayscaleToRGB': GrayscaleToRGB,
@@ -83,6 +86,7 @@
8386
'RandomTranspose': RandomTranspose,
8487
'RandomFlip': RandomFlip,
8588
'RandomRotate': RandomRotate,
89+
'RandomRot90': RandomRot90,
8690
'ReduceLabelDim': ReduceLabelDim,
8791
'Rescale': Rescale,
8892
'Resample': Resample,

0 commit comments

Comments
 (0)