Skip to content

Commit a9bc1a6

Browse files
authored
Harden long-run production provenance (#1002)
1 parent 1b55c9a commit a9bc1a6

5 files changed

Lines changed: 347 additions & 3 deletions

File tree

changelog.d/1002.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Harden long-run production provenance checks for Modal source packaging and output artifacts.

modal_app/long_term_projection.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,44 @@ def _local_git_sha() -> str:
6969
return result.stdout.strip()
7070

7171

72+
def _local_git_dirty() -> bool:
73+
try:
74+
result = subprocess.run(
75+
["git", "status", "--porcelain"],
76+
cwd=_local,
77+
check=True,
78+
capture_output=True,
79+
text=True,
80+
)
81+
except (OSError, subprocess.CalledProcessError):
82+
return True
83+
return bool(result.stdout.strip())
84+
85+
86+
def _validate_local_source(source_sha: str, *, allow_dirty_source: bool) -> None:
87+
if allow_dirty_source:
88+
return
89+
if _local_git_dirty():
90+
raise ValueError(
91+
"The local policyengine-us-data checkout has uncommitted changes. "
92+
"Commit and pass its SHA with --source-sha before running Modal, or "
93+
"rerun with --allow-dirty-source for an explicitly non-publishable "
94+
"experiment."
95+
)
96+
local_sha = _local_git_sha()
97+
if not local_sha:
98+
raise ValueError(
99+
"Could not resolve the local policyengine-us-data git SHA; pass "
100+
"--allow-dirty-source only for an explicitly non-publishable experiment."
101+
)
102+
if local_sha != source_sha:
103+
raise ValueError(
104+
"The requested source_sha does not match the local checkout that Modal "
105+
f"will package: {source_sha} != {local_sha}. Check out the exact "
106+
"source SHA before running production."
107+
)
108+
109+
72110
def _append_optional_value(
73111
command: list[str],
74112
flag: str,
@@ -359,11 +397,16 @@ def main(
359397
support_augmentation_sanitize_worker_non_target_income: bool = False,
360398
support_augmentation_sanitize_clone_non_target_income: bool = False,
361399
spawn: bool = False,
400+
allow_dirty_source: bool = False,
362401
) -> None:
363402
if not source_sha:
364403
source_sha = os.environ.get("GITHUB_SHA", "") or _local_git_sha()
365404
if not source_sha:
366405
raise ValueError("source_sha is required; pass --source-sha.")
406+
_validate_local_source(
407+
source_sha,
408+
allow_dirty_source=allow_dirty_source,
409+
)
367410
run_id = sanitize_run_id(run_id)
368411
kwargs = {
369412
"years": years,

policyengine_us_data/datasets/cps/long_term/run_long_term_production.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import subprocess
99
import sys
1010
from datetime import UTC, datetime
11-
from importlib import metadata
11+
from importlib import import_module, metadata
1212
from pathlib import Path
1313

1414
from policyengine_us_data.datasets.cps.long_term.run_household_projection_parallel import (
@@ -27,6 +27,10 @@
2727
DEFAULT_HF_REPO = "policyengine/policyengine-us-data"
2828
DEFAULT_ARTIFACT_PREFIX = "long_term"
2929
DEFAULT_TAX_ASSUMPTION = "trustees-2025-core-thresholds-v1"
30+
PACKAGE_VERSION_MODULES = {
31+
"policyengine-us-data": "policyengine_us_data.__version__",
32+
"policyengine_us_data": "policyengine_us_data.__version__",
33+
}
3034

3135

3236
def _git_sha() -> str:
@@ -47,7 +51,73 @@ def _package_version(package_name: str) -> str | None:
4751
try:
4852
return metadata.version(package_name)
4953
except metadata.PackageNotFoundError:
50-
return None
54+
version_module = PACKAGE_VERSION_MODULES.get(package_name)
55+
if version_module is None:
56+
return None
57+
try:
58+
module = import_module(version_module)
59+
except ImportError:
60+
return None
61+
return getattr(module, "__version__", None)
62+
63+
64+
def _write_json(path: Path, payload: dict) -> None:
65+
path.write_text(
66+
json.dumps(payload, indent=2, sort_keys=True) + "\n",
67+
encoding="utf-8",
68+
)
69+
70+
71+
def stamp_projection_provenance(
72+
*,
73+
output_dir: Path,
74+
source_sha: str,
75+
run_id: str,
76+
) -> None:
77+
"""Stamp run provenance into artifacts created by the year runner."""
78+
metadata_paths = sorted(output_dir.glob("*.h5.metadata.json"))
79+
if not metadata_paths:
80+
raise FileNotFoundError(f"No year metadata sidecars found in {output_dir}.")
81+
82+
for metadata_path in metadata_paths:
83+
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
84+
if source_sha:
85+
payload["source_sha"] = source_sha
86+
if run_id:
87+
payload["run_id"] = run_id
88+
_write_json(metadata_path, payload)
89+
90+
h5_path = Path(str(metadata_path).removesuffix(".metadata.json"))
91+
if not h5_path.exists():
92+
raise FileNotFoundError(f"Missing H5 artifact for {metadata_path}.")
93+
with _open_h5_append(h5_path) as h5_file:
94+
if source_sha:
95+
h5_file.attrs["source_sha"] = source_sha
96+
if run_id:
97+
h5_file.attrs["run_id"] = run_id
98+
99+
calibration_manifest_path = output_dir / "calibration_manifest.json"
100+
if not calibration_manifest_path.exists():
101+
raise FileNotFoundError(
102+
f"Missing calibration manifest: {calibration_manifest_path}"
103+
)
104+
manifest = json.loads(calibration_manifest_path.read_text(encoding="utf-8"))
105+
if source_sha:
106+
manifest["source_sha"] = source_sha
107+
if run_id:
108+
manifest["run_id"] = run_id
109+
for dataset in manifest.get("datasets", {}).values():
110+
if source_sha:
111+
dataset["source_sha"] = source_sha
112+
if run_id:
113+
dataset["run_id"] = run_id
114+
_write_json(calibration_manifest_path, manifest)
115+
116+
117+
def _open_h5_append(path: Path):
118+
import h5py
119+
120+
return h5py.File(path, "a")
51121

52122

53123
def _add_optional_value(
@@ -198,7 +268,7 @@ def write_manifest(
198268
"run_url": os.environ.get("US_DATA_GITHUB_RUN_URL", ""),
199269
},
200270
"package_versions": {
201-
"policyengine-us-data": _package_version("policyengine_us_data"),
271+
"policyengine-us-data": _package_version("policyengine-us-data"),
202272
"policyengine-us": _package_version("policyengine-us"),
203273
"policyengine-core": _package_version("policyengine-core"),
204274
},
@@ -339,6 +409,11 @@ def main() -> int:
339409
print("Running long-run projection command:")
340410
print(" ".join(command))
341411
subprocess.run(command, check=True)
412+
stamp_projection_provenance(
413+
output_dir=output_dir,
414+
source_sha=source_sha,
415+
run_id=run_id,
416+
)
342417

343418
artifacts = collect_artifacts(output_dir, args.artifact_prefix)
344419
manifest_path = write_manifest(

tests/unit/test_long_term_calibration_contract.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import hashlib
44
import json
55
import subprocess
6+
import sys
67
from argparse import Namespace
8+
from importlib import metadata
9+
from pathlib import Path
710
from types import SimpleNamespace
811
import numpy as np
912
import pytest
@@ -16,6 +19,7 @@
1619
from policyengine_us_data.datasets.cps.long_term import (
1720
prototype_synthetic_2100_support as synthetic_support_module,
1821
)
22+
from policyengine_us_data.datasets.cps.long_term import run_long_term_production
1923
from policyengine_us_data.datasets.cps.long_term.calibration import (
2024
assess_nonnegative_feasibility,
2125
build_calibration_audit,
@@ -88,7 +92,9 @@
8892
year_output_dir,
8993
)
9094
from policyengine_us_data.datasets.cps.long_term.run_long_term_production import (
95+
_package_version,
9196
build_projection_command,
97+
stamp_projection_provenance,
9298
)
9399

94100

@@ -2118,6 +2124,170 @@ def test_long_term_production_command_carries_2100_contract(tmp_path):
21182124
assert "--allow-validation-failures" in command
21192125

21202126

2127+
def test_long_term_production_stamps_source_sha_into_projection_artifacts(tmp_path):
2128+
import h5py
2129+
2130+
h5_path = tmp_path / "2075.h5"
2131+
with h5py.File(h5_path, "w") as h5_file:
2132+
h5_file.create_dataset("household_weight/2075", data=[1.0])
2133+
metadata_path = tmp_path / "2075.h5.metadata.json"
2134+
metadata_path.write_text(
2135+
json.dumps(
2136+
{
2137+
"year": 2075,
2138+
"calibration_audit": {"calibration_quality": "exact"},
2139+
}
2140+
),
2141+
encoding="utf-8",
2142+
)
2143+
manifest_path = tmp_path / "calibration_manifest.json"
2144+
manifest_path.write_text(
2145+
json.dumps({"datasets": {"2075": {"h5": "2075.h5"}}}),
2146+
encoding="utf-8",
2147+
)
2148+
2149+
stamp_projection_provenance(
2150+
output_dir=tmp_path,
2151+
source_sha="abc123",
2152+
run_id="run-123",
2153+
)
2154+
2155+
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
2156+
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
2157+
assert metadata["source_sha"] == "abc123"
2158+
assert metadata["run_id"] == "run-123"
2159+
with h5py.File(h5_path) as h5_file:
2160+
assert h5_file.attrs["source_sha"] == "abc123"
2161+
assert h5_file.attrs["run_id"] == "run-123"
2162+
assert manifest["source_sha"] == "abc123"
2163+
assert manifest["run_id"] == "run-123"
2164+
assert manifest["datasets"]["2075"]["source_sha"] == "abc123"
2165+
assert manifest["datasets"]["2075"]["run_id"] == "run-123"
2166+
2167+
2168+
def test_long_term_production_main_uploads_stamped_artifacts(
2169+
tmp_path,
2170+
monkeypatch,
2171+
):
2172+
import h5py
2173+
2174+
output_dir = tmp_path / "out"
2175+
captured_command = []
2176+
2177+
def fake_run(command, check):
2178+
del check
2179+
captured_command.extend(command)
2180+
command_output_dir = Path(command[command.index("--output-dir") + 1])
2181+
command_output_dir.mkdir(parents=True, exist_ok=True)
2182+
with h5py.File(command_output_dir / "2075.h5", "w") as h5_file:
2183+
h5_file.create_dataset("household_weight/2075", data=[1.0])
2184+
(command_output_dir / "2075.h5.metadata.json").write_text(
2185+
json.dumps(
2186+
{
2187+
"year": 2075,
2188+
"base_dataset_path": "hf://example/base.h5",
2189+
"profile": {"name": "ss-payroll-tob"},
2190+
"calibration_audit": {"calibration_quality": "exact"},
2191+
}
2192+
),
2193+
encoding="utf-8",
2194+
)
2195+
(command_output_dir / "calibration_manifest.json").write_text(
2196+
json.dumps({"datasets": {"2075": {"h5": "2075.h5"}}}),
2197+
encoding="utf-8",
2198+
)
2199+
2200+
uploaded = []
2201+
2202+
def fake_upload(
2203+
*,
2204+
artifacts,
2205+
output_dir,
2206+
args,
2207+
run_id,
2208+
source_sha,
2209+
):
2210+
uploaded.extend(path.name for path in artifacts if path.suffix == ".json")
2211+
metadata_payload = json.loads(
2212+
(output_dir / "2075.h5.metadata.json").read_text(encoding="utf-8")
2213+
)
2214+
manifest_payload = json.loads(
2215+
(output_dir / "calibration_manifest.json").read_text(encoding="utf-8")
2216+
)
2217+
assert metadata_payload["source_sha"] == "abc123"
2218+
assert manifest_payload["source_sha"] == "abc123"
2219+
assert run_id == "run-123"
2220+
assert source_sha == "abc123"
2221+
assert args.upload_to_hf_staging is True
2222+
return len(artifacts)
2223+
2224+
build_info = PolicyEngineUSBuildInfo(
2225+
version="1.693.4",
2226+
locked_version="1.693.4",
2227+
package_file_sha256="file-sha",
2228+
package_tree_sha256="tree-sha",
2229+
)
2230+
monkeypatch.setattr(
2231+
run_long_term_production.subprocess,
2232+
"run",
2233+
fake_run,
2234+
)
2235+
monkeypatch.setattr(
2236+
run_long_term_production,
2237+
"assert_locked_policyengine_us_version",
2238+
lambda: build_info,
2239+
)
2240+
monkeypatch.setattr(
2241+
run_long_term_production,
2242+
"upload_artifacts",
2243+
fake_upload,
2244+
)
2245+
monkeypatch.setattr(run_long_term_production, "_git_sha", lambda: "abc123")
2246+
monkeypatch.setenv("HUGGING_FACE_TOKEN", "token")
2247+
monkeypatch.setattr(
2248+
sys,
2249+
"argv",
2250+
[
2251+
"run_long_term_production.py",
2252+
"--years",
2253+
"2075",
2254+
"--jobs",
2255+
"1",
2256+
"--output-dir",
2257+
str(output_dir),
2258+
"--run-id",
2259+
"run-123",
2260+
"--source-sha",
2261+
"abc123",
2262+
"--upload-to-hf-staging",
2263+
],
2264+
)
2265+
2266+
assert run_long_term_production.main() == 0
2267+
2268+
assert captured_command
2269+
assert "2075.h5.metadata.json" in uploaded
2270+
assert "calibration_manifest.json" in uploaded
2271+
with h5py.File(output_dir / "2075.h5") as h5_file:
2272+
assert h5_file.attrs["source_sha"] == "abc123"
2273+
assert h5_file.attrs["run_id"] == "run-123"
2274+
2275+
2276+
def test_long_term_production_reads_source_tree_data_package_version(monkeypatch):
2277+
from policyengine_us_data.__version__ import __version__
2278+
2279+
def fail_metadata_version(package_name):
2280+
raise metadata.PackageNotFoundError(package_name)
2281+
2282+
monkeypatch.setattr(
2283+
"policyengine_us_data.datasets.cps.long_term."
2284+
"run_long_term_production.metadata.version",
2285+
fail_metadata_version,
2286+
)
2287+
2288+
assert _package_version("policyengine-us-data") == __version__
2289+
2290+
21212291
def test_parallel_projection_validate_forwarded_args_rejects_wrapper_flags():
21222292
with pytest.raises(ValueError, match="--output-dir"):
21232293
validate_forwarded_args(["--output-dir", "/tmp/out"])

0 commit comments

Comments
 (0)