3939import time
4040import traceback
4141from datetime import datetime , timezone
42- from io import BytesIO
4342from pathlib import Path
4443
4544import modal
110109from policyengine_us_data .pipeline_schema import PipelineNode # noqa: E402
111110from policyengine_us_data .fit_weights import ( # noqa: E402
112111 FitScope ,
112+ FittedWeightsInputBundle ,
113+ FittedWeightsOutputBundle ,
113114 NATIONAL_FIT_LAMBDA_L0 as _NATIONAL_FIT_LAMBDA_L0 ,
114115 fit_artifacts_for_scope ,
115116 fitted_weights_spec_for_scope ,
@@ -279,23 +280,30 @@ def archive_diagnostics(
279280 vol : modal .Volume ,
280281 prefix : str = "" ,
281282 scope : FitScope | str | None = None ,
282- ) -> None :
283+ ) -> list [ ArtifactReference ] :
283284 """Archive calibration diagnostics to the run directory."""
284285 diag_dir = Path (RUNS_DIR ) / run_id / "diagnostics"
285286 diag_dir .mkdir (parents = True , exist_ok = True )
286287
287288 scope = scope or (FitScope .NATIONAL if prefix == "national_" else FitScope .REGIONAL )
288289 file_map = fit_artifacts_for_scope (scope ).diagnostic_result_filenames ()
290+ written_paths : list [Path ] = []
289291
290292 for key , filename in file_map .items ():
291293 data = result_bytes .get (key )
292294 if data :
293295 path = diag_dir / filename
294296 with open (path , "wb" ) as f :
295297 f .write (data )
298+ written_paths .append (path )
296299 print (f" Archived { filename } ({ len (data ):,} bytes)" )
297300
298301 vol .commit ()
302+ return collect_artifacts (
303+ written_paths ,
304+ role = "diagnostic" ,
305+ missing_ok = True ,
306+ )
299307
300308
301309# ── Include other Modal apps ─────────────────────────────────────
@@ -1315,12 +1323,11 @@ def run_pipeline(
13151323 print (f" Completed in { completed_package_manifest .duration_s } s" )
13161324
13171325 # ── Step 3: Fit weights (parallel) ──
1318- fit_inputs = _artifact_identities (
1319- {
1320- "calibration_package" : _artifacts_dir (run_id )
1321- / "calibration_package.pkl" ,
1322- }
1326+ regional_fit_input = FittedWeightsInputBundle (
1327+ scope = FitScope .REGIONAL ,
1328+ calibration_package_path = _artifacts_dir (run_id ) / "calibration_package.pkl" ,
13231329 )
1330+ fit_inputs = _artifact_identities (regional_fit_input .artifact_identity_paths ())
13241331 regional_fit_spec = fitted_weights_spec_for_scope (FitScope .REGIONAL )
13251332 national_fit_spec = fitted_weights_spec_for_scope (FitScope .NATIONAL )
13261333 regional_fit_artifacts = fit_artifacts_for_scope (FitScope .REGIONAL )
@@ -1425,34 +1432,26 @@ def run_pipeline(
14251432 # Collect regional results
14261433 print (" Waiting for regional fit..." )
14271434 regional_result = regional_handle .get ()
1435+ regional_output = FittedWeightsOutputBundle .from_result_bytes (
1436+ scope = FitScope .REGIONAL ,
1437+ result_bytes = regional_result ,
1438+ run_id = run_id ,
1439+ )
14281440 print (" Regional fit complete. Writing to volume..." )
14291441
14301442 # Write regional results to pipeline volume (run-scoped)
14311443 artifacts_rel = f"artifacts/{ run_id } " if run_id else "artifacts"
14321444 with pipeline_volume .batch_upload (force = True ) as batch :
1433- batch .put_file (
1434- BytesIO (regional_result ["weights" ]),
1435- f"{ artifacts_rel } /{ regional_fit_artifacts .weights .filename } " ,
1436- )
1437- if regional_result .get ("geography" ):
1438- batch .put_file (
1439- BytesIO (regional_result ["geography" ]),
1440- f"{ artifacts_rel } /{ regional_fit_artifacts .geography .filename } " ,
1441- )
1442- if regional_result .get ("config" ):
1443- batch .put_file (
1444- BytesIO (regional_result ["config" ]),
1445- f"{ artifacts_rel } /{ regional_fit_artifacts .run_config .filename } " ,
1446- )
1445+ regional_output .write_artifacts (batch , artifacts_rel )
14471446
1448- archive_diagnostics (
1447+ regional_diagnostics = archive_diagnostics (
14491448 run_id ,
1450- regional_result ,
1449+ regional_output . diagnostic_result_bytes () ,
14511450 pipeline_volume ,
1452- scope = FitScope . REGIONAL ,
1451+ scope = regional_output . scope ,
14531452 )
14541453 regional_outputs = collect_artifacts (
1455- regional_fit_artifacts .artifact_paths (_artifacts_dir (run_id )),
1454+ regional_output .artifact_paths (_artifacts_dir (run_id )),
14561455 missing_ok = True ,
14571456 )
14581457 regional_fit_reuse_measurement = ReuseMeasurement (
@@ -1462,7 +1461,7 @@ def run_pipeline(
14621461 _complete_step_manifest (
14631462 regional_fit_manifest ,
14641463 outputs = regional_outputs ,
1465- diagnostics = _collect_diagnostics ( run_id ) ,
1464+ diagnostics = regional_diagnostics ,
14661465 reuse_decision = "computed" ,
14671466 reuse_measurement = regional_fit_reuse_measurement ,
14681467 vol = pipeline_volume ,
@@ -1473,38 +1472,30 @@ def run_pipeline(
14731472 if national_handle is not None :
14741473 print (" Waiting for national fit..." )
14751474 national_result = national_handle .get ()
1475+ national_output = FittedWeightsOutputBundle .from_result_bytes (
1476+ scope = FitScope .NATIONAL ,
1477+ result_bytes = national_result ,
1478+ run_id = run_id ,
1479+ )
14761480 print (" National fit complete. Writing to volume..." )
14771481
14781482 with pipeline_volume .batch_upload (force = True ) as batch :
1479- batch .put_file (
1480- BytesIO (national_result ["weights" ]),
1481- f"{ artifacts_rel } /{ national_fit_artifacts .weights .filename } " ,
1482- )
1483- if national_result .get ("geography" ):
1484- batch .put_file (
1485- BytesIO (national_result ["geography" ]),
1486- f"{ artifacts_rel } /{ national_fit_artifacts .geography .filename } " ,
1487- )
1488- if national_result .get ("config" ):
1489- batch .put_file (
1490- BytesIO (national_result ["config" ]),
1491- f"{ artifacts_rel } /{ national_fit_artifacts .run_config .filename } " ,
1492- )
1493-
1494- archive_diagnostics (
1483+ national_output .write_artifacts (batch , artifacts_rel )
1484+
1485+ national_diagnostics = archive_diagnostics (
14951486 run_id ,
1496- national_result ,
1487+ national_output . diagnostic_result_bytes () ,
14971488 pipeline_volume ,
1498- scope = FitScope . NATIONAL ,
1489+ scope = national_output . scope ,
14991490 )
15001491 national_outputs = collect_artifacts (
1501- national_fit_artifacts .artifact_paths (_artifacts_dir (run_id )),
1492+ national_output .artifact_paths (_artifacts_dir (run_id )),
15021493 missing_ok = True ,
15031494 )
15041495 _complete_step_manifest (
15051496 national_fit_manifest ,
15061497 outputs = national_outputs ,
1507- diagnostics = _collect_diagnostics ( run_id ) ,
1498+ diagnostics = national_diagnostics ,
15081499 reuse_decision = "computed" ,
15091500 reuse_measurement = ReuseMeasurement (
15101501 expected_outputs = len (national_outputs ),
0 commit comments