Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/1100.changed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Stabilize Stage 2 chunked matrix execution request, result, and resume metadata.
35 changes: 35 additions & 0 deletions docs/pipeline_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,8 @@ stages:
- stage2_geography_assignment_result
- stage2_matrix_build_spec
- stage2_matrix_build_service
- stage2_chunk_build_request
- stage2_chunk_worker_result
- stage2_matrix_build_result
- target_resolve
- stage2_target_config_apply
Expand All @@ -838,6 +840,7 @@ stages:
- out_targets
- out_target_facets
- out_geography_summary
- out_chunk_result_manifests
- out_matrix_summary
- stage2_calibration_package_contract_writer
- out_contract
Expand Down Expand Up @@ -911,6 +914,18 @@ stages:
label: matrix_summary.json
node_type: artifact
description: Compact Stage 2 matrix shape, sparsity, target order, builder mode, and chunk lineage summary
- id: out_chunk_result_manifests
label: chunk_results/*.json
node_type: artifact
description: Per-chunk structured progress, cache, and error metadata for chunked matrix execution
- id: stage2_chunk_build_request
label: Chunk Build Request
node_type: library
description: Typed Modal worker request carrying run, chunk, state path, resume flag, and lineage signature material
- id: stage2_chunk_worker_result
label: Chunk Worker Result
node_type: library
description: Typed Modal worker result carrying per-chunk completion, cache, and error metadata
- id: out_contract
label: calibration_package_contract.json
node_type: artifact
Expand Down Expand Up @@ -1024,6 +1039,10 @@ stages:
target: stage2_matrix_build_service
edge_type: data_flow
label: builder mode and chunk settings
- source: stage2_matrix_build_service
target: stage2_chunk_build_request
edge_type: data_flow
label: chunk ids and lineage signature
- source: stage2_target_catalog_load
target: stage2_target_config_apply
edge_type: data_flow
Expand Down Expand Up @@ -1077,6 +1096,22 @@ stages:
target: stage2_matrix_build_result
edge_type: data_flow
label: chunk manifest and shards
- source: stage2_chunk_build_request
target: build_matrix_chunked
edge_type: data_flow
label: typed worker request
- source: build_matrix_chunked
target: stage2_chunk_worker_result
edge_type: data_flow
label: typed worker result
- source: stage2_chunk_worker_result
target: out_chunk_result_manifests
edge_type: produces_artifact
label: per-chunk progress and error metadata
- source: out_chunk_result_manifests
target: stage2_matrix_build_result
edge_type: data_flow
label: chunk execution diagnostics
- source: stage2_matrix_build_service
target: build_matrix
edge_type: uses_library
Expand Down
145 changes: 107 additions & 38 deletions modal_app/matrix_chunk_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,71 +55,140 @@ def _chunk_root(run_id: str) -> str:
nonpreemptible=True,
)
def build_matrix_chunk_worker(
run_id: str,
chunk_ids: List[int],
request: Dict | None = None,
run_id: str | None = None,
chunk_ids: List[int] | None = None,
resume_chunks: bool = False,
) -> Dict:
"""Materialize ``chunk_ids`` from the pickled ``SharedBuildState``.
"""Materialize a typed chunk request from pickled ``SharedBuildState``.

Args:
run_id: Pipeline run identifier; selects the volume path for
this worker's shared state and shard output directory.
chunk_ids: Chunk indices this worker is responsible for.
resume_chunks: Whether to trust matching pre-existing COO shards.
Fresh builds pass ``False`` so workers overwrite stale chunks.
request: Typed worker request material. Legacy ``run_id`` and
``chunk_ids`` arguments are accepted for compatibility with
older local tests and undeployed call sites.
run_id: Legacy pipeline run identifier.
chunk_ids: Legacy chunk indices this worker is responsible for.
resume_chunks: Legacy resume flag.

Returns:
Dict with ``chunk_ids``, ``nnz_per_chunk``, and ``errors``
lists suitable for the coordinator to aggregate.
Structured worker result material suitable for the coordinator to
aggregate.
"""
from policyengine_us_data.calibration.chunked_matrix_assembler import (
ChunkedMatrixAssembler,
)
from policyengine_us_data.calibration.signatures import signature_mismatches
from policyengine_us_data.calibration_package.matrix import (
CHUNK_EXECUTION_SCHEMA_VERSION,
ChunkBuildRequest,
ChunkExecutionResult,
ChunkWorkerResult,
write_chunk_result_manifest,
)

pipeline_vol.reload()
chunk_root = Path(_chunk_root(run_id))
state_path = chunk_root / "chunk_build_state.pkl"
if request is None:
if run_id is None or chunk_ids is None:
raise ValueError("request or legacy run_id/chunk_ids are required")
chunk_root = Path(_chunk_root(run_id))
request_obj = ChunkBuildRequest(
schema_version=CHUNK_EXECUTION_SCHEMA_VERSION,
run_id=run_id,
chunk_ids=tuple(chunk_ids),
chunk_root=str(chunk_root),
state_path=str(chunk_root / "chunk_build_state.pkl"),
resume_chunks=resume_chunks,
lineage_signature={},
)
else:
request_obj = ChunkBuildRequest.from_dict(request)
chunk_root = Path(request_obj.chunk_root)
state_path = Path(request_obj.state_path)
if not state_path.exists():
return {
"chunk_ids": list(chunk_ids),
"nnz_per_chunk": [],
"errors": [
{
"chunk_ids": list(chunk_ids),
"error": f"Missing shared state at {state_path}",
}
],
}
chunk_results = tuple(
ChunkExecutionResult.failure(
run_id=request_obj.run_id,
chunk_id=chunk_id,
error=f"Missing shared state at {state_path}",
)
for chunk_id in request_obj.chunk_ids
)
for result in chunk_results:
write_chunk_result_manifest(chunk_root, result)
pipeline_vol.commit()
return ChunkWorkerResult(
schema_version=CHUNK_EXECUTION_SCHEMA_VERSION,
run_id=request_obj.run_id,
chunk_ids=request_obj.chunk_ids,
chunk_results=chunk_results,
).to_dict()

with open(state_path, "rb") as f:
shared_state = pickle.load(f)
if request_obj.lineage_signature:
state_lineage_signature = getattr(shared_state, "lineage_signature", {})
fatal, _ = signature_mismatches(
state_lineage_signature,
request_obj.lineage_signature,
)
if fatal:
error = "Chunk request lineage mismatch: " + "; ".join(fatal)
chunk_results = tuple(
ChunkExecutionResult.failure(
run_id=request_obj.run_id,
chunk_id=chunk_id,
error=error,
)
for chunk_id in request_obj.chunk_ids
)
for result in chunk_results:
write_chunk_result_manifest(chunk_root, result)
pipeline_vol.commit()
return ChunkWorkerResult(
schema_version=CHUNK_EXECUTION_SCHEMA_VERSION,
run_id=request_obj.run_id,
chunk_ids=request_obj.chunk_ids,
chunk_results=chunk_results,
).to_dict()

assembler = ChunkedMatrixAssembler(
shared_state=shared_state,
chunk_root=chunk_root,
chunk_size=shared_state.chunk_size,
resume=resume_chunks,
resume=request_obj.resume_chunks,
keep_chunks=False,
)

errors: List[Dict] = []
nnz_per_chunk: List[int] = []
for chunk_id in chunk_ids:
chunk_results: List[ChunkExecutionResult] = []
for chunk_id in request_obj.chunk_ids:
try:
result = assembler.run_single_chunk(chunk_id)
nnz_per_chunk.append(result.nnz)
chunk_results.append(
ChunkExecutionResult.from_chunk_result(
run_id=request_obj.run_id,
result=result,
)
)
except Exception as exc:
errors.append(
{
"chunk_id": chunk_id,
"error": str(exc),
"traceback": traceback.format_exc(),
}
traceback_text = traceback.format_exc()
assembler.record_chunk_error(
chunk_id=chunk_id,
error=str(exc),
traceback=traceback_text,
)
chunk_results.append(
ChunkExecutionResult.failure(
run_id=request_obj.run_id,
chunk_id=chunk_id,
error=str(exc),
traceback=traceback_text,
)
)

pipeline_vol.commit()
return {
"chunk_ids": list(chunk_ids),
"nnz_per_chunk": nnz_per_chunk,
"errors": errors,
}
return ChunkWorkerResult(
schema_version=CHUNK_EXECUTION_SCHEMA_VERSION,
run_id=request_obj.run_id,
chunk_ids=request_obj.chunk_ids,
chunk_results=tuple(chunk_results),
).to_dict()
75 changes: 67 additions & 8 deletions policyengine_us_data/calibration/chunked_matrix_assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

import numpy as np
from scipy import sparse

from policyengine_us_data.calibration_package.matrix import (
ChunkExecutionResult,
write_chunk_result_manifest,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -81,6 +86,7 @@ class SharedBuildState:
cd_geoid: np.ndarray
county_fips: np.ndarray
state_fips: np.ndarray
lineage_signature: Dict[str, Any]

@property
def n_total(self) -> int:
Expand Down Expand Up @@ -190,6 +196,26 @@ def stream_csr_from_shards(
return X


def _format_duration(seconds: float) -> str:
seconds = max(0, int(round(seconds)))
hours, remainder = divmod(seconds, 3600)
minutes, seconds = divmod(remainder, 60)
if hours:
return f"{hours}h {minutes:02d}m {seconds:02d}s"
if minutes:
return f"{minutes}m {seconds:02d}s"
return f"{seconds}s"


def _current_rss_mb() -> Optional[float]:
try:
import psutil

return psutil.Process().memory_info().rss / 1024**2
except Exception:
return None


class ChunkedMatrixAssembler:
"""Coordinate partitioning, per-chunk execution, and streaming assembly.

Expand Down Expand Up @@ -267,11 +293,6 @@ def run_chunks(self, chunk_ids: Iterable[int]) -> List[ChunkResult]:
cached_chunks,
)
else:
from policyengine_us_data.calibration.unified_matrix_builder import (
_current_rss_mb,
_format_duration,
)

rss = _current_rss_mb()
rss_part = f", rss={rss:,.0f} MB" if rss is not None else ""
logger.info(
Expand Down Expand Up @@ -322,7 +343,9 @@ def run_single_chunk(self, chunk_id: int) -> ChunkResult:
f"{cached_col_start}-{cached_col_end - 1}, "
f"expected {plan.col_start}-{plan.col_end - 1}"
)
return ChunkResult(chunk_id=chunk_id, nnz=cached_nnz, cached=True)
result = ChunkResult(chunk_id=chunk_id, nnz=cached_nnz, cached=True)
self.write_result_manifest(result)
return result

# Imports are local so the module is import-safe in lightweight
# environments (e.g., cold Modal containers that haven't yet
Expand Down Expand Up @@ -520,7 +543,7 @@ def run_single_chunk(self, chunk_id: int) -> ChunkResult:
if not self.keep_chunks and plan.h5_path.exists():
plan.h5_path.unlink()

return ChunkResult(
result = ChunkResult(
chunk_id=chunk_id,
nnz=int(vals.shape[0]),
cached=False,
Expand All @@ -530,6 +553,42 @@ def run_single_chunk(self, chunk_id: int) -> ChunkResult:
unique_counties=getattr(summary, "unique_counties", None),
unique_cds=getattr(summary, "unique_cds", None),
)
self.write_result_manifest(result)
return result

def write_result_manifest(self, result: ChunkResult) -> Path:
"""Persist structured progress metadata for one chunk."""

return write_chunk_result_manifest(
self.chunk_root,
ChunkExecutionResult.from_chunk_result(
run_id=self._manifest_run_id(),
result=result,
),
)

def record_chunk_error(
self,
*,
chunk_id: int,
error: str,
traceback: str | None = None,
) -> Path:
"""Persist structured error metadata for one chunk."""

return write_chunk_result_manifest(
self.chunk_root,
ChunkExecutionResult.failure(
run_id=self._manifest_run_id(),
chunk_id=chunk_id,
error=error,
traceback=traceback,
),
)

def _manifest_run_id(self) -> str:
lineage_signature = getattr(self.shared_state, "lineage_signature", {})
return str(lineage_signature.get("run_id", ""))

def assemble_final(self) -> sparse.csr_matrix:
"""Stream-assemble the final CSR matrix from all shards on disk."""
Expand Down
Loading