Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.d/1045.changed
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added scoped Stage 3 fitted-weight input and output bundles.
Bumped policyengine-us to 1.701.1.
9 changes: 9 additions & 0 deletions docs/engineering/stages/fit_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ builds. The public identity boundary lives in `policyengine_us_data.fit_weights`
step manifests for reuse decisions.
- `ScopedFitArtifacts` defines the artifact filenames written by the Modal fit
step and consumed by downstream H5 builders.
- `FittedWeightsInputBundle`, `FitResultBytes`, and
`FittedWeightsOutputBundle` keep Stage 3 package inputs and remote result
bytes typed before they become files.

The current artifact names remain behavior-compatible:

Expand All @@ -21,3 +24,9 @@ The current artifact names remain behavior-compatible:
When changing Stage 3 fitting parameters, artifact names, or scope behavior,
update the central specs first and then adapt Modal callers to consume those
specs. Do not add parallel filename constants in orchestration code.

When changing remote result handling, keep `_collect_outputs(...)` as the
compatibility adapter for subprocess stdout markers and convert its dictionary
shape into `FittedWeightsOutputBundle` before writing artifacts to the pipeline
volume. Fit step manifests should attach diagnostics from the matching output
scope rather than all run diagnostics.
83 changes: 37 additions & 46 deletions modal_app/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import time
import traceback
from datetime import datetime, timezone
from io import BytesIO
from pathlib import Path

import modal
Expand Down Expand Up @@ -110,6 +109,8 @@
from policyengine_us_data.pipeline_schema import PipelineNode # noqa: E402
from policyengine_us_data.fit_weights import ( # noqa: E402
FitScope,
FittedWeightsInputBundle,
FittedWeightsOutputBundle,
NATIONAL_FIT_LAMBDA_L0 as _NATIONAL_FIT_LAMBDA_L0,
fit_artifacts_for_scope,
fitted_weights_spec_for_scope,
Expand Down Expand Up @@ -279,23 +280,30 @@ def archive_diagnostics(
vol: modal.Volume,
prefix: str = "",
scope: FitScope | str | None = None,
) -> None:
) -> list[ArtifactReference]:
"""Archive calibration diagnostics to the run directory."""
diag_dir = Path(RUNS_DIR) / run_id / "diagnostics"
diag_dir.mkdir(parents=True, exist_ok=True)

scope = scope or (FitScope.NATIONAL if prefix == "national_" else FitScope.REGIONAL)
file_map = fit_artifacts_for_scope(scope).diagnostic_result_filenames()
written_paths: list[Path] = []

for key, filename in file_map.items():
data = result_bytes.get(key)
if data:
path = diag_dir / filename
with open(path, "wb") as f:
f.write(data)
written_paths.append(path)
print(f" Archived {filename} ({len(data):,} bytes)")

vol.commit()
return collect_artifacts(
written_paths,
role="diagnostic",
missing_ok=True,
)


# ── Include other Modal apps ─────────────────────────────────────
Expand Down Expand Up @@ -1315,12 +1323,11 @@ def run_pipeline(
print(f" Completed in {completed_package_manifest.duration_s}s")

# ── Step 3: Fit weights (parallel) ──
fit_inputs = _artifact_identities(
{
"calibration_package": _artifacts_dir(run_id)
/ "calibration_package.pkl",
}
regional_fit_input = FittedWeightsInputBundle(
scope=FitScope.REGIONAL,
calibration_package_path=_artifacts_dir(run_id) / "calibration_package.pkl",
)
fit_inputs = _artifact_identities(regional_fit_input.artifact_identity_paths())
regional_fit_spec = fitted_weights_spec_for_scope(FitScope.REGIONAL)
national_fit_spec = fitted_weights_spec_for_scope(FitScope.NATIONAL)
regional_fit_artifacts = fit_artifacts_for_scope(FitScope.REGIONAL)
Expand Down Expand Up @@ -1425,34 +1432,26 @@ def run_pipeline(
# Collect regional results
print(" Waiting for regional fit...")
regional_result = regional_handle.get()
regional_output = FittedWeightsOutputBundle.from_result_bytes(
scope=FitScope.REGIONAL,
result_bytes=regional_result,
run_id=run_id,
)
print(" Regional fit complete. Writing to volume...")

# Write regional results to pipeline volume (run-scoped)
artifacts_rel = f"artifacts/{run_id}" if run_id else "artifacts"
with pipeline_volume.batch_upload(force=True) as batch:
batch.put_file(
BytesIO(regional_result["weights"]),
f"{artifacts_rel}/{regional_fit_artifacts.weights.filename}",
)
if regional_result.get("geography"):
batch.put_file(
BytesIO(regional_result["geography"]),
f"{artifacts_rel}/{regional_fit_artifacts.geography.filename}",
)
if regional_result.get("config"):
batch.put_file(
BytesIO(regional_result["config"]),
f"{artifacts_rel}/{regional_fit_artifacts.run_config.filename}",
)
regional_output.write_artifacts(batch, artifacts_rel)

archive_diagnostics(
regional_diagnostics = archive_diagnostics(
run_id,
regional_result,
regional_output.diagnostic_result_bytes(),
pipeline_volume,
scope=FitScope.REGIONAL,
scope=regional_output.scope,
)
regional_outputs = collect_artifacts(
regional_fit_artifacts.artifact_paths(_artifacts_dir(run_id)),
regional_output.artifact_paths(_artifacts_dir(run_id)),
missing_ok=True,
)
regional_fit_reuse_measurement = ReuseMeasurement(
Expand All @@ -1462,7 +1461,7 @@ def run_pipeline(
_complete_step_manifest(
regional_fit_manifest,
outputs=regional_outputs,
diagnostics=_collect_diagnostics(run_id),
diagnostics=regional_diagnostics,
reuse_decision="computed",
reuse_measurement=regional_fit_reuse_measurement,
vol=pipeline_volume,
Expand All @@ -1473,38 +1472,30 @@ def run_pipeline(
if national_handle is not None:
print(" Waiting for national fit...")
national_result = national_handle.get()
national_output = FittedWeightsOutputBundle.from_result_bytes(
scope=FitScope.NATIONAL,
result_bytes=national_result,
run_id=run_id,
)
print(" National fit complete. Writing to volume...")

with pipeline_volume.batch_upload(force=True) as batch:
batch.put_file(
BytesIO(national_result["weights"]),
f"{artifacts_rel}/{national_fit_artifacts.weights.filename}",
)
if national_result.get("geography"):
batch.put_file(
BytesIO(national_result["geography"]),
f"{artifacts_rel}/{national_fit_artifacts.geography.filename}",
)
if national_result.get("config"):
batch.put_file(
BytesIO(national_result["config"]),
f"{artifacts_rel}/{national_fit_artifacts.run_config.filename}",
)

archive_diagnostics(
national_output.write_artifacts(batch, artifacts_rel)

national_diagnostics = archive_diagnostics(
run_id,
national_result,
national_output.diagnostic_result_bytes(),
pipeline_volume,
scope=FitScope.NATIONAL,
scope=national_output.scope,
)
national_outputs = collect_artifacts(
national_fit_artifacts.artifact_paths(_artifacts_dir(run_id)),
national_output.artifact_paths(_artifacts_dir(run_id)),
missing_ok=True,
)
_complete_step_manifest(
national_fit_manifest,
outputs=national_outputs,
diagnostics=_collect_diagnostics(run_id),
diagnostics=national_diagnostics,
reuse_decision="computed",
reuse_measurement=ReuseMeasurement(
expected_outputs=len(national_outputs),
Expand Down
15 changes: 8 additions & 7 deletions modal_app/remote_calibration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from modal_app.images import gpu_image as image # noqa: E402
from policyengine_us_data.fit_weights import ( # noqa: E402
FitResultBytes,
FitScope,
NATIONAL_FIT_LAMBDA_L0,
fit_artifacts_for_scope,
Expand Down Expand Up @@ -139,13 +140,13 @@ def _collect_outputs(cal_lines):
with open(config_path, "rb") as f:
config_bytes = f.read()

return {
"weights": weights_bytes,
"geography": geography_bytes,
"log": log_bytes,
"cal_log": cal_log_bytes,
"config": config_bytes,
}
return FitResultBytes(
weights=weights_bytes,
geography=geography_bytes,
diagnostics=log_bytes,
epoch_log=cal_log_bytes,
run_config=config_bytes,
).to_result_dict()


def _fit_output_filenames(
Expand Down
12 changes: 12 additions & 0 deletions policyengine_us_data/fit_weights/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
ScopedFitArtifacts,
fit_artifacts_for_scope,
)
from policyengine_us_data.fit_weights.bundles import (
FitResultBytes,
FitWeightsBuildContext,
FittedWeightsInputBundle,
FittedWeightsOutputBundle,
MissingFitWeightsOutputError,
)
from policyengine_us_data.fit_weights.specs import (
FIT_BETA,
FIT_LOG_FREQ,
Expand Down Expand Up @@ -35,8 +42,13 @@
"FitArtifactRole",
"FitArtifactSpec",
"FitHyperparameters",
"FitResultBytes",
"FitScope",
"FitWeightsBuildContext",
"FittedWeightsInputBundle",
"FittedWeightsOutputBundle",
"FittedWeightsSpec",
"MissingFitWeightsOutputError",
"ScopedFitArtifacts",
"fit_artifacts_for_scope",
"fitted_weights_spec_for_scope",
Expand Down
Loading