Skip to content

Commit 0d1b0be

Browse files
Add import smoke tests to PR CI (#19091) (#19091)
1 parent e4ab34d commit 0d1b0be

4 files changed

Lines changed: 692 additions & 1 deletion

File tree

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from __future__ import annotations
9+
10+
"""Validate that backend Python modules can be imported.
11+
12+
The workflow passes backend-specific paths and package prefixes so the same
13+
checker can be reused for different backends.
14+
"""
15+
16+
import argparse
17+
import importlib
18+
import sys
19+
from pathlib import Path
20+
21+
22+
def parse_args() -> argparse.Namespace:
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument(
25+
"--name",
26+
required=True,
27+
help="Display name for log messages, for example `QNN`.",
28+
)
29+
parser.add_argument(
30+
"--package-root",
31+
required=True,
32+
help="Path to the backend package root, relative to ExecuTorch root.",
33+
)
34+
parser.add_argument(
35+
"--package-prefix",
36+
required=True,
37+
help="Python package prefix, for example `executorch.backends.qualcomm`.",
38+
)
39+
parser.add_argument(
40+
"--skip-segment",
41+
action="append",
42+
default=["fb", "test", "tests"],
43+
help="Package path segment to skip while walking modules.",
44+
)
45+
return parser.parse_args()
46+
47+
48+
def resolve_executorch_root() -> Path:
49+
for parent in Path(__file__).resolve().parents:
50+
if (parent / "backends").is_dir() and (parent / "examples").is_dir():
51+
return parent
52+
raise RuntimeError(
53+
f"Could not locate ExecuTorch root from {Path(__file__).resolve()}"
54+
)
55+
56+
57+
def resolve_directory(executorch_root: Path, relative_path: str) -> Path:
58+
directory = executorch_root / relative_path
59+
if not directory.is_dir():
60+
raise RuntimeError(
61+
f"Directory `{relative_path}` was not found under {executorch_root}"
62+
)
63+
return directory
64+
65+
66+
def normalize_package_prefix(package_prefix: str) -> str:
67+
return package_prefix[:-1] if package_prefix.endswith(".") else package_prefix
68+
69+
70+
def should_skip_path(path: Path, skip_segments: list[str]) -> bool:
71+
if any(segment in path.parts for segment in skip_segments):
72+
return True
73+
74+
stem = path.stem
75+
return any(
76+
stem == segment or stem.startswith(f"{segment}_") for segment in skip_segments
77+
)
78+
79+
80+
def discover_modules(
81+
package_root: Path,
82+
package_prefix: str,
83+
skip_segments: list[str],
84+
) -> list[str]:
85+
modules = []
86+
for path in sorted(package_root.rglob("*.py")):
87+
relative_path = path.relative_to(package_root)
88+
if should_skip_path(relative_path, skip_segments):
89+
continue
90+
91+
if relative_path.name == "__init__.py":
92+
module_suffix = ".".join(relative_path.parent.parts)
93+
if module_suffix:
94+
modules.append(f"{package_prefix}.{module_suffix}")
95+
else:
96+
modules.append(package_prefix)
97+
continue
98+
99+
modules.append(
100+
f"{package_prefix}.{'.'.join(relative_path.with_suffix('').parts)}"
101+
)
102+
return modules
103+
104+
105+
def main() -> None:
106+
args = parse_args()
107+
executorch_root = resolve_executorch_root()
108+
package_root = resolve_directory(executorch_root, args.package_root)
109+
package_prefix = normalize_package_prefix(args.package_prefix)
110+
111+
failures: list[tuple[str, str, str]] = []
112+
modules = discover_modules(package_root, package_prefix, args.skip_segment)
113+
total_modules = len(modules)
114+
if total_modules == 0:
115+
print(f"No {args.name} Python modules found under {package_root}")
116+
sys.exit(1)
117+
118+
for index, name in enumerate(modules, 1):
119+
print(f"[{index}/{total_modules}] importing {name}", flush=True)
120+
try:
121+
importlib.import_module(name)
122+
except Exception as error:
123+
failures.append((name, type(error).__name__, str(error)))
124+
125+
if failures:
126+
print(f"{len(failures)}/{total_modules} {args.name} import failure(s):")
127+
for name, error_type, message in failures:
128+
print(f" FAIL: {name} -- {error_type}: {message}")
129+
sys.exit(1)
130+
131+
print(f"All {total_modules} {args.name} modules imported successfully")
132+
133+
134+
if __name__ == "__main__":
135+
main()

0 commit comments

Comments
 (0)