Skip to content

Commit 1b57a42

Browse files
authored
Allocate Medicaid cost if enrolled from SLCSP (#1138)
* Allocate Medicaid conditional cost from SLCSP * Preserve Medicaid costs in local H5 builds * Pin us-data to PE-US Medicaid cost fix
1 parent a9af6f5 commit 1b57a42

10 files changed

Lines changed: 817 additions & 5 deletions

File tree

changelog.d/1138.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allocate conditional Medicaid cost from an SLCSP age/location index in us-data.

policyengine_us_data/build_outputs/us_augmentations.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@
3131
"TAKEUP_VARIABLE_ENTITIES",
3232
"US_ENTITY_POSTPROCESSOR_KEY",
3333
"US_GEOGRAPHY_POSTPROCESSOR_KEY",
34+
"US_MEDICAID_COST_POSTPROCESSOR_KEY",
3435
"US_TAKEUP_POSTPROCESSOR_KEY",
3536
"USEntityPostProcessor",
3637
"USEntityPostProcessorResult",
3738
"USGeographyPostProcessor",
3839
"USGeographyPostProcessorResult",
40+
"USMedicaidCostPostProcessor",
41+
"USMedicaidCostPostProcessorResult",
3942
"USTakeupPostProcessor",
4043
"USTakeupPostProcessorResult",
4144
"default_us_postprocessors",
@@ -52,6 +55,7 @@
5255
US_ENTITY_POSTPROCESSOR_KEY = "us_entity"
5356
US_GEOGRAPHY_POSTPROCESSOR_KEY = "us_geography"
5457
US_TAKEUP_POSTPROCESSOR_KEY = "us_takeup"
58+
US_MEDICAID_COST_POSTPROCESSOR_KEY = "us_medicaid_cost"
5559

5660

5761
@pipeline_node(
@@ -134,6 +138,32 @@ def data(self) -> PayloadData:
134138
return self.payload.data
135139

136140

141+
@pipeline_node(
142+
id="local_h5_us_medicaid_cost_postprocessor_result",
143+
label="USMedicaidCostPostProcessorResult",
144+
node_type="library",
145+
description="US Medicaid conditional-cost local H5 payload data.",
146+
source_file="policyengine_us_data/build_outputs/us_augmentations.py",
147+
status="current",
148+
stability="moving",
149+
pathways=["local_h5"],
150+
validation_commands=[
151+
"uv run pytest tests/unit/build_outputs/test_us_augmentations.py"
152+
],
153+
)
154+
@dataclass(frozen=True)
155+
class USMedicaidCostPostProcessorResult:
156+
"""Payload after conditional Medicaid cost fields are applied."""
157+
158+
payload: H5Payload
159+
160+
@property
161+
def data(self) -> PayloadData:
162+
"""Augmented payload data retained for transitional callers."""
163+
164+
return self.payload.data
165+
166+
137167
@pipeline_node(
138168
id="local_h5_us_entity_postprocessor",
139169
label="USEntityPostProcessor",
@@ -571,22 +601,106 @@ def _build_eligibility_masks(
571601
}
572602

573603

604+
@pipeline_node(
605+
id="local_h5_us_medicaid_cost_postprocessor",
606+
label="USMedicaidCostPostProcessor",
607+
node_type="library",
608+
description="Preserve Medicaid cost-if-enrolled inputs in local H5 payloads.",
609+
source_file="policyengine_us_data/build_outputs/us_augmentations.py",
610+
status="current",
611+
stability="moving",
612+
pathways=["local_h5"],
613+
validation_commands=[
614+
"uv run pytest tests/unit/build_outputs/test_us_augmentations.py"
615+
],
616+
)
617+
@dataclass(frozen=True)
618+
class USMedicaidCostPostProcessor:
619+
"""Preserve source Medicaid conditional costs after local H5 transforms."""
620+
621+
spec = PayloadPostProcessorSpec(
622+
key=US_MEDICAID_COST_POSTPROCESSOR_KEY,
623+
requires=(
624+
US_ENTITY_POSTPROCESSOR_KEY,
625+
US_GEOGRAPHY_POSTPROCESSOR_KEY,
626+
US_TAKEUP_POSTPROCESSOR_KEY,
627+
),
628+
)
629+
630+
def apply(
631+
self,
632+
*,
633+
payload: H5Payload,
634+
context: PayloadBuildContext,
635+
) -> USMedicaidCostPostProcessorResult:
636+
"""Return a payload with source conditional Medicaid costs preserved."""
637+
638+
output = _copy_payload(payload.data)
639+
cost_periods = output.get("medicaid_cost_if_enrolled", {})
640+
if context.time_period not in cost_periods:
641+
source_values = _source_person_period_values(
642+
context=context,
643+
variable="medicaid_cost_if_enrolled",
644+
)
645+
if source_values is not None:
646+
output["medicaid_cost_if_enrolled"] = {
647+
context.time_period: source_values.astype(np.float32)
648+
}
649+
650+
variable_entities = dict(payload.variable_entities)
651+
if "medicaid_cost_if_enrolled" in output:
652+
variable_entities["medicaid_cost_if_enrolled"] = "person"
653+
return USMedicaidCostPostProcessorResult(
654+
payload=H5Payload(
655+
data=output,
656+
time_period=payload.time_period,
657+
entity_lengths=payload.entity_lengths,
658+
variable_entities=variable_entities,
659+
),
660+
)
661+
662+
574663
def default_us_postprocessors() -> tuple[
575-
USEntityPostProcessor | USGeographyPostProcessor | USTakeupPostProcessor, ...
664+
USEntityPostProcessor
665+
| USGeographyPostProcessor
666+
| USTakeupPostProcessor
667+
| USMedicaidCostPostProcessor,
668+
...,
576669
]:
577670
"""Return production US postprocessors in their required order."""
578671

579672
return (
580673
USEntityPostProcessor(),
581674
USGeographyPostProcessor(),
582675
USTakeupPostProcessor(),
676+
USMedicaidCostPostProcessor(),
583677
)
584678

585679

586680
def _copy_payload(data: Mapping[str, Mapping[Any, np.ndarray]]) -> PayloadData:
587681
return {variable: dict(periods) for variable, periods in data.items()}
588682

589683

684+
def _source_person_period_values(
685+
*,
686+
context: PayloadBuildContext,
687+
variable: str,
688+
) -> np.ndarray | None:
689+
"""Return a source person input mapped to the local payload, if available."""
690+
691+
if variable not in context.source.input_variables:
692+
return None
693+
provider = context.source.variable_provider
694+
get_array = getattr(provider, "get_array", None)
695+
if not callable(get_array):
696+
return None
697+
try:
698+
values = np.asarray(get_array(variable, context.time_period))
699+
except (KeyError, ValueError):
700+
return None
701+
return values[context.reindexed.person_source_indices]
702+
703+
590704
def _required_period_array(
591705
data: Mapping[str, Mapping[Any, np.ndarray]],
592706
variable: str,

policyengine_us_data/datasets/cps/cps.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
derive_treasury_tipped_occupation_code,
5050
derive_is_tipped_occupation,
5151
)
52+
from policyengine_us_data.datasets.cps.medicaid_cost import (
53+
add_medicaid_cost_if_enrolled_to_dataset,
54+
)
5255
from policyengine_us_data.utils.takeup import (
5356
_sum_person_values_to_tax_units,
5457
_voluntary_filing_age_bin,
@@ -360,6 +363,8 @@ def generate(self):
360363
add_takeup(self)
361364
logging.info("Imputing Marketplace plan benchmark ratio")
362365
add_marketplace_plan_benchmark_ratio(self)
366+
logging.info("Adding Medicaid cost if enrolled")
367+
add_medicaid_cost_if_enrolled_to_dataset(self)
363368
logging.info("Deriving other health insurance premiums")
364369
derive_other_health_insurance_premiums(self)
365370
logging.info("Downsampling")

policyengine_us_data/datasets/cps/extended_cps.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
derive_flsa_overtime_premium,
2626
load_take_up_rate,
2727
)
28+
from policyengine_us_data.datasets.cps.medicaid_cost import (
29+
add_medicaid_cost_if_enrolled_to_time_period_data,
30+
)
2831
from policyengine_us_data.datasets.cps.takeup import prioritize_reported_recipients
2932
from policyengine_us_data.datasets.org import (
3033
ORG_IMPUTED_VARIABLES,
@@ -1190,6 +1193,10 @@ def generate(self):
11901193
new_data,
11911194
self.time_period,
11921195
)
1196+
new_data = add_medicaid_cost_if_enrolled_to_time_period_data(
1197+
new_data,
1198+
self.time_period,
1199+
)
11931200
new_data = self._validate_housing_assistance_microsimulation(
11941201
new_data,
11951202
self.time_period,

0 commit comments

Comments
 (0)