Skip to content

Commit 2a1816f

Browse files
committed
CI data build parallelization
1 parent 6a90ecf commit 2a1816f

2 files changed

Lines changed: 128 additions & 36 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Parallelized data build in Modal for data_build.py.

modal_app/data_build.py

Lines changed: 124 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
22
import subprocess
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from typing import Optional
5+
36
import modal
47

58
app = modal.App("policyengine-us-data")
@@ -28,6 +31,33 @@ def setup_gcp_credentials():
2831
return None
2932

3033

34+
def run_script(
35+
script_path: str,
36+
args: Optional[list] = None,
37+
env: Optional[dict] = None,
38+
) -> str:
39+
"""Run a script with uv and return its path for logging.
40+
41+
Args:
42+
script_path: Path to the Python script to run.
43+
args: Optional list of command-line arguments.
44+
env: Optional environment variables dict.
45+
46+
Returns:
47+
The script_path that was executed.
48+
49+
Raises:
50+
subprocess.CalledProcessError: If the script fails.
51+
"""
52+
cmd = ["uv", "run", "python", script_path]
53+
if args:
54+
cmd.extend(args)
55+
print(f"Starting {script_path}...")
56+
subprocess.run(cmd, check=True, env=env or os.environ.copy())
57+
print(f"Completed {script_path}")
58+
return script_path
59+
60+
3161
@app.function(
3262
image=image,
3363
secrets=[hf_secret, gcp_secret],
@@ -38,6 +68,7 @@ def setup_gcp_credentials():
3868
def build_datasets(
3969
upload: bool = False,
4070
branch: str = "main",
71+
sequential: bool = False,
4172
):
4273
setup_gcp_credentials()
4374

@@ -50,45 +81,100 @@ def build_datasets(
5081
env = os.environ.copy()
5182

5283
# Download prerequisites
53-
subprocess.run(
54-
[
55-
"uv",
56-
"run",
57-
"python",
58-
"policyengine_us_data/storage/download_private_prerequisites.py",
59-
],
60-
check=True,
84+
run_script(
85+
"policyengine_us_data/storage/download_private_prerequisites.py",
6186
env=env,
6287
)
6388

64-
# Build main datasets
65-
scripts = [
66-
"policyengine_us_data/utils/uprating.py",
67-
"policyengine_us_data/datasets/acs/acs.py",
68-
"policyengine_us_data/datasets/cps/cps.py",
69-
"policyengine_us_data/datasets/puf/irs_puf.py",
70-
"policyengine_us_data/datasets/puf/puf.py",
71-
"policyengine_us_data/datasets/cps/extended_cps.py",
72-
"policyengine_us_data/datasets/cps/enhanced_cps.py",
73-
"policyengine_us_data/datasets/cps/small_enhanced_cps.py",
74-
]
75-
for script in scripts:
76-
print(f"Running {script}...")
77-
subprocess.run(["uv", "run", "python", script], check=True, env=env)
78-
79-
# Build stratified CPS for local area calibration
80-
print("Running create_stratified_cps.py...")
81-
subprocess.run(
82-
[
83-
"uv",
84-
"run",
85-
"python",
86-
"policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py",
87-
"10500",
88-
],
89-
check=True,
90-
env=env,
91-
)
89+
if sequential:
90+
# Original sequential execution for backward compatibility
91+
scripts = [
92+
"policyengine_us_data/utils/uprating.py",
93+
"policyengine_us_data/datasets/acs/acs.py",
94+
"policyengine_us_data/datasets/cps/cps.py",
95+
"policyengine_us_data/datasets/puf/irs_puf.py",
96+
"policyengine_us_data/datasets/puf/puf.py",
97+
"policyengine_us_data/datasets/cps/extended_cps.py",
98+
"policyengine_us_data/datasets/cps/enhanced_cps.py",
99+
"policyengine_us_data/datasets/cps/small_enhanced_cps.py",
100+
]
101+
for script in scripts:
102+
run_script(script, env=env)
103+
104+
# Build stratified CPS
105+
run_script(
106+
"policyengine_us_data/datasets/cps/"
107+
"local_area_calibration/create_stratified_cps.py",
108+
args=["10500"],
109+
env=env,
110+
)
111+
else:
112+
# Parallel execution based on dependency groups
113+
# GROUP 1: Independent scripts - run in parallel
114+
print("=== Phase 1: Building independent datasets (parallel) ===")
115+
group1 = [
116+
"policyengine_us_data/utils/uprating.py",
117+
"policyengine_us_data/datasets/acs/acs.py",
118+
"policyengine_us_data/datasets/puf/irs_puf.py",
119+
]
120+
with ThreadPoolExecutor(max_workers=3) as executor:
121+
futures = {
122+
executor.submit(run_script, s, env=env): s for s in group1
123+
}
124+
for future in as_completed(futures):
125+
future.result() # Raises if script failed
126+
127+
# GROUP 2: Depends on Group 1 - run in parallel
128+
# cps.py needs acs, puf.py needs irs_puf + uprating
129+
print("=== Phase 2: Building CPS and PUF (parallel) ===")
130+
group2 = [
131+
"policyengine_us_data/datasets/cps/cps.py",
132+
"policyengine_us_data/datasets/puf/puf.py",
133+
]
134+
with ThreadPoolExecutor(max_workers=2) as executor:
135+
futures = {
136+
executor.submit(run_script, s, env=env): s for s in group2
137+
}
138+
for future in as_completed(futures):
139+
future.result()
140+
141+
# SEQUENTIAL: Extended CPS (needs both cps and puf)
142+
print("=== Phase 3: Building extended CPS ===")
143+
run_script(
144+
"policyengine_us_data/datasets/cps/extended_cps.py",
145+
env=env,
146+
)
147+
148+
# GROUP 3: After extended_cps - run in parallel
149+
# enhanced_cps and stratified_cps both depend on extended_cps
150+
print(
151+
"=== Phase 4: Building enhanced and stratified CPS (parallel)"
152+
" ==="
153+
)
154+
with ThreadPoolExecutor(max_workers=2) as executor:
155+
futures = [
156+
executor.submit(
157+
run_script,
158+
"policyengine_us_data/datasets/cps/enhanced_cps.py",
159+
env=env,
160+
),
161+
executor.submit(
162+
run_script,
163+
"policyengine_us_data/datasets/cps/"
164+
"local_area_calibration/create_stratified_cps.py",
165+
args=["10500"],
166+
env=env,
167+
),
168+
]
169+
for future in as_completed(futures):
170+
future.result()
171+
172+
# SEQUENTIAL: Small enhanced CPS (needs enhanced_cps)
173+
print("=== Phase 5: Building small enhanced CPS ===")
174+
run_script(
175+
"policyengine_us_data/datasets/cps/small_enhanced_cps.py",
176+
env=env,
177+
)
92178

93179
# Run local area calibration tests
94180
print("Running local area calibration tests...")
@@ -128,9 +214,11 @@ def build_datasets(
128214
def main(
129215
upload: bool = False,
130216
branch: str = "main",
217+
sequential: bool = False,
131218
):
132219
result = build_datasets.remote(
133220
upload=upload,
134221
branch=branch,
222+
sequential=sequential,
135223
)
136224
print(result)

0 commit comments

Comments
 (0)