Skip to content

Commit 7e6909d

Browse files
baogorekclaude
andcommitted
Fix args sync, DRY sequential path, thread-safe volume commits
- Bake correct defaults (12000, 99.5, seed=3526) into create_stratified_cps.py so callers no longer need to pass args - Remove hardcoded args from Makefile and data_build.py (both paths) - Replace verbose sequential scripts_with_outputs list with SCRIPT_OUTPUTS iteration - Add threading.Lock around volume.commit() in save_checkpoint to prevent concurrent commit races in parallel phases Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f009156 commit 7e6909d

3 files changed

Lines changed: 9 additions & 70 deletions

File tree

Makefile

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,7 @@ data: download
9595
python policyengine_us_data/datasets/cps/extended_cps.py
9696
python policyengine_us_data/datasets/cps/enhanced_cps.py
9797
python policyengine_us_data/datasets/cps/small_enhanced_cps.py
98-
# 12000: number of households our GPUs can handle (found via trial and error).
99-
# --top=99.5: include only top 0.5% (vs default 1%) to preserve
100-
# representation of lower-income households.
101-
# --seed=3526: reproducible stratified sampling.
102-
python policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py 12000 --top=99.5 --seed=3526
98+
python policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py
10399

104100
publish-local-area:
105101
python policyengine_us_data/datasets/cps/local_area_calibration/publish_local_area.py

modal_app/data_build.py

Lines changed: 5 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import shutil
33
import subprocess
4+
import threading
45
from concurrent.futures import ThreadPoolExecutor, as_completed
56
from pathlib import Path
67
from typing import Optional
@@ -26,6 +27,7 @@
2627

2728
REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git"
2829
VOLUME_MOUNT = "/checkpoints"
30+
_volume_lock = threading.Lock()
2931

3032
# Script to output file mapping for checkpointing
3133
# Values can be a single file path (str) or a list of file paths
@@ -122,7 +124,8 @@ def save_checkpoint(
122124
checkpoint_path = get_checkpoint_path(branch, output_file)
123125
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
124126
shutil.copy2(local_path, checkpoint_path)
125-
volume.commit()
127+
with _volume_lock:
128+
volume.commit()
126129
print(f"Checkpointed: {output_file}")
127130

128131

@@ -304,71 +307,12 @@ def build_datasets(
304307
)
305308

306309
if sequential:
307-
# Sequential execution with checkpointing
308-
scripts_with_outputs = [
309-
(
310-
"policyengine_us_data/utils/uprating.py",
311-
SCRIPT_OUTPUTS["policyengine_us_data/utils/uprating.py"],
312-
None,
313-
),
314-
(
315-
"policyengine_us_data/datasets/acs/acs.py",
316-
SCRIPT_OUTPUTS["policyengine_us_data/datasets/acs/acs.py"],
317-
None,
318-
),
319-
(
320-
"policyengine_us_data/datasets/cps/cps.py",
321-
SCRIPT_OUTPUTS["policyengine_us_data/datasets/cps/cps.py"],
322-
None,
323-
),
324-
(
325-
"policyengine_us_data/datasets/puf/irs_puf.py",
326-
SCRIPT_OUTPUTS["policyengine_us_data/datasets/puf/irs_puf.py"],
327-
None,
328-
),
329-
(
330-
"policyengine_us_data/datasets/puf/puf.py",
331-
SCRIPT_OUTPUTS["policyengine_us_data/datasets/puf/puf.py"],
332-
None,
333-
),
334-
(
335-
"policyengine_us_data/datasets/cps/extended_cps.py",
336-
SCRIPT_OUTPUTS[
337-
"policyengine_us_data/datasets/cps/extended_cps.py"
338-
],
339-
None,
340-
),
341-
(
342-
"policyengine_us_data/datasets/cps/enhanced_cps.py",
343-
SCRIPT_OUTPUTS[
344-
"policyengine_us_data/datasets/cps/enhanced_cps.py"
345-
],
346-
None,
347-
),
348-
(
349-
"policyengine_us_data/datasets/cps/"
350-
"local_area_calibration/create_stratified_cps.py",
351-
SCRIPT_OUTPUTS[
352-
"policyengine_us_data/datasets/cps/"
353-
"local_area_calibration/create_stratified_cps.py"
354-
],
355-
["10500"],
356-
),
357-
(
358-
"policyengine_us_data/datasets/cps/small_enhanced_cps.py",
359-
SCRIPT_OUTPUTS[
360-
"policyengine_us_data/datasets/cps/small_enhanced_cps.py"
361-
],
362-
None,
363-
),
364-
]
365-
for script, output, args in scripts_with_outputs:
310+
for script, output in SCRIPT_OUTPUTS.items():
366311
run_script_with_checkpoint(
367312
script,
368313
output,
369314
branch,
370315
checkpoint_volume,
371-
args=args,
372316
env=env,
373317
)
374318
else:
@@ -472,7 +416,6 @@ def build_datasets(
472416
],
473417
branch,
474418
checkpoint_volume,
475-
args=["10500"],
476419
env=env,
477420
),
478421
]

policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,10 +307,10 @@ def create_stratified_cps_dataset(
307307
if __name__ == "__main__":
308308
import sys
309309

310-
target = 30_000
311-
high_pct = 99
310+
target = 12_000
311+
high_pct = 99.5
312312
oversample = False
313-
seed = None
313+
seed = 3526
314314

315315
for arg in sys.argv[1:]:
316316
if arg == "--oversample-poor":

0 commit comments

Comments
 (0)