Skip to content

Commit b401d2b

Browse files
authored
Merge pull request #687 from PolicyEngine/codex/parallel-long-run-wrapper-upstream
Add parallel year wrapper for long-run H5 builds
2 parents b17e083 + 6ff719a commit b401d2b

3 files changed

Lines changed: 434 additions & 0 deletions

File tree

policyengine_us_data/datasets/cps/long_term/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ python run_household_projection.py 2050 --profile age-only
2323

2424
# GREG with age + Social Security only
2525
python run_household_projection.py 2100 --profile ss
26+
27+
# Parallel year-level H5 construction with one subprocess per year
28+
python run_household_projection_parallel.py \
29+
--years 2026-2035,2045,2049,2062,2063,2070 \
30+
--jobs 6 \
31+
--output-dir ./projected_datasets_parallel \
32+
--profile ss-payroll-tob \
33+
--target-source oact_2025_08_05_provisional
2634
```
2735

2836
**Arguments:**
@@ -46,6 +54,11 @@ python run_household_projection.py 2100 --profile ss
4654
- `--use-tob`: Include TOB (Taxation of Benefits) revenue as calibration target (requires `--greg`)
4755
- `--save-h5`: Save year-specific .h5 files to `./projected_datasets/` directory
4856

57+
**Parallel wrapper:**
58+
- `run_household_projection_parallel.py` runs one `run_household_projection.py YEAR YEAR ...` subprocess per year and merges the resulting H5 artifacts into one output directory.
59+
- The wrapper forces `--save-h5` and controls `--output-dir` itself, so those flags should not be forwarded to the inner runner.
60+
- Per-year stdout/stderr logs are written under `OUTPUT_DIR/.parallel_logs/`.
61+
4962
**Named profiles:**
5063
- `age-only`: IPF age-only calibration
5164
- `ss`: positive entropy calibration with age + Social Security
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import json
5+
import shutil
6+
import subprocess
7+
import sys
8+
from concurrent.futures import ThreadPoolExecutor, as_completed
9+
from pathlib import Path
10+
11+
try:
12+
from .calibration_artifacts import update_dataset_manifest
13+
except ImportError: # pragma: no cover - script execution fallback
14+
from calibration_artifacts import update_dataset_manifest
15+
16+
17+
SCRIPT_DIR = Path(__file__).resolve().parent
18+
RUNNER_PATH = SCRIPT_DIR / "run_household_projection.py"
19+
20+
21+
def parse_years(spec: str) -> list[int]:
22+
years: set[int] = set()
23+
for part in spec.split(","):
24+
chunk = part.strip()
25+
if not chunk:
26+
continue
27+
if "-" in chunk:
28+
start_str, end_str = chunk.split("-", 1)
29+
start = int(start_str)
30+
end = int(end_str)
31+
if end < start:
32+
raise ValueError(f"Invalid year range: {chunk}")
33+
years.update(range(start, end + 1))
34+
else:
35+
years.add(int(chunk))
36+
if not years:
37+
raise ValueError("No years provided")
38+
return sorted(years)
39+
40+
41+
def parse_args() -> tuple[argparse.Namespace, list[str]]:
42+
parser = argparse.ArgumentParser(
43+
description=(
44+
"Run long-run household projections in parallel, one year per "
45+
"subprocess, then merge the resulting H5 artifacts into one output "
46+
"directory and rebuild the calibration manifest."
47+
)
48+
)
49+
parser.add_argument(
50+
"--years",
51+
required=True,
52+
help="Comma-separated years and ranges, e.g. 2026-2035,2045,2070.",
53+
)
54+
parser.add_argument(
55+
"--jobs",
56+
type=int,
57+
default=4,
58+
help="Maximum number of year subprocesses to run concurrently.",
59+
)
60+
parser.add_argument(
61+
"--output-dir",
62+
required=True,
63+
help="Final output directory for merged YYYY.h5 artifacts.",
64+
)
65+
parser.add_argument(
66+
"--keep-temp",
67+
action="store_true",
68+
help="Keep per-year temporary output directories after a successful merge.",
69+
)
70+
args, forwarded_args = parser.parse_known_args()
71+
return args, forwarded_args
72+
73+
74+
def validate_forwarded_args(forwarded_args: list[str]) -> None:
75+
blocked = {"--output-dir", "--save-h5"}
76+
for arg in forwarded_args:
77+
if arg in blocked:
78+
raise ValueError(
79+
f"{arg} is controlled by run_household_projection_parallel.py; "
80+
"pass it to the wrapper instead."
81+
)
82+
83+
84+
def year_output_dir(root: Path, year: int) -> Path:
85+
return root / ".parallel_tmp" / str(year)
86+
87+
88+
def year_log_path(root: Path, year: int) -> Path:
89+
return root / ".parallel_logs" / f"{year}.log"
90+
91+
92+
def run_year(
93+
*,
94+
year: int,
95+
output_root: Path,
96+
forwarded_args: list[str],
97+
) -> tuple[int, Path]:
98+
output_dir = year_output_dir(output_root, year)
99+
log_path = year_log_path(output_root, year)
100+
output_dir.mkdir(parents=True, exist_ok=True)
101+
log_path.parent.mkdir(parents=True, exist_ok=True)
102+
103+
command = [
104+
sys.executable,
105+
str(RUNNER_PATH),
106+
str(year),
107+
str(year),
108+
"--output-dir",
109+
str(output_dir),
110+
"--save-h5",
111+
*forwarded_args,
112+
]
113+
114+
with log_path.open("w", encoding="utf-8") as log_file:
115+
completed = subprocess.run(
116+
command,
117+
cwd=SCRIPT_DIR,
118+
stdout=log_file,
119+
stderr=subprocess.STDOUT,
120+
check=False,
121+
)
122+
123+
if completed.returncode != 0:
124+
raise RuntimeError(
125+
f"Year {year} failed with exit code {completed.returncode}. See {log_path}."
126+
)
127+
128+
expected_h5 = output_dir / f"{year}.h5"
129+
expected_metadata = output_dir / f"{year}.h5.metadata.json"
130+
if not expected_h5.exists() or not expected_metadata.exists():
131+
raise FileNotFoundError(
132+
f"Year {year} finished without expected artifacts in {output_dir}."
133+
)
134+
135+
return year, output_dir
136+
137+
138+
def copy_support_reports(temp_output_dir: Path, final_output_dir: Path) -> None:
139+
for report_path in sorted(
140+
temp_output_dir.glob("support_augmentation_report*.json")
141+
):
142+
target_path = final_output_dir / report_path.name
143+
if not target_path.exists():
144+
shutil.copy2(report_path, target_path)
145+
continue
146+
if report_path.read_bytes() != target_path.read_bytes():
147+
raise ValueError(
148+
f"Conflicting support augmentation report contents for {report_path.name}"
149+
)
150+
151+
152+
def _json_clone(value):
153+
return json.loads(json.dumps(value))
154+
155+
156+
def manifest_contract(manifest: dict) -> dict:
157+
return {
158+
"base_dataset_path": manifest["base_dataset_path"],
159+
"profile": _json_clone(manifest["profile"]),
160+
"target_source": _json_clone(manifest.get("target_source")),
161+
"tax_assumption": _json_clone(manifest.get("tax_assumption")),
162+
"support_augmentation": _json_clone(manifest.get("support_augmentation")),
163+
}
164+
165+
166+
def merge_outputs(
167+
*,
168+
years: list[int],
169+
output_root: Path,
170+
keep_temp: bool,
171+
) -> Path:
172+
output_root.mkdir(parents=True, exist_ok=True)
173+
manifest_seed = None
174+
manifest_path = None
175+
176+
for year in years:
177+
temp_output_dir = year_output_dir(output_root, year)
178+
temp_manifest_path = temp_output_dir / "calibration_manifest.json"
179+
if not temp_manifest_path.exists():
180+
raise FileNotFoundError(
181+
f"Missing temp manifest for year {year}: {temp_manifest_path}"
182+
)
183+
184+
temp_manifest = json.loads(temp_manifest_path.read_text(encoding="utf-8"))
185+
if manifest_seed is None:
186+
manifest_seed = manifest_contract(temp_manifest)
187+
else:
188+
for key, value in manifest_seed.items():
189+
if _json_clone(temp_manifest.get(key)) != value:
190+
raise ValueError(
191+
f"Temp manifest mismatch for {key} in year {year}: "
192+
f"{temp_manifest.get(key)} != {value}"
193+
)
194+
195+
h5_name = f"{year}.h5"
196+
metadata_name = f"{year}.h5.metadata.json"
197+
shutil.copy2(temp_output_dir / h5_name, output_root / h5_name)
198+
shutil.copy2(temp_output_dir / metadata_name, output_root / metadata_name)
199+
copy_support_reports(temp_output_dir, output_root)
200+
201+
metadata = json.loads(
202+
(temp_output_dir / metadata_name).read_text(encoding="utf-8")
203+
)
204+
manifest_path = update_dataset_manifest(
205+
output_root,
206+
year=year,
207+
h5_path=output_root / h5_name,
208+
metadata_path=output_root / metadata_name,
209+
base_dataset_path=manifest_seed["base_dataset_path"],
210+
profile=manifest_seed["profile"],
211+
calibration_audit=metadata["calibration_audit"],
212+
target_source=manifest_seed["target_source"],
213+
tax_assumption=manifest_seed["tax_assumption"],
214+
support_augmentation=manifest_seed["support_augmentation"],
215+
)
216+
217+
if not keep_temp:
218+
shutil.rmtree(output_root / ".parallel_tmp", ignore_errors=True)
219+
220+
return manifest_path
221+
222+
223+
def main() -> int:
224+
args, forwarded_args = parse_args()
225+
validate_forwarded_args(forwarded_args)
226+
227+
output_root = Path(args.output_dir).expanduser().resolve()
228+
years = parse_years(args.years)
229+
230+
print(
231+
f"Running {len(years)} year jobs with concurrency {args.jobs} into {output_root}"
232+
)
233+
234+
completed_years: list[int] = []
235+
with ThreadPoolExecutor(max_workers=max(args.jobs, 1)) as executor:
236+
future_map = {
237+
executor.submit(
238+
run_year,
239+
year=year,
240+
output_root=output_root,
241+
forwarded_args=forwarded_args,
242+
): year
243+
for year in years
244+
}
245+
for future in as_completed(future_map):
246+
year = future_map[future]
247+
try:
248+
future.result()
249+
except Exception as error:
250+
print(f"Year {year} failed: {error}", file=sys.stderr)
251+
return 1
252+
completed_years.append(year)
253+
print(f"Completed year {year}")
254+
255+
manifest_path = merge_outputs(
256+
years=years,
257+
output_root=output_root,
258+
keep_temp=args.keep_temp,
259+
)
260+
print(f"Merged {len(completed_years)} yearly artifacts into {output_root}")
261+
print(f"Rebuilt manifest at {manifest_path}")
262+
return 0
263+
264+
265+
if __name__ == "__main__":
266+
raise SystemExit(main())

0 commit comments

Comments
 (0)