11import os
22import subprocess
3+ from concurrent .futures import ThreadPoolExecutor , as_completed
4+ from typing import Optional
5+
36import modal
47
58app = modal .App ("policyengine-us-data" )
@@ -28,6 +31,33 @@ def setup_gcp_credentials():
2831 return None
2932
3033
34+ def run_script (
35+ script_path : str ,
36+ args : Optional [list ] = None ,
37+ env : Optional [dict ] = None ,
38+ ) -> str :
39+ """Run a script with uv and return its path for logging.
40+
41+ Args:
42+ script_path: Path to the Python script to run.
43+ args: Optional list of command-line arguments.
44+ env: Optional environment variables dict.
45+
46+ Returns:
47+ The script_path that was executed.
48+
49+ Raises:
50+ subprocess.CalledProcessError: If the script fails.
51+ """
52+ cmd = ["uv" , "run" , "python" , script_path ]
53+ if args :
54+ cmd .extend (args )
55+ print (f"Starting { script_path } ..." )
56+ subprocess .run (cmd , check = True , env = env or os .environ .copy ())
57+ print (f"Completed { script_path } " )
58+ return script_path
59+
60+
3161@app .function (
3262 image = image ,
3363 secrets = [hf_secret , gcp_secret ],
@@ -38,6 +68,7 @@ def setup_gcp_credentials():
3868def build_datasets (
3969 upload : bool = False ,
4070 branch : str = "main" ,
71+ sequential : bool = False ,
4172):
4273 setup_gcp_credentials ()
4374
@@ -50,45 +81,100 @@ def build_datasets(
5081 env = os .environ .copy ()
5182
5283 # Download prerequisites
53- subprocess .run (
54- [
55- "uv" ,
56- "run" ,
57- "python" ,
58- "policyengine_us_data/storage/download_private_prerequisites.py" ,
59- ],
60- check = True ,
84+ run_script (
85+ "policyengine_us_data/storage/download_private_prerequisites.py" ,
6186 env = env ,
6287 )
6388
64- # Build main datasets
65- scripts = [
66- "policyengine_us_data/utils/uprating.py" ,
67- "policyengine_us_data/datasets/acs/acs.py" ,
68- "policyengine_us_data/datasets/cps/cps.py" ,
69- "policyengine_us_data/datasets/puf/irs_puf.py" ,
70- "policyengine_us_data/datasets/puf/puf.py" ,
71- "policyengine_us_data/datasets/cps/extended_cps.py" ,
72- "policyengine_us_data/datasets/cps/enhanced_cps.py" ,
73- "policyengine_us_data/datasets/cps/small_enhanced_cps.py" ,
74- ]
75- for script in scripts :
76- print (f"Running { script } ..." )
77- subprocess .run (["uv" , "run" , "python" , script ], check = True , env = env )
78-
79- # Build stratified CPS for local area calibration
80- print ("Running create_stratified_cps.py..." )
81- subprocess .run (
82- [
83- "uv" ,
84- "run" ,
85- "python" ,
86- "policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py" ,
87- "10500" ,
88- ],
89- check = True ,
90- env = env ,
91- )
89+ if sequential :
90+ # Original sequential execution for backward compatibility
91+ scripts = [
92+ "policyengine_us_data/utils/uprating.py" ,
93+ "policyengine_us_data/datasets/acs/acs.py" ,
94+ "policyengine_us_data/datasets/cps/cps.py" ,
95+ "policyengine_us_data/datasets/puf/irs_puf.py" ,
96+ "policyengine_us_data/datasets/puf/puf.py" ,
97+ "policyengine_us_data/datasets/cps/extended_cps.py" ,
98+ "policyengine_us_data/datasets/cps/enhanced_cps.py" ,
99+ "policyengine_us_data/datasets/cps/small_enhanced_cps.py" ,
100+ ]
101+ for script in scripts :
102+ run_script (script , env = env )
103+
104+ # Build stratified CPS
105+ run_script (
106+ "policyengine_us_data/datasets/cps/"
107+ "local_area_calibration/create_stratified_cps.py" ,
108+ args = ["10500" ],
109+ env = env ,
110+ )
111+ else :
112+ # Parallel execution based on dependency groups
113+ # GROUP 1: Independent scripts - run in parallel
114+ print ("=== Phase 1: Building independent datasets (parallel) ===" )
115+ group1 = [
116+ "policyengine_us_data/utils/uprating.py" ,
117+ "policyengine_us_data/datasets/acs/acs.py" ,
118+ "policyengine_us_data/datasets/puf/irs_puf.py" ,
119+ ]
120+ with ThreadPoolExecutor (max_workers = 3 ) as executor :
121+ futures = {
122+ executor .submit (run_script , s , env = env ): s for s in group1
123+ }
124+ for future in as_completed (futures ):
125+ future .result () # Raises if script failed
126+
127+ # GROUP 2: Depends on Group 1 - run in parallel
128+ # cps.py needs acs, puf.py needs irs_puf + uprating
129+ print ("=== Phase 2: Building CPS and PUF (parallel) ===" )
130+ group2 = [
131+ "policyengine_us_data/datasets/cps/cps.py" ,
132+ "policyengine_us_data/datasets/puf/puf.py" ,
133+ ]
134+ with ThreadPoolExecutor (max_workers = 2 ) as executor :
135+ futures = {
136+ executor .submit (run_script , s , env = env ): s for s in group2
137+ }
138+ for future in as_completed (futures ):
139+ future .result ()
140+
141+ # SEQUENTIAL: Extended CPS (needs both cps and puf)
142+ print ("=== Phase 3: Building extended CPS ===" )
143+ run_script (
144+ "policyengine_us_data/datasets/cps/extended_cps.py" ,
145+ env = env ,
146+ )
147+
148+ # GROUP 3: After extended_cps - run in parallel
149+ # enhanced_cps and stratified_cps both depend on extended_cps
150+ print (
151+ "=== Phase 4: Building enhanced and stratified CPS (parallel)"
152+ " ==="
153+ )
154+ with ThreadPoolExecutor (max_workers = 2 ) as executor :
155+ futures = [
156+ executor .submit (
157+ run_script ,
158+ "policyengine_us_data/datasets/cps/enhanced_cps.py" ,
159+ env = env ,
160+ ),
161+ executor .submit (
162+ run_script ,
163+ "policyengine_us_data/datasets/cps/"
164+ "local_area_calibration/create_stratified_cps.py" ,
165+ args = ["10500" ],
166+ env = env ,
167+ ),
168+ ]
169+ for future in as_completed (futures ):
170+ future .result ()
171+
172+ # SEQUENTIAL: Small enhanced CPS (needs enhanced_cps)
173+ print ("=== Phase 5: Building small enhanced CPS ===" )
174+ run_script (
175+ "policyengine_us_data/datasets/cps/small_enhanced_cps.py" ,
176+ env = env ,
177+ )
92178
93179 # Run local area calibration tests
94180 print ("Running local area calibration tests..." )
@@ -128,9 +214,11 @@ def build_datasets(
128214def main (
129215 upload : bool = False ,
130216 branch : str = "main" ,
217+ sequential : bool = False ,
131218):
132219 result = build_datasets .remote (
133220 upload = upload ,
134221 branch = branch ,
222+ sequential = sequential ,
135223 )
136224 print (result )
0 commit comments