Skip to content

Commit e3943d2

Browse files
baogorekclaude
andcommitted
Enable takeup re-randomization in stacked dataset H5 builds
Workers now always re-draw takeup using block-level seeded draws, matching the calibration matrix builder's computation. This fixes H5 files producing aca_ptc values 6-40x off from calibration targets. Pipeline changes: - publish_local_area: thread rerandomize_takeup/blocks/filter params - worker_script: always rerandomize, optionally use calibration blocks - local_area: pass blocks path to workers when available - huggingface: optionally download stacked_blocks.npy - unified_calibration: print BLOCKS_PATH for Modal collection - remote_calibration_runner: collect, save, and upload blocks to HF Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 45aebc8 commit e3943d2

6 files changed

Lines changed: 331 additions & 22 deletions

File tree

modal_app/local_area.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -154,23 +154,32 @@ def build_areas_worker(
154154

155155
work_items_json = json.dumps(work_items)
156156

157+
worker_cmd = [
158+
"uv",
159+
"run",
160+
"python",
161+
"modal_app/worker_script.py",
162+
"--work-items",
163+
work_items_json,
164+
"--weights-path",
165+
calibration_inputs["weights"],
166+
"--dataset-path",
167+
calibration_inputs["dataset"],
168+
"--db-path",
169+
calibration_inputs["database"],
170+
"--output-dir",
171+
str(output_dir),
172+
]
173+
if "blocks" in calibration_inputs:
174+
worker_cmd.extend(
175+
[
176+
"--calibration-blocks",
177+
calibration_inputs["blocks"],
178+
]
179+
)
180+
157181
result = subprocess.run(
158-
[
159-
"uv",
160-
"run",
161-
"python",
162-
"modal_app/worker_script.py",
163-
"--work-items",
164-
work_items_json,
165-
"--weights-path",
166-
calibration_inputs["weights"],
167-
"--dataset-path",
168-
calibration_inputs["dataset"],
169-
"--db-path",
170-
calibration_inputs["database"],
171-
"--output-dir",
172-
str(output_dir),
173-
],
182+
worker_cmd,
174183
capture_output=True,
175184
text=True,
176185
env=os.environ.copy(),
@@ -474,11 +483,15 @@ def coordinate_publish(
474483
staging_volume.commit()
475484
print("Calibration inputs downloaded")
476485

486+
blocks_path = calibration_dir / "calibration" / "stacked_blocks.npy"
477487
calibration_inputs = {
478488
"weights": str(weights_path),
479489
"dataset": str(dataset_path),
480490
"database": str(db_path),
481491
}
492+
if blocks_path.exists():
493+
calibration_inputs["blocks"] = str(blocks_path)
494+
print(f"Calibration blocks found: {blocks_path}")
482495

483496
result = subprocess.run(
484497
[
@@ -582,8 +595,7 @@ def coordinate_publish(
582595
for err in all_errors[:5]:
583596
err_msg = err.get("error", "Unknown")[:100]
584597
print(
585-
f" - {err.get('item', err.get('worker'))}: "
586-
f"{err_msg}"
598+
f" - {err.get('item', err.get('worker'))}: " f"{err_msg}"
587599
)
588600
if len(all_errors) > 5:
589601
print(f" ... and {len(all_errors) - 5} more")

modal_app/remote_calibration_runner.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,17 @@ def _collect_outputs(cal_lines):
7272
output_path = None
7373
log_path = None
7474
cal_log_path = None
75+
config_path = None
76+
blocks_path = None
7577
for line in cal_lines:
7678
if "OUTPUT_PATH:" in line:
7779
output_path = line.split("OUTPUT_PATH:")[1].strip()
80+
elif "CONFIG_PATH:" in line:
81+
config_path = line.split("CONFIG_PATH:")[1].strip()
7882
elif "CAL_LOG_PATH:" in line:
7983
cal_log_path = line.split("CAL_LOG_PATH:")[1].strip()
84+
elif "BLOCKS_PATH:" in line:
85+
blocks_path = line.split("BLOCKS_PATH:")[1].strip()
8086
elif "LOG_PATH:" in line:
8187
log_path = line.split("LOG_PATH:")[1].strip()
8288

@@ -93,13 +99,94 @@ def _collect_outputs(cal_lines):
9399
with open(cal_log_path, "rb") as f:
94100
cal_log_bytes = f.read()
95101

102+
config_bytes = None
103+
if config_path:
104+
with open(config_path, "rb") as f:
105+
config_bytes = f.read()
106+
107+
blocks_bytes = None
108+
if blocks_path and os.path.exists(blocks_path):
109+
with open(blocks_path, "rb") as f:
110+
blocks_bytes = f.read()
111+
96112
return {
97113
"weights": weights_bytes,
98114
"log": log_bytes,
99115
"cal_log": cal_log_bytes,
116+
"config": config_bytes,
117+
"blocks": blocks_bytes,
100118
}
101119

102120

121+
def _upload_logs_to_hf(log_files: dict):
122+
"""Upload calibration log files to HuggingFace.
123+
124+
Args:
125+
log_files: dict mapping HF path suffixes to local file paths,
126+
e.g. {"calibration_log.csv": "calibration_log.csv"}
127+
"""
128+
from huggingface_hub import HfApi, CommitOperationAdd
129+
130+
token = os.environ.get("HUGGING_FACE_TOKEN")
131+
repo = "policyengine/policyengine-us-data"
132+
133+
api = HfApi()
134+
operations = []
135+
for hf_name, local_path in log_files.items():
136+
if not os.path.exists(local_path):
137+
print(f"Skipping {local_path} (not found)", flush=True)
138+
continue
139+
operations.append(
140+
CommitOperationAdd(
141+
path_in_repo=f"calibration/logs/{hf_name}",
142+
path_or_fileobj=local_path,
143+
)
144+
)
145+
146+
if not operations:
147+
print("No log files to upload.", flush=True)
148+
return
149+
150+
api.create_commit(
151+
token=token,
152+
repo_id=repo,
153+
operations=operations,
154+
repo_type="model",
155+
commit_message=(f"Upload {len(operations)} calibration log file(s)"),
156+
)
157+
uploaded = [op.path_in_repo for op in operations]
158+
print(f"Uploaded to HuggingFace: {uploaded}", flush=True)
159+
160+
161+
def _upload_calibration_artifact(local_path: str, hf_name: str):
162+
"""Upload a calibration artifact to calibration/ on HuggingFace."""
163+
from huggingface_hub import HfApi, CommitOperationAdd
164+
165+
if not os.path.exists(local_path):
166+
print(f"Skipping {local_path} (not found)", flush=True)
167+
return
168+
169+
token = os.environ.get("HUGGING_FACE_TOKEN")
170+
repo = "policyengine/policyengine-us-data"
171+
api = HfApi()
172+
api.create_commit(
173+
token=token,
174+
repo_id=repo,
175+
operations=[
176+
CommitOperationAdd(
177+
path_in_repo=f"calibration/{hf_name}",
178+
path_or_fileobj=local_path,
179+
)
180+
],
181+
repo_type="model",
182+
commit_message=f"Upload calibration artifact: {hf_name}",
183+
)
184+
print(
185+
f"Uploaded {local_path} to calibration/{hf_name}",
186+
flush=True,
187+
)
188+
189+
103190
def _fit_weights_impl(
104191
branch: str,
105192
epochs: int,
@@ -631,6 +718,7 @@ def main(
631718
package_volume: bool = False,
632719
county_level: bool = False,
633720
workers: int = 1,
721+
upload_logs: bool = False,
634722
):
635723
if gpu not in GPU_FUNCTIONS:
636724
raise ValueError(
@@ -706,8 +794,31 @@ def main(
706794
f.write(result["log"])
707795
print(f"Diagnostics log saved to: {log_output}")
708796

797+
cal_log_output = "calibration_log.csv"
709798
if result.get("cal_log"):
710-
cal_log_output = "calibration_log.csv"
711799
with open(cal_log_output, "wb") as f:
712800
f.write(result["cal_log"])
713801
print(f"Calibration log saved to: {cal_log_output}")
802+
803+
config_output = "unified_run_config.json"
804+
if result.get("config"):
805+
with open(config_output, "wb") as f:
806+
f.write(result["config"])
807+
print(f"Run config saved to: {config_output}")
808+
809+
blocks_output = "stacked_blocks.npy"
810+
if result.get("blocks"):
811+
with open(blocks_output, "wb") as f:
812+
f.write(result["blocks"])
813+
print(f"Stacked blocks saved to: {blocks_output}")
814+
815+
if upload_logs:
816+
log_files = {
817+
"calibration_log.csv": cal_log_output,
818+
"unified_diagnostics.csv": log_output,
819+
"unified_run_config.json": config_output,
820+
}
821+
_upload_logs_to_hf(log_files)
822+
823+
if result.get("blocks"):
824+
_upload_calibration_artifact(blocks_output, "stacked_blocks.npy")

modal_app/worker_script.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ def main():
2020
parser.add_argument("--dataset-path", required=True)
2121
parser.add_argument("--db-path", required=True)
2222
parser.add_argument("--output-dir", required=True)
23+
parser.add_argument(
24+
"--calibration-blocks",
25+
type=str,
26+
default=None,
27+
help="Path to stacked_blocks.npy from calibration",
28+
)
2329
args = parser.parse_args()
2430

2531
work_items = json.loads(args.work_items)
@@ -28,6 +34,19 @@ def main():
2834
db_path = Path(args.db_path)
2935
output_dir = Path(args.output_dir)
3036

37+
calibration_blocks = None
38+
if args.calibration_blocks:
39+
calibration_blocks = np.load(args.calibration_blocks)
40+
41+
rerandomize_takeup = True
42+
from policyengine_us_data.utils.takeup import (
43+
TAKEUP_AFFECTED_TARGETS,
44+
)
45+
46+
takeup_filter = [
47+
info["takeup_var"] for info in TAKEUP_AFFECTED_TARGETS.values()
48+
]
49+
3150
original_stdout = sys.stdout
3251
sys.stdout = sys.stderr
3352

@@ -63,6 +82,9 @@ def main():
6382
cds_to_calibrate=cds_to_calibrate,
6483
dataset_path=dataset_path,
6584
output_dir=output_dir,
85+
rerandomize_takeup=rerandomize_takeup,
86+
calibration_blocks=calibration_blocks,
87+
takeup_filter=takeup_filter,
6688
)
6789
elif item_type == "district":
6890
state_code, dist_num = item_id.split("-")
@@ -72,9 +94,7 @@ def main():
7294
state_fips = fips
7395
break
7496
if state_fips is None:
75-
raise ValueError(
76-
f"Unknown state in district: {item_id}"
77-
)
97+
raise ValueError(f"Unknown state in district: {item_id}")
7898

7999
candidate = f"{state_fips}{int(dist_num):02d}"
80100
if candidate in cds_to_calibrate:
@@ -100,6 +120,9 @@ def main():
100120
cds_to_calibrate=cds_to_calibrate,
101121
dataset_path=dataset_path,
102122
output_dir=output_dir,
123+
rerandomize_takeup=rerandomize_takeup,
124+
calibration_blocks=calibration_blocks,
125+
takeup_filter=takeup_filter,
103126
)
104127
elif item_type == "city":
105128
path = build_city_h5(
@@ -108,6 +131,9 @@ def main():
108131
cds_to_calibrate=cds_to_calibrate,
109132
dataset_path=dataset_path,
110133
output_dir=output_dir,
134+
rerandomize_takeup=rerandomize_takeup,
135+
calibration_blocks=calibration_blocks,
136+
takeup_filter=takeup_filter,
111137
)
112138
else:
113139
raise ValueError(f"Unknown item type: {item_type}")

0 commit comments

Comments
 (0)