-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathmain.py
More file actions
63 lines (47 loc) · 1.88 KB
/
main.py
File metadata and controls
63 lines (47 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import tensorflow as tf
from data_generator import input_fn
import sys
sys.path.append('../../')
from delay_model import RouteNet_Fermi
for tm in ['constant_bitrate', 'onoff', 'autocorrelated', 'modulated', 'all_multiplexed']:
TRAIN_PATH = f'../../data/traffic_models/{tm}/train'
VALIDATION_PATH = f'../../data/traffic_models/{tm}/test'
TEST_PATH = f'../../data/traffic_models/{tm}/test'
ds_train = input_fn(TRAIN_PATH, shuffle=True)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
ds_train = ds_train.repeat()
ds_validation = input_fn(VALIDATION_PATH, shuffle=False)
ds_validation = ds_validation.prefetch(tf.data.experimental.AUTOTUNE)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model = RouteNet_Fermi()
loss_object = tf.keras.losses.MeanAbsolutePercentageError()
model.compile(loss=loss_object,
optimizer=optimizer,
run_eagerly=False)
ckpt_dir = f'./ckpt_dir_{tm}'
latest = tf.train.latest_checkpoint(ckpt_dir)
if latest is not None:
print("Found a pretrained model, restoring...")
model.load_weights(latest)
else:
print("Starting training from scratch...")
filepath = os.path.join(ckpt_dir, "{epoch:02d}-{val_loss:.2f}")
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=filepath,
verbose=1,
mode="min",
monitor='val_loss',
save_best_only=False,
save_weights_only=True,
save_freq='epoch')
model.fit(ds_train,
epochs=50,
steps_per_epoch=2000,
validation_data=ds_validation,
validation_steps=200,
callbacks=[cp_callback],
use_multiprocessing=True)
ds_test = input_fn(TEST_PATH, shuffle=False)
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
model.evaluate(ds_test)