-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathmain.py
More file actions
39 lines (31 loc) · 1.03 KB
/
main.py
File metadata and controls
39 lines (31 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
import numpy as np
from encoder_decoder import EncoderDecoder
from test import test_model
def load_dataset(path, filename):
train_data = np.load(path + filename)
# patch size 4 x 4
train_data = train_data.reshape(train_data.shape[0], train_data.shape[1], 50, 50, 4)
train_data[train_data < 200] = 0
train_data[train_data >= 200] = 1
#train_data = train_data / 255.0
print(train_data.shape)
X = train_data[:, :10, :, :, :]
Y = train_data[:, 10:21, :, :, :]
X = tf.convert_to_tensor(X, dtype=tf.float32)
Y = tf.convert_to_tensor(Y, dtype=tf.float32)
return (X, Y)
def main():
X, Y = load_dataset("../input/nexraddata/", 'data.npy')
model = EncoderDecoder(
2,
[64, 48], [(3, 3), (3, 3)],
16,
(X.shape[2], X.shape[3], X.shape[4]),
'./training_checkpoints'
)
# model.restore()
model.train(X[:700], Y[:700], 400, X[700:800], Y[700:800])
test_model(model, X, Y)
if __name__ == "__main__":
main()