-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathvalidation.py
More file actions
50 lines (35 loc) · 1.44 KB
/
validation.py
File metadata and controls
50 lines (35 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import re
import numpy as np
import tensorflow as tf
from data_generator import input_fn
import sys
sys.path.append('../../')
from delay_model import RouteNet_Fermi
for tm in ['constant_bitrate', 'onoff', 'autocorrelated', 'modulated', 'all_multiplexed']:
TEST_PATH = f'../../data/traffic_models/{tm}/test'
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model = RouteNet_Fermi()
loss_object = tf.keras.losses.MeanAbsolutePercentageError()
model.compile(loss=loss_object,
optimizer=optimizer,
run_eagerly=False)
best = None
best_mre = float('inf')
ckpt_dir = f'./ckpt_dir_{tm}'
for f in os.listdir(ckpt_dir):
if os.path.isfile(os.path.join(ckpt_dir, f)):
reg = re.findall("\d+\.\d+", f)
if len(reg) > 0:
mre = float(reg[0])
if mre <= best_mre:
best = f.replace('.index', '')
best = best.replace('.data', '')
best = best.replace('-00000-of-00001', '')
best_mre = mre
print("BEST CHECKOINT FOUND FOR {}: {}".format(tm.upper(), best))
model.load_weights(os.path.join(ckpt_dir, best))
ds_test = input_fn(TEST_PATH, shuffle=False)
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
predictions = model.evaluate(ds_test, verbose=1)
np.save(f'predictions_delay_{tm}.npy', np.squeeze(predictions))