Skip to content

Commit 8aab0bf

Browse files
committed
Extract local H5 builder postprocessors
1 parent a64a437 commit 8aab0bf

12 files changed

Lines changed: 1900 additions & 303 deletions

File tree

docs/pipeline_map.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,10 +1167,6 @@ stages:
11671167
label: Modal Worker Container
11681168
node_type: external
11691169
description: 16GB RAM, 1 CPU each, 8-hour timeout
1170-
- id: spm_recalc
1171-
label: SPM Threshold Recalculation
1172-
node_type: process
1173-
description: Local median rents, family composition, tenure type
11741170
- id: takeup_apply
11751171
label: Takeup Re-application
11761172
node_type: process
@@ -1235,9 +1231,6 @@ stages:
12351231
target: geo_derive
12361232
edge_type: data_flow
12371233
- source: geo_derive
1238-
target: spm_recalc
1239-
edge_type: data_flow
1240-
- source: spm_recalc
12411234
target: takeup_apply
12421235
edge_type: data_flow
12431236
- source: modal_coord

policyengine_us_data/build_outputs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,7 @@
66
fingerprinting, clone-weight shape contracts, worker partitioning, source
77
dataset snapshot contracts, worker input normalization, worker-bootstrap
88
artifacts, worker-scoped session and validation context setup, microsimulation
9-
access helpers, clone selection, entity reindexing, and source-variable cloning.
9+
access helpers, clone selection, entity reindexing, source-variable cloning,
10+
validated H5 payload contracts, ordered output postprocessing, one-area payload
11+
building, and H5 writing.
1012
"""
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""One-area local H5 build orchestration."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass, field
6+
from typing import Any, Mapping, Protocol, TypeAlias
7+
8+
import numpy as np
9+
10+
from policyengine_us_data.pipeline_metadata import pipeline_node
11+
12+
from .payload import H5Payload, PayloadBuildContext
13+
from .reindexing import EntityReindexer, ReindexedEntities
14+
from .requests import AreaBuildRequest
15+
from .selection import AreaSelector, CloneSelection
16+
from .source_dataset import SourceDatasetSnapshot
17+
from .variables import VariableCloner
18+
from .weights import CloneWeightMatrix
19+
20+
__all__ = [
21+
"LocalAreaBuildResult",
22+
"LocalAreaDatasetBuilder",
23+
"PayloadPostProcessor",
24+
"PayloadPostProcessorRun",
25+
"PayloadPostProcessorResult",
26+
]
27+
28+
29+
class PayloadPostProcessorResult(Protocol):
30+
"""Result contract for a payload postprocessor."""
31+
32+
payload: H5Payload
33+
34+
35+
PostProcessorReturn: TypeAlias = H5Payload | PayloadPostProcessorResult
36+
37+
38+
class PayloadPostProcessor(Protocol):
39+
"""Country- or product-specific processor for an H5 payload."""
40+
41+
def apply(
42+
self,
43+
*,
44+
payload: H5Payload,
45+
context: PayloadBuildContext,
46+
) -> PostProcessorReturn:
47+
"""Return a processed `H5Payload` or structured result with `.payload`."""
48+
49+
50+
@dataclass(frozen=True)
51+
class PayloadPostProcessorRun:
52+
"""Result metadata for one payload postprocessor invocation."""
53+
54+
name: str
55+
postprocessor_type: type
56+
result: PostProcessorReturn
57+
58+
59+
@pipeline_node(
60+
id="local_h5_build_result",
61+
label="LocalAreaBuildResult",
62+
node_type="library",
63+
description="In-memory local H5 payload and diagnostics for one area.",
64+
source_file="policyengine_us_data/build_outputs/builder.py",
65+
status="current",
66+
stability="moving",
67+
pathways=["local_h5"],
68+
validation_commands=["uv run pytest tests/unit/build_outputs/test_builder.py"],
69+
)
70+
@dataclass(frozen=True)
71+
class LocalAreaBuildResult:
72+
"""In-memory output from building one local H5 area."""
73+
74+
payload: H5Payload
75+
selection: CloneSelection
76+
reindexed: ReindexedEntities
77+
variables_saved: int
78+
summary: Mapping[str, int | float | str]
79+
postprocessor_runs: tuple[PayloadPostProcessorRun, ...] = ()
80+
81+
@property
82+
def data(self) -> Mapping[str, Mapping[Any, np.ndarray]]:
83+
"""Payload data retained for transitional callers."""
84+
85+
return self.payload.data
86+
87+
@property
88+
def time_period(self) -> int:
89+
"""Payload time period retained for transitional callers."""
90+
91+
return self.payload.time_period
92+
93+
def postprocessor_result(self, postprocessor: type | str) -> Any | None:
94+
"""Return the result for one configured postprocessor."""
95+
96+
key = (
97+
postprocessor if isinstance(postprocessor, str) else postprocessor.__name__
98+
)
99+
for run in self.postprocessor_runs:
100+
if run.name == key:
101+
return run.result
102+
if not isinstance(postprocessor, str) and issubclass(
103+
run.postprocessor_type,
104+
postprocessor,
105+
):
106+
return run.result
107+
return None
108+
109+
def postprocessor_results(self, postprocessor: type | str) -> tuple[Any, ...]:
110+
"""Return every result for a configured postprocessor type or name."""
111+
112+
key = (
113+
postprocessor if isinstance(postprocessor, str) else postprocessor.__name__
114+
)
115+
return tuple(
116+
run.result
117+
for run in self.postprocessor_runs
118+
if run.name == key
119+
or (
120+
not isinstance(postprocessor, str)
121+
and issubclass(run.postprocessor_type, postprocessor)
122+
)
123+
)
124+
125+
126+
@pipeline_node(
127+
id="local_h5_dataset_builder",
128+
label="LocalAreaDatasetBuilder",
129+
node_type="library",
130+
description="Build the in-memory payload for one local-area or national H5 output.",
131+
source_file="policyengine_us_data/build_outputs/builder.py",
132+
status="current",
133+
stability="moving",
134+
pathways=["local_h5"],
135+
validation_commands=["uv run pytest tests/unit/build_outputs/test_builder.py"],
136+
)
137+
@dataclass(frozen=True)
138+
class LocalAreaDatasetBuilder:
139+
"""Coordinate clone selection, reindexing, variable cloning, and postprocessing."""
140+
141+
selector: AreaSelector = field(default_factory=AreaSelector)
142+
reindexer: EntityReindexer = field(default_factory=EntityReindexer)
143+
variable_cloner: VariableCloner = field(default_factory=VariableCloner)
144+
postprocessors: tuple[PayloadPostProcessor, ...] = ()
145+
146+
def build(
147+
self,
148+
*,
149+
source: SourceDatasetSnapshot,
150+
simulation: Any,
151+
weights: CloneWeightMatrix,
152+
geography: Any,
153+
request: AreaBuildRequest,
154+
takeup_filter: tuple[str, ...] | None = None,
155+
) -> LocalAreaBuildResult:
156+
"""Build one local H5 payload without writing it to disk."""
157+
158+
selection = self.selector.select(
159+
weights=weights,
160+
geography=geography,
161+
filters=request.filters,
162+
)
163+
reindexed = self.reindexer.reindex(source=source, selection=selection)
164+
payload = self.variable_cloner.clone(
165+
source=source,
166+
selection=selection,
167+
reindexed=reindexed,
168+
)
169+
h5_payload = H5Payload(
170+
data=payload.data,
171+
time_period=int(source.time_period),
172+
entity_lengths=_entity_lengths(reindexed),
173+
)
174+
context = PayloadBuildContext(
175+
source=source,
176+
simulation=simulation,
177+
selection=selection,
178+
reindexed=reindexed,
179+
geography=geography,
180+
time_period=int(source.time_period),
181+
takeup_filter=takeup_filter,
182+
)
183+
postprocessor_runs: list[PayloadPostProcessorRun] = []
184+
for postprocessor in self.postprocessors:
185+
result = postprocessor.apply(payload=h5_payload, context=context)
186+
postprocessor_runs.append(
187+
PayloadPostProcessorRun(
188+
name=type(postprocessor).__name__,
189+
postprocessor_type=type(postprocessor),
190+
result=result,
191+
)
192+
)
193+
h5_payload = _payload_from_postprocessor_result(result)
194+
195+
return LocalAreaBuildResult(
196+
payload=h5_payload,
197+
selection=selection,
198+
reindexed=reindexed,
199+
variables_saved=payload.values_saved,
200+
summary=_build_summary(
201+
request=request,
202+
selection=selection,
203+
reindexed=reindexed,
204+
variables_saved=payload.values_saved,
205+
),
206+
postprocessor_runs=tuple(postprocessor_runs),
207+
)
208+
209+
210+
def _build_summary(
211+
*,
212+
request: AreaBuildRequest,
213+
selection: CloneSelection,
214+
reindexed: ReindexedEntities,
215+
variables_saved: int,
216+
) -> dict[str, int | float | str]:
217+
summary: dict[str, int | float | str] = {
218+
"area_type": request.area_type,
219+
"area_id": request.area_id,
220+
"display_name": request.display_name,
221+
"active_clones": selection.n_selected_clones,
222+
"total_weight": float(np.sum(selection.weights)),
223+
"persons": int(len(reindexed.person_ids)),
224+
"variables_saved": int(variables_saved),
225+
}
226+
for entity_key, entity_source_indices in reindexed.subentity_source_indices.items():
227+
summary[f"{entity_key}s"] = int(len(entity_source_indices))
228+
return summary
229+
230+
231+
def _entity_lengths(reindexed: ReindexedEntities) -> dict[str, int]:
232+
lengths = {
233+
"household": int(len(reindexed.household_ids)),
234+
"person": int(len(reindexed.person_ids)),
235+
}
236+
for entity_key, entity_ids in reindexed.subentity_ids.items():
237+
lengths[entity_key] = int(len(entity_ids))
238+
return lengths
239+
240+
241+
def _payload_from_postprocessor_result(result: PostProcessorReturn) -> H5Payload:
242+
if isinstance(result, H5Payload):
243+
return result
244+
payload = getattr(result, "payload", None)
245+
if isinstance(payload, H5Payload):
246+
return payload
247+
raise TypeError(
248+
"Payload postprocessors must return H5Payload or an object exposing "
249+
"an H5Payload `.payload` attribute"
250+
)

0 commit comments

Comments
 (0)