-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_train_data.py
More file actions
25 lines (24 loc) · 804 Bytes
/
load_train_data.py
File metadata and controls
25 lines (24 loc) · 804 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import cupy as np
import pickle
def load_train_data(data_file_path, packed_file_path):
# Speed up the getting data part by being smart and not wasting time on parsing the train.csv file an cache it as a pickle so that it can be imported very quickly. saves 6s
try:
with open(packed_file_path, "rb") as f:
data = pickle.load(f)
return data["x"], data["y"]
except:
data = np.genfromtxt(data_file_path, delimiter=',', skip_header=1, dtype=np.float64)
x = data[:, 1:]
x = x.reshape((x.shape[0], 784, 1))
y = data[:, 0]
y_new = np.zeros((len(y), 10, 1))
for i in range(len(y)):
tmp = y[i]
y_new[i][int(tmp)] = 1
data = {
"x": x,
"y": y_new,
}
with open(packed_file_path, "wb") as f:
pickle.dump(data, f)
return x, y