Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 60 additions & 41 deletions modal_app/local_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
modal run modal_app/local_area.py --branch=main --num-workers=8
"""

import json
import os
import subprocess
import sys
import json
import modal
import traceback
from pathlib import Path
from typing import List, Dict
from typing import Dict, List

import modal

_baked = "/root/policyengine-us-data"
_local = str(Path(__file__).resolve().parent.parent)
Expand All @@ -26,6 +28,7 @@
sys.path.insert(0, _p)

from modal_app.images import cpu_image as image
from modal_app.resilience import reconcile_version_dir_fingerprint

app = modal.App("policyengine-us-data-local-area")

Expand Down Expand Up @@ -244,6 +247,10 @@ def run_phase(
for i, handle in enumerate(handles):
try:
result = handle.get()
if result is None:
all_errors.append({"worker": i, "error": "Worker returned None"})
print(f" Worker {i}: returned None (no results)")
continue
all_results.append(result)
print(
f" Worker {i}: {len(result['completed'])} completed, "
Expand All @@ -257,7 +264,9 @@ def run_phase(
all_validation_rows.extend(v_rows)
print(f" Worker {i}: {len(v_rows)} validation rows")
except Exception as e:
all_errors.append({"worker": i, "error": str(e)})
all_errors.append(
{"worker": i, "error": str(e), "traceback": traceback.format_exc()}
)
print(f" Worker {i}: CRASHED - {e}")

total_completed = sum(len(r["completed"]) for r in all_results)
Expand All @@ -277,7 +286,7 @@ def run_phase(
if all_errors:
print(f"\nErrors ({len(all_errors)}):")
for err in all_errors[:5]:
err_msg = err.get("error", "Unknown")[:100]
err_msg = str(err.get("error") or "Unknown")[:200]
print(f" - {err.get('item', err.get('worker'))}: {err_msg}")
if len(all_errors) > 5:
print(f" ... and {len(all_errors) - 5} more")
Expand Down Expand Up @@ -355,15 +364,17 @@ def build_areas_worker(
result = subprocess.run(
worker_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=os.environ.copy(),
)

if result.returncode != 0:
print(f"Worker stderr:\n{result.stderr}", file=__import__("sys").stderr)
return {
"completed": [],
"failed": [f"{item['type']}:{item['id']}" for item in work_items],
"errors": [{"error": result.stderr}],
"errors": [{"error": (result.stderr or "No stderr")[:2000]}],
}

try:
Expand Down Expand Up @@ -621,6 +632,7 @@ def coordinate_publish(
n_clones: int = 430,
validate: bool = True,
run_id: str = "",
expected_fingerprint: str = "",
) -> Dict:
"""Coordinate the full publishing workflow."""
setup_gcp_credentials()
Expand All @@ -640,8 +652,6 @@ def coordinate_publish(
print(f"Publishing version {version} from branch {branch}")
print(f"Using {num_workers} parallel workers")

import shutil

staging_dir = Path(VOLUME_MOUNT)
version_dir = staging_dir / version

Expand Down Expand Up @@ -676,45 +686,52 @@ def coordinate_publish(
}
validate_artifacts(config_json_path, artifacts)

if validate:
try:
from sqlalchemy import create_engine as _create_engine
from policyengine_us_data.calibration.validate_staging import (
_query_all_active_targets,
)

_test_engine = _create_engine(f"sqlite:///{db_path}")
_df = _query_all_active_targets(_test_engine, 2024)
print(f"Validation pre-flight OK: {len(_df)} targets queryable")
_test_engine.dispose()
except Exception as e:
print(f"WARNING: Validation pre-flight failed: {e}")
print("Disabling validation to protect H5 builds")
validate = False

# Fingerprint-based cache invalidation
fp_result = subprocess.run(
[
"uv",
"run",
"python",
"-c",
f"""
if expected_fingerprint:
fingerprint = expected_fingerprint
print(f"Using pinned fingerprint from pipeline: {fingerprint}")
else:
fp_result = subprocess.run(
[
"uv",
"run",
"python",
"-c",
f"""
from policyengine_us_data.calibration.publish_local_area import (
compute_input_fingerprint,
)
print(compute_input_fingerprint("{weights_path}", "{dataset_path}", {n_clones}, seed=42))
""",
],
capture_output=True,
text=True,
env=os.environ.copy(),
)
if fp_result.returncode != 0:
raise RuntimeError(f"Failed to compute fingerprint: {fp_result.stderr}")
fingerprint = fp_result.stdout.strip()
fingerprint_file = version_dir / "fingerprint.json"
if version_dir.exists():
if fingerprint_file.exists():
stored = json.loads(fingerprint_file.read_text())
if stored.get("fingerprint") == fingerprint:
print(f"Inputs unchanged ({fingerprint}), resuming...")
else:
print(
f"Inputs changed "
f"({stored.get('fingerprint')} -> {fingerprint}), "
f"rebuilding..."
)
shutil.rmtree(version_dir)
else:
print("No fingerprint found, clearing stale directory...")
shutil.rmtree(version_dir)
version_dir.mkdir(parents=True, exist_ok=True)
fingerprint_file.write_text(json.dumps({"fingerprint": fingerprint}))
],
capture_output=True,
text=True,
env=os.environ.copy(),
)
if fp_result.returncode != 0:
raise RuntimeError(f"Failed to compute fingerprint: {fp_result.stderr}")
fingerprint = fp_result.stdout.strip()
reconcile_action = reconcile_version_dir_fingerprint(version_dir, fingerprint)
if reconcile_action == "resume":
print(f"Inputs unchanged ({fingerprint}), resuming...")
else:
print(f"Prepared staging directory for fingerprint {fingerprint}")
staging_volume.commit()
result = subprocess.run(
[
Expand Down Expand Up @@ -834,6 +851,7 @@ def coordinate_publish(
return {
"message": (f"Build complete for version {version}. Upload skipped."),
"validation_rows": accumulated_validation_rows,
"fingerprint": fingerprint,
}

print("\nValidating staging...")
Expand Down Expand Up @@ -869,6 +887,7 @@ def coordinate_publish(
"message": result,
"run_id": run_id,
"validation_rows": accumulated_validation_rows,
"fingerprint": fingerprint,
}


Expand Down
36 changes: 28 additions & 8 deletions modal_app/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
sys.path.insert(0, _p)

from modal_app.images import cpu_image as image
from modal_app.resilience import ensure_resume_sha_compatible

# ── Modal resources ──────────────────────────────────────────────

Expand Down Expand Up @@ -101,6 +102,8 @@ class RunMetadata:
status: str # running | completed | failed | promoted
step_timings: dict = field(default_factory=dict)
error: Optional[str] = None
resume_history: list = field(default_factory=list)
fingerprint: Optional[str] = None

def to_dict(self) -> dict:
return asdict(self)
Expand Down Expand Up @@ -651,14 +654,24 @@ def run_pipeline(
if resume_run_id:
print(f"Resuming run {resume_run_id}...")
meta = read_run_meta(resume_run_id, pipeline_volume)
if meta.sha != sha:
raise RuntimeError(
f"Branch {branch} has moved since run "
f"started.\n"
f" Run SHA: {meta.sha[:12]}\n"
f" Current SHA: {sha[:12]}\n"
f"Start a fresh run instead."
)
current_sha = sha
ensure_resume_sha_compatible(
branch=branch,
run_sha=meta.sha,
current_sha=current_sha,
)
sha = meta.sha
version = meta.version
if not hasattr(meta, "resume_history") or meta.resume_history is None:
meta.resume_history = []
meta.resume_history.append(
{
"resumed_at": datetime.now(timezone.utc).isoformat(),
"code_sha": current_sha,
"original_sha": meta.sha,
"branch": branch,
}
)
meta.status = "running"
run_id = resume_run_id
else:
Expand Down Expand Up @@ -883,6 +896,7 @@ def run_pipeline(
n_clones=n_clones,
validate=True,
run_id=run_id,
expected_fingerprint=meta.fingerprint or "",
)
print(f" → coordinate_publish fc: {regional_h5_handle.object_id}")

Expand Down Expand Up @@ -919,6 +933,12 @@ def run_pipeline(
)
print(f" Regional H5: {regional_msg}")

if isinstance(regional_h5_result, dict) and regional_h5_result.get(
"fingerprint"
):
meta.fingerprint = regional_h5_result["fingerprint"]
write_run_meta(meta, pipeline_volume)

national_h5_result = None
if national_h5_handle is not None:
print(" Waiting for national H5 build...")
Expand Down
72 changes: 71 additions & 1 deletion modal_app/resilience.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Subprocess retry wrapper for network-dependent operations."""
"""Helpers for retry and resume safety in Modal workflows."""

import json
import shutil
import subprocess
import time
from pathlib import Path
from typing import Optional


Expand Down Expand Up @@ -42,3 +45,70 @@ def run_with_retry(
time.sleep(delay)
else:
raise subprocess.CalledProcessError(result.returncode, cmd)


def ensure_resume_sha_compatible(
branch: str,
run_sha: str,
current_sha: str,
) -> None:
"""Require resumed runs to use the same pinned commit.

Modal workers execute the code baked into the current image,
so resuming across branch movement would mix new code with
artifacts and metadata from the old run.
"""
if run_sha != current_sha:
raise RuntimeError(
f"Branch {branch} has moved since run started.\n"
f" Run SHA: {run_sha[:12]}\n"
f" Current SHA: {current_sha[:12]}\n"
f"Start a fresh run instead."
)


def reconcile_version_dir_fingerprint(
version_dir: Path,
fingerprint: str,
) -> str:
"""Prepare a staging version directory for a specific fingerprint.

Safe behavior:
- same fingerprint: resume in place
- changed or missing fingerprint with existing H5s: stop and preserve
- changed or missing fingerprint without H5s: clear stale directory
"""
fingerprint_file = version_dir / "fingerprint.json"

if not version_dir.exists():
version_dir.mkdir(parents=True, exist_ok=True)
fingerprint_file.write_text(json.dumps({"fingerprint": fingerprint}))
return "initialized"

h5_count = len(list(version_dir.rglob("*.h5")))
if fingerprint_file.exists():
stored = json.loads(fingerprint_file.read_text())
stored_fingerprint = stored.get("fingerprint")
if stored_fingerprint == fingerprint:
return "resume"
if h5_count > 0:
raise RuntimeError(
"Fingerprint mismatch with existing staged H5 files.\n"
f" Stored: {stored_fingerprint}\n"
f" Current: {fingerprint}\n"
f" H5 files preserved: {h5_count}\n"
"Start a fresh version or clear the stale outputs explicitly."
)
shutil.rmtree(version_dir)
else:
if h5_count > 0:
raise RuntimeError(
"Missing fingerprint metadata with existing staged H5 files.\n"
f" H5 files preserved: {h5_count}\n"
"Start a fresh version or clear the stale outputs explicitly."
)
shutil.rmtree(version_dir)

version_dir.mkdir(parents=True, exist_ok=True)
fingerprint_file.write_text(json.dumps({"fingerprint": fingerprint}))
return "initialized"
41 changes: 30 additions & 11 deletions policyengine_us_data/calibration/publish_local_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,36 @@ def validate_or_clear_checkpoints(fingerprint: str):
)
else:
print(f"No checkpoint metadata, starting fresh ({fingerprint})")
for cp in [
CHECKPOINT_FILE,
CHECKPOINT_FILE_DISTRICTS,
CHECKPOINT_FILE_CITIES,
]:
if cp.exists():
cp.unlink()
for subdir in ["states", "districts", "cities"]:
d = WORK_DIR / subdir
if d.exists():
shutil.rmtree(d)
h5_count = sum(
1
for subdir in ["states", "districts", "cities"]
if (WORK_DIR / subdir).exists()
for _ in (WORK_DIR / subdir).rglob("*.h5")
)
if h5_count > 0:
print(
f"WARNING: {h5_count} H5 files exist. "
f"Clearing only checkpoint files, preserving H5s."
)
for cp in [
CHECKPOINT_FILE,
CHECKPOINT_FILE_DISTRICTS,
CHECKPOINT_FILE_CITIES,
]:
if cp.exists():
cp.unlink()
else:
for cp in [
CHECKPOINT_FILE,
CHECKPOINT_FILE_DISTRICTS,
CHECKPOINT_FILE_CITIES,
]:
if cp.exists():
cp.unlink()
for subdir in ["states", "districts", "cities"]:
d = WORK_DIR / subdir
if d.exists():
shutil.rmtree(d)
META_FILE.parent.mkdir(parents=True, exist_ok=True)
META_FILE.write_text(json.dumps({"fingerprint": fingerprint}))

Expand Down
Loading
Loading