Skip to content

Commit 6550da6

Browse files
committed
Tighten Stage 2 target config identity validation
1 parent c7929cd commit 6550da6

6 files changed

Lines changed: 88 additions & 18 deletions

File tree

policyengine_us_data/calibration/unified_calibration.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,11 @@ def run_calibration(
14521452

14531453
started_at = _utc_now_isoformat()
14541454
t0 = time.time()
1455+
resolved_target_identity = _target_config_identity_for_metadata(
1456+
target_config=target_config,
1457+
target_config_path=target_config_path,
1458+
target_config_identity=target_config_identity,
1459+
)
14551460

14561461
# Early exit: load pre-built package
14571462
if package_path is not None:
@@ -1697,11 +1702,6 @@ def run_calibration(
16971702
# Step 6b: Save the calibration package. By default this is the
16981703
# minimal package selected by target_config.yaml; use
16991704
# --all-active-targets to build a broad diagnostic package.
1700-
resolved_target_identity = _target_config_identity_for_metadata(
1701-
target_config=target_config,
1702-
target_config_path=target_config_path,
1703-
target_config_identity=target_config_identity,
1704-
)
17051705
metadata = {
17061706
"dataset_path": dataset_path,
17071707
"db_path": db_path,

policyengine_us_data/stage_contracts/calibration_package.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def _parameters_with_package_identity(
568568
payload["target_config_mode"] = (
569569
"all_active_targets" if payload.get("target_config") is None else "explicit"
570570
)
571-
return payload
571+
return CalibrationPackageParameters.from_dict(payload).to_dict()
572572

573573

574574
def _require_existing_file(path: Path, label: str) -> None:

policyengine_us_data/stage_contracts/calibration_package_schema.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,10 @@ def __post_init__(self) -> None:
238238
_validate_bool(self.parallel_matrix, "parallel_matrix")
239239
if self.target_config is not None and not isinstance(self.target_config, str):
240240
raise ValueError("target_config must be a string or None")
241-
if self.target_config_sha256 is not None and not isinstance(
241+
_validate_optional_sha256(
242242
self.target_config_sha256,
243-
str,
244-
):
245-
raise ValueError("target_config_sha256 must be a string or None")
243+
"target_config_sha256",
244+
)
246245
if self.target_config_mode is not None:
247246
if not isinstance(self.target_config_mode, str):
248247
raise ValueError("target_config_mode must be a string or None")
@@ -649,5 +648,8 @@ def _validate_optional_sha256(value: Any, key: str) -> None:
649648
return
650649
if not isinstance(value, str) or not value.startswith("sha256:"):
651650
raise ValueError(f"Calibration package field {key!r} must be a SHA-256 digest")
652-
if len(value) != len("sha256:") + 64:
651+
digest = value.removeprefix("sha256:")
652+
if len(digest) != 64:
653+
raise ValueError(f"Calibration package field {key!r} must be a SHA-256 digest")
654+
if any(character not in "0123456789abcdef" for character in digest.lower()):
653655
raise ValueError(f"Calibration package field {key!r} must be a SHA-256 digest")

tests/unit/calibration/test_unified_calibration.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,21 @@
4444
_calibration_package_contract_parameters,
4545
_target_config_identity_for_metadata,
4646
check_package_staleness,
47+
run_calibration,
4748
)
4849
from policyengine_us_data.stage_contracts.calibration_package import (
4950
CalibrationPackageParameters,
5051
)
5152

53+
TARGET_CONFIG_SHA256 = "sha256:" + "a" * 64
54+
5255

5356
def test_calibration_package_contract_parameters_track_effective_matrix_mode():
5457
params = _calibration_package_contract_parameters(
5558
workers=8,
5659
n_clones=430,
5760
target_config_path="policyengine_us_data/calibration/target_config.yaml",
58-
target_config_sha256="abc123",
61+
target_config_sha256=TARGET_CONFIG_SHA256,
5962
target_config_mode="default",
6063
skip_county=True,
6164
skip_source_impute=True,
@@ -71,7 +74,7 @@ def test_calibration_package_contract_parameters_track_effective_matrix_mode():
7174
"workers": None,
7275
"n_clones": 430,
7376
"target_config": "policyengine_us_data/calibration/target_config.yaml",
74-
"target_config_sha256": "abc123",
77+
"target_config_sha256": TARGET_CONFIG_SHA256,
7578
"target_config_mode": "default",
7679
"skip_county": True,
7780
"skip_source_impute": True,
@@ -116,6 +119,22 @@ def test_target_config_identity_for_metadata_requires_identity_for_parsed_config
116119
)
117120

118121

122+
def test_run_calibration_validates_target_identity_before_dataset_loading():
123+
with (
124+
patch.dict(sys.modules, {"policyengine_us": None}),
125+
pytest.raises(
126+
ValueError,
127+
match="target_config_path or target_config_identity",
128+
),
129+
):
130+
run_calibration(
131+
dataset_path="/missing/source.h5",
132+
db_path="/missing/policy_data.db",
133+
target_config={"include": []},
134+
target_config_path=None,
135+
)
136+
137+
119138
def test_check_package_staleness_warns_for_old_utc_timestamp(
120139
capsys,
121140
monkeypatch,

tests/unit/fixtures/calibration_package_stage_contract.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
CALIBRATION_COMPLETED_AT = "2026-05-08T12:02:00Z"
2323
CALIBRATION_DURATION_S = 120.0
2424
TARGET_CONFIG_PATH = "policyengine_us_data/calibration/target_config.yaml"
25+
TARGET_CONFIG_SHA256 = "sha256:" + "a" * 64
2526

2627
CALIBRATION_BLOCK_GEOIDS = ("010010001", "010010002", "020010001")
2728
CALIBRATION_CD_GEOIDS = ("0101", "0102", "0201")
@@ -66,7 +67,7 @@ def calibration_package_payload() -> dict[str, Any]:
6667
"dataset_sha256": "sha256:dataset",
6768
"db_sha256": "sha256:db",
6869
"target_config_path": TARGET_CONFIG_PATH,
69-
"target_config_sha256": "sha256:target-config",
70+
"target_config_sha256": TARGET_CONFIG_SHA256,
7071
"target_config_mode": "explicit",
7172
"n_clones": 3,
7273
"seed": 42,
@@ -149,7 +150,7 @@ def calibration_package_parameters() -> dict[str, Any]:
149150
"workers": None,
150151
"n_clones": 3,
151152
"target_config": TARGET_CONFIG_PATH,
152-
"target_config_sha256": "sha256:target-config",
153+
"target_config_sha256": TARGET_CONFIG_SHA256,
153154
"target_config_mode": "explicit",
154155
"skip_county": True,
155156
"skip_source_impute": True,

tests/unit/test_calibration_package_stage_contract.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from tests.unit.fixtures.calibration_package_stage_contract import (
2+
TARGET_CONFIG_SHA256,
23
TARGET_CONFIG_PATH,
34
calibration_package_contract,
45
calibration_package_parameters,
@@ -60,7 +61,7 @@ def test_calibration_package_parameters_parse_runtime_args():
6061
workers=8,
6162
n_clones=430,
6263
target_config_path=TARGET_CONFIG_PATH,
63-
target_config_sha256="sha256:target-config",
64+
target_config_sha256=TARGET_CONFIG_SHA256,
6465
target_config_mode="explicit",
6566
skip_county=True,
6667
skip_source_impute=True,
@@ -82,7 +83,7 @@ def test_calibration_package_parameters_parse_runtime_args():
8283
"skip_takeup_rerandomize": False,
8384
"target_config": TARGET_CONFIG_PATH,
8485
"target_config_mode": "explicit",
85-
"target_config_sha256": "sha256:target-config",
86+
"target_config_sha256": TARGET_CONFIG_SHA256,
8687
"workers": None,
8788
}
8889

@@ -109,6 +110,28 @@ def test_calibration_package_parameters_require_identity_for_config_modes():
109110
raise AssertionError("Explicit target config mode should require checksum")
110111

111112

113+
def test_calibration_package_parameters_reject_malformed_target_config_checksum():
114+
try:
115+
CalibrationPackageParameters.from_runtime_args(
116+
workers=8,
117+
n_clones=430,
118+
target_config_path=TARGET_CONFIG_PATH,
119+
target_config_sha256="sha256:target-config",
120+
target_config_mode="explicit",
121+
skip_county=True,
122+
skip_source_impute=True,
123+
skip_takeup_rerandomize=False,
124+
chunked_matrix=False,
125+
chunk_size=25_000,
126+
parallel=False,
127+
num_matrix_workers=50,
128+
)
129+
except ValueError as exc:
130+
assert "SHA-256 digest" in str(exc)
131+
else:
132+
raise AssertionError("Malformed target config checksum should fail")
133+
134+
112135
def test_calibration_package_parameters_accept_legacy_identity_fields_missing():
113136
params = CalibrationPackageParameters.from_dict(
114137
{
@@ -130,6 +153,31 @@ def test_calibration_package_parameters_accept_legacy_identity_fields_missing():
130153
assert params.target_config_sha256 is None
131154

132155

156+
def test_calibration_package_contract_revalidates_backfilled_identity(tmp_path):
157+
dataset_path, db_path, package_path = contract_input_paths(tmp_path)
158+
package = calibration_package_payload()
159+
package["metadata"].pop("target_config_sha256")
160+
write_calibration_package_payload(package_path, package)
161+
parameters = calibration_package_parameters()
162+
parameters.pop("target_config_mode")
163+
parameters.pop("target_config_sha256")
164+
165+
try:
166+
build_calibration_package_contract(
167+
package_path=package_path,
168+
dataset_path=dataset_path,
169+
db_path=db_path,
170+
package=package,
171+
parameters=parameters,
172+
run_id="run-a",
173+
completed_at="2026-05-08T12:02:00Z",
174+
)
175+
except ValueError as exc:
176+
assert "target_config and target_config_sha256" in str(exc)
177+
else:
178+
raise AssertionError("Backfilled target config identity should be revalidated")
179+
180+
133181
def test_calibration_package_parameters_reject_inconsistent_chunk_shape():
134182
try:
135183
CalibrationPackageParameters(
@@ -252,7 +300,7 @@ def test_calibration_package_contract_records_matrix_summary(tmp_path):
252300
assert summary["matrix_density"] == 0.5
253301
assert summary["n_targets"] == 2
254302
assert summary["target_name_count"] == 2
255-
assert summary["target_config_sha256"] == "sha256:target-config"
303+
assert summary["target_config_sha256"] == TARGET_CONFIG_SHA256
256304
assert summary["n_clones"] == 3
257305
assert summary["seed"] == 42
258306
assert summary["matrix_builder"] == "chunked"

0 commit comments

Comments
 (0)