-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathtest_periodic_eval_callback.py
More file actions
202 lines (170 loc) · 7.69 KB
/
test_periodic_eval_callback.py
File metadata and controls
202 lines (170 loc) · 7.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from ast import arg
import os
import json
import pytest
from dicee.config import Namespace
from dicee.executer import Execute
class TestPeriodicEvalCallback:
"""Regression tests for periodic evaluation."""
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_eval_every_n_epochs(self):
"""Test periodic evaluation callback with Keci model."""
args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.optim = 'Adam'
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.backend = "pandas"
args.num_epochs = 10
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_every_n_epochs = 3
args.trainer = 'torchCPUTrainer'
result = Execute(args).start()
eval_report_path = result['path_experiment_folder'] + '/eval_report_n_epochs.json'
# if last epoch is in _n_epochs, its skipped as it is evaluated at training end
eval_epochs = list(range(args.eval_every_n_epochs, args.num_epochs + 1, args.eval_every_n_epochs))
if args.num_epochs in eval_epochs:
eval_epochs.remove(args.num_epochs)
assert os.path.exists(eval_report_path)
with open(eval_report_path, 'r') as f:
json_report = json.load(f)
assert isinstance(json_report, dict)
assert len(json_report) == len(eval_epochs)
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_eval_at_epochs(self):
"""Test periodic evaluation callback with Keci model."""
args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.optim = 'Adam'
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.backend = "pandas"
args.num_epochs = 10
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_at_epochs = [3, 5, 8]
args.trainer = 'torchCPUTrainer'
result = Execute(args).start()
eval_report_path = result['path_experiment_folder'] + '/eval_report_n_epochs.json'
assert os.path.exists(eval_report_path)
with open(eval_report_path, 'r') as f:
eval_report_n_epochs = json.load(f)
assert isinstance(eval_report_n_epochs, dict)
if args.num_epochs in args.eval_at_epochs:
assert len(eval_report_n_epochs) == len(args.eval_at_epochs) - 1
else:
assert len(eval_report_n_epochs) == len(args.eval_at_epochs)
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_eval_every_n_epochs_and_at_epochs(self):
"""Test periodic evaluation callback with Keci model."""
args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.optim = 'Adam'
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.backend = "pandas"
args.num_epochs = 12
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_every_n_epochs = 4
args.eval_at_epochs = [3, 7, 10]
args.trainer = 'torchCPUTrainer'
result = Execute(args).start()
eval_report_path = result['path_experiment_folder'] + '/eval_report_n_epochs.json'
assert os.path.exists(eval_report_path)
with open(eval_report_path, 'r') as f:
eval_report_n_epochs = json.load(f)
# Check if the report is a dictionary
assert isinstance(eval_report_n_epochs, dict)
# Check if the number of epochs in the report matches the expected evaluation epochs
expected_eval_epochs = set(range(args.eval_every_n_epochs, args.num_epochs + 1, args.eval_every_n_epochs))
expected_eval_epochs.update(args.eval_at_epochs)
if args.num_epochs in expected_eval_epochs:
expected_eval_epochs.remove(args.num_epochs)
n_step_eval_epochs = set(eval_report_n_epochs.keys())
for epoch in expected_eval_epochs:
assert f"epoch_{epoch}_eval" in n_step_eval_epochs
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_eval_every_n_epochs_with_save_model(self):
"""Test periodic evaluation callback with Keci model and model saving."""
args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.optim = 'Adam'
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.backend = "pandas"
args.num_epochs = 12
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_every_n_epochs = 4
args.save_every_n_epochs = True
args.trainer = 'torchCPUTrainer'
result = Execute(args).start()
eval_report_path = result['path_experiment_folder'] + '/eval_report_n_epochs.json'
assert os.path.exists(eval_report_path)
with open(eval_report_path, 'r') as f:
eval_report_n_epochs = json.load(f)
# Check if the report is a dictionary
assert isinstance(eval_report_n_epochs, dict)
checkpoints_dir = result['path_experiment_folder'] + '/models_n_epochs'
assert os.path.exists(checkpoints_dir)
# Check if the number of epochs in the report matches the expected evaluation epochs
expected_eval_epochs = set(range(args.eval_every_n_epochs, args.num_epochs + 1, args.eval_every_n_epochs))
if args.num_epochs in expected_eval_epochs:
expected_eval_epochs.remove(args.num_epochs)
pt_files = [f for f in os.listdir(checkpoints_dir) if f.endswith('.pt')]
assert len(pt_files) == len(expected_eval_epochs)
for epoch in expected_eval_epochs:
assert f"model_at_epoch_{epoch}.pt" in pt_files
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_eval_every_n_epochs_eval_model(self):
"""Test periodic evaluation callback with Keci model and model evaluation."""
args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.optim = 'Adam'
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.backend = "pandas"
args.num_epochs = 12
args.batch_size = 1024
args.lr = 0.1
args.embedding_dim = 32
args.eval_model = 'train'
args.eval_every_n_epochs = 4
args.n_epochs_eval_model = 'test_val'
args.trainer = 'torchCPUTrainer'
result = Execute(args).start()
eval_report_path = result['path_experiment_folder'] + '/eval_report_n_epochs.json'
assert os.path.exists(eval_report_path)
with open(eval_report_path, 'r') as f:
eval_report_n_epochs = json.load(f)
# Check if the report is a dictionary
assert isinstance(eval_report_n_epochs, dict)
# Check if the number of epochs in the report matches the expected evaluation epochs
expected_eval_epochs = set(range(args.eval_every_n_epochs, args.num_epochs + 1, args.eval_every_n_epochs))
if args.num_epochs in expected_eval_epochs:
if all(split in args.eval_model.split('_') for split in args.n_epochs_eval_model.split('_')):
expected_eval_epochs.remove(args.num_epochs)
for epoch in expected_eval_epochs:
assert f"epoch_{epoch}_eval" in eval_report_n_epochs.keys()
for eval_epochs in eval_report_n_epochs.keys():
eval_report_epoch = eval_report_n_epochs[eval_epochs]
assert isinstance(eval_report_epoch, dict)
eval_modes = eval_report_epoch.keys()
for eval_model in args.n_epochs_eval_model.split('_'):
assert any(eval_model.lower() == mode.lower() for mode in eval_modes)