-
Notifications
You must be signed in to change notification settings - Fork 608
Expand file tree
/
Copy pathtest_argument_parser.py
More file actions
411 lines (349 loc) · 13.9 KB
/
test_argument_parser.py
File metadata and controls
411 lines (349 loc) · 13.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Unittests for argument parser."""
import re
import unittest
from argparse import (
Namespace,
)
from contextlib import (
redirect_stderr,
)
from io import (
StringIO,
)
from typing import (
TYPE_CHECKING,
Any,
)
from deepmd.main import (
get_ll,
parse_args,
)
if TYPE_CHECKING:
try:
from typing import TypedDict # python==3.8
except ImportError:
from typing_extensions import TypedDict # python<=3.7
class DATA(TypedDict):
type: type | tuple[type]
value: Any
TEST_DICT = dict[str, DATA]
def build_args(args: "TEST_DICT", command: str) -> list[str]:
"""Build list of arguments similar to one generated by `sys.argv` used by argparse.
Parameters
----------
args : TEST_DICT
from dictionary with specifications how to build each argument
command : str
first argument that chooses subparser
Returns
-------
list[str]
arguments with options as list of strings, goal is to emulate `sys.argv`
"""
args_list = [command]
for argument, test_data in args.items():
# arguments without dash are positional, their name should not appear in
# arguments list
if argument.startswith("-"):
args_list.append(argument)
# arguments without value are passed as such, typically these are where action
# is 'count' or 'store_true'
if "value" in test_data:
args_list += str(test_data["value"]).split()
return args_list
class TestParserOutput(unittest.TestCase):
"""Test if parser correctly parses supplied arguments."""
def attr_and_type_check(
self, namespace: Namespace, mapping: "TEST_DICT", command: str, test_value: bool
) -> None:
"""Check attributes of `argparse.Manespace` types and values are as expected.
First check for attribute existence, if it exists check its type and if type is
as expected check value
Parameters
----------
namespace : Namespace
`argparse.Manespace` object aoutput from parser
mapping : TEST_DICT
mapping of argument names and their types and values
command : str
first argument that sets subparser
test_value : bool
whether to test for value match
"""
mapping = {**{"command": {"type": str, "value": command}}, **mapping}
for argument, test_data in mapping.items():
# get expected type
expected_type = test_data["type"]
# if data has different destination attribute, use it
if "dest" in test_data:
argument = test_data["dest"]
# remove first one/two hyphens from argument name
argument = re.sub(r"^-{1,2}", "", argument)
# remove any hyphens from string as these are replaced to
# underscores by argparse
attribute = re.sub("-", "_", argument)
# first check if namespace object hat the expected attribute
self.assertTrue(
hasattr(namespace, attribute),
msg=f"Namespace object does not have expected attribute: {attribute}",
)
# than check if the attribute is of expected type
self.assertIsInstance(
getattr(namespace, attribute),
expected_type,
msg=f"Namespace attribute '{attribute}' is of wrong type, expected: "
f"{expected_type}, got: {type(getattr(namespace, attribute))}",
)
# if argument has associated value check if it is same as expected
if "value" in test_data and test_value:
# use expected value if supplied
if "expected" in test_data:
expected = test_data["expected"]
else:
expected = test_data["value"]
self.assertEqual(
expected,
getattr(namespace, attribute),
msg=f"Got wrong parsed value, expected: {test_data['value']}, got "
f"{getattr(namespace, attribute)}",
)
def run_test(self, *, command: str, mapping: "TEST_DICT") -> None:
"""Run test first for specified arguments and then for default.
Parameters
----------
command : str
first argument that sets subparser
mapping : TEST_DICT
mapping of argument names and their types and values
Raises
------
SystemExit
If parser for some reason fails
NotImplementedError
[description]
"""
# test passed in arguments
cmd_args = build_args(mapping, command)
buffer = StringIO()
try:
with redirect_stderr(buffer):
namespace = parse_args(cmd_args)
except SystemExit as e:
raise SystemExit(
f"Encountered exception when parsing arguments ->\n\n"
f"{buffer.getvalue()}\n"
f"passed in arguments were: {cmd_args}\n"
f"built from dict {mapping}"
) from e
self.attr_and_type_check(namespace, mapping, command, test_value=True)
# check for required arguments
required = []
for argument, data in mapping.items():
if not argument.startswith("-"):
if isinstance(data["type"], tuple):
t = data["type"][0]
else:
t = data["type"]
if t is str:
required.append("STRING")
elif t in (int, float):
required.append("11111")
else:
raise NotImplementedError(
f"Option for type: {t} not implemented, please do so!"
)
# test default values
cmd_args = [command, *required]
buffer = StringIO()
try:
with redirect_stderr(buffer):
namespace = parse_args(cmd_args)
except SystemExit as e:
raise SystemExit(
f"Encountered exception when parsing DEFAULT arguments ->\n\n"
f"{buffer.getvalue()}\n"
f"passed in arguments were: {cmd_args}\n"
f"built from dict {mapping}"
) from e
self.attr_and_type_check(namespace, mapping, command, test_value=False)
def test_no_command(self) -> None:
"""Test that parser outputs nothing when no command is input and does not fail."""
self.assertIsNone(parse_args([]).command)
def test_wrong_command(self) -> None:
"""Test that parser fails if no command is passed in."""
with self.assertRaises(SystemExit):
parse_args(["RANDOM_WRONG_COMMAND"])
def test_parser_log(self) -> None:
"""Check if logging associated attributes are present in specified parsers."""
ARGS = {
"--log-level": {"type": int, "value": "INFO", "expected": 20},
"--log-path": {"type": (str, type(None)), "value": "LOGFILE"},
}
for parser in (
"transfer",
"train",
"freeze",
"test",
"compress",
"doc-train-input",
"model-devi",
):
if parser in ("train",):
args = {**{"INPUT": {"type": str, "value": "INFILE"}}, **ARGS}
else:
args = ARGS
self.run_test(command=parser, mapping=args)
def test_parser_mpi(self) -> None:
"""Check if mpi-log attribute is present in specified parsers."""
ARGS = {"--mpi-log": {"type": str, "value": "master"}}
for parser in ("train", "compress"):
if parser in ("train"):
args = {**{"INPUT": {"type": str, "value": "INFILE"}}, **ARGS}
else:
args = ARGS
self.run_test(command=parser, mapping=args)
def test_parser_transfer(self) -> None:
"""Test transfer subparser."""
ARGS = {
"--raw-model": {"type": str, "value": "INFILE.PB"},
"--old-model": {"type": str, "value": "OUTFILE.PB"},
"--output": {"type": str, "value": "OUTPUT"},
}
self.run_test(command="transfer", mapping=ARGS)
def test_parser_train_init_model(self) -> None:
"""Test train init-model subparser."""
ARGS = {
"INPUT": {"type": str, "value": "INFILE"},
"--init-model": {"type": (str, type(None)), "value": "SYSTEM_DIR"},
"--output": {"type": str, "value": "OUTPUT"},
}
self.run_test(command="train", mapping=ARGS)
def test_parser_train_restart(self) -> None:
"""Test train restart subparser."""
ARGS = {
"INPUT": {"type": str, "value": "INFILE"},
"--restart": {"type": (str, type(None)), "value": "RESTART"},
"--output": {"type": str, "value": "OUTPUT"},
}
self.run_test(command="train", mapping=ARGS)
def test_parser_train_init_frz_model(self) -> None:
"""Test train init-frz-model subparser."""
ARGS = {
"INPUT": {"type": str, "value": "INFILE"},
"--init-frz-model": {"type": (str, type(None)), "value": "INIT_FRZ_MODEL"},
"--output": {"type": str, "value": "OUTPUT"},
}
self.run_test(command="train", mapping=ARGS)
def test_parser_train_finetune(self) -> None:
"""Test train finetune subparser."""
ARGS = {
"INPUT": {"type": str, "value": "INFILE"},
"--finetune": {"type": (str, type(None)), "value": "FINETUNE"},
"--output": {"type": str, "value": "OUTPUT"},
}
self.run_test(command="train", mapping=ARGS)
def test_parser_train_wrong_subcommand(self) -> None:
"""Test train with multiple subparsers."""
ARGS = {
"INPUT": {"type": str, "value": "INFILE"},
"--init-model": {"type": (str, type(None)), "value": "SYSTEM_DIR"},
"--restart": {"type": (str, type(None)), "value": "RESTART"},
"--output": {"type": str, "value": "OUTPUT"},
}
with self.assertRaises(SystemExit):
self.run_test(command="train", mapping=ARGS)
def test_parser_train_allow_ref(self) -> None:
"""Test train --allow-ref option."""
args = parse_args(["train", "INFILE", "--allow-ref"])
self.assertTrue(args.allow_ref)
args_default = parse_args(["train", "INFILE"])
self.assertFalse(args_default.allow_ref)
def test_parser_freeze(self) -> None:
"""Test freeze subparser."""
ARGS = {
"--checkpoint-folder": {"type": str, "value": "FOLDER"},
"--output": {"type": str, "value": "FROZEN.PB"},
"--node-names": {"type": (str, type(None)), "value": "NODES"},
}
self.run_test(command="freeze", mapping=ARGS)
def test_parser_test(self) -> None:
"""Test test subparser."""
ARGS = {
"--model": {"type": str, "value": "MODEL.PB"},
"--system": {"type": str, "value": "SYSTEM_DIR"},
"--numb-test": {"type": int, "value": 1},
"--rand-seed": {"type": (int, type(None)), "value": 12321},
"--detail-file": {"type": (str, type(None)), "value": "TARGET.FILE"},
"--atomic": {"type": bool},
}
self.run_test(command="test", mapping=ARGS)
def test_parser_test_train_data(self) -> None:
"""Test test subparser with train-data."""
ARGS = {
"--model": {"type": str, "value": "MODEL.PB"},
"--train-data": {
"type": (str, type(None)),
"value": "INPUT.JSON",
"dest": "train_json",
},
}
self.run_test(command="test", mapping=ARGS)
def test_parser_test_valid_data(self) -> None:
"""Test test subparser with valid-data."""
ARGS = {
"--model": {"type": str, "value": "MODEL.PB"},
"--valid-data": {
"type": (str, type(None)),
"value": "INPUT.JSON",
"dest": "valid_json",
},
}
self.run_test(command="test", mapping=ARGS)
def test_parser_compress(self) -> None:
"""Test compress subparser."""
ARGS = {
"--output": {"type": str, "value": "OUTFILE"},
"--extrapolate": {"type": int, "value": 5},
"--step": {"type": float, "value": 0.1},
"--frequency": {"type": int, "value": -1},
"--checkpoint-folder": {"type": str, "value": "."},
}
self.run_test(command="compress", mapping=ARGS)
def test_parser_doc(self) -> None:
"""Test doc subparser."""
ARGS = {
"--out-type": {"type": str, "value": "rst"},
}
self.run_test(command="doc-train-input", mapping=ARGS)
def test_parser_model_devi(self) -> None:
"""Test model-devi subparser."""
ARGS = {
"--models": {
"type": list,
"value": "GRAPH.000.pb GRAPH.001.pb",
"expected": ["GRAPH.000.pb", "GRAPH.001.pb"],
},
"--system": {"type": str, "value": "SYSTEM_DIR"},
"--output": {"type": str, "value": "OUTFILE"},
"--frequency": {"type": int, "value": 1},
}
self.run_test(command="model-devi", mapping=ARGS)
def test_get_log_level(self) -> None:
MAPPING = {
"DEBUG": 10,
"INFO": 20,
"WARNING": 30,
"ERROR": 40,
"3": 10,
"2": 20,
"1": 30,
"0": 40,
}
for input_val, expected_result in MAPPING.items():
self.assertEqual(
get_ll(input_val),
expected_result,
msg=f"Expected: {expected_result} result for input value: {input_val} "
f"but got {get_ll(input_val)}",
)