Skip to content

Commit fd88812

Browse files
Add deterministic structural check scripts (#948)
* Add structural check scripts for exports and ARCHITECTURE inventory Introduce AST-based checkers that catch export wiring drift and stale experiment inventory tables without importing causalpy, wire them into prek/Makefile, and fix the missing SyntheticDifferenceInDifferences export in experiments/__init__.py. Closes #947 Co-authored-by: Cursor <cursoragent@cursor.com> * Check structural notes in architecture inventory Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent c43af84 commit fd88812

9 files changed

Lines changed: 811 additions & 2 deletions

.pre-commit-config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,15 @@ repos:
136136
language: python
137137
files: ^causalpy/.*\.py$
138138
exclude: ^causalpy/tests/
139+
- id: check-public-exports
140+
name: Check public API export wiring
141+
entry: python scripts/check_public_exports.py --check
142+
language: python
143+
pass_filenames: false
144+
files: ^(causalpy/__init__\.py|causalpy/experiments/.*|causalpy/checks/.*)$
145+
- id: check-architecture-inventory
146+
name: Check ARCHITECTURE.md experiment inventory
147+
entry: python scripts/check_architecture_inventory.py --check
148+
language: python
149+
pass_filenames: false
150+
files: ^(ARCHITECTURE\.md|causalpy/experiments/.*)$

ARCHITECTURE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ Copy the closest existing experiment or model and follow the `BaseExperiment` co
8989
- Raise `FormulaException`, `DataException`, or `BadIndexException` from `causalpy.custom_exceptions` for formula, data, and index errors
9090
- Avoid backwards-compat shims for APIs introduced in the same PR
9191

92-
**Keeping it current:** When you add, remove, or structurally change an experiment class, PyMC model, backend dispatch path, or data contract, update this file in the same PR.
92+
**Keeping it current:** When you add, remove, or structurally change an experiment class, PyMC model, backend dispatch path, or data contract, update this file in the same PR. Export wiring and the experiment inventory table are enforced by `scripts/check_public_exports.py` and `scripts/check_architecture_inventory.py` (also run via prek); run `make check-exports` / `make check-architecture` locally if needed.

Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ PACKAGE_DIR = causalpy
88
# COMMANDS #
99
#################################################################################
1010

11-
.PHONY: init setup lint check_lint test uml html cleandocs doctest run_notebooks_full help
11+
.PHONY: init setup lint check_lint check-exports check-architecture test uml html cleandocs doctest run_notebooks_full help
1212

1313
init: ## Install the package in editable mode
1414
python -m pip install -e . --no-deps
@@ -28,6 +28,12 @@ check_lint: ## Check code formatting and linting without making changes
2828
ruff format --diff --check .
2929
interrogate .
3030

31+
check-exports: ## Verify experiment/check public API export wiring
32+
python scripts/check_public_exports.py --check
33+
34+
check-architecture: ## Verify ARCHITECTURE.md experiment inventory matches code
35+
python scripts/check_architecture_inventory.py --check
36+
3137
doctest: ## Run doctests for the causalpy module
3238
python -m pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
3339

causalpy/experiments/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .regression_kink import RegressionKink
2525
from .staggered_did import StaggeredDifferenceInDifferences
2626
from .synthetic_control import SyntheticControl
27+
from .synthetic_difference_in_differences import SyntheticDifferenceInDifferences
2728

2829
__all__ = [
2930
"DifferenceInDifferences",
@@ -37,4 +38,5 @@
3738
"RegressionKink",
3839
"StaggeredDifferenceInDifferences",
3940
"SyntheticControl",
41+
"SyntheticDifferenceInDifferences",
4042
]
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2026 - 2026 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for ``scripts/check_architecture_inventory.py``."""
15+
16+
from __future__ import annotations
17+
18+
import importlib.util
19+
import subprocess
20+
import sys
21+
from pathlib import Path
22+
23+
import pytest
24+
25+
REPO_ROOT = Path(__file__).resolve().parents[2]
26+
SCRIPT_PATH = REPO_ROOT / "scripts" / "check_architecture_inventory.py"
27+
SCRIPTS_DIR = REPO_ROOT / "scripts"
28+
ARCHITECTURE_PATH = REPO_ROOT / "ARCHITECTURE.md"
29+
30+
31+
def _load_script_module():
32+
if str(SCRIPTS_DIR) not in sys.path:
33+
sys.path.insert(0, str(SCRIPTS_DIR))
34+
spec = importlib.util.spec_from_file_location(
35+
"check_architecture_inventory", SCRIPT_PATH
36+
)
37+
assert spec is not None and spec.loader is not None
38+
module = importlib.util.module_from_spec(spec)
39+
sys.modules[spec.name] = module
40+
spec.loader.exec_module(module)
41+
return module
42+
43+
44+
@pytest.fixture(scope="module")
45+
def script_module():
46+
return _load_script_module()
47+
48+
49+
def test_repo_architecture_inventory_is_current(script_module) -> None:
50+
"""The repository ``ARCHITECTURE.md`` inventory should match the code."""
51+
assert script_module.check_inventory(ARCHITECTURE_PATH) == []
52+
53+
54+
def test_parses_architecture_inventory_table(script_module) -> None:
55+
"""The parser should read experiment rows from ``ARCHITECTURE.md``."""
56+
rows = script_module._parse_architecture_inventory(ARCHITECTURE_PATH)
57+
assert "SyntheticDifferenceInDifferences" in rows
58+
assert rows["SyntheticDifferenceInDifferences"]["backends"] == "OLS + Bayes"
59+
60+
61+
def test_introspected_backends_match_known_experiments(script_module) -> None:
62+
"""Introspection should surface backend support and plot stubs."""
63+
inventory = script_module.introspected_inventory()
64+
assert inventory["RegressionKink"].backends == "Bayes only"
65+
assert inventory["PanelRegression"].default_model is None
66+
assert inventory["InstrumentalVariable"].plot_is_stub is True
67+
68+
69+
def test_print_markdown_includes_all_experiments(script_module) -> None:
70+
"""``--print-markdown`` output should include every experiment class."""
71+
markdown = script_module.print_markdown(ARCHITECTURE_PATH)
72+
assert "| Class | Method | Backends | Notable quirk |" in markdown
73+
assert "`SyntheticDifferenceInDifferences`" in markdown
74+
75+
76+
def test_cli_exits_zero_when_inventory_is_current() -> None:
77+
"""CLI ``--check`` should succeed on the current repository."""
78+
result = subprocess.run(
79+
[sys.executable, str(SCRIPT_PATH), "--check"],
80+
capture_output=True,
81+
text=True,
82+
check=False,
83+
cwd=REPO_ROOT,
84+
)
85+
assert result.returncode == 0
86+
assert result.stdout == ""
87+
88+
89+
def test_cli_detects_backend_drift(tmp_path: Path, script_module) -> None:
90+
"""Backend mismatches in the doc table should be reported as errors."""
91+
architecture = tmp_path / "ARCHITECTURE.md"
92+
architecture.write_text(
93+
"\n".join(
94+
[
95+
"## Experiment Inventory",
96+
"",
97+
"| Class | Method | Backends | Notable quirk |",
98+
"|-------|--------|----------|---------------|",
99+
"| `RegressionKink` | RKD | OLS + Bayes | wrong |",
100+
]
101+
)
102+
)
103+
errors = script_module.check_inventory(architecture)
104+
assert any(
105+
"RegressionKink" in line and "backends mismatch" in line for line in errors
106+
)
107+
108+
109+
def test_cli_detects_missing_model_required_note(tmp_path: Path, script_module) -> None:
110+
"""Model-required inventory notes should match ``_default_model_class``."""
111+
architecture = tmp_path / "ARCHITECTURE.md"
112+
architecture.write_text(
113+
"\n".join(
114+
[
115+
"## Experiment Inventory",
116+
"",
117+
"| Class | Method | Backends | Notable quirk |",
118+
"|-------|--------|----------|---------------|",
119+
"| `PanelRegression` | Panel FE | OLS + Bayes | missing note |",
120+
]
121+
)
122+
)
123+
errors = script_module.check_inventory(architecture)
124+
assert any(
125+
"PanelRegression" in line and "model-required note mismatch" in line
126+
for line in errors
127+
)
128+
129+
130+
def test_cli_detects_stale_model_required_note(tmp_path: Path, script_module) -> None:
131+
"""Default-model experiments should not be documented as model-required."""
132+
architecture = tmp_path / "ARCHITECTURE.md"
133+
architecture.write_text(
134+
"\n".join(
135+
[
136+
"## Experiment Inventory",
137+
"",
138+
"| Class | Method | Backends | Notable quirk |",
139+
"|-------|--------|----------|---------------|",
140+
"| `RegressionKink` | RKD | Bayes only | model required |",
141+
]
142+
)
143+
)
144+
errors = script_module.check_inventory(architecture)
145+
assert any(
146+
"RegressionKink" in line and "model-required note mismatch" in line
147+
for line in errors
148+
)
149+
150+
151+
def test_cli_detects_plot_stub_note_drift(tmp_path: Path, script_module) -> None:
152+
"""Unified-plot inventory notes should match explicit ``plot()`` stubs."""
153+
architecture = tmp_path / "ARCHITECTURE.md"
154+
architecture.write_text(
155+
"\n".join(
156+
[
157+
"## Experiment Inventory",
158+
"",
159+
"| Class | Method | Backends | Notable quirk |",
160+
"|-------|--------|----------|---------------|",
161+
"| `InstrumentalVariable` | IV/2SLS | Bayes only | missing note |",
162+
"| `RegressionKink` | RKD | Bayes only | no unified `plot()` |",
163+
]
164+
)
165+
)
166+
errors = script_module.check_inventory(architecture)
167+
assert any(
168+
"InstrumentalVariable" in line and "plot-stub note mismatch" in line
169+
for line in errors
170+
)
171+
assert any(
172+
"RegressionKink" in line and "plot-stub note mismatch" in line
173+
for line in errors
174+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2026 - 2026 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for ``scripts/check_public_exports.py``."""
15+
16+
from __future__ import annotations
17+
18+
import importlib.util
19+
import subprocess
20+
import sys
21+
from pathlib import Path
22+
23+
import pytest
24+
25+
REPO_ROOT = Path(__file__).resolve().parents[2]
26+
SCRIPT_PATH = REPO_ROOT / "scripts" / "check_public_exports.py"
27+
SCRIPTS_DIR = REPO_ROOT / "scripts"
28+
29+
30+
def _load_script_module():
31+
if str(SCRIPTS_DIR) not in sys.path:
32+
sys.path.insert(0, str(SCRIPTS_DIR))
33+
spec = importlib.util.spec_from_file_location("check_public_exports", SCRIPT_PATH)
34+
assert spec is not None and spec.loader is not None
35+
module = importlib.util.module_from_spec(spec)
36+
spec.loader.exec_module(module)
37+
return module
38+
39+
40+
@pytest.fixture(scope="module")
41+
def script_module():
42+
return _load_script_module()
43+
44+
45+
def test_repo_export_wiring_is_current(script_module) -> None:
46+
"""The repository export wiring should pass the checker."""
47+
assert script_module.check_exports() == []
48+
49+
50+
def test_detects_synthetic_did_in_experiment_exports(script_module) -> None:
51+
"""Synthetic DiD must be exported from ``experiments/__init__.py``."""
52+
discovered = script_module.discover_experiment_class_names(
53+
REPO_ROOT / "causalpy" / "experiments"
54+
)
55+
assert "SyntheticDifferenceInDifferences" in discovered
56+
57+
experiments_init = REPO_ROOT / "causalpy" / "experiments" / "__init__.py"
58+
_, exp_imports = script_module._parse_init_exports(experiments_init)
59+
assert "SyntheticDifferenceInDifferences" in exp_imports
60+
61+
62+
def test_cli_exits_zero_when_exports_are_current() -> None:
63+
"""CLI ``--check`` should succeed on the current repository."""
64+
result = subprocess.run(
65+
[sys.executable, str(SCRIPT_PATH), "--check"],
66+
capture_output=True,
67+
text=True,
68+
check=False,
69+
cwd=REPO_ROOT,
70+
)
71+
assert result.returncode == 0
72+
assert result.stdout == ""
73+
74+
75+
def test_cli_requires_check_flag() -> None:
76+
"""The CLI should require an explicit ``--check`` flag."""
77+
result = subprocess.run(
78+
[sys.executable, str(SCRIPT_PATH)],
79+
capture_output=True,
80+
text=True,
81+
check=False,
82+
cwd=REPO_ROOT,
83+
)
84+
assert result.returncode != 0
85+
assert "--check is required" in result.stderr

0 commit comments

Comments
 (0)