Skip to content

Commit 4806a11

Browse files
Add import smoke tests to PR CI (#19091)
Summary: Add a PR QNN import job that validates backend module imports and statically checks internal imports for runnable Qualcomm example entrypoints. Also fix the stale `ExecutorchBackendConfig` import in the QAIHub stable diffusion example so the new check passes. Differential Revision: D102218906
1 parent 7b5dcc1 commit 4806a11

5 files changed

Lines changed: 501 additions & 3 deletions

File tree

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Qualcomm Innovation Center, Inc.
3+
# All rights reserved
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""Validate internal imports for QNN example entrypoints.
9+
10+
Entrypoints are discovered dynamically from `examples/qualcomm` by looking for
11+
Python files that define a standard `if __name__ == "__main__"` block. This
12+
keeps the check focused on runnable scripts while avoiding a hardcoded list
13+
that drifts as examples are added, moved, or removed.
14+
"""
15+
16+
import ast
17+
import importlib.util
18+
import sys
19+
from pathlib import Path
20+
21+
22+
EXAMPLE_MODULE_PREFIX = "executorch.examples.qualcomm."
23+
24+
25+
def resolve_examples_root():
26+
for parent in Path(__file__).resolve().parents:
27+
candidate = parent / "examples" / "qualcomm"
28+
if candidate.is_dir():
29+
return candidate
30+
return None
31+
32+
33+
def is_main_guard(test: ast.AST) -> bool:
34+
if not isinstance(test, ast.Compare):
35+
return False
36+
if len(test.ops) != 1 or len(test.comparators) != 1:
37+
return False
38+
if not isinstance(test.ops[0], ast.Eq):
39+
return False
40+
if not isinstance(test.left, ast.Name) or test.left.id != "__name__":
41+
return False
42+
comparator = test.comparators[0]
43+
return isinstance(comparator, ast.Constant) and comparator.value == "__main__"
44+
45+
46+
def is_entrypoint(tree: ast.AST) -> bool:
47+
for node in ast.walk(tree):
48+
if isinstance(node, ast.If) and is_main_guard(node.test):
49+
return True
50+
return False
51+
52+
53+
def discover_entrypoints(examples_root: Path) -> list[str]:
54+
entrypoints = []
55+
for path in sorted(examples_root.rglob("*.py")):
56+
if path.name == "__init__.py":
57+
continue
58+
tree = ast.parse(path.read_text(), filename=str(path))
59+
if is_entrypoint(tree):
60+
entrypoints.append(path.relative_to(examples_root).as_posix())
61+
return entrypoints
62+
63+
64+
def module_base_path(repo_root: Path, module_name: str) -> Path:
65+
return repo_root.joinpath(*module_name.split(".")[1:])
66+
67+
68+
def module_exists(repo_root: Path, module_name: str) -> bool:
69+
base_path = module_base_path(repo_root, module_name)
70+
if base_path.is_dir() or base_path.with_suffix(".py").is_file():
71+
return True
72+
73+
try:
74+
return importlib.util.find_spec(module_name) is not None
75+
except (AttributeError, ImportError, ModuleNotFoundError, ValueError):
76+
return False
77+
78+
79+
def module_source_file(repo_root: Path, module_name: str):
80+
base_path = module_base_path(repo_root, module_name)
81+
file_path = base_path.with_suffix(".py")
82+
if file_path.is_file():
83+
return file_path
84+
init_path = base_path / "__init__.py"
85+
if init_path.is_file():
86+
return init_path
87+
return None
88+
89+
90+
def source_module_name(repo_root: Path, source_file: Path) -> str:
91+
relative_path = source_file.relative_to(repo_root)
92+
if relative_path.name == "__init__.py":
93+
relative_path = relative_path.parent
94+
else:
95+
relative_path = relative_path.with_suffix("")
96+
return "executorch." + ".".join(relative_path.parts)
97+
98+
99+
def target_names(node):
100+
names = set()
101+
if isinstance(node, ast.Name):
102+
names.add(node.id)
103+
elif isinstance(node, (ast.Tuple, ast.List)):
104+
for element in node.elts:
105+
names.update(target_names(element))
106+
return names
107+
108+
109+
def collect_exported_names(
110+
repo_root: Path,
111+
module_name: str,
112+
body: list[ast.stmt],
113+
export_cache: dict[Path, set[str]],
114+
is_package: bool,
115+
) -> set[str]:
116+
names = set()
117+
118+
for node in body:
119+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
120+
names.add(node.name)
121+
elif isinstance(node, ast.Import):
122+
for alias in node.names:
123+
names.add(alias.asname or alias.name.split(".")[0])
124+
elif isinstance(node, ast.ImportFrom):
125+
try:
126+
imported_module = resolve_from_module(
127+
module_name, node, is_package=is_package
128+
)
129+
except ImportError:
130+
imported_module = ""
131+
132+
for alias in node.names:
133+
if alias.name == "*":
134+
if imported_module.startswith("executorch.") and module_exists(
135+
repo_root, imported_module
136+
):
137+
source_file = module_source_file(repo_root, imported_module)
138+
if source_file is not None:
139+
names.update(
140+
exported_names(repo_root, source_file, export_cache)
141+
)
142+
continue
143+
144+
names.add(alias.asname or alias.name)
145+
elif isinstance(node, ast.Assign):
146+
for target in node.targets:
147+
names.update(target_names(target))
148+
elif isinstance(node, ast.AnnAssign):
149+
names.update(target_names(node.target))
150+
elif isinstance(node, ast.If):
151+
names.update(
152+
collect_exported_names(
153+
repo_root, module_name, node.body, export_cache, is_package
154+
)
155+
)
156+
names.update(
157+
collect_exported_names(
158+
repo_root, module_name, node.orelse, export_cache, is_package
159+
)
160+
)
161+
elif isinstance(node, ast.Try):
162+
names.update(
163+
collect_exported_names(
164+
repo_root, module_name, node.body, export_cache, is_package
165+
)
166+
)
167+
for handler in node.handlers:
168+
names.update(
169+
collect_exported_names(
170+
repo_root, module_name, handler.body, export_cache, is_package
171+
)
172+
)
173+
names.update(
174+
collect_exported_names(
175+
repo_root, module_name, node.orelse, export_cache, is_package
176+
)
177+
)
178+
names.update(
179+
collect_exported_names(
180+
repo_root, module_name, node.finalbody, export_cache, is_package
181+
)
182+
)
183+
elif isinstance(node, (ast.For, ast.AsyncFor, ast.While, ast.With, ast.AsyncWith)):
184+
names.update(
185+
collect_exported_names(
186+
repo_root, module_name, node.body, export_cache, is_package
187+
)
188+
)
189+
orelse = getattr(node, "orelse", [])
190+
names.update(
191+
collect_exported_names(
192+
repo_root, module_name, orelse, export_cache, is_package
193+
)
194+
)
195+
elif isinstance(node, ast.Match):
196+
for case in node.cases:
197+
names.update(
198+
collect_exported_names(
199+
repo_root, module_name, case.body, export_cache, is_package
200+
)
201+
)
202+
203+
return names
204+
205+
206+
def exported_names(
207+
repo_root: Path, source_file: Path, export_cache: dict[Path, set[str]]
208+
) -> set[str]:
209+
cached_names = export_cache.get(source_file)
210+
if cached_names is not None:
211+
return cached_names
212+
213+
names = set()
214+
export_cache[source_file] = names
215+
216+
module_name = source_module_name(repo_root, source_file)
217+
tree = ast.parse(source_file.read_text(), filename=str(source_file))
218+
names.update(
219+
collect_exported_names(
220+
repo_root,
221+
module_name,
222+
tree.body,
223+
export_cache,
224+
source_file.name == "__init__.py",
225+
)
226+
)
227+
return names
228+
229+
230+
def resolve_from_module(
231+
module_name: str, node: ast.ImportFrom, is_package: bool = False
232+
) -> str:
233+
if node.level == 0:
234+
return node.module or ""
235+
package_name = module_name if is_package else module_name.rpartition(".")[0]
236+
relative_name = "." * node.level + (node.module or "")
237+
return importlib.util.resolve_name(relative_name, package_name)
238+
239+
240+
def validate_import_from(
241+
repo_root: Path,
242+
module_name: str,
243+
entrypoint: str,
244+
node: ast.ImportFrom,
245+
export_cache: dict[Path, set[str]],
246+
) -> tuple[list[str], int]:
247+
failures = []
248+
try:
249+
imported_module = resolve_from_module(module_name, node)
250+
except ImportError as error:
251+
failures.append(f"{entrypoint}:{node.lineno} relative import could not be resolved: {error}")
252+
return failures, 0
253+
254+
if not imported_module.startswith("executorch."):
255+
return failures, 0
256+
257+
checks = 1
258+
if not module_exists(repo_root, imported_module):
259+
failures.append(
260+
f"{entrypoint}:{node.lineno} missing internal module `{imported_module}`"
261+
)
262+
return failures, checks
263+
264+
source_file = module_source_file(repo_root, imported_module)
265+
exported = (
266+
exported_names(repo_root, source_file, export_cache) if source_file else set()
267+
)
268+
269+
for alias in node.names:
270+
if alias.name == "*":
271+
continue
272+
submodule_name = f"{imported_module}.{alias.name}"
273+
if module_exists(repo_root, submodule_name):
274+
checks += 1
275+
continue
276+
if source_file is None or alias.name not in exported:
277+
failures.append(
278+
f"{entrypoint}:{node.lineno} unresolved internal import "
279+
f"`{alias.name}` from `{imported_module}`"
280+
)
281+
checks += 1
282+
283+
return failures, checks
284+
285+
286+
def validate_entrypoint(
287+
repo_root: Path,
288+
examples_root: Path,
289+
relative_path: str,
290+
export_cache: dict[Path, set[str]],
291+
) -> tuple[list[str], int]:
292+
entrypoint_path = examples_root / relative_path
293+
if not entrypoint_path.is_file():
294+
return [f"{relative_path}: allowlisted entrypoint not found"], 0
295+
296+
module_name = (
297+
EXAMPLE_MODULE_PREFIX
298+
+ str(Path(relative_path).with_suffix("")).replace("/", ".")
299+
)
300+
tree = ast.parse(entrypoint_path.read_text(), filename=str(entrypoint_path))
301+
302+
failures = []
303+
checks = 0
304+
for node in ast.walk(tree):
305+
if isinstance(node, ast.Import):
306+
for alias in node.names:
307+
if alias.name.startswith("executorch."):
308+
checks += 1
309+
if not module_exists(repo_root, alias.name):
310+
failures.append(
311+
f"{relative_path}:{node.lineno} missing internal module `{alias.name}`"
312+
)
313+
elif isinstance(node, ast.ImportFrom):
314+
import_failures, import_checks = validate_import_from(
315+
repo_root,
316+
module_name,
317+
relative_path,
318+
node,
319+
export_cache,
320+
)
321+
failures.extend(import_failures)
322+
checks += import_checks
323+
324+
return failures, checks
325+
326+
327+
def main():
328+
if sys.version_info < (3, 10):
329+
print("Python 3.10+ is required to parse QNN example sources")
330+
sys.exit(1)
331+
332+
examples_root = resolve_examples_root()
333+
if examples_root is None:
334+
print(f"QNN examples root not found from {Path(__file__).resolve()}")
335+
sys.exit(1)
336+
337+
repo_root = examples_root.parent.parent
338+
entrypoints = discover_entrypoints(examples_root)
339+
if not entrypoints:
340+
print(f"No QNN example entrypoints found under {examples_root}")
341+
sys.exit(1)
342+
343+
all_failures = []
344+
total_checks = 0
345+
export_cache = {}
346+
347+
for relative_path in entrypoints:
348+
failures, checks = validate_entrypoint(
349+
repo_root, examples_root, relative_path, export_cache
350+
)
351+
all_failures.extend(failures)
352+
total_checks += checks
353+
354+
if total_checks == 0:
355+
print("No QNN example imports were checked")
356+
sys.exit(1)
357+
358+
if all_failures:
359+
print(
360+
f"{len(all_failures)} unresolved internal import(s) "
361+
f"across {len(entrypoints)} QNN example entrypoint(s):"
362+
)
363+
for failure in all_failures:
364+
print(f" FAIL: {failure}")
365+
sys.exit(1)
366+
367+
print(
368+
f"Validated {total_checks} internal import(s) across "
369+
f"{len(entrypoints)} QNN example entrypoint(s)"
370+
)
371+
372+
373+
if __name__ == "__main__":
374+
main()

0 commit comments

Comments
 (0)