|
3 | 3 | import hashlib |
4 | 4 | import json |
5 | 5 | import subprocess |
| 6 | +import sys |
6 | 7 | from argparse import Namespace |
| 8 | +from importlib import metadata |
| 9 | +from pathlib import Path |
7 | 10 | from types import SimpleNamespace |
8 | 11 | import numpy as np |
9 | 12 | import pytest |
|
16 | 19 | from policyengine_us_data.datasets.cps.long_term import ( |
17 | 20 | prototype_synthetic_2100_support as synthetic_support_module, |
18 | 21 | ) |
| 22 | +from policyengine_us_data.datasets.cps.long_term import run_long_term_production |
19 | 23 | from policyengine_us_data.datasets.cps.long_term.calibration import ( |
20 | 24 | assess_nonnegative_feasibility, |
21 | 25 | build_calibration_audit, |
|
88 | 92 | year_output_dir, |
89 | 93 | ) |
90 | 94 | from policyengine_us_data.datasets.cps.long_term.run_long_term_production import ( |
| 95 | + _package_version, |
91 | 96 | build_projection_command, |
| 97 | + stamp_projection_provenance, |
92 | 98 | ) |
93 | 99 |
|
94 | 100 |
|
@@ -2118,6 +2124,170 @@ def test_long_term_production_command_carries_2100_contract(tmp_path): |
2118 | 2124 | assert "--allow-validation-failures" in command |
2119 | 2125 |
|
2120 | 2126 |
|
| 2127 | +def test_long_term_production_stamps_source_sha_into_projection_artifacts(tmp_path): |
| 2128 | + import h5py |
| 2129 | + |
| 2130 | + h5_path = tmp_path / "2075.h5" |
| 2131 | + with h5py.File(h5_path, "w") as h5_file: |
| 2132 | + h5_file.create_dataset("household_weight/2075", data=[1.0]) |
| 2133 | + metadata_path = tmp_path / "2075.h5.metadata.json" |
| 2134 | + metadata_path.write_text( |
| 2135 | + json.dumps( |
| 2136 | + { |
| 2137 | + "year": 2075, |
| 2138 | + "calibration_audit": {"calibration_quality": "exact"}, |
| 2139 | + } |
| 2140 | + ), |
| 2141 | + encoding="utf-8", |
| 2142 | + ) |
| 2143 | + manifest_path = tmp_path / "calibration_manifest.json" |
| 2144 | + manifest_path.write_text( |
| 2145 | + json.dumps({"datasets": {"2075": {"h5": "2075.h5"}}}), |
| 2146 | + encoding="utf-8", |
| 2147 | + ) |
| 2148 | + |
| 2149 | + stamp_projection_provenance( |
| 2150 | + output_dir=tmp_path, |
| 2151 | + source_sha="abc123", |
| 2152 | + run_id="run-123", |
| 2153 | + ) |
| 2154 | + |
| 2155 | + metadata = json.loads(metadata_path.read_text(encoding="utf-8")) |
| 2156 | + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) |
| 2157 | + assert metadata["source_sha"] == "abc123" |
| 2158 | + assert metadata["run_id"] == "run-123" |
| 2159 | + with h5py.File(h5_path) as h5_file: |
| 2160 | + assert h5_file.attrs["source_sha"] == "abc123" |
| 2161 | + assert h5_file.attrs["run_id"] == "run-123" |
| 2162 | + assert manifest["source_sha"] == "abc123" |
| 2163 | + assert manifest["run_id"] == "run-123" |
| 2164 | + assert manifest["datasets"]["2075"]["source_sha"] == "abc123" |
| 2165 | + assert manifest["datasets"]["2075"]["run_id"] == "run-123" |
| 2166 | + |
| 2167 | + |
| 2168 | +def test_long_term_production_main_uploads_stamped_artifacts( |
| 2169 | + tmp_path, |
| 2170 | + monkeypatch, |
| 2171 | +): |
| 2172 | + import h5py |
| 2173 | + |
| 2174 | + output_dir = tmp_path / "out" |
| 2175 | + captured_command = [] |
| 2176 | + |
| 2177 | + def fake_run(command, check): |
| 2178 | + del check |
| 2179 | + captured_command.extend(command) |
| 2180 | + command_output_dir = Path(command[command.index("--output-dir") + 1]) |
| 2181 | + command_output_dir.mkdir(parents=True, exist_ok=True) |
| 2182 | + with h5py.File(command_output_dir / "2075.h5", "w") as h5_file: |
| 2183 | + h5_file.create_dataset("household_weight/2075", data=[1.0]) |
| 2184 | + (command_output_dir / "2075.h5.metadata.json").write_text( |
| 2185 | + json.dumps( |
| 2186 | + { |
| 2187 | + "year": 2075, |
| 2188 | + "base_dataset_path": "hf://example/base.h5", |
| 2189 | + "profile": {"name": "ss-payroll-tob"}, |
| 2190 | + "calibration_audit": {"calibration_quality": "exact"}, |
| 2191 | + } |
| 2192 | + ), |
| 2193 | + encoding="utf-8", |
| 2194 | + ) |
| 2195 | + (command_output_dir / "calibration_manifest.json").write_text( |
| 2196 | + json.dumps({"datasets": {"2075": {"h5": "2075.h5"}}}), |
| 2197 | + encoding="utf-8", |
| 2198 | + ) |
| 2199 | + |
| 2200 | + uploaded = [] |
| 2201 | + |
| 2202 | + def fake_upload( |
| 2203 | + *, |
| 2204 | + artifacts, |
| 2205 | + output_dir, |
| 2206 | + args, |
| 2207 | + run_id, |
| 2208 | + source_sha, |
| 2209 | + ): |
| 2210 | + uploaded.extend(path.name for path in artifacts if path.suffix == ".json") |
| 2211 | + metadata_payload = json.loads( |
| 2212 | + (output_dir / "2075.h5.metadata.json").read_text(encoding="utf-8") |
| 2213 | + ) |
| 2214 | + manifest_payload = json.loads( |
| 2215 | + (output_dir / "calibration_manifest.json").read_text(encoding="utf-8") |
| 2216 | + ) |
| 2217 | + assert metadata_payload["source_sha"] == "abc123" |
| 2218 | + assert manifest_payload["source_sha"] == "abc123" |
| 2219 | + assert run_id == "run-123" |
| 2220 | + assert source_sha == "abc123" |
| 2221 | + assert args.upload_to_hf_staging is True |
| 2222 | + return len(artifacts) |
| 2223 | + |
| 2224 | + build_info = PolicyEngineUSBuildInfo( |
| 2225 | + version="1.693.4", |
| 2226 | + locked_version="1.693.4", |
| 2227 | + package_file_sha256="file-sha", |
| 2228 | + package_tree_sha256="tree-sha", |
| 2229 | + ) |
| 2230 | + monkeypatch.setattr( |
| 2231 | + run_long_term_production.subprocess, |
| 2232 | + "run", |
| 2233 | + fake_run, |
| 2234 | + ) |
| 2235 | + monkeypatch.setattr( |
| 2236 | + run_long_term_production, |
| 2237 | + "assert_locked_policyengine_us_version", |
| 2238 | + lambda: build_info, |
| 2239 | + ) |
| 2240 | + monkeypatch.setattr( |
| 2241 | + run_long_term_production, |
| 2242 | + "upload_artifacts", |
| 2243 | + fake_upload, |
| 2244 | + ) |
| 2245 | + monkeypatch.setattr(run_long_term_production, "_git_sha", lambda: "abc123") |
| 2246 | + monkeypatch.setenv("HUGGING_FACE_TOKEN", "token") |
| 2247 | + monkeypatch.setattr( |
| 2248 | + sys, |
| 2249 | + "argv", |
| 2250 | + [ |
| 2251 | + "run_long_term_production.py", |
| 2252 | + "--years", |
| 2253 | + "2075", |
| 2254 | + "--jobs", |
| 2255 | + "1", |
| 2256 | + "--output-dir", |
| 2257 | + str(output_dir), |
| 2258 | + "--run-id", |
| 2259 | + "run-123", |
| 2260 | + "--source-sha", |
| 2261 | + "abc123", |
| 2262 | + "--upload-to-hf-staging", |
| 2263 | + ], |
| 2264 | + ) |
| 2265 | + |
| 2266 | + assert run_long_term_production.main() == 0 |
| 2267 | + |
| 2268 | + assert captured_command |
| 2269 | + assert "2075.h5.metadata.json" in uploaded |
| 2270 | + assert "calibration_manifest.json" in uploaded |
| 2271 | + with h5py.File(output_dir / "2075.h5") as h5_file: |
| 2272 | + assert h5_file.attrs["source_sha"] == "abc123" |
| 2273 | + assert h5_file.attrs["run_id"] == "run-123" |
| 2274 | + |
| 2275 | + |
| 2276 | +def test_long_term_production_reads_source_tree_data_package_version(monkeypatch): |
| 2277 | + from policyengine_us_data.__version__ import __version__ |
| 2278 | + |
| 2279 | + def fail_metadata_version(package_name): |
| 2280 | + raise metadata.PackageNotFoundError(package_name) |
| 2281 | + |
| 2282 | + monkeypatch.setattr( |
| 2283 | + "policyengine_us_data.datasets.cps.long_term." |
| 2284 | + "run_long_term_production.metadata.version", |
| 2285 | + fail_metadata_version, |
| 2286 | + ) |
| 2287 | + |
| 2288 | + assert _package_version("policyengine-us-data") == __version__ |
| 2289 | + |
| 2290 | + |
2121 | 2291 | def test_parallel_projection_validate_forwarded_args_rejects_wrapper_flags(): |
2122 | 2292 | with pytest.raises(ValueError, match="--output-dir"): |
2123 | 2293 | validate_forwarded_args(["--output-dir", "/tmp/out"]) |
|
0 commit comments