Skip to content

Commit d921ecb

Browse files
committed
add run_id to artifact directories
1 parent 243fdeb commit d921ecb

4 files changed

Lines changed: 55 additions & 17 deletions

File tree

modal_app/data_build.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def build_datasets(
350350
clear_checkpoints: bool = False,
351351
skip_tests: bool = False,
352352
skip_enhanced_cps: bool = False,
353+
run_id: str = "",
353354
):
354355
"""Build all datasets with preemption-resilient checkpointing.
355356
@@ -593,6 +594,8 @@ def build_datasets(
593594
# failure does not block downstream calibration steps.
594595
print("Copying pipeline artifacts to shared volume...")
595596
artifacts_dir = Path(PIPELINE_MOUNT) / "artifacts"
597+
if run_id:
598+
artifacts_dir = artifacts_dir / run_id
596599
artifacts_dir.mkdir(parents=True, exist_ok=True)
597600

598601
# Copy all intermediate H5 datasets for lineage tracing

modal_app/local_area.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,9 @@ def coordinate_publish(
646646
version_dir = staging_dir / version
647647

648648
pipeline_volume.reload()
649-
artifacts = Path("/pipeline/artifacts")
649+
artifacts = (
650+
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
651+
)
650652
weights_path = artifacts / "calibration_weights.npy"
651653
db_path = artifacts / "policy_data.db"
652654
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"
@@ -929,7 +931,9 @@ def coordinate_national_publish(
929931
staging_dir = Path(VOLUME_MOUNT)
930932

931933
pipeline_volume.reload()
932-
artifacts = Path("/pipeline/artifacts")
934+
artifacts = (
935+
Path(f"/pipeline/artifacts/{run_id}") if run_id else Path("/pipeline/artifacts")
936+
)
933937
weights_path = artifacts / "national_calibration_weights.npy"
934938
db_path = artifacts / "policy_data.db"
935939
dataset_path = artifacts / "source_imputed_stratified_extended_cps.h5"

modal_app/pipeline.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,21 @@
6767
REPO_URL = "https://github.com/PolicyEngine/policyengine-us-data.git"
6868
PIPELINE_MOUNT = "/pipeline"
6969
STAGING_MOUNT = "/staging"
70-
ARTIFACTS_DIR = f"{PIPELINE_MOUNT}/artifacts"
70+
ARTIFACTS_BASE = f"{PIPELINE_MOUNT}/artifacts"
7171
RUNS_DIR = f"{PIPELINE_MOUNT}/runs"
7272

7373

74+
def artifacts_dir_for_run(run_id: str) -> str:
75+
"""Return the run-scoped artifacts directory.
76+
77+
When run_id is empty, falls back to the flat base path
78+
for backward compatibility with standalone invocations.
79+
"""
80+
if run_id:
81+
return f"{ARTIFACTS_BASE}/{run_id}"
82+
return ARTIFACTS_BASE
83+
84+
7485
# ── Run metadata ─────────────────────────────────────────────────
7586

7687

@@ -302,7 +313,7 @@ def stage_base_datasets(
302313
version: Package version string for the commit.
303314
branch: Git branch for repo clone.
304315
"""
305-
artifacts = Path(ARTIFACTS_DIR)
316+
artifacts = Path(artifacts_dir_for_run(run_id))
306317

307318
files_with_paths = []
308319

@@ -666,8 +677,8 @@ def run_pipeline(
666677
run_dir.mkdir(parents=True, exist_ok=True)
667678
(run_dir / "diagnostics").mkdir(exist_ok=True)
668679

669-
# Create artifacts directory
670-
Path(ARTIFACTS_DIR).mkdir(parents=True, exist_ok=True)
680+
# Create run-scoped artifacts directory
681+
Path(artifacts_dir_for_run(run_id)).mkdir(parents=True, exist_ok=True)
671682

672683
write_run_meta(meta, pipeline_volume)
673684

@@ -704,6 +715,7 @@ def run_pipeline(
704715
clear_checkpoints=clear_checkpoints,
705716
skip_tests=True,
706717
skip_enhanced_cps=False,
718+
run_id=run_id,
707719
)
708720

709721
# The build_datasets step produces files in its
@@ -732,6 +744,7 @@ def run_pipeline(
732744
branch=branch,
733745
workers=num_workers,
734746
n_clones=n_clones,
747+
run_id=run_id,
735748
)
736749
print(f" Package at: {pkg_path}")
737750

@@ -750,7 +763,7 @@ def run_pipeline(
750763
print("\n[Step 3/5] Fitting calibration weights...")
751764
step_start = time.time()
752765

753-
vol_path = "/pipeline/artifacts/calibration_package.pkl"
766+
vol_path = f"{artifacts_dir_for_run(run_id)}/calibration_package.pkl"
754767
target_cfg = "policyengine_us_data/calibration/target_config.yaml"
755768

756769
# Spawn regional fit
@@ -794,16 +807,17 @@ def run_pipeline(
794807
regional_result = regional_handle.get()
795808
print(" Regional fit complete. Writing to volume...")
796809

797-
# Write regional results to pipeline volume
810+
# Write regional results to pipeline volume (run-scoped)
811+
artifacts_rel = f"artifacts/{run_id}" if run_id else "artifacts"
798812
with pipeline_volume.batch_upload(force=True) as batch:
799813
batch.put_file(
800814
BytesIO(regional_result["weights"]),
801-
"artifacts/calibration_weights.npy",
815+
f"{artifacts_rel}/calibration_weights.npy",
802816
)
803817
if regional_result.get("config"):
804818
batch.put_file(
805819
BytesIO(regional_result["config"]),
806-
"artifacts/unified_run_config.json",
820+
f"{artifacts_rel}/unified_run_config.json",
807821
)
808822

809823
archive_diagnostics(
@@ -822,12 +836,12 @@ def run_pipeline(
822836
with pipeline_volume.batch_upload(force=True) as batch:
823837
batch.put_file(
824838
BytesIO(national_result["weights"]),
825-
"artifacts/national_calibration_weights.npy",
839+
f"{artifacts_rel}/national_calibration_weights.npy",
826840
)
827841
if national_result.get("config"):
828842
batch.put_file(
829843
BytesIO(national_result["config"]),
830-
"artifacts/national_unified_run_config.json",
844+
f"{artifacts_rel}/national_unified_run_config.json",
831845
)
832846

833847
archive_diagnostics(

modal_app/remote_calibration_runner.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,13 @@ def _fit_weights_impl(
156156
log_freq: int = None,
157157
skip_county: bool = True,
158158
workers: int = 8,
159+
artifacts_dir: str = "",
159160
) -> dict:
160161
"""Full pipeline: read data from pipeline volume, build matrix, fit."""
161162
_setup_repo()
162163

163164
pipeline_vol.reload()
164-
artifacts = f"{PIPELINE_MOUNT}/artifacts"
165+
artifacts = artifacts_dir if artifacts_dir else f"{PIPELINE_MOUNT}/artifacts"
165166
db_path = f"{artifacts}/policy_data.db"
166167
dataset_path = f"{artifacts}/source_imputed_stratified_extended_cps.h5"
167168
for label, p in [("database", db_path), ("dataset", dataset_path)]:
@@ -324,12 +325,15 @@ def _build_package_impl(
324325
skip_county: bool = True,
325326
workers: int = 8,
326327
n_clones: int = 430,
328+
run_id: str = "",
327329
) -> str:
328330
"""Read data from pipeline volume, build X matrix, save package."""
329331
_setup_repo()
330332

331333
pipeline_vol.reload()
332334
artifacts = f"{PIPELINE_MOUNT}/artifacts"
335+
if run_id:
336+
artifacts = f"{artifacts}/{run_id}"
333337
db_path = f"{artifacts}/policy_data.db"
334338
dataset_path = f"{artifacts}/source_imputed_stratified_extended_cps.h5"
335339
for label, p in [("database", db_path), ("dataset", dataset_path)]:
@@ -338,7 +342,7 @@ def _build_package_impl(
338342
f"Missing {label} on pipeline volume: {p}. Run data_build first."
339343
)
340344

341-
pkg_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl"
345+
pkg_path = f"{artifacts}/calibration_package.pkl"
342346
script_path = "policyengine_us_data/calibration/unified_calibration.py"
343347
cmd = [
344348
"uv",
@@ -405,13 +409,15 @@ def build_package_remote(
405409
skip_county: bool = True,
406410
workers: int = 8,
407411
n_clones: int = 430,
412+
run_id: str = "",
408413
) -> str:
409414
return _build_package_impl(
410415
branch,
411416
target_config=target_config,
412417
skip_county=skip_county,
413418
workers=workers,
414419
n_clones=n_clones,
420+
run_id=run_id,
415421
)
416422

417423

@@ -421,7 +427,7 @@ def build_package_remote(
421427
volumes={PIPELINE_MOUNT: pipeline_vol},
422428
nonpreemptible=True,
423429
)
424-
def check_volume_package() -> dict:
430+
def check_volume_package(artifacts_dir: str = "") -> dict:
425431
"""Check if a calibration package exists on the volume.
426432
427433
Reads the lightweight JSON sidecar for provenance fields.
@@ -430,8 +436,9 @@ def check_volume_package() -> dict:
430436
import datetime
431437
import json
432438

433-
pkg_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package.pkl"
434-
sidecar_path = f"{PIPELINE_MOUNT}/artifacts/calibration_package_meta.json"
439+
base = artifacts_dir if artifacts_dir else f"{PIPELINE_MOUNT}/artifacts"
440+
pkg_path = f"{base}/calibration_package.pkl"
441+
sidecar_path = f"{base}/calibration_package_meta.json"
435442
if not os.path.exists(pkg_path):
436443
return {"exists": False}
437444

@@ -485,6 +492,7 @@ def fit_weights_t4(
485492
log_freq: int = None,
486493
skip_county: bool = True,
487494
workers: int = 8,
495+
artifacts_dir: str = "",
488496
) -> dict:
489497
return _fit_weights_impl(
490498
branch,
@@ -497,6 +505,7 @@ def fit_weights_t4(
497505
log_freq,
498506
skip_county=skip_county,
499507
workers=workers,
508+
artifacts_dir=artifacts_dir,
500509
)
501510

502511

@@ -520,6 +529,7 @@ def fit_weights_a10(
520529
log_freq: int = None,
521530
skip_county: bool = True,
522531
workers: int = 8,
532+
artifacts_dir: str = "",
523533
) -> dict:
524534
return _fit_weights_impl(
525535
branch,
@@ -532,6 +542,7 @@ def fit_weights_a10(
532542
log_freq,
533543
skip_county=skip_county,
534544
workers=workers,
545+
artifacts_dir=artifacts_dir,
535546
)
536547

537548

@@ -555,6 +566,7 @@ def fit_weights_a100_40(
555566
log_freq: int = None,
556567
skip_county: bool = True,
557568
workers: int = 8,
569+
artifacts_dir: str = "",
558570
) -> dict:
559571
return _fit_weights_impl(
560572
branch,
@@ -567,6 +579,7 @@ def fit_weights_a100_40(
567579
log_freq,
568580
skip_county=skip_county,
569581
workers=workers,
582+
artifacts_dir=artifacts_dir,
570583
)
571584

572585

@@ -590,6 +603,7 @@ def fit_weights_a100_80(
590603
log_freq: int = None,
591604
skip_county: bool = True,
592605
workers: int = 8,
606+
artifacts_dir: str = "",
593607
) -> dict:
594608
return _fit_weights_impl(
595609
branch,
@@ -602,6 +616,7 @@ def fit_weights_a100_80(
602616
log_freq,
603617
skip_county=skip_county,
604618
workers=workers,
619+
artifacts_dir=artifacts_dir,
605620
)
606621

607622

@@ -625,6 +640,7 @@ def fit_weights_h100(
625640
log_freq: int = None,
626641
skip_county: bool = True,
627642
workers: int = 8,
643+
artifacts_dir: str = "",
628644
) -> dict:
629645
return _fit_weights_impl(
630646
branch,
@@ -637,6 +653,7 @@ def fit_weights_h100(
637653
log_freq,
638654
skip_county=skip_county,
639655
workers=workers,
656+
artifacts_dir=artifacts_dir,
640657
)
641658

642659

0 commit comments

Comments
 (0)