Skip to content

Commit 09b2e92

Browse files
committed
Declare local H5 postprocessor dependencies
1 parent 8aab0bf commit 09b2e92

4 files changed

Lines changed: 164 additions & 3 deletions

File tree

policyengine_us_data/build_outputs/builder.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"LocalAreaDatasetBuilder",
2323
"PayloadPostProcessor",
2424
"PayloadPostProcessorRun",
25+
"PayloadPostProcessorSpec",
2526
"PayloadPostProcessorResult",
2627
]
2728

@@ -38,6 +39,8 @@ class PayloadPostProcessorResult(Protocol):
3839
class PayloadPostProcessor(Protocol):
3940
"""Country- or product-specific processor for an H5 payload."""
4041

42+
spec: PayloadPostProcessorSpec
43+
4144
def apply(
4245
self,
4346
*,
@@ -47,10 +50,19 @@ def apply(
4750
"""Return a processed `H5Payload` or structured result with `.payload`."""
4851

4952

53+
@dataclass(frozen=True)
54+
class PayloadPostProcessorSpec:
55+
"""Stable identity and ordering requirements for a payload postprocessor."""
56+
57+
key: str
58+
requires: tuple[str, ...] = ()
59+
60+
5061
@dataclass(frozen=True)
5162
class PayloadPostProcessorRun:
5263
"""Result metadata for one payload postprocessor invocation."""
5364

65+
key: str
5466
name: str
5567
postprocessor_type: type
5668
result: PostProcessorReturn
@@ -97,7 +109,7 @@ def postprocessor_result(self, postprocessor: type | str) -> Any | None:
97109
postprocessor if isinstance(postprocessor, str) else postprocessor.__name__
98110
)
99111
for run in self.postprocessor_runs:
100-
if run.name == key:
112+
if run.key == key or run.name == key:
101113
return run.result
102114
if not isinstance(postprocessor, str) and issubclass(
103115
run.postprocessor_type,
@@ -115,7 +127,8 @@ def postprocessor_results(self, postprocessor: type | str) -> tuple[Any, ...]:
115127
return tuple(
116128
run.result
117129
for run in self.postprocessor_runs
118-
if run.name == key
130+
if run.key == key
131+
or run.name == key
119132
or (
120133
not isinstance(postprocessor, str)
121134
and issubclass(run.postprocessor_type, postprocessor)
@@ -143,6 +156,9 @@ class LocalAreaDatasetBuilder:
143156
variable_cloner: VariableCloner = field(default_factory=VariableCloner)
144157
postprocessors: tuple[PayloadPostProcessor, ...] = ()
145158

159+
def __post_init__(self) -> None:
160+
_validate_postprocessor_order(self.postprocessors)
161+
146162
def build(
147163
self,
148164
*,
@@ -185,6 +201,7 @@ def build(
185201
result = postprocessor.apply(payload=h5_payload, context=context)
186202
postprocessor_runs.append(
187203
PayloadPostProcessorRun(
204+
key=_postprocessor_spec(postprocessor).key,
188205
name=type(postprocessor).__name__,
189206
postprocessor_type=type(postprocessor),
190207
result=result,
@@ -248,3 +265,31 @@ def _payload_from_postprocessor_result(result: PostProcessorReturn) -> H5Payload
248265
"Payload postprocessors must return H5Payload or an object exposing "
249266
"an H5Payload `.payload` attribute"
250267
)
268+
269+
270+
def _validate_postprocessor_order(
271+
postprocessors: tuple[PayloadPostProcessor, ...],
272+
) -> None:
273+
seen: set[str] = set()
274+
for postprocessor in postprocessors:
275+
spec = _postprocessor_spec(postprocessor)
276+
if spec.key in seen:
277+
raise ValueError(f"Duplicate payload postprocessor key: {spec.key}")
278+
missing = tuple(
279+
requirement for requirement in spec.requires if requirement not in seen
280+
)
281+
if missing:
282+
raise ValueError(
283+
f"{type(postprocessor).__name__} requires postprocessor(s) "
284+
f"to run first: {', '.join(missing)}"
285+
)
286+
seen.add(spec.key)
287+
288+
289+
def _postprocessor_spec(
290+
postprocessor: PayloadPostProcessor,
291+
) -> PayloadPostProcessorSpec:
292+
spec = getattr(postprocessor, "spec", None)
293+
if isinstance(spec, PayloadPostProcessorSpec):
294+
return spec
295+
return PayloadPostProcessorSpec(key=type(postprocessor).__name__)

policyengine_us_data/build_outputs/us_augmentations.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818
reported_subsidized_marketplace_by_tax_unit,
1919
)
2020

21+
from .builder import PayloadPostProcessorSpec
2122
from .payload import H5Payload, PayloadBuildContext
2223
from .selection import CloneSelection
2324
from .simulation_access import calculate_variable_values
2425
from .variables import GEOGRAPHY_VARIABLES
2526

2627
__all__ = [
2728
"TAKEUP_VARIABLE_ENTITIES",
29+
"US_ENTITY_POSTPROCESSOR_KEY",
30+
"US_GEOGRAPHY_POSTPROCESSOR_KEY",
31+
"US_TAKEUP_POSTPROCESSOR_KEY",
2832
"USEntityPostProcessor",
2933
"USEntityPostProcessorResult",
3034
"USGeographyPostProcessor",
@@ -42,6 +46,9 @@
4246
str(spec["variable"]): str(spec["entity"]) for spec in SIMPLE_TAKEUP_VARS
4347
}
4448
REQUIRED_TAKEUP_SUBENTITIES = ("tax_unit", "spm_unit")
49+
US_ENTITY_POSTPROCESSOR_KEY = "us_entity"
50+
US_GEOGRAPHY_POSTPROCESSOR_KEY = "us_geography"
51+
US_TAKEUP_POSTPROCESSOR_KEY = "us_takeup"
4552

4653

4754
@pipeline_node(
@@ -141,6 +148,8 @@ def data(self) -> PayloadData:
141148
class USEntityPostProcessor:
142149
"""Apply US entity IDs and calibrated household weights."""
143150

151+
spec = PayloadPostProcessorSpec(key=US_ENTITY_POSTPROCESSOR_KEY)
152+
144153
def apply(
145154
self,
146155
*,
@@ -192,6 +201,8 @@ def apply(
192201
class USGeographyPostProcessor:
193202
"""Apply block-derived US geography overrides."""
194203

204+
spec = PayloadPostProcessorSpec(key=US_GEOGRAPHY_POSTPROCESSOR_KEY)
205+
195206
geography_deriver: GeographyDeriver = derive_geography_from_blocks
196207
_string_geography_variables: tuple[str, ...] = field(
197208
default=GEOGRAPHY_VARIABLES,
@@ -295,6 +306,10 @@ def _apply_los_angeles_zip_patch(
295306
class USTakeupPostProcessor:
296307
"""Apply US take-up draws after entity and geography postprocessing."""
297308

309+
spec = PayloadPostProcessorSpec(
310+
key=US_TAKEUP_POSTPROCESSOR_KEY,
311+
requires=(US_ENTITY_POSTPROCESSOR_KEY, US_GEOGRAPHY_POSTPROCESSOR_KEY),
312+
)
298313
takeup_applier: TakeupApplier = apply_block_takeup_to_arrays
299314
sum_person_values_to_tax_units: Callable[
300315
[np.ndarray, np.ndarray, np.ndarray],
@@ -312,6 +327,7 @@ def apply(
312327
self._validate_required_subentities(context)
313328
output = _copy_payload(payload.data)
314329
time_period = context.time_period
330+
self._validate_required_payload_fields(output, time_period)
315331
results = self._build_takeup_results(output, context)
316332
takeup_variables = tuple(str(variable) for variable in results)
317333
self._validate_takeup_variables(takeup_variables)
@@ -349,6 +365,32 @@ def _validate_required_subentities(self, context: PayloadBuildContext) -> None:
349365
f"US take-up requires reindexed subentities: {', '.join(missing)}"
350366
)
351367

368+
def _validate_required_payload_fields(
369+
self,
370+
data: PayloadData,
371+
time_period: int,
372+
) -> None:
373+
_required_period_array(
374+
data,
375+
"state_fips",
376+
time_period,
377+
"US take-up requires state_fips from USGeographyPostProcessor",
378+
)
379+
if _has_period_array(
380+
data,
381+
"reported_has_subsidized_marketplace_health_coverage_at_interview",
382+
time_period,
383+
):
384+
for variable in ("person_tax_unit_id", "tax_unit_id"):
385+
_required_period_array(
386+
data,
387+
variable,
388+
time_period,
389+
"US take-up reported ACA anchors require "
390+
"person_tax_unit_id and tax_unit_id from "
391+
"USEntityPostProcessor",
392+
)
393+
352394
def _build_takeup_results(
353395
self,
354396
data: PayloadData,
@@ -481,6 +523,14 @@ def _required_period_array(
481523
return np.asarray(data[variable][time_period])
482524

483525

526+
def _has_period_array(
527+
data: Mapping[str, Mapping[Any, np.ndarray]],
528+
variable: str,
529+
time_period: int,
530+
) -> bool:
531+
return variable in data and time_period in data[variable]
532+
533+
484534
def _build_reported_takeup_anchors(
485535
data: Mapping[str, Mapping[Any, np.ndarray]],
486536
time_period: int,

tests/unit/build_outputs/test_builder.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from types import SimpleNamespace
33

44
import numpy as np
5+
import pytest
56

67
from policyengine_us_data.build_outputs.builder import (
78
LocalAreaBuildResult,
89
LocalAreaDatasetBuilder,
910
PayloadPostProcessorRun,
11+
PayloadPostProcessorSpec,
1012
)
1113
from policyengine_us_data.build_outputs.payload import H5Payload
1214
from policyengine_us_data.build_outputs.reindexing import ReindexedEntities
@@ -181,11 +183,13 @@ class _PostProcessor:
181183
summary={},
182184
postprocessor_runs=(
183185
PayloadPostProcessorRun(
186+
key="first",
184187
name="_PostProcessor",
185188
postprocessor_type=_PostProcessor,
186189
result=first,
187190
),
188191
PayloadPostProcessorRun(
192+
key="second",
189193
name="_PostProcessor",
190194
postprocessor_type=_PostProcessor,
191195
result=second,
@@ -194,9 +198,36 @@ class _PostProcessor:
194198
)
195199

196200
assert result.postprocessor_result(_PostProcessor) is first
201+
assert result.postprocessor_result("second") is second
197202
assert result.postprocessor_results(_PostProcessor) == (first, second)
198203

199204

205+
def test_local_area_dataset_builder_rejects_missing_postprocessor_dependency():
206+
class _DependentPostProcessor:
207+
spec = PayloadPostProcessorSpec(
208+
key="dependent",
209+
requires=("upstream",),
210+
)
211+
212+
with pytest.raises(
213+
ValueError,
214+
match="_DependentPostProcessor requires postprocessor\\(s\\) "
215+
"to run first: upstream",
216+
):
217+
LocalAreaDatasetBuilder(postprocessors=(_DependentPostProcessor(),))
218+
219+
220+
def test_local_area_dataset_builder_rejects_duplicate_postprocessor_keys():
221+
class _PostProcessor:
222+
spec = PayloadPostProcessorSpec(key="duplicate")
223+
224+
with pytest.raises(
225+
ValueError,
226+
match="Duplicate payload postprocessor key: duplicate",
227+
):
228+
LocalAreaDatasetBuilder(postprocessors=(_PostProcessor(), _PostProcessor()))
229+
230+
200231
def _payload(label: str) -> H5Payload:
201232
return H5Payload(
202233
data={label: {2024: np.array([1])}},

tests/unit/build_outputs/test_us_augmentations.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
SourceDatasetSnapshot,
1414
)
1515
from policyengine_us_data.build_outputs.us_augmentations import (
16+
US_ENTITY_POSTPROCESSOR_KEY,
17+
US_GEOGRAPHY_POSTPROCESSOR_KEY,
18+
US_TAKEUP_POSTPROCESSOR_KEY,
1619
USEntityPostProcessor,
1720
USGeographyPostProcessor,
1821
USTakeupPostProcessor,
@@ -127,11 +130,22 @@ def _geography_payload(context=None) -> H5Payload:
127130

128131

129132
def test_default_us_postprocessors_are_in_runtime_order():
130-
assert tuple(type(processor) for processor in default_us_postprocessors()) == (
133+
postprocessors = default_us_postprocessors()
134+
135+
assert tuple(type(processor) for processor in postprocessors) == (
131136
USEntityPostProcessor,
132137
USGeographyPostProcessor,
133138
USTakeupPostProcessor,
134139
)
140+
assert tuple(processor.spec.key for processor in postprocessors) == (
141+
US_ENTITY_POSTPROCESSOR_KEY,
142+
US_GEOGRAPHY_POSTPROCESSOR_KEY,
143+
US_TAKEUP_POSTPROCESSOR_KEY,
144+
)
145+
seen = set()
146+
for processor in postprocessors:
147+
assert set(processor.spec.requires) <= seen
148+
seen.add(processor.spec.key)
135149

136150

137151
def test_build_reported_takeup_anchors_skips_missing_period():
@@ -356,6 +370,27 @@ def test_us_takeup_postprocessor_requires_geography_first():
356370
)
357371

358372

373+
def test_us_takeup_postprocessor_requires_entity_ids_for_reported_aca_anchor():
374+
payload = _base_payload(
375+
{
376+
"state_fips": {2024: np.array([6, 37], dtype=np.int32)},
377+
"reported_has_subsidized_marketplace_health_coverage_at_interview": {
378+
2024: np.array([True, False, False])
379+
},
380+
}
381+
)
382+
383+
with pytest.raises(
384+
ValueError,
385+
match="US take-up reported ACA anchors require person_tax_unit_id "
386+
"and tax_unit_id from USEntityPostProcessor",
387+
):
388+
USTakeupPostProcessor(takeup_applier=lambda **kwargs: {}).apply(
389+
payload=payload,
390+
context=_context(),
391+
)
392+
393+
359394
def test_us_takeup_postprocessor_rejects_unknown_takeup_results():
360395
service = USTakeupPostProcessor(
361396
takeup_applier=lambda **kwargs: {

0 commit comments

Comments
 (0)