|
31 | 31 | "TAKEUP_VARIABLE_ENTITIES", |
32 | 32 | "US_ENTITY_POSTPROCESSOR_KEY", |
33 | 33 | "US_GEOGRAPHY_POSTPROCESSOR_KEY", |
| 34 | + "US_MEDICAID_COST_POSTPROCESSOR_KEY", |
34 | 35 | "US_TAKEUP_POSTPROCESSOR_KEY", |
35 | 36 | "USEntityPostProcessor", |
36 | 37 | "USEntityPostProcessorResult", |
37 | 38 | "USGeographyPostProcessor", |
38 | 39 | "USGeographyPostProcessorResult", |
| 40 | + "USMedicaidCostPostProcessor", |
| 41 | + "USMedicaidCostPostProcessorResult", |
39 | 42 | "USTakeupPostProcessor", |
40 | 43 | "USTakeupPostProcessorResult", |
41 | 44 | "default_us_postprocessors", |
|
52 | 55 | US_ENTITY_POSTPROCESSOR_KEY = "us_entity" |
53 | 56 | US_GEOGRAPHY_POSTPROCESSOR_KEY = "us_geography" |
54 | 57 | US_TAKEUP_POSTPROCESSOR_KEY = "us_takeup" |
| 58 | +US_MEDICAID_COST_POSTPROCESSOR_KEY = "us_medicaid_cost" |
55 | 59 |
|
56 | 60 |
|
57 | 61 | @pipeline_node( |
@@ -134,6 +138,32 @@ def data(self) -> PayloadData: |
134 | 138 | return self.payload.data |
135 | 139 |
|
136 | 140 |
|
| 141 | +@pipeline_node( |
| 142 | + id="local_h5_us_medicaid_cost_postprocessor_result", |
| 143 | + label="USMedicaidCostPostProcessorResult", |
| 144 | + node_type="library", |
| 145 | + description="US Medicaid conditional-cost local H5 payload data.", |
| 146 | + source_file="policyengine_us_data/build_outputs/us_augmentations.py", |
| 147 | + status="current", |
| 148 | + stability="moving", |
| 149 | + pathways=["local_h5"], |
| 150 | + validation_commands=[ |
| 151 | + "uv run pytest tests/unit/build_outputs/test_us_augmentations.py" |
| 152 | + ], |
| 153 | +) |
| 154 | +@dataclass(frozen=True) |
| 155 | +class USMedicaidCostPostProcessorResult: |
| 156 | + """Payload after conditional Medicaid cost fields are applied.""" |
| 157 | + |
| 158 | + payload: H5Payload |
| 159 | + |
| 160 | + @property |
| 161 | + def data(self) -> PayloadData: |
| 162 | + """Augmented payload data retained for transitional callers.""" |
| 163 | + |
| 164 | + return self.payload.data |
| 165 | + |
| 166 | + |
137 | 167 | @pipeline_node( |
138 | 168 | id="local_h5_us_entity_postprocessor", |
139 | 169 | label="USEntityPostProcessor", |
@@ -571,22 +601,106 @@ def _build_eligibility_masks( |
571 | 601 | } |
572 | 602 |
|
573 | 603 |
|
| 604 | +@pipeline_node( |
| 605 | + id="local_h5_us_medicaid_cost_postprocessor", |
| 606 | + label="USMedicaidCostPostProcessor", |
| 607 | + node_type="library", |
| 608 | + description="Preserve Medicaid cost-if-enrolled inputs in local H5 payloads.", |
| 609 | + source_file="policyengine_us_data/build_outputs/us_augmentations.py", |
| 610 | + status="current", |
| 611 | + stability="moving", |
| 612 | + pathways=["local_h5"], |
| 613 | + validation_commands=[ |
| 614 | + "uv run pytest tests/unit/build_outputs/test_us_augmentations.py" |
| 615 | + ], |
| 616 | +) |
| 617 | +@dataclass(frozen=True) |
| 618 | +class USMedicaidCostPostProcessor: |
| 619 | + """Preserve source Medicaid conditional costs after local H5 transforms.""" |
| 620 | + |
| 621 | + spec = PayloadPostProcessorSpec( |
| 622 | + key=US_MEDICAID_COST_POSTPROCESSOR_KEY, |
| 623 | + requires=( |
| 624 | + US_ENTITY_POSTPROCESSOR_KEY, |
| 625 | + US_GEOGRAPHY_POSTPROCESSOR_KEY, |
| 626 | + US_TAKEUP_POSTPROCESSOR_KEY, |
| 627 | + ), |
| 628 | + ) |
| 629 | + |
| 630 | + def apply( |
| 631 | + self, |
| 632 | + *, |
| 633 | + payload: H5Payload, |
| 634 | + context: PayloadBuildContext, |
| 635 | + ) -> USMedicaidCostPostProcessorResult: |
| 636 | + """Return a payload with source conditional Medicaid costs preserved.""" |
| 637 | + |
| 638 | + output = _copy_payload(payload.data) |
| 639 | + cost_periods = output.get("medicaid_cost_if_enrolled", {}) |
| 640 | + if context.time_period not in cost_periods: |
| 641 | + source_values = _source_person_period_values( |
| 642 | + context=context, |
| 643 | + variable="medicaid_cost_if_enrolled", |
| 644 | + ) |
| 645 | + if source_values is not None: |
| 646 | + output["medicaid_cost_if_enrolled"] = { |
| 647 | + context.time_period: source_values.astype(np.float32) |
| 648 | + } |
| 649 | + |
| 650 | + variable_entities = dict(payload.variable_entities) |
| 651 | + if "medicaid_cost_if_enrolled" in output: |
| 652 | + variable_entities["medicaid_cost_if_enrolled"] = "person" |
| 653 | + return USMedicaidCostPostProcessorResult( |
| 654 | + payload=H5Payload( |
| 655 | + data=output, |
| 656 | + time_period=payload.time_period, |
| 657 | + entity_lengths=payload.entity_lengths, |
| 658 | + variable_entities=variable_entities, |
| 659 | + ), |
| 660 | + ) |
| 661 | + |
| 662 | + |
574 | 663 | def default_us_postprocessors() -> tuple[ |
575 | | - USEntityPostProcessor | USGeographyPostProcessor | USTakeupPostProcessor, ... |
| 664 | + USEntityPostProcessor |
| 665 | + | USGeographyPostProcessor |
| 666 | + | USTakeupPostProcessor |
| 667 | + | USMedicaidCostPostProcessor, |
| 668 | + ..., |
576 | 669 | ]: |
577 | 670 | """Return production US postprocessors in their required order.""" |
578 | 671 |
|
579 | 672 | return ( |
580 | 673 | USEntityPostProcessor(), |
581 | 674 | USGeographyPostProcessor(), |
582 | 675 | USTakeupPostProcessor(), |
| 676 | + USMedicaidCostPostProcessor(), |
583 | 677 | ) |
584 | 678 |
|
585 | 679 |
|
586 | 680 | def _copy_payload(data: Mapping[str, Mapping[Any, np.ndarray]]) -> PayloadData: |
587 | 681 | return {variable: dict(periods) for variable, periods in data.items()} |
588 | 682 |
|
589 | 683 |
|
| 684 | +def _source_person_period_values( |
| 685 | + *, |
| 686 | + context: PayloadBuildContext, |
| 687 | + variable: str, |
| 688 | +) -> np.ndarray | None: |
| 689 | + """Return a source person input mapped to the local payload, if available.""" |
| 690 | + |
| 691 | + if variable not in context.source.input_variables: |
| 692 | + return None |
| 693 | + provider = context.source.variable_provider |
| 694 | + get_array = getattr(provider, "get_array", None) |
| 695 | + if not callable(get_array): |
| 696 | + return None |
| 697 | + try: |
| 698 | + values = np.asarray(get_array(variable, context.time_period)) |
| 699 | + except (KeyError, ValueError): |
| 700 | + return None |
| 701 | + return values[context.reindexed.person_source_indices] |
| 702 | + |
| 703 | + |
590 | 704 | def _required_period_array( |
591 | 705 | data: Mapping[str, Mapping[Any, np.ndarray]], |
592 | 706 | variable: str, |
|
0 commit comments