|
| 1 | +"""Verify that hook name constants stay in sync with actual usage. |
| 2 | +
|
| 3 | +Uses Python's Abstract Syntax Tree (AST) module to parse source files and |
| 4 | +find every string literal passed to hook-dispatch functions, then checks |
| 5 | +that the declared constants match. |
| 6 | +""" |
| 7 | + |
| 8 | +import ast |
| 9 | +import pathlib |
| 10 | + |
| 11 | +from fromager.hooks import GLOBAL_HOOK_NAMES |
| 12 | +from fromager.overrides import OVERRIDE_HOOK_NAMES |
| 13 | + |
| 14 | +SRC_DIR = pathlib.Path(__file__).parent.parent / "src" / "fromager" |
| 15 | + |
| 16 | + |
| 17 | +def _called_function_name(node: ast.Call) -> str | None: |
| 18 | + """Return the simple name of the called function, or None.""" |
| 19 | + if isinstance(node.func, ast.Name): |
| 20 | + return node.func.id |
| 21 | + if isinstance(node.func, ast.Attribute): |
| 22 | + return node.func.attr |
| 23 | + return None |
| 24 | + |
| 25 | + |
| 26 | +def _collect_string_arg( |
| 27 | + source_files: list[pathlib.Path], |
| 28 | + func_names: set[str], |
| 29 | + arg_index: int, |
| 30 | +) -> set[str]: |
| 31 | + """Find every string literal passed at ``arg_index`` to calls of ``func_names``. |
| 32 | +
|
| 33 | + Scans the AST of each file for calls like ``func("hook_name", ...)`` |
| 34 | + and returns the set of string values found at the given position. |
| 35 | + """ |
| 36 | + found: set[str] = set() |
| 37 | + for path in source_files: |
| 38 | + tree = ast.parse(path.read_text(), filename=str(path)) |
| 39 | + for node in ast.walk(tree): |
| 40 | + if not isinstance(node, ast.Call): |
| 41 | + continue |
| 42 | + if _called_function_name(node) not in func_names: |
| 43 | + continue |
| 44 | + if len(node.args) <= arg_index: |
| 45 | + continue |
| 46 | + arg = node.args[arg_index] |
| 47 | + if isinstance(arg, ast.Constant) and isinstance(arg.value, str): |
| 48 | + found.add(arg.value) |
| 49 | + return found |
| 50 | + |
| 51 | + |
| 52 | +def test_override_hook_names_match_usage() -> None: |
| 53 | + """OVERRIDE_HOOK_NAMES must list every hook passed to |
| 54 | + find_override_method / find_and_invoke across the source tree.""" |
| 55 | + source_files = [ |
| 56 | + p |
| 57 | + for p in SRC_DIR.rglob("*.py") |
| 58 | + if p.name != "overrides.py" # skip the forwarding call (uses a variable) |
| 59 | + ] |
| 60 | + used = _collect_string_arg( |
| 61 | + source_files, |
| 62 | + {"find_and_invoke", "find_override_method"}, |
| 63 | + arg_index=1, |
| 64 | + ) |
| 65 | + registered = set(OVERRIDE_HOOK_NAMES) |
| 66 | + missing = used - registered |
| 67 | + extra = registered - used |
| 68 | + assert not missing, ( |
| 69 | + f"Hooks used in source but missing from OVERRIDE_HOOK_NAMES: {missing}" |
| 70 | + ) |
| 71 | + assert not extra, f"Hooks in OVERRIDE_HOOK_NAMES but not used in source: {extra}" |
| 72 | + |
| 73 | + |
| 74 | +def test_global_hook_names_match_usage() -> None: |
| 75 | + """GLOBAL_HOOK_NAMES must list every hook passed to _get_hooks in hooks.py.""" |
| 76 | + used = _collect_string_arg( |
| 77 | + [SRC_DIR / "hooks.py"], |
| 78 | + {"_get_hooks"}, |
| 79 | + arg_index=0, |
| 80 | + ) |
| 81 | + registered = set(GLOBAL_HOOK_NAMES) |
| 82 | + missing = used - registered |
| 83 | + extra = registered - used |
| 84 | + assert not missing, ( |
| 85 | + f"Hooks used in hooks.py but missing from GLOBAL_HOOK_NAMES: {missing}" |
| 86 | + ) |
| 87 | + assert not extra, f"Hooks in GLOBAL_HOOK_NAMES but not used in hooks.py: {extra}" |
0 commit comments