Skip to content

Commit d15085e

Browse files
authored
Merge pull request #929 from PolicyEngine/feat/stage-2-calibration-package-contract
Emit Stage 2 calibration package contract
2 parents c1b1e22 + cc0771d commit d15085e

12 files changed

Lines changed: 1533 additions & 21 deletions

File tree

AGENTS.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ Before creating or sharing any PR, all developers and agents must:
3131
`gh repo view PolicyEngine/policyengine-us-data --json nameWithOwner`.
3232
2. Push the branch to that repository, for example:
3333
`git push upstream HEAD:<branch-name>`.
34-
3. Create the PR from the same repository, for example:
35-
`gh pr create --repo PolicyEngine/policyengine-us-data --head <branch-name> --base main`.
36-
4. Verify the PR head repository before reporting it:
37-
`gh pr view <PR> --repo PolicyEngine/policyengine-us-data --json headRepositoryOwner,headRepository`.
38-
39-
The PR is valid only if the head repository is `PolicyEngine/policyengine-us-data`.
34+
3. Create the PR as a draft from the same repository, for example:
35+
`gh pr create --draft --repo PolicyEngine/policyengine-us-data --head <branch-name> --base main`.
36+
4. Verify the PR is draft and the head repository is canonical before reporting
37+
it:
38+
`gh pr view <PR> --repo PolicyEngine/policyengine-us-data --json isDraft,headRepositoryOwner,headRepository`.
39+
40+
The PR is valid only if `isDraft` is `true` and the head repository is
41+
`PolicyEngine/policyengine-us-data`.
4042
If you cannot push to the canonical repository, stop and ask for access. Do not
4143
create a fork PR as a fallback. If you accidentally create one, immediately
42-
close it and replace it with a same-repository PR.
44+
close it and replace it with a same-repository draft PR.

changelog.d/926.added

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Stage 2 now emits and validates a semantic `calibration_package_contract.json`
2+
sidecar next to `calibration_package.pkl`.

docs/engineering/skills/github-prs.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,18 @@ Before creating or sharing a PR:
2424
`make lint`.
2525
6. Push the current branch to the canonical repository:
2626
`make push-pr-branch`.
27-
7. Create the PR from that same repository:
28-
`gh pr create --repo PolicyEngine/policyengine-us-data --head "$(git branch --show-current)" --base main`.
29-
8. Verify the PR head repository:
30-
`gh pr view <PR> --repo PolicyEngine/policyengine-us-data --json headRepositoryOwner,headRepository`.
31-
32-
The PR is valid only if the head repository is
27+
7. Create the PR as a draft from that same repository:
28+
`gh pr create --draft --repo PolicyEngine/policyengine-us-data --head "$(git branch --show-current)" --base main`.
29+
8. Verify the PR is draft and the head repository is canonical:
30+
`gh pr view <PR> --repo PolicyEngine/policyengine-us-data --json isDraft,headRepositoryOwner,headRepository`.
31+
9. Leave the PR as draft unless a maintainer explicitly asks for it to be
32+
marked ready for review.
33+
34+
The PR is valid only if `isDraft` is `true` and the head repository is
3335
`PolicyEngine/policyengine-us-data`. If you cannot push to the canonical
3436
repository, stop and ask for access. Do not create a fork PR as a fallback. If
3537
you accidentally create one, close it immediately and replace it with a
36-
same-repository PR.
38+
same-repository draft PR.
3739

3840
## PR title
3941

modal_app/remote_calibration_runner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,19 @@ def _build_package_impl(
420420
if build_rc != 0:
421421
raise RuntimeError(f"Package build failed with code {build_rc}")
422422

423+
from policyengine_us_data.stage_contracts.calibration_package import (
424+
CALIBRATION_PACKAGE_CONTRACT_FILENAME,
425+
validate_persisted_calibration_package_contract,
426+
)
427+
428+
contract_path = f"{artifacts}/{CALIBRATION_PACKAGE_CONTRACT_FILENAME}"
429+
validate_persisted_calibration_package_contract(
430+
package_path=Path(pkg_path),
431+
contract_path=Path(contract_path),
432+
dataset_path=Path(dataset_path),
433+
db_path=Path(db_path),
434+
)
435+
423436
sidecar_ok = _write_package_sidecar(pkg_path)
424437
if not sidecar_ok:
425438
print(

policyengine_us_data/calibration/unified_calibration.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import logging
3131
import os
3232
import sys
33+
from datetime import UTC, datetime
3334
from pathlib import Path
3435
from typing import Optional
3536

@@ -44,6 +45,9 @@
4445
create_target_groups,
4546
)
4647
from policyengine_us_data.pipeline_metadata import pipeline_node
48+
from policyengine_us_data.stage_contracts.calibration_package import (
49+
CalibrationPackageParameters,
50+
)
4751
from policyengine_us_data.pipeline_schema import PipelineNode
4852

4953
logging.basicConfig(
@@ -71,6 +75,41 @@
7175
DEFAULT_TARGET_CONFIG_PATH = Path(__file__).resolve().parent / "target_config.yaml"
7276

7377

78+
def _utc_now_isoformat() -> str:
79+
"""Return a compact UTC timestamp for contract metadata."""
80+
81+
return datetime.now(UTC).isoformat().replace("+00:00", "Z")
82+
83+
84+
def _calibration_package_contract_parameters(
85+
*,
86+
workers: int,
87+
n_clones: int,
88+
target_config_path: str | None,
89+
skip_county: bool,
90+
skip_source_impute: bool,
91+
skip_takeup_rerandomize: bool,
92+
chunked_matrix: bool,
93+
chunk_size: int,
94+
parallel: bool,
95+
num_matrix_workers: int,
96+
) -> CalibrationPackageParameters:
97+
"""Return Stage 2 parameters that affect package construction."""
98+
99+
return CalibrationPackageParameters.from_runtime_args(
100+
workers=workers,
101+
n_clones=n_clones,
102+
target_config_path=target_config_path,
103+
skip_county=skip_county,
104+
skip_source_impute=skip_source_impute,
105+
skip_takeup_rerandomize=skip_takeup_rerandomize,
106+
chunked_matrix=chunked_matrix,
107+
chunk_size=chunk_size,
108+
parallel=parallel,
109+
num_matrix_workers=num_matrix_workers,
110+
)
111+
112+
74113
def get_git_provenance() -> dict:
75114
"""Capture git state and package version for provenance tracking."""
76115
import subprocess as _sp
@@ -152,7 +191,11 @@ def check_package_staleness(metadata: dict) -> None:
152191
if created:
153192
try:
154193
built_dt = datetime.datetime.fromisoformat(created)
155-
age = datetime.datetime.now() - built_dt
194+
if built_dt.tzinfo is None:
195+
built_dt = built_dt.replace(tzinfo=datetime.UTC)
196+
age = datetime.datetime.now(datetime.UTC) - built_dt.astimezone(
197+
datetime.UTC
198+
)
156199
if age.days > 7:
157200
print(f"WARNING: Package is {age.days} days old (built {created})")
158201
except Exception:
@@ -1303,6 +1346,7 @@ def run_calibration(
13031346
"""
13041347
import time
13051348

1349+
started_at = _utc_now_isoformat()
13061350
t0 = time.time()
13071351

13081352
# Early exit: load pre-built package
@@ -1547,16 +1591,14 @@ def run_calibration(
15471591
# Step 6b: Save the calibration package. By default this is the
15481592
# minimal package selected by target_config.yaml; use
15491593
# --all-active-targets to build a broad diagnostic package.
1550-
import datetime
1551-
15521594
metadata = {
15531595
"dataset_path": dataset_path,
15541596
"db_path": db_path,
15551597
"n_clones": n_clones,
15561598
"n_records": X_sparse.shape[1],
15571599
"base_n_records": n_records,
15581600
"seed": seed,
1559-
"created_at": datetime.datetime.now().isoformat(),
1601+
"created_at": _utc_now_isoformat(),
15601602
"target_config_path": target_config_path,
15611603
"package_scope": "minimal" if target_config else "all_active_targets",
15621604
"matrix_builder": "chunked" if chunked_matrix else "precompute",
@@ -1573,20 +1615,63 @@ def run_calibration(
15731615
Path(target_config_path)
15741616
)
15751617

1618+
initial_weights = compute_initial_weights(X_sparse, targets_df)
15761619
if package_output_path:
1577-
full_initial_weights = compute_initial_weights(X_sparse, targets_df)
1620+
package_payload = {
1621+
"X_sparse": X_sparse,
1622+
"targets_df": targets_df,
1623+
"target_names": target_names,
1624+
"metadata": metadata,
1625+
"initial_weights": initial_weights,
1626+
"cd_geoid": geography.cd_geoid,
1627+
"block_geoid": geography.block_geoid,
1628+
}
15781629
save_calibration_package(
15791630
package_output_path,
15801631
X_sparse,
15811632
targets_df,
15821633
target_names,
15831634
metadata,
1584-
initial_weights=full_initial_weights,
1635+
initial_weights=initial_weights,
15851636
cd_geoid=geography.cd_geoid,
15861637
block_geoid=geography.block_geoid,
15871638
)
1639+
from policyengine_us_data.stage_contracts.calibration_package import (
1640+
validate_calibration_package_contract,
1641+
write_calibration_package_contract,
1642+
)
15881643

1589-
initial_weights = compute_initial_weights(X_sparse, targets_df)
1644+
completed_at = _utc_now_isoformat()
1645+
write_calibration_package_contract(
1646+
package_path=Path(package_output_path),
1647+
dataset_path=Path(dataset_path),
1648+
db_path=Path(db_path),
1649+
package=package_payload,
1650+
parameters=_calibration_package_contract_parameters(
1651+
workers=workers,
1652+
n_clones=n_clones,
1653+
target_config_path=target_config_path,
1654+
skip_county=skip_county,
1655+
skip_source_impute=skip_source_impute,
1656+
skip_takeup_rerandomize=skip_takeup_rerandomize,
1657+
chunked_matrix=chunked_matrix,
1658+
chunk_size=chunk_size,
1659+
parallel=parallel,
1660+
num_matrix_workers=num_matrix_workers,
1661+
),
1662+
run_id=run_id,
1663+
started_at=started_at,
1664+
completed_at=completed_at,
1665+
duration_s=round(time.time() - t0, 1),
1666+
code_sha=metadata.get("git_commit"),
1667+
package_version=metadata.get("package_version"),
1668+
)
1669+
validate_calibration_package_contract(
1670+
package_path=Path(package_output_path),
1671+
package=package_payload,
1672+
dataset_path=Path(dataset_path),
1673+
db_path=Path(db_path),
1674+
)
15901675

15911676
if build_only:
15921677
from policyengine_us_data.calibration.validate_package import (

policyengine_us_data/stage_contracts/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@
1212
"""
1313

1414
from .artifacts import ArtifactRef
15+
from .calibration_package import (
16+
CALIBRATION_PACKAGE_CONTRACT_FILENAME,
17+
CALIBRATION_PACKAGE_CONTRACT_TYPE,
18+
CalibrationPackageParameters,
19+
CalibrationPackageSummary,
20+
build_calibration_package_contract,
21+
load_calibration_package_payload,
22+
summarize_calibration_package,
23+
validate_calibration_package_contract,
24+
validate_persisted_calibration_package_contract,
25+
write_calibration_package_contract,
26+
)
1527
from .constants import (
1628
CONTRACT_FINGERPRINT_ALGORITHM,
1729
CONTRACT_SCHEMA_VERSION,
@@ -73,6 +85,10 @@
7385
"VALIDATION_REPORT_STATUSES",
7486
"ArtifactRef",
7587
"CANONICAL_STAGE_IDS",
88+
"CALIBRATION_PACKAGE_CONTRACT_FILENAME",
89+
"CALIBRATION_PACKAGE_CONTRACT_TYPE",
90+
"CalibrationPackageParameters",
91+
"CalibrationPackageSummary",
7692
"CONTRACT_TYPE_BY_STAGE_ID",
7793
"DATASET_BUILD_OUTPUT_CONTRACT_FILENAME",
7894
"DATASET_BUILD_OUTPUT_CONTRACT_TYPE",
@@ -97,6 +113,7 @@
97113
"ValidationFindingStatus",
98114
"ValidationReport",
99115
"ValidationReportStatus",
116+
"build_calibration_package_contract",
100117
"build_dataset_build_output_contract",
101118
"canonicalize_for_fingerprint",
102119
"contract_from_json",
@@ -105,7 +122,12 @@
105122
"fingerprint_material",
106123
"is_canonical_stage_id",
107124
"is_canonical_substage_id",
125+
"load_calibration_package_payload",
108126
"read_contract",
127+
"summarize_calibration_package",
109128
"substage_ids_for_stage",
129+
"validate_calibration_package_contract",
130+
"validate_persisted_calibration_package_contract",
131+
"write_calibration_package_contract",
110132
"write_contract",
111133
]

0 commit comments

Comments
 (0)