Skip to content

Commit 3cbdd53

Browse files
authored
Merge pull request #1382 from LearnerLiyf/dev-postgresql
Add the dataset for cifar100
2 parents ef2ba0c + bb68ae6 commit 3cbdd53

1 file changed

Lines changed: 81 additions & 0 deletions

File tree

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
try:
21+
import pickle
22+
except ImportError:
23+
import cPickle as pickle
24+
25+
import numpy as np
26+
import os
27+
import sys
28+
29+
30+
def load_dataset(filepath):
31+
with open(filepath, 'rb') as fd:
32+
try:
33+
cifar100 = pickle.load(fd, encoding='latin1')
34+
except TypeError:
35+
cifar100 = pickle.load(fd)
36+
image = cifar100['data'].astype(dtype=np.uint8)
37+
image = image.reshape((-1, 3, 32, 32))
38+
label = np.asarray(cifar100['fine_labels'], dtype=np.uint8)
39+
label = label.reshape(label.size, 1)
40+
return image, label
41+
42+
43+
def load_train_data(dir_path='/tmp/cifar-100-python'):
44+
images, labels = load_dataset(check_dataset_exist(dir_path + "/train"))
45+
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
46+
47+
48+
def load_test_data(dir_path='/tmp/cifar-100-python'):
49+
images, labels = load_dataset(check_dataset_exist(dir_path + "/test"))
50+
return np.array(images, dtype=np.float32), np.array(labels, dtype=np.int32)
51+
52+
53+
def check_dataset_exist(dirpath):
54+
if not os.path.exists(dirpath):
55+
print(
56+
'Please download the cifar100 dataset using python data/download_cifar100.py'
57+
)
58+
sys.exit(0)
59+
return dirpath
60+
61+
62+
def normalize(train_x, val_x):
63+
mean = [0.4914, 0.4822, 0.4465]
64+
std = [0.2023, 0.1994, 0.2010]
65+
train_x /= 255
66+
val_x /= 255
67+
for ch in range(0, 2):
68+
train_x[:, ch, :, :] -= mean[ch]
69+
train_x[:, ch, :, :] /= std[ch]
70+
val_x[:, ch, :, :] -= mean[ch]
71+
val_x[:, ch, :, :] /= std[ch]
72+
return train_x, val_x
73+
74+
75+
def load():
76+
train_x, train_y = load_train_data()
77+
val_x, val_y = load_test_data()
78+
train_x, val_x = normalize(train_x, val_x)
79+
train_y = train_y.flatten()
80+
val_y = val_y.flatten()
81+
return train_x, train_y, val_x, val_y

0 commit comments

Comments
 (0)