Skip to content

Commit e5d674f

Browse files
baogorekclaude
andcommitted
Pipeline resilience: prevent artifact loss and cascading failures
Never rmtree version directories containing H5 files on fingerprint mismatch — update fingerprint and resume instead. Pin fingerprint in RunMetadata so resumed runs aren't invalidated by branch drift. Add validation pre-flight to catch schema mismatches before spawning workers, and harden error dicts with stderr fallback/truncation and traceback capture. Add CI tests for query-schema compatibility. Closes #652 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 94abbf9 commit e5d674f

4 files changed

Lines changed: 287 additions & 39 deletions

File tree

modal_app/local_area.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import subprocess
1616
import sys
1717
import json
18+
import traceback
1819
import modal
1920
from pathlib import Path
2021
from typing import List, Dict
@@ -244,6 +245,10 @@ def run_phase(
244245
for i, handle in enumerate(handles):
245246
try:
246247
result = handle.get()
248+
if result is None:
249+
all_errors.append({"worker": i, "error": "Worker returned None"})
250+
print(f" Worker {i}: returned None (no results)")
251+
continue
247252
all_results.append(result)
248253
print(
249254
f" Worker {i}: {len(result['completed'])} completed, "
@@ -257,7 +262,7 @@ def run_phase(
257262
all_validation_rows.extend(v_rows)
258263
print(f" Worker {i}: {len(v_rows)} validation rows")
259264
except Exception as e:
260-
all_errors.append({"worker": i, "error": str(e)})
265+
all_errors.append({"worker": i, "error": str(e), "traceback": traceback.format_exc()})
261266
print(f" Worker {i}: CRASHED - {e}")
262267

263268
total_completed = sum(len(r["completed"]) for r in all_results)
@@ -277,7 +282,7 @@ def run_phase(
277282
if all_errors:
278283
print(f"\nErrors ({len(all_errors)}):")
279284
for err in all_errors[:5]:
280-
err_msg = err.get("error", "Unknown")[:100]
285+
err_msg = str(err.get("error") or "Unknown")[:200]
281286
print(f" - {err.get('item', err.get('worker'))}: {err_msg}")
282287
if len(all_errors) > 5:
283288
print(f" ... and {len(all_errors) - 5} more")
@@ -355,15 +360,17 @@ def build_areas_worker(
355360
result = subprocess.run(
356361
worker_cmd,
357362
stdout=subprocess.PIPE,
363+
stderr=subprocess.PIPE,
358364
text=True,
359365
env=os.environ.copy(),
360366
)
361367

362368
if result.returncode != 0:
369+
print(f"Worker stderr:\n{result.stderr}", file=__import__('sys').stderr)
363370
return {
364371
"completed": [],
365372
"failed": [f"{item['type']}:{item['id']}" for item in work_items],
366-
"errors": [{"error": result.stderr}],
373+
"errors": [{"error": (result.stderr or "No stderr")[:2000]}],
367374
}
368375

369376
try:
@@ -621,6 +628,7 @@ def coordinate_publish(
621628
n_clones: int = 430,
622629
validate: bool = True,
623630
run_id: str = "",
631+
expected_fingerprint: str = "",
624632
) -> Dict:
625633
"""Coordinate the full publishing workflow."""
626634
setup_gcp_credentials()
@@ -676,43 +684,77 @@ def coordinate_publish(
676684
}
677685
validate_artifacts(config_json_path, artifacts)
678686

687+
if validate:
688+
try:
689+
from sqlalchemy import create_engine as _create_engine
690+
from policyengine_us_data.calibration.validate_staging import (
691+
_query_all_active_targets,
692+
)
693+
_test_engine = _create_engine(f"sqlite:///{db_path}")
694+
_df = _query_all_active_targets(_test_engine, 2024)
695+
print(f"Validation pre-flight OK: {len(_df)} targets queryable")
696+
_test_engine.dispose()
697+
except Exception as e:
698+
print(f"WARNING: Validation pre-flight failed: {e}")
699+
print("Disabling validation to protect H5 builds")
700+
validate = False
701+
679702
# Fingerprint-based cache invalidation
680-
fp_result = subprocess.run(
681-
[
682-
"uv",
683-
"run",
684-
"python",
685-
"-c",
686-
f"""
703+
if expected_fingerprint:
704+
fingerprint = expected_fingerprint
705+
print(f"Using pinned fingerprint from pipeline: {fingerprint}")
706+
else:
707+
fp_result = subprocess.run(
708+
[
709+
"uv",
710+
"run",
711+
"python",
712+
"-c",
713+
f"""
687714
from policyengine_us_data.calibration.publish_local_area import (
688715
compute_input_fingerprint,
689716
)
690717
print(compute_input_fingerprint("{weights_path}", "{dataset_path}", {n_clones}, seed=42))
691718
""",
692-
],
693-
capture_output=True,
694-
text=True,
695-
env=os.environ.copy(),
696-
)
697-
if fp_result.returncode != 0:
698-
raise RuntimeError(f"Failed to compute fingerprint: {fp_result.stderr}")
699-
fingerprint = fp_result.stdout.strip()
719+
],
720+
capture_output=True,
721+
text=True,
722+
env=os.environ.copy(),
723+
)
724+
if fp_result.returncode != 0:
725+
raise RuntimeError(f"Failed to compute fingerprint: {fp_result.stderr}")
726+
fingerprint = fp_result.stdout.strip()
700727
fingerprint_file = version_dir / "fingerprint.json"
701728
if version_dir.exists():
729+
h5_count = len(list(version_dir.rglob("*.h5")))
702730
if fingerprint_file.exists():
703731
stored = json.loads(fingerprint_file.read_text())
704732
if stored.get("fingerprint") == fingerprint:
705733
print(f"Inputs unchanged ({fingerprint}), resuming...")
706734
else:
735+
if h5_count > 0:
736+
print(
737+
f"WARNING: Inputs changed "
738+
f"({stored.get('fingerprint')} -> {fingerprint}) "
739+
f"but {h5_count} H5 files exist. "
740+
f"Updating fingerprint and resuming."
741+
)
742+
else:
743+
print(
744+
f"Inputs changed "
745+
f"({stored.get('fingerprint')} -> {fingerprint}), "
746+
f"clearing empty directory..."
747+
)
748+
shutil.rmtree(version_dir)
749+
else:
750+
if h5_count > 0:
707751
print(
708-
f"Inputs changed "
709-
f"({stored.get('fingerprint')} -> {fingerprint}), "
710-
f"rebuilding..."
752+
f"WARNING: No fingerprint found but {h5_count} H5 files exist. "
753+
f"Writing fingerprint and resuming."
711754
)
755+
else:
756+
print("No fingerprint found, clearing empty stale directory...")
712757
shutil.rmtree(version_dir)
713-
else:
714-
print("No fingerprint found, clearing stale directory...")
715-
shutil.rmtree(version_dir)
716758
version_dir.mkdir(parents=True, exist_ok=True)
717759
fingerprint_file.write_text(json.dumps({"fingerprint": fingerprint}))
718760
staging_volume.commit()
@@ -834,6 +876,7 @@ def coordinate_publish(
834876
return {
835877
"message": (f"Build complete for version {version}. Upload skipped."),
836878
"validation_rows": accumulated_validation_rows,
879+
"fingerprint": fingerprint,
837880
}
838881

839882
print("\nValidating staging...")
@@ -869,6 +912,7 @@ def coordinate_publish(
869912
"message": result,
870913
"run_id": run_id,
871914
"validation_rows": accumulated_validation_rows,
915+
"fingerprint": fingerprint,
872916
}
873917

874918

modal_app/pipeline.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class RunMetadata:
101101
status: str # running | completed | failed | promoted
102102
step_timings: dict = field(default_factory=dict)
103103
error: Optional[str] = None
104+
resume_history: list = field(default_factory=list)
105+
fingerprint: Optional[str] = None
104106

105107
def to_dict(self) -> dict:
106108
return asdict(self)
@@ -651,14 +653,24 @@ def run_pipeline(
651653
if resume_run_id:
652654
print(f"Resuming run {resume_run_id}...")
653655
meta = read_run_meta(resume_run_id, pipeline_volume)
656+
current_sha = sha
654657
if meta.sha != sha:
655-
raise RuntimeError(
656-
f"Branch {branch} has moved since run "
657-
f"started.\n"
658+
print(
659+
f"WARNING: Branch {branch} has moved since run started.\n"
658660
f" Run SHA: {meta.sha[:12]}\n"
659661
f" Current SHA: {sha[:12]}\n"
660-
f"Start a fresh run instead."
662+
f" Resuming with original run artifacts, current code."
661663
)
664+
sha = meta.sha
665+
version = meta.version
666+
if not hasattr(meta, "resume_history") or meta.resume_history is None:
667+
meta.resume_history = []
668+
meta.resume_history.append({
669+
"resumed_at": datetime.now(timezone.utc).isoformat(),
670+
"code_sha": current_sha,
671+
"original_sha": meta.sha,
672+
"branch": branch,
673+
})
662674
meta.status = "running"
663675
run_id = resume_run_id
664676
else:
@@ -883,6 +895,7 @@ def run_pipeline(
883895
n_clones=n_clones,
884896
validate=True,
885897
run_id=run_id,
898+
expected_fingerprint=meta.fingerprint or "",
886899
)
887900
print(f" → coordinate_publish fc: {regional_h5_handle.object_id}")
888901

@@ -919,6 +932,10 @@ def run_pipeline(
919932
)
920933
print(f" Regional H5: {regional_msg}")
921934

935+
if isinstance(regional_h5_result, dict) and regional_h5_result.get("fingerprint"):
936+
meta.fingerprint = regional_h5_result["fingerprint"]
937+
write_run_meta(meta, pipeline_volume)
938+
922939
national_h5_result = None
923940
if national_h5_handle is not None:
924941
print(" Waiting for national H5 build...")

policyengine_us_data/calibration/publish_local_area.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,36 @@ def validate_or_clear_checkpoints(fingerprint: str):
9999
)
100100
else:
101101
print(f"No checkpoint metadata, starting fresh ({fingerprint})")
102-
for cp in [
103-
CHECKPOINT_FILE,
104-
CHECKPOINT_FILE_DISTRICTS,
105-
CHECKPOINT_FILE_CITIES,
106-
]:
107-
if cp.exists():
108-
cp.unlink()
109-
for subdir in ["states", "districts", "cities"]:
110-
d = WORK_DIR / subdir
111-
if d.exists():
112-
shutil.rmtree(d)
102+
h5_count = sum(
103+
1
104+
for subdir in ["states", "districts", "cities"]
105+
if (WORK_DIR / subdir).exists()
106+
for _ in (WORK_DIR / subdir).rglob("*.h5")
107+
)
108+
if h5_count > 0:
109+
print(
110+
f"WARNING: {h5_count} H5 files exist. "
111+
f"Clearing only checkpoint files, preserving H5s."
112+
)
113+
for cp in [
114+
CHECKPOINT_FILE,
115+
CHECKPOINT_FILE_DISTRICTS,
116+
CHECKPOINT_FILE_CITIES,
117+
]:
118+
if cp.exists():
119+
cp.unlink()
120+
else:
121+
for cp in [
122+
CHECKPOINT_FILE,
123+
CHECKPOINT_FILE_DISTRICTS,
124+
CHECKPOINT_FILE_CITIES,
125+
]:
126+
if cp.exists():
127+
cp.unlink()
128+
for subdir in ["states", "districts", "cities"]:
129+
d = WORK_DIR / subdir
130+
if d.exists():
131+
shutil.rmtree(d)
113132
META_FILE.parent.mkdir(parents=True, exist_ok=True)
114133
META_FILE.write_text(json.dumps({"fingerprint": fingerprint}))
115134

0 commit comments

Comments
 (0)