Skip to content

Commit f40796c

Browse files
committed
Add parallel year wrapper for long-run H5 builds
1 parent b17e083 commit f40796c

2 files changed

Lines changed: 275 additions & 0 deletions

File tree

policyengine_us_data/datasets/cps/long_term/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ python run_household_projection.py 2050 --profile age-only
2323

2424
# GREG with age + Social Security only
2525
python run_household_projection.py 2100 --profile ss
26+
27+
# Parallel year-level H5 construction with one subprocess per year
28+
python run_household_projection_parallel.py \
29+
--years 2026-2035,2045,2049,2062,2063,2070 \
30+
--jobs 6 \
31+
--output-dir ./projected_datasets_parallel \
32+
--profile ss-payroll-tob \
33+
--target-source oact_2025_08_05_provisional
2634
```
2735

2836
**Arguments:**
@@ -46,6 +54,11 @@ python run_household_projection.py 2100 --profile ss
4654
- `--use-tob`: Include TOB (Taxation of Benefits) revenue as calibration target (requires `--greg`)
4755
- `--save-h5`: Save year-specific .h5 files to `./projected_datasets/` directory
4856

57+
**Parallel wrapper:**
58+
- `run_household_projection_parallel.py` runs one `run_household_projection.py YEAR YEAR ...` subprocess per year and merges the resulting H5 artifacts into one output directory.
59+
- The wrapper forces `--save-h5` and controls `--output-dir` itself, so those flags should not be forwarded to the inner runner.
60+
- Per-year stdout/stderr logs are written under `OUTPUT_DIR/.parallel_logs/`.
61+
4962
**Named profiles:**
5063
- `age-only`: IPF age-only calibration
5164
- `ss`: positive entropy calibration with age + Social Security
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import json
5+
import shutil
6+
import subprocess
7+
import sys
8+
from concurrent.futures import ThreadPoolExecutor, as_completed
9+
from pathlib import Path
10+
11+
from calibration_artifacts import update_dataset_manifest
12+
13+
14+
SCRIPT_DIR = Path(__file__).resolve().parent
15+
RUNNER_PATH = SCRIPT_DIR / "run_household_projection.py"
16+
17+
18+
def parse_years(spec: str) -> list[int]:
19+
years: set[int] = set()
20+
for part in spec.split(","):
21+
chunk = part.strip()
22+
if not chunk:
23+
continue
24+
if "-" in chunk:
25+
start_str, end_str = chunk.split("-", 1)
26+
start = int(start_str)
27+
end = int(end_str)
28+
if end < start:
29+
raise ValueError(f"Invalid year range: {chunk}")
30+
years.update(range(start, end + 1))
31+
else:
32+
years.add(int(chunk))
33+
if not years:
34+
raise ValueError("No years provided")
35+
return sorted(years)
36+
37+
38+
def parse_args() -> tuple[argparse.Namespace, list[str]]:
39+
parser = argparse.ArgumentParser(
40+
description=(
41+
"Run long-run household projections in parallel, one year per "
42+
"subprocess, then merge the resulting H5 artifacts into one output "
43+
"directory and rebuild the calibration manifest."
44+
)
45+
)
46+
parser.add_argument(
47+
"--years",
48+
required=True,
49+
help="Comma-separated years and ranges, e.g. 2026-2035,2045,2070.",
50+
)
51+
parser.add_argument(
52+
"--jobs",
53+
type=int,
54+
default=4,
55+
help="Maximum number of year subprocesses to run concurrently.",
56+
)
57+
parser.add_argument(
58+
"--output-dir",
59+
required=True,
60+
help="Final output directory for merged YYYY.h5 artifacts.",
61+
)
62+
parser.add_argument(
63+
"--keep-temp",
64+
action="store_true",
65+
help="Keep per-year temporary output directories after a successful merge.",
66+
)
67+
args, forwarded_args = parser.parse_known_args()
68+
return args, forwarded_args
69+
70+
71+
def validate_forwarded_args(forwarded_args: list[str]) -> None:
72+
blocked = {"--output-dir", "--save-h5"}
73+
for arg in forwarded_args:
74+
if arg in blocked:
75+
raise ValueError(
76+
f"{arg} is controlled by run_household_projection_parallel.py; "
77+
"pass it to the wrapper instead."
78+
)
79+
80+
81+
def year_output_dir(root: Path, year: int) -> Path:
82+
return root / ".parallel_tmp" / str(year)
83+
84+
85+
def year_log_path(root: Path, year: int) -> Path:
86+
return root / ".parallel_logs" / f"{year}.log"
87+
88+
89+
def run_year(
90+
*,
91+
year: int,
92+
output_root: Path,
93+
forwarded_args: list[str],
94+
) -> tuple[int, Path]:
95+
output_dir = year_output_dir(output_root, year)
96+
log_path = year_log_path(output_root, year)
97+
output_dir.mkdir(parents=True, exist_ok=True)
98+
log_path.parent.mkdir(parents=True, exist_ok=True)
99+
100+
command = [
101+
sys.executable,
102+
str(RUNNER_PATH),
103+
str(year),
104+
str(year),
105+
"--output-dir",
106+
str(output_dir),
107+
"--save-h5",
108+
*forwarded_args,
109+
]
110+
111+
with log_path.open("w", encoding="utf-8") as log_file:
112+
completed = subprocess.run(
113+
command,
114+
cwd=SCRIPT_DIR,
115+
stdout=log_file,
116+
stderr=subprocess.STDOUT,
117+
check=False,
118+
)
119+
120+
if completed.returncode != 0:
121+
raise RuntimeError(
122+
f"Year {year} failed with exit code {completed.returncode}. "
123+
f"See {log_path}."
124+
)
125+
126+
expected_h5 = output_dir / f"{year}.h5"
127+
expected_metadata = output_dir / f"{year}.h5.metadata.json"
128+
if not expected_h5.exists() or not expected_metadata.exists():
129+
raise FileNotFoundError(
130+
f"Year {year} finished without expected artifacts in {output_dir}."
131+
)
132+
133+
return year, output_dir
134+
135+
136+
def copy_support_reports(temp_output_dir: Path, final_output_dir: Path) -> None:
137+
for report_path in sorted(temp_output_dir.glob("support_augmentation_report*.json")):
138+
target_path = final_output_dir / report_path.name
139+
if not target_path.exists():
140+
shutil.copy2(report_path, target_path)
141+
continue
142+
if report_path.read_bytes() != target_path.read_bytes():
143+
raise ValueError(
144+
f"Conflicting support augmentation report contents for {report_path.name}"
145+
)
146+
147+
148+
def _json_clone(value):
149+
return json.loads(json.dumps(value))
150+
151+
152+
def manifest_contract(manifest: dict) -> dict:
153+
return {
154+
"base_dataset_path": manifest["base_dataset_path"],
155+
"profile": _json_clone(manifest["profile"]),
156+
"target_source": _json_clone(manifest.get("target_source")),
157+
"tax_assumption": _json_clone(manifest.get("tax_assumption")),
158+
"support_augmentation": _json_clone(manifest.get("support_augmentation")),
159+
}
160+
161+
162+
def merge_outputs(
163+
*,
164+
years: list[int],
165+
output_root: Path,
166+
keep_temp: bool,
167+
) -> Path:
168+
output_root.mkdir(parents=True, exist_ok=True)
169+
manifest_seed = None
170+
manifest_path = None
171+
172+
for year in years:
173+
temp_output_dir = year_output_dir(output_root, year)
174+
temp_manifest_path = temp_output_dir / "calibration_manifest.json"
175+
if not temp_manifest_path.exists():
176+
raise FileNotFoundError(
177+
f"Missing temp manifest for year {year}: {temp_manifest_path}"
178+
)
179+
180+
temp_manifest = json.loads(temp_manifest_path.read_text(encoding="utf-8"))
181+
if manifest_seed is None:
182+
manifest_seed = manifest_contract(temp_manifest)
183+
else:
184+
for key, value in manifest_seed.items():
185+
if _json_clone(temp_manifest.get(key)) != value:
186+
raise ValueError(
187+
f"Temp manifest mismatch for {key} in year {year}: "
188+
f"{temp_manifest.get(key)} != {value}"
189+
)
190+
191+
h5_name = f"{year}.h5"
192+
metadata_name = f"{year}.h5.metadata.json"
193+
shutil.copy2(temp_output_dir / h5_name, output_root / h5_name)
194+
shutil.copy2(temp_output_dir / metadata_name, output_root / metadata_name)
195+
copy_support_reports(temp_output_dir, output_root)
196+
197+
metadata = json.loads(
198+
(temp_output_dir / metadata_name).read_text(encoding="utf-8")
199+
)
200+
manifest_path = update_dataset_manifest(
201+
output_root,
202+
year=year,
203+
h5_path=output_root / h5_name,
204+
metadata_path=output_root / metadata_name,
205+
base_dataset_path=manifest_seed["base_dataset_path"],
206+
profile=manifest_seed["profile"],
207+
calibration_audit=metadata["calibration_audit"],
208+
target_source=manifest_seed["target_source"],
209+
tax_assumption=manifest_seed["tax_assumption"],
210+
support_augmentation=manifest_seed["support_augmentation"],
211+
)
212+
213+
if not keep_temp:
214+
shutil.rmtree(output_root / ".parallel_tmp", ignore_errors=True)
215+
216+
return manifest_path
217+
218+
219+
def main() -> int:
220+
args, forwarded_args = parse_args()
221+
validate_forwarded_args(forwarded_args)
222+
223+
output_root = Path(args.output_dir).expanduser().resolve()
224+
years = parse_years(args.years)
225+
226+
print(
227+
f"Running {len(years)} year jobs with concurrency {args.jobs} into {output_root}"
228+
)
229+
230+
completed_years: list[int] = []
231+
with ThreadPoolExecutor(max_workers=max(args.jobs, 1)) as executor:
232+
future_map = {
233+
executor.submit(
234+
run_year,
235+
year=year,
236+
output_root=output_root,
237+
forwarded_args=forwarded_args,
238+
): year
239+
for year in years
240+
}
241+
for future in as_completed(future_map):
242+
year = future_map[future]
243+
try:
244+
future.result()
245+
except Exception as error:
246+
print(f"Year {year} failed: {error}", file=sys.stderr)
247+
return 1
248+
completed_years.append(year)
249+
print(f"Completed year {year}")
250+
251+
manifest_path = merge_outputs(
252+
years=years,
253+
output_root=output_root,
254+
keep_temp=args.keep_temp,
255+
)
256+
print(f"Merged {len(completed_years)} yearly artifacts into {output_root}")
257+
print(f"Rebuilt manifest at {manifest_path}")
258+
return 0
259+
260+
261+
if __name__ == "__main__":
262+
raise SystemExit(main())

0 commit comments

Comments
 (0)