Skip to content

Commit 3fe5102

Browse files
committed
Add epsilon-insensitive calibration policy
1 parent 7160996 commit 3fe5102

19 files changed

Lines changed: 1362 additions & 25 deletions

changelog.d/1053.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add epsilon-insensitive calibration target tolerances, target-policy artifacts, and hard-fail versus warning enforcement for calibration diagnostics.

modal_app/pipeline.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def _calibration_package_parameters(
162162
workers: int,
163163
n_clones: int,
164164
target_config: str | None,
165+
target_policy: str | None,
165166
skip_county: bool,
166167
chunked_matrix: bool,
167168
chunk_size: int,
@@ -174,6 +175,7 @@ def _calibration_package_parameters(
174175
"workers": workers if not chunked_matrix else None,
175176
"n_clones": n_clones,
176177
"target_config": target_config,
178+
"target_policy": target_policy,
177179
"skip_county": skip_county,
178180
"chunked_matrix": bool(chunked_matrix),
179181
"chunk_size": chunk_size if chunked_matrix else None,
@@ -281,6 +283,8 @@ def archive_diagnostics(
281283
"log": f"{prefix}unified_diagnostics.csv",
282284
"cal_log": f"{prefix}calibration_log.csv",
283285
"config": f"{prefix}unified_run_config.json",
286+
"target_policy": f"{prefix}calibration_target_policy.jsonl",
287+
"target_policy_summary": (f"{prefix}calibration_target_policy_summary.json"),
284288
}
285289

286290
for key, filename in file_map.items():
@@ -1242,6 +1246,7 @@ def run_pipeline(
12421246
workers=num_workers,
12431247
n_clones=n_clones,
12441248
target_config=None,
1249+
target_policy="policyengine_us_data/calibration/target_policy.yaml",
12451250
skip_county=True,
12461251
chunked_matrix=chunked_matrix,
12471252
chunk_size=chunk_size,
@@ -1302,7 +1307,12 @@ def run_pipeline(
13021307
completed_package_manifest = _complete_step_manifest(
13031308
active_step_manifest,
13041309
outputs=collect_artifacts(
1305-
[_artifacts_dir(run_id) / "calibration_package.pkl"],
1310+
[
1311+
_artifacts_dir(run_id) / "calibration_package.pkl",
1312+
_artifacts_dir(run_id) / "calibration_target_policy.jsonl",
1313+
_artifacts_dir(run_id)
1314+
/ "calibration_target_policy_summary.json",
1315+
],
13061316
missing_ok=True,
13071317
),
13081318
vol=pipeline_volume,
@@ -1321,19 +1331,23 @@ def run_pipeline(
13211331
"gpu": gpu,
13221332
"epochs": epochs,
13231333
"target_config": "policyengine_us_data/calibration/target_config.yaml",
1334+
"target_policy": "policyengine_us_data/calibration/target_policy.yaml",
13241335
"beta": 0.65,
13251336
"lambda_l0": 1e-7,
13261337
"lambda_l2": 1e-8,
13271338
"log_freq": 100,
1339+
"loss_type": "relative_epsilon",
13281340
}
13291341
national_fit_parameters = {
13301342
"gpu": national_gpu,
13311343
"epochs": national_epochs,
13321344
"target_config": "policyengine_us_data/calibration/target_config.yaml",
1345+
"target_policy": "policyengine_us_data/calibration/target_policy.yaml",
13331346
"beta": 0.65,
13341347
"lambda_l0": NATIONAL_FIT_LAMBDA_L0,
13351348
"lambda_l2": 1e-12,
13361349
"log_freq": 100,
1350+
"loss_type": "relative_epsilon",
13371351
"skip_national": skip_national,
13381352
}
13391353
regional_fit_reuse = _step_reusable(

modal_app/remote_calibration_runner.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def _collect_outputs(cal_lines):
9999
log_path = None
100100
cal_log_path = None
101101
config_path = None
102+
target_policy_path = None
103+
target_policy_summary_path = None
102104
for line in cal_lines:
103105
if "OUTPUT_PATH:" in line:
104106
output_path = line.split("OUTPUT_PATH:")[1].strip()
@@ -110,6 +112,12 @@ def _collect_outputs(cal_lines):
110112
cal_log_path = line.split("CAL_LOG_PATH:")[1].strip()
111113
elif "LOG_PATH:" in line:
112114
log_path = line.split("LOG_PATH:")[1].strip()
115+
elif "TARGET_POLICY_PATH:" in line:
116+
target_policy_path = line.split("TARGET_POLICY_PATH:")[1].strip()
117+
elif "TARGET_POLICY_SUMMARY_PATH:" in line:
118+
target_policy_summary_path = line.split("TARGET_POLICY_SUMMARY_PATH:")[
119+
1
120+
].strip()
113121

114122
with open(output_path, "rb") as f:
115123
weights_bytes = f.read()
@@ -134,12 +142,24 @@ def _collect_outputs(cal_lines):
134142
with open(config_path, "rb") as f:
135143
config_bytes = f.read()
136144

145+
target_policy_bytes = None
146+
if target_policy_path:
147+
with open(target_policy_path, "rb") as f:
148+
target_policy_bytes = f.read()
149+
150+
target_policy_summary_bytes = None
151+
if target_policy_summary_path:
152+
with open(target_policy_summary_path, "rb") as f:
153+
target_policy_summary_bytes = f.read()
154+
137155
return {
138156
"weights": weights_bytes,
139157
"geography": geography_bytes,
140158
"log": log_bytes,
141159
"cal_log": cal_log_bytes,
142160
"config": config_bytes,
161+
"target_policy": target_policy_bytes,
162+
"target_policy_summary": target_policy_summary_bytes,
143163
}
144164

145165

modal_app/step_manifests/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def artifact_identities(paths: dict[str, str | Path]) -> dict:
215215
def collect_diagnostics(run_id: str) -> list[ArtifactReference]:
216216
return collect_directory_artifacts(
217217
run_dir(run_id) / "diagnostics",
218-
patterns=("*.csv", "*.json", "*.txt"),
218+
patterns=("*.csv", "*.json", "*.jsonl", "*.txt"),
219219
role="diagnostic",
220220
)
221221

policyengine_us_data/calibration/signatures.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def build_checkpoint_signature(
108108
lambda_l2: float,
109109
learning_rate: float,
110110
target_groups: np.ndarray | None = None,
111+
target_weights: np.ndarray | None = None,
112+
target_tolerances: np.ndarray | None = None,
113+
target_scales: np.ndarray | None = None,
114+
calibration_loss_type: str = "relative",
111115
) -> dict:
112116
"""Build a compact signature to validate calibration checkpoint resume."""
113117
targets_arr = np.asarray(targets, dtype=np.float64)
@@ -116,20 +120,33 @@ def build_checkpoint_signature(
116120
if target_groups is None
117121
else np.asarray(target_groups, dtype=np.int64)
118122
)
123+
target_weights_arr = _optional_float_signature_array(target_weights)
124+
target_tolerances_arr = _optional_float_signature_array(target_tolerances)
125+
target_scales_arr = _optional_float_signature_array(target_scales)
119126
return {
120127
"n_features": int(X_sparse.shape[1]),
121128
"n_targets": int(len(targets_arr)),
122129
"x_sparse_sha256": hash_sparse_matrix(X_sparse),
123130
"target_names_sha256": hash_string_list(target_names),
124131
"targets_sha256": hashlib.sha256(targets_arr.tobytes()).hexdigest(),
125132
"target_groups_sha256": hash_numpy_array(target_groups_arr),
133+
"target_weights_sha256": hash_numpy_array(target_weights_arr),
134+
"target_tolerances_sha256": hash_numpy_array(target_tolerances_arr),
135+
"target_scales_sha256": hash_numpy_array(target_scales_arr),
136+
"calibration_loss_type": str(calibration_loss_type),
126137
"lambda_l0": float(lambda_l0),
127138
"beta": float(beta),
128139
"lambda_l2": float(lambda_l2),
129140
"learning_rate": float(learning_rate),
130141
}
131142

132143

144+
def _optional_float_signature_array(values: np.ndarray | None) -> np.ndarray:
145+
if values is None:
146+
return np.array([], dtype=np.float64)
147+
return np.asarray(values, dtype=np.float64)
148+
149+
133150
def checkpoint_signature_mismatches(
134151
expected: dict,
135152
actual: dict,

0 commit comments

Comments
 (0)