-
Notifications
You must be signed in to change notification settings - Fork 268
Expand file tree
/
Copy pathutils.py
More file actions
66 lines (51 loc) · 1.98 KB
/
utils.py
File metadata and controls
66 lines (51 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import gc
import importlib.util
import sys
from pathlib import Path
import pytest
class SampleTestError(Exception):
pass
def run_example(parent_dir: str, rel_path_to_example: str, env=None) -> None:
fullpath = Path(parent_dir) / rel_path_to_example
module_name = fullpath.stem
old_sys_path = sys.path.copy()
old_argv = sys.argv
try:
sys.path.append(parent_dir)
sys.argv = [str(fullpath)]
# Collect metadata for file 'module_name' located at 'fullpath'.
spec = importlib.util.spec_from_file_location(module_name, fullpath)
if spec is None or spec.loader is None:
raise ImportError(f"Failed to load spec for {rel_path_to_example}")
# Otherwise convert the spec to a module, then run the module.
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
# This runs top-level code.
spec.loader.exec_module(module)
# If the module has a main() function, call it.
if hasattr(module, "main"):
module.main()
except ImportError as e:
# for samples requiring any of optional dependencies
for m in ("cupy", "torch"):
if f"No module named '{m}'" in str(e):
pytest.skip(f"{m} not installed, skipping related tests")
break
else:
raise
except SystemExit:
# for samples that early return due to any missing requirements
pytest.skip(f"skip {rel_path_to_example}")
except Exception as e:
msg = "\n"
msg += f"Got error ({rel_path_to_example}):\n"
msg += str(e)
raise SampleTestError(msg) from e
finally:
sys.path = old_sys_path
sys.argv = old_argv
# further reduce the memory watermark
sys.modules.pop(module_name, None)
gc.collect()