Skip to content

Commit 275a1d3

Browse files
authored
Merge pull request #1046 from PolicyEngine/agent/stage-3/pr-3b-scoped-input-output-bundles
Add scoped Stage 3 fit result bundles
2 parents 338157e + 8f41e8b commit 275a1d3

10 files changed

Lines changed: 496 additions & 69 deletions

File tree

changelog.d/1045.changed

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Added scoped Stage 3 fitted-weight input and output bundles.
2+
Bumped policyengine-us to 1.701.1.

docs/engineering/stages/fit_weights.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ builds. The public identity boundary lives in `policyengine_us_data.fit_weights`
88
step manifests for reuse decisions.
99
- `ScopedFitArtifacts` defines the artifact filenames written by the Modal fit
1010
step and consumed by downstream H5 builders.
11+
- `FittedWeightsInputBundle`, `FitResultBytes`, and
12+
`FittedWeightsOutputBundle` keep Stage 3 package inputs and remote result
13+
bytes typed before they become files.
1114

1215
The current artifact names remain behavior-compatible:
1316

@@ -21,3 +24,9 @@ The current artifact names remain behavior-compatible:
2124
When changing Stage 3 fitting parameters, artifact names, or scope behavior,
2225
update the central specs first and then adapt Modal callers to consume those
2326
specs. Do not add parallel filename constants in orchestration code.
27+
28+
When changing remote result handling, keep `_collect_outputs(...)` as the
29+
compatibility adapter for subprocess stdout markers and convert its dictionary
30+
shape into `FittedWeightsOutputBundle` before writing artifacts to the pipeline
31+
volume. Fit step manifests should attach diagnostics from the matching output
32+
scope rather than all run diagnostics.

modal_app/pipeline.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import time
4040
import traceback
4141
from datetime import datetime, timezone
42-
from io import BytesIO
4342
from pathlib import Path
4443

4544
import modal
@@ -110,6 +109,8 @@
110109
from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402
111110
from policyengine_us_data.fit_weights import ( # noqa: E402
112111
FitScope,
112+
FittedWeightsInputBundle,
113+
FittedWeightsOutputBundle,
113114
NATIONAL_FIT_LAMBDA_L0 as _NATIONAL_FIT_LAMBDA_L0,
114115
fit_artifacts_for_scope,
115116
fitted_weights_spec_for_scope,
@@ -279,23 +280,30 @@ def archive_diagnostics(
279280
vol: modal.Volume,
280281
prefix: str = "",
281282
scope: FitScope | str | None = None,
282-
) -> None:
283+
) -> list[ArtifactReference]:
283284
"""Archive calibration diagnostics to the run directory."""
284285
diag_dir = Path(RUNS_DIR) / run_id / "diagnostics"
285286
diag_dir.mkdir(parents=True, exist_ok=True)
286287

287288
scope = scope or (FitScope.NATIONAL if prefix == "national_" else FitScope.REGIONAL)
288289
file_map = fit_artifacts_for_scope(scope).diagnostic_result_filenames()
290+
written_paths: list[Path] = []
289291

290292
for key, filename in file_map.items():
291293
data = result_bytes.get(key)
292294
if data:
293295
path = diag_dir / filename
294296
with open(path, "wb") as f:
295297
f.write(data)
298+
written_paths.append(path)
296299
print(f" Archived {filename} ({len(data):,} bytes)")
297300

298301
vol.commit()
302+
return collect_artifacts(
303+
written_paths,
304+
role="diagnostic",
305+
missing_ok=True,
306+
)
299307

300308

301309
# ── Include other Modal apps ─────────────────────────────────────
@@ -1315,12 +1323,11 @@ def run_pipeline(
13151323
print(f" Completed in {completed_package_manifest.duration_s}s")
13161324

13171325
# ── Step 3: Fit weights (parallel) ──
1318-
fit_inputs = _artifact_identities(
1319-
{
1320-
"calibration_package": _artifacts_dir(run_id)
1321-
/ "calibration_package.pkl",
1322-
}
1326+
regional_fit_input = FittedWeightsInputBundle(
1327+
scope=FitScope.REGIONAL,
1328+
calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl",
13231329
)
1330+
fit_inputs = _artifact_identities(regional_fit_input.artifact_identity_paths())
13241331
regional_fit_spec = fitted_weights_spec_for_scope(FitScope.REGIONAL)
13251332
national_fit_spec = fitted_weights_spec_for_scope(FitScope.NATIONAL)
13261333
regional_fit_artifacts = fit_artifacts_for_scope(FitScope.REGIONAL)
@@ -1425,34 +1432,26 @@ def run_pipeline(
14251432
# Collect regional results
14261433
print(" Waiting for regional fit...")
14271434
regional_result = regional_handle.get()
1435+
regional_output = FittedWeightsOutputBundle.from_result_bytes(
1436+
scope=FitScope.REGIONAL,
1437+
result_bytes=regional_result,
1438+
run_id=run_id,
1439+
)
14281440
print(" Regional fit complete. Writing to volume...")
14291441

14301442
# Write regional results to pipeline volume (run-scoped)
14311443
artifacts_rel = f"artifacts/{run_id}" if run_id else "artifacts"
14321444
with pipeline_volume.batch_upload(force=True) as batch:
1433-
batch.put_file(
1434-
BytesIO(regional_result["weights"]),
1435-
f"{artifacts_rel}/{regional_fit_artifacts.weights.filename}",
1436-
)
1437-
if regional_result.get("geography"):
1438-
batch.put_file(
1439-
BytesIO(regional_result["geography"]),
1440-
f"{artifacts_rel}/{regional_fit_artifacts.geography.filename}",
1441-
)
1442-
if regional_result.get("config"):
1443-
batch.put_file(
1444-
BytesIO(regional_result["config"]),
1445-
f"{artifacts_rel}/{regional_fit_artifacts.run_config.filename}",
1446-
)
1445+
regional_output.write_artifacts(batch, artifacts_rel)
14471446

1448-
archive_diagnostics(
1447+
regional_diagnostics = archive_diagnostics(
14491448
run_id,
1450-
regional_result,
1449+
regional_output.diagnostic_result_bytes(),
14511450
pipeline_volume,
1452-
scope=FitScope.REGIONAL,
1451+
scope=regional_output.scope,
14531452
)
14541453
regional_outputs = collect_artifacts(
1455-
regional_fit_artifacts.artifact_paths(_artifacts_dir(run_id)),
1454+
regional_output.artifact_paths(_artifacts_dir(run_id)),
14561455
missing_ok=True,
14571456
)
14581457
regional_fit_reuse_measurement = ReuseMeasurement(
@@ -1462,7 +1461,7 @@ def run_pipeline(
14621461
_complete_step_manifest(
14631462
regional_fit_manifest,
14641463
outputs=regional_outputs,
1465-
diagnostics=_collect_diagnostics(run_id),
1464+
diagnostics=regional_diagnostics,
14661465
reuse_decision="computed",
14671466
reuse_measurement=regional_fit_reuse_measurement,
14681467
vol=pipeline_volume,
@@ -1473,38 +1472,30 @@ def run_pipeline(
14731472
if national_handle is not None:
14741473
print(" Waiting for national fit...")
14751474
national_result = national_handle.get()
1475+
national_output = FittedWeightsOutputBundle.from_result_bytes(
1476+
scope=FitScope.NATIONAL,
1477+
result_bytes=national_result,
1478+
run_id=run_id,
1479+
)
14761480
print(" National fit complete. Writing to volume...")
14771481

14781482
with pipeline_volume.batch_upload(force=True) as batch:
1479-
batch.put_file(
1480-
BytesIO(national_result["weights"]),
1481-
f"{artifacts_rel}/{national_fit_artifacts.weights.filename}",
1482-
)
1483-
if national_result.get("geography"):
1484-
batch.put_file(
1485-
BytesIO(national_result["geography"]),
1486-
f"{artifacts_rel}/{national_fit_artifacts.geography.filename}",
1487-
)
1488-
if national_result.get("config"):
1489-
batch.put_file(
1490-
BytesIO(national_result["config"]),
1491-
f"{artifacts_rel}/{national_fit_artifacts.run_config.filename}",
1492-
)
1493-
1494-
archive_diagnostics(
1483+
national_output.write_artifacts(batch, artifacts_rel)
1484+
1485+
national_diagnostics = archive_diagnostics(
14951486
run_id,
1496-
national_result,
1487+
national_output.diagnostic_result_bytes(),
14971488
pipeline_volume,
1498-
scope=FitScope.NATIONAL,
1489+
scope=national_output.scope,
14991490
)
15001491
national_outputs = collect_artifacts(
1501-
national_fit_artifacts.artifact_paths(_artifacts_dir(run_id)),
1492+
national_output.artifact_paths(_artifacts_dir(run_id)),
15021493
missing_ok=True,
15031494
)
15041495
_complete_step_manifest(
15051496
national_fit_manifest,
15061497
outputs=national_outputs,
1507-
diagnostics=_collect_diagnostics(run_id),
1498+
diagnostics=national_diagnostics,
15081499
reuse_decision="computed",
15091500
reuse_measurement=ReuseMeasurement(
15101501
expected_outputs=len(national_outputs),

modal_app/remote_calibration_runner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from modal_app.images import gpu_image as image # noqa: E402
1515
from policyengine_us_data.fit_weights import ( # noqa: E402
16+
FitResultBytes,
1617
FitScope,
1718
NATIONAL_FIT_LAMBDA_L0,
1819
fit_artifacts_for_scope,
@@ -139,13 +140,13 @@ def _collect_outputs(cal_lines):
139140
with open(config_path, "rb") as f:
140141
config_bytes = f.read()
141142

142-
return {
143-
"weights": weights_bytes,
144-
"geography": geography_bytes,
145-
"log": log_bytes,
146-
"cal_log": cal_log_bytes,
147-
"config": config_bytes,
148-
}
143+
return FitResultBytes(
144+
weights=weights_bytes,
145+
geography=geography_bytes,
146+
diagnostics=log_bytes,
147+
epoch_log=cal_log_bytes,
148+
run_config=config_bytes,
149+
).to_result_dict()
149150

150151

151152
def _fit_output_filenames(

policyengine_us_data/fit_weights/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
ScopedFitArtifacts,
88
fit_artifacts_for_scope,
99
)
10+
from policyengine_us_data.fit_weights.bundles import (
11+
FitResultBytes,
12+
FitWeightsBuildContext,
13+
FittedWeightsInputBundle,
14+
FittedWeightsOutputBundle,
15+
MissingFitWeightsOutputError,
16+
)
1017
from policyengine_us_data.fit_weights.specs import (
1118
FIT_BETA,
1219
FIT_LOG_FREQ,
@@ -35,8 +42,13 @@
3542
"FitArtifactRole",
3643
"FitArtifactSpec",
3744
"FitHyperparameters",
45+
"FitResultBytes",
3846
"FitScope",
47+
"FitWeightsBuildContext",
48+
"FittedWeightsInputBundle",
49+
"FittedWeightsOutputBundle",
3950
"FittedWeightsSpec",
51+
"MissingFitWeightsOutputError",
4052
"ScopedFitArtifacts",
4153
"fit_artifacts_for_scope",
4254
"fitted_weights_spec_for_scope",

0 commit comments

Comments
 (0)