Skip to content

Commit a4634f2

Browse files
committed
add support to h5 files
1, add support to h5 files 2, edit Rescale, Pad and Rotate so that setting output size to a 2D list is allowed for 3D images
1 parent 70736d5 commit a4634f2

File tree

6 files changed

+143
-60
lines changed

6 files changed

+143
-60
lines changed

pymic/io/image_read_write.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def load_image_as_nd_array(image_name):
8181
if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or
8282
image_name.endswith(".mha")):
8383
image_dict = load_nifty_volume_as_4d_array(image_name)
84-
elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or
85-
image_name.endswith(".tif") or image_name.endswith(".png")):
84+
elif(image_name.lower().endswith(".jpg") or image_name.lower().endswith(".jpeg") or
85+
image_name.lower().endswith(".tif") or image_name.lower().endswith(".png")):
8686
image_dict = load_rgb_image_as_3d_array(image_name)
8787
else:
8888
raise ValueError("unsupported image format: {0:}".format(image_name))

pymic/io/nifty_dataset.py

Lines changed: 97 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,50 +3,91 @@
33

44
import logging
55
import os
6+
import h5py
67
import pandas as pd
78
import numpy as np
89
from torch.utils.data import Dataset
910
from pymic import TaskType
1011
from pymic.io.image_read_write import load_image_as_nd_array
1112

13+
def check_and_expand_dim(x, img_dim):
14+
"""
15+
check the input dim and expand it with a channel dimension if necessary.
16+
For 2D images, return a 3D numpy array with a shape of [C, H, W]
17+
for 3D images, return a 3D numpy array with a shape of [C, D, H, W]
18+
"""
19+
input_dim = len(x.shape)
20+
if(input_dim == 2 and img_dim == 2):
21+
x = np.expand_dims(x, axis = 0)
22+
elif(input_dim == 3 and img_dim == 3):
23+
x = np.expand_dims(x, axis = 0)
24+
return x
25+
1226
class NiftyDataset(Dataset):
1327
"""
1428
Dataset for loading images for segmentation. It generates 4D tensors with
1529
dimention order [C, D, H, W] for 3D images, and 3D tensors
1630
with dimention order [C, H, W] for 2D images.
1731
1832
:param root_dir: (str) Directory with all the images.
19-
:param csv_file: (str) Path to the csv file with image names.
20-
:param modal_num: (int) Number of modalities.
33+
:param csv: (str) Path to the csv file with image names. If it is None,
34+
the images will be those under root_dir. This only works for testing with
35+
a single input modality. If the images are stored in h5 files, the *.csv file
36+
only has one column, while for other types of images such as .nii.gz and.png,
37+
each column is for an input modality, and the last column is for label.
38+
:param modal_num: (int) Number of modalities. This is only used if the data_file is *.csv.
39+
:param image_dim: (int) Spacial dimension of the input image. This is ony used for h5 files.
2140
:param with_label: (bool) Load the data with segmentation ground truth or not.
2241
:param transform: (list) List of transforms to be applied on a sample.
2342
The built-in transforms can listed in :mod:`pymic.transform.trans_dict`.
2443
"""
25-
# def __init__(self, root_dir, csv_file, modal_num = 1,
26-
def __init__(self, root_dir, csv_file, modal_num = 1, allow_missing_modal = False,
27-
with_label = False, transform=None, task = TaskType.SEGMENTATION):
44+
def __init__(self, root_dir, csv_file, modal_num = 1, image_dim = 3, allow_missing_modal = False,
45+
with_label = True, transform=None, task = TaskType.SEGMENTATION):
2846
self.root_dir = root_dir
29-
self.csv_items = pd.read_csv(csv_file)
47+
if(csv_file is not None):
48+
self.csv_items = pd.read_csv(csv_file)
49+
else:
50+
img_names = os.listdir(root_dir)
51+
img_names = [item for item in img_names if ("nii" in item or "jpg" in item or
52+
"jpeg" in item or "bmp" in item or "png" in item)]
53+
csv_dict = {"image":img_names}
54+
self.csv_items = pd.DataFrame.from_dict(csv_dict)
55+
3056
self.modal_num = modal_num
57+
self.image_dim = image_dim
3158
self.allow_emtpy= allow_missing_modal
3259
self.with_label = with_label
3360
self.transform = transform
3461
self.task = task
62+
self.h5files = False
3563
assert self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]
3664

37-
csv_keys = list(self.csv_items.keys())
38-
if('label' not in csv_keys):
65+
# check if the files are h5 images, and if the labels are provided.
66+
temp_name = self.csv_items.iloc[0, 0]
67+
logging.warning(temp_name)
68+
if(temp_name.endswith(".h5")):
69+
self.h5files = True
70+
temp_full_name = "{0:}/{1:}".format(self.root_dir, temp_name)
71+
h5f = h5py.File(temp_full_name, 'r')
72+
if('label' not in h5f):
73+
self.with_label = False
74+
else:
75+
csv_keys = list(self.csv_items.keys())
76+
if('label' not in csv_keys):
77+
self.with_label = False
78+
79+
self.image_weight_idx = None
80+
self.pixel_weight_idx = None
81+
if('image_weight' in csv_keys):
82+
self.image_weight_idx = csv_keys.index('image_weight')
83+
if('pixel_weight' in csv_keys):
84+
self.pixel_weight_idx = csv_keys.index('pixel_weight')
85+
if(not self.with_label):
3986
logging.warning("`label` section is not found in the csv file {0:}".format(
40-
csv_file) + "\n -- This is only allowed for self-supervised learning" +
87+
csv_file) + "or the corresponding h5 file." +
88+
"\n -- This is only allowed for self-supervised learning" +
4189
"\n -- when `SelfSuperviseLabel` is used in the transform, or when" +
4290
"\n -- loading the unlabeled data for preprocessing.")
43-
self.with_label = False
44-
self.image_weight_idx = None
45-
self.pixel_weight_idx = None
46-
if('image_weight' in csv_keys):
47-
self.image_weight_idx = csv_keys.index('image_weight')
48-
if('pixel_weight' in csv_keys):
49-
self.pixel_weight_idx = csv_keys.index('pixel_weight')
5091

5192
def __len__(self):
5293
return len(self.csv_items)
@@ -92,36 +133,46 @@ def __get_pixel_weight__(self, idx):
92133
def __getitem__(self, idx):
93134
names_list, image_list = [], []
94135
image_shape = None
95-
for i in range (self.modal_num):
96-
image_name = self.csv_items.iloc[idx, i]
97-
image_full_name = "{0:}/{1:}".format(self.root_dir, image_name)
98-
if(os.path.exists(image_full_name)):
99-
image_dict = load_image_as_nd_array(image_full_name)
100-
image_data = image_dict['data_array']
101-
elif(self.allow_emtpy and image_shape is not None):
102-
image_data = np.zeros(image_shape)
103-
else:
104-
raise KeyError("File not found: {0:}".format(image_full_name))
105-
if(i == 0):
106-
image_shape = image_data.shape
107-
names_list.append(image_name)
108-
image_list.append(image_data)
109-
image = np.concatenate(image_list, axis = 0)
110-
image = np.asarray(image, np.float32)
111-
112-
sample = {'image': image, 'names' : names_list,
113-
'origin':image_dict['origin'],
114-
'spacing': image_dict['spacing'],
115-
'direction':image_dict['direction']}
116-
if (self.with_label):
117-
sample['label'], label_name = self.__getlabel__(idx)
118-
sample['names'].append(label_name)
119-
assert(image.shape[1:] == sample['label'].shape[1:])
120-
if (self.image_weight_idx is not None):
121-
sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx]
122-
if (self.pixel_weight_idx is not None):
123-
sample['pixel_weight'] = self.__get_pixel_weight__(idx)
124-
assert(image.shape[1:] == sample['pixel_weight'].shape[1:])
136+
if(self.h5files):
137+
sample_name = self.csv_items.iloc[idx, 0]
138+
h5f = h5py.File(self.root_dir + '/' + sample_name, 'r')
139+
img = check_and_expand_dim(h5f['image'][:], self.image_dim)
140+
sample = {'image':img}
141+
if(self.with_label):
142+
lab = check_and_expand_dim(h5f['label'][:], self.image_dim)
143+
sample['label'] = lab
144+
sample['names'] = [sample_name]
145+
else:
146+
for i in range (self.modal_num):
147+
image_name = self.csv_items.iloc[idx, i]
148+
image_full_name = "{0:}/{1:}".format(self.root_dir, image_name)
149+
if(os.path.exists(image_full_name)):
150+
image_dict = load_image_as_nd_array(image_full_name)
151+
image_data = image_dict['data_array']
152+
elif(self.allow_emtpy and image_shape is not None):
153+
image_data = np.zeros(image_shape)
154+
else:
155+
raise KeyError("File not found: {0:}".format(image_full_name))
156+
if(i == 0):
157+
image_shape = image_data.shape
158+
names_list.append(image_name)
159+
image_list.append(image_data)
160+
image = np.concatenate(image_list, axis = 0)
161+
image = np.asarray(image, np.float32)
162+
163+
sample = {'image': image, 'names' : names_list,
164+
'origin':image_dict['origin'],
165+
'spacing': image_dict['spacing'],
166+
'direction':image_dict['direction']}
167+
if (self.with_label):
168+
sample['label'], label_name = self.__getlabel__(idx)
169+
sample['names'].append(label_name)
170+
assert(image.shape[1:] == sample['label'].shape[1:])
171+
if (self.image_weight_idx is not None):
172+
sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx]
173+
if (self.pixel_weight_idx is not None):
174+
sample['pixel_weight'] = self.__get_pixel_weight__(idx)
175+
assert(image.shape[1:] == sample['pixel_weight'].shape[1:])
125176
if self.transform:
126177
sample = self.transform(sample)
127178

pymic/net_run/agent_seg.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,18 @@ def get_stage_dataset_from_config(self, stage):
7171
modal_num = self.config['dataset'].get('modal_num', 1)
7272
allow_miss = self.config['dataset'].get('allow_missing_modal', False)
7373
stage_dir = self.config['dataset'].get('train_dir', None)
74-
if(stage == 'valid' and "valid_dir" in self.config['dataset']):
75-
stage_dir = self.config['dataset']['valid_dir']
76-
if(stage == 'test' and "test_dir" in self.config['dataset']):
77-
stage_dir = self.config['dataset']['test_dir']
74+
stage_dim = self.config['dataset'].get('train_dim', 3)
75+
if(stage == 'valid'): # and "valid_dir" in self.config['dataset']):
76+
stage_dir = self.config['dataset'].get('valid_dir', stage_dir)
77+
stage_dim = self.config['dataset'].get('valid_dim', stage_dim)
78+
if(stage == 'test'): # and "test_dir" in self.config['dataset']):
79+
stage_dir = self.config['dataset'].get('test_dir', stage_dir)
80+
stage_dim = self.config['dataset'].get('test_dim', stage_dim)
7881
logging.info("Creating dataset for {0:}".format(stage))
7982
dataset = NiftyDataset(root_dir = stage_dir,
8083
csv_file = csv_file,
8184
modal_num = modal_num,
85+
image_dim = stage_dim,
8286
allow_missing_modal = allow_miss,
8387
with_label= with_label,
8488
transform = data_transform,

pymic/transform/pad.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def __call__(self, sample):
3838
image = sample['image']
3939
input_shape = image.shape
4040
input_dim = len(input_shape) - 1
41+
42+
if(input_dim == 3):
43+
if(len(self.output_size) == 2):
44+
# for 3D images, igore the z-axis
45+
self.output_size = [input_shape[1]] + list(self.output_size)
4146
assert(len(self.output_size) == input_dim)
4247
if(self.ceil_mode):
4348
multiple = [int(math.ceil(float(input_shape[1+i])/self.output_size[i]))\

pymic/transform/rescale.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class Rescale(AbstractTransform):
1717
following fields:
1818
1919
:param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis,
20-
such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D.
21-
If int, the smallest axis is matched to output_size keeping aspect ratio the same
22-
as the input.
20+
such as [D, H, W] or [H, W]. For 3D images, if D is None, or the lenght of tuple/list is 2,
21+
the input image is only reslcaled in 2D. If int, the smallest axis is matched to output_size
22+
keeping aspect ratio the same as the input.
2323
:param `Rescale_inverse`: (optional, bool)
2424
Is inverse transform needed for inference. Default is `True`.
2525
"""
@@ -38,6 +38,8 @@ def __call__(self, sample):
3838
output_size = self.output_size
3939
if(output_size[0] is None):
4040
output_size[0] = input_shape[1]
41+
if(input_dim == 3 and len(self.output_size) == 2):
42+
output_size = [input_shape[1]] + list(output_size)
4143
assert(len(output_size) == input_dim)
4244
else:
4345
min_edge = min(input_shape[1:])

pymic/transform/rotate.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@ class RandomRotate(AbstractTransform):
1919
2020
:param `RandomRotate_angle_range_d`: (list/tuple or None)
2121
Rotation angle (degree) range along depth axis (x-y plane), e.g., (-90, 90).
22+
The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True.
2223
If None, no rotation along this axis.
2324
:param `RandomRotate_angle_range_h`: (list/tuple or None)
2425
Rotation angle (degree) range along height axis (x-z plane), e.g., (-90, 90).
26+
The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True.
2527
If None, no rotation along this axis. Only used for 3D images.
2628
:param `RandomRotate_angle_range_w`: (list/tuple or None)
2729
Rotation angle (degree) range along width axis (y-z plane), e.g., (-90, 90).
30+
The length of the list/tuple can be larger than 2, when `RandomRotate_discrete_mode` is True.
2831
If None, no rotation along this axis. Only used for 3D images.
32+
:param `RandomRotate_discrete_mode`: (optional, bool) Whether the rotate angles
33+
are discrete values in rangle range. For example, if you only want to rotate
34+
the images with a fixed set of angles like (90, 180, 270), then set discrete_mode mode as True.
2935
:param `RandomRotate_probability`: (optional, float)
3036
The probability of applying RandomRotate. Default is 0.5.
3137
:param `RandomRotate_inverse`: (optional, bool)
@@ -36,8 +42,11 @@ def __init__(self, params):
3642
self.angle_range_d = params['RandomRotate_angle_range_d'.lower()]
3743
self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None)
3844
self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None)
45+
self.discrete_mode = params.get('RandomRotate_discrete_mode'.lower(), False)
3946
self.prob = params.get('RandomRotate_probability'.lower(), 0.5)
4047
self.inverse = params.get('RandomRotate_inverse'.lower(), True)
48+
if(len(self.angle_range_d) > 2):
49+
assert(self.discrete_mode)
4150

4251
def __apply_transformation(self, image, transform_param_list, order = 1):
4352
"""
@@ -63,15 +72,27 @@ def __call__(self, sample):
6372

6473
transform_param_list = []
6574
if(self.angle_range_d is not None):
66-
angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1])
75+
if(self.discrete_mode):
76+
idx = random.randint(0, len(self.angle_range_d) - 1)
77+
angle_d = self.angle_range_d[idx]
78+
else:
79+
angle_d = np.random.uniform(self.angle_range_d[0], self.angle_range_d[1])
6780
transform_param_list.append([angle_d, (-1, -2)])
6881
if(input_dim == 3):
6982
if(self.angle_range_h is not None):
70-
angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1])
71-
transform_param_list.append([angle_h, (-1, -3)])
83+
if(self.discrete_mode):
84+
idx = random.randint(0, len(self.angle_range_h) - 1)
85+
angle_h = self.angle_range_h[idx]
86+
else:
87+
angle_h = np.random.uniform(self.angle_range_h[0], self.angle_range_h[1])
88+
transform_param_list.append([angle_h, (-1, -3)])
7289
if(self.angle_range_w is not None):
73-
angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1])
74-
transform_param_list.append([angle_w, (-2, -3)])
90+
if(self.discrete_mode):
91+
idx = random.randint(0, len(self.angle_range_w) - 1)
92+
angle_w = self.angle_range_w[idx]
93+
else:
94+
angle_w = np.random.uniform(self.angle_range_w[0], self.angle_range_w[1])
95+
transform_param_list.append([angle_w, (-2, -3)])
7596
assert(len(transform_param_list) > 0)
7697
# select a random transform from the possible list rather than
7798
# use a combination for higher efficiency

0 commit comments

Comments
 (0)