Skip to content

Commit 588d796

Browse files
committed
Create the data folder for the healthcare model zoo
1 parent cf7f55c commit 588d796

1 file changed

Lines changed: 122 additions & 0 deletions

File tree

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
try:
20+
import pickle
21+
except ImportError:
22+
import cPickle as pickle
23+
24+
import numpy as np
25+
import os
26+
import sys
27+
from PIL import Image
28+
29+
30+
# need to save to specific local directories
31+
def load_train_data(dir_path="/tmp/malaria", resize_size=(128, 128)):
32+
dir_path = check_dataset_exist(dirpath=dir_path)
33+
path_train_label_1 = os.path.join(dir_path, "training_set/Parasitized")
34+
path_train_label_0 = os.path.join(dir_path, "training_set/Uninfected")
35+
train_label_1 = load_image_path(os.listdir(path_train_label_1))
36+
train_label_0 = load_image_path(os.listdir(path_train_label_0))
37+
labels = []
38+
Images = np.empty((len(train_label_1) + len(train_label_0),
39+
3, resize_size[0], resize_size[1]), dtype=np.uint8)
40+
for i in range(len(train_label_0)):
41+
image_path = os.path.join(path_train_label_0, train_label_0[i])
42+
temp_image = np.array(Image.open(image_path).resize(
43+
resize_size).convert("RGB")).transpose(2, 0, 1)
44+
Images[i] = temp_image
45+
labels.append(0)
46+
for i in range(len(train_label_1)):
47+
image_path = os.path.join(path_train_label_1, train_label_1[i])
48+
temp_image = np.array(Image.open(image_path).resize(
49+
resize_size).convert("RGB")).transpose(2, 0, 1)
50+
Images[i + len(train_label_0)] = temp_image
51+
labels.append(1)
52+
53+
Images = np.array(Images, dtype=np.float32)
54+
labels = np.array(labels, dtype=np.int32)
55+
return Images, labels
56+
57+
58+
# need to save to specific local directories
59+
def load_test_data(dir_path='/tmp/malaria', resize_size=(128, 128)):
60+
dir_path = check_dataset_exist(dirpath=dir_path)
61+
path_test_label_1 = os.path.join(dir_path, "testing_set/Parasitized")
62+
path_test_label_0 = os.path.join(dir_path, "testing_set/Uninfected")
63+
test_label_1 = load_image_path(os.listdir(path_test_label_1))
64+
test_label_0 = load_image_path(os.listdir(path_test_label_0))
65+
labels = []
66+
Images = np.empty((len(test_label_1) + len(test_label_0),
67+
3, resize_size[0], resize_size[1]), dtype=np.uint8)
68+
for i in range(len(test_label_0)):
69+
image_path = os.path.join(path_test_label_0, test_label_0[i])
70+
temp_image = np.array(Image.open(image_path).resize(
71+
resize_size).convert("RGB")).transpose(2, 0, 1)
72+
Images[i] = temp_image
73+
labels.append(0)
74+
for i in range(len(test_label_1)):
75+
image_path = os.path.join(path_test_label_1, test_label_1[i])
76+
temp_image = np.array(Image.open(image_path).resize(
77+
resize_size).convert("RGB")).transpose(2, 0, 1)
78+
Images[i + len(test_label_0)] = temp_image
79+
labels.append(1)
80+
81+
Images = np.array(Images, dtype=np.float32)
82+
labels = np.array(labels, dtype=np.int32)
83+
return Images, labels
84+
85+
86+
def load_image_path(list):
87+
new_list = []
88+
for image_path in list:
89+
if (image_path.endswith(".png") or image_path.endswith(".jpg")):
90+
new_list.append(image_path)
91+
return new_list
92+
93+
94+
def check_dataset_exist(dirpath):
95+
if not os.path.exists(dirpath):
96+
print(
97+
'Please download the malaria dataset first'
98+
)
99+
sys.exit(0)
100+
return dirpath
101+
102+
103+
def normalize(train_x, val_x):
104+
mean = [0.5339, 0.4180, 0.4460] # mean for malaria dataset
105+
std = [0.3329, 0.2637, 0.2761] # std for malaria dataset
106+
train_x /= 255
107+
val_x /= 255
108+
for ch in range(0, 2):
109+
train_x[:, ch, :, :] -= mean[ch]
110+
train_x[:, ch, :, :] /= std[ch]
111+
val_x[:, ch, :, :] -= mean[ch]
112+
val_x[:, ch, :, :] /= std[ch]
113+
return train_x, val_x
114+
115+
116+
def load(dir_path):
117+
train_x, train_y = load_train_data(dir_path=dir_path)
118+
val_x, val_y = load_test_data(dir_path=dir_path)
119+
train_x, val_x = normalize(train_x, val_x)
120+
train_y = train_y.flatten()
121+
val_y = val_y.flatten()
122+
return train_x, train_y, val_x, val_y

0 commit comments

Comments
 (0)