Skip to content
This repository was archived by the owner on Jun 14, 2026. It is now read-only.

Commit a9262ea

Browse files
committed
Use PolicyEngine formulas for oracle targets
1 parent c08252f commit a9262ea

15 files changed

Lines changed: 451 additions & 95 deletions

src/microplex_us/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@
167167
"infer_policyengine_us_variable_bindings",
168168
"load_policyengine_us_entity_tables",
169169
"materialize_policyengine_us_variables",
170+
"policyengine_us_formula_variables_for_targets",
170171
"policyengine_us_variables_to_materialize",
171172
"project_frame_to_time_period_arrays",
172173
"write_policyengine_us_time_period_dataset",
@@ -356,6 +357,7 @@ def __getattr__(name: str) -> Any:
356357
"infer_policyengine_us_variable_bindings",
357358
"load_policyengine_us_entity_tables",
358359
"materialize_policyengine_us_variables",
360+
"policyengine_us_formula_variables_for_targets",
359361
"policyengine_us_variables_to_materialize",
360362
"project_frame_to_time_period_arrays",
361363
"write_policyengine_us_time_period_dataset",

src/microplex_us/pipelines/summarize_policyengine_oracle_target_drilldown.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def summarize_us_policyengine_oracle_target_drilldown(
6565
_supported_targets,
6666
_constraints,
6767
_feasibility_filter_summary,
68-
_materialized_variables,
68+
calibration_materialized_variables,
6969
_materialization_failures,
7070
) = pipeline._resolve_policyengine_calibration_targets(
7171
tables,
@@ -100,6 +100,8 @@ def summarize_us_policyengine_oracle_target_drilldown(
100100
str(variable)
101101
for variable in manifest.get("calibration", {}).get("materialized_variables", ())
102102
}
103+
materialized_variables.update(str(variable) for variable in calibration_materialized_variables)
104+
materialized_variables.update(str(variable) for variable in report.materialized_variables)
103105
ledger_by_name = {
104106
str(entry["target_name"]): dict(entry)
105107
for entry in target_ledger

src/microplex_us/pipelines/us.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
infer_policyengine_us_variable_bindings,
7373
load_us_pipeline_checkpoint,
7474
materialize_policyengine_us_variables_safely,
75+
policyengine_us_formula_variables_for_targets,
7576
policyengine_us_variables_to_materialize,
7677
resolve_policyengine_excluded_export_variables,
7778
save_us_pipeline_checkpoint,
@@ -3831,9 +3832,15 @@ def _resolve_policyengine_calibration_targets(
38313832
period=target_period,
38323833
for_calibration=True,
38333834
).targets
3835+
force_materialize_variables = policyengine_us_formula_variables_for_targets(
3836+
canonical_targets,
3837+
simulation_cls=self.config.policyengine_simulation_cls,
3838+
direct_override_variables=self.config.policyengine_direct_override_variables,
3839+
)
38343840
missing_variables = policyengine_us_variables_to_materialize(
38353841
canonical_targets,
38363842
bindings,
3843+
force_materialize_variables=force_materialize_variables,
38373844
)
38383845
materialization_failures: dict[str, str] = {}
38393846
materialized_variables: set[str] = set()
@@ -3844,9 +3851,20 @@ def _resolve_policyengine_calibration_targets(
38443851
period=target_period,
38453852
dataset_year=self.config.policyengine_dataset_year or target_period,
38463853
simulation_cls=self.config.policyengine_simulation_cls,
3854+
direct_override_variables=self.config.policyengine_direct_override_variables,
38473855
batch_size=self.config.policyengine_materialize_batch_size,
38483856
)
38493857
tables = materialization_result.tables
3858+
unmaterialized_forced_variables = (
3859+
force_materialize_variables
3860+
& missing_variables
3861+
- set(materialization_result.bindings)
3862+
)
3863+
bindings = {
3864+
variable: binding
3865+
for variable, binding in bindings.items()
3866+
if variable not in unmaterialized_forced_variables
3867+
}
38503868
bindings = {
38513869
**bindings,
38523870
**materialization_result.bindings,

src/microplex_us/policyengine/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
infer_policyengine_us_variable_bindings,
4040
load_policyengine_us_entity_tables,
4141
materialize_policyengine_us_variables,
42+
policyengine_us_formula_variables_for_targets,
4243
policyengine_us_variables_to_materialize,
4344
project_frame_to_time_period_arrays,
4445
write_policyengine_us_time_period_dataset,
@@ -79,6 +80,7 @@
7980
"infer_policyengine_us_variable_bindings",
8081
"load_policyengine_us_entity_tables",
8182
"materialize_policyengine_us_variables",
83+
"policyengine_us_formula_variables_for_targets",
8284
"policyengine_us_variables_to_materialize",
8385
"project_frame_to_time_period_arrays",
8486
"write_policyengine_us_time_period_dataset",

src/microplex_us/policyengine/comparison.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
infer_policyengine_us_variable_bindings,
3535
load_policyengine_us_entity_tables,
3636
materialize_policyengine_us_variables_safely,
37+
policyengine_us_formula_variables_for_targets,
38+
policyengine_us_variables_to_materialize,
3739
)
3840

3941
POLICYENGINE_US_BENCHMARK_GROUP_FIELDS = (
@@ -363,20 +365,35 @@ def evaluate_policyengine_us_target_set(
363365
target_list = _normalize_target_list(targets)
364366
working_tables = tables
365367
bindings = infer_policyengine_us_variable_bindings(working_tables)
368+
force_materialize_variables = policyengine_us_formula_variables_for_targets(
369+
target_list,
370+
simulation_cls=simulation_cls,
371+
direct_override_variables=direct_override_variables,
372+
)
373+
variables_to_materialize = policyengine_us_variables_to_materialize(
374+
target_list,
375+
bindings,
376+
force_materialize_variables=force_materialize_variables,
377+
)
366378
materialization_result = materialize_policyengine_us_variables_safely(
367379
working_tables,
368-
variables=tuple(
369-
feature
370-
for target in target_list
371-
for feature in target.required_features
372-
if feature not in bindings
373-
),
380+
variables=tuple(sorted(variables_to_materialize)),
374381
period=period,
375382
dataset_year=dataset_year,
376383
simulation_cls=simulation_cls,
377384
direct_override_variables=direct_override_variables,
378385
)
379386
working_tables = materialization_result.tables
387+
unmaterialized_forced_variables = (
388+
force_materialize_variables
389+
& variables_to_materialize
390+
- set(materialization_result.bindings)
391+
)
392+
bindings = {
393+
variable: binding
394+
for variable, binding in bindings.items()
395+
if variable not in unmaterialized_forced_variables
396+
}
380397
bindings = {
381398
**bindings,
382399
**materialization_result.bindings,

src/microplex_us/policyengine/us.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ class PolicyEngineUSVariableMaterializationResult:
286286
"other_medical_expenses",
287287
"over_the_counter_health_expenses",
288288
"self_employment_income_before_lsr",
289-
"social_security_retirement",
289+
"social_security_retirement_reported",
290290
"social_security_disability",
291291
"social_security_survivors",
292292
"social_security_dependents",
@@ -327,6 +327,7 @@ class PolicyEngineUSVariableMaterializationResult:
327327

328328
POLICYENGINE_US_EXPORT_COLUMN_ALIASES: dict[str, str] = {
329329
"race": "cps_race",
330+
"social_security_retirement": "social_security_retirement_reported",
330331
}
331332

332333
POLICYENGINE_US_EXPORT_DEFAULTS: dict[str, Any] = {
@@ -1866,18 +1867,70 @@ def compile_supported_policyengine_us_household_linear_constraints(
18661867
return supported_targets, unsupported_targets, tuple(constraints)
18671868

18681869

1870+
def _policyengine_us_target_required_variables(targets: list[TargetSpec]) -> set[str]:
1871+
return {
1872+
feature
1873+
for target in targets
1874+
for feature in target.required_features
1875+
}
1876+
1877+
1878+
def policyengine_us_formula_variables_for_targets(
1879+
targets: list[TargetSpec],
1880+
*,
1881+
simulation_cls: Any | None = None,
1882+
tax_benefit_system: Any | None = None,
1883+
direct_override_variables: tuple[str, ...] = (),
1884+
) -> set[str]:
1885+
"""Return target features that should be recalculated by PolicyEngine."""
1886+
required_variables = _policyengine_us_target_required_variables(targets)
1887+
if not required_variables:
1888+
return set()
1889+
if tax_benefit_system is None:
1890+
tax_benefit_system = _resolve_policyengine_us_tax_benefit_system(
1891+
simulation_cls
1892+
)
1893+
variables = getattr(tax_benefit_system, "variables", {})
1894+
direct_overrides = set(direct_override_variables)
1895+
formula_variables: set[str] = set()
1896+
for variable in required_variables:
1897+
if variable in direct_overrides:
1898+
continue
1899+
variable_metadata = variables.get(variable)
1900+
if variable_metadata is None:
1901+
continue
1902+
if _policyengine_us_variable_is_calculated(variable_metadata):
1903+
formula_variables.add(variable)
1904+
return formula_variables
1905+
1906+
1907+
def _policyengine_us_variable_is_calculated(variable_metadata: Any) -> bool:
1908+
if getattr(variable_metadata, "formulas", {}):
1909+
return True
1910+
if getattr(variable_metadata, "adds", ()) or getattr(variable_metadata, "subtracts", ()):
1911+
return True
1912+
is_input_variable = getattr(variable_metadata, "is_input_variable", None)
1913+
if callable(is_input_variable):
1914+
try:
1915+
return not bool(is_input_variable())
1916+
except TypeError:
1917+
return False
1918+
return False
1919+
1920+
18691921
def policyengine_us_variables_to_materialize(
18701922
targets: list[TargetSpec],
18711923
bindings: dict[str, PolicyEngineUSVariableBinding],
1924+
*,
1925+
force_materialize_variables: set[str] | tuple[str, ...] | None = None,
18721926
) -> set[str]:
18731927
"""Compute the missing features required to score the given targets."""
1874-
requested_variables = {
1875-
feature
1876-
for target in targets
1877-
for feature in target.required_features
1878-
}
1928+
requested_variables = _policyengine_us_target_required_variables(targets)
1929+
force_variables = set(force_materialize_variables or ())
18791930
return {
1880-
variable for variable in requested_variables if variable not in bindings
1931+
variable
1932+
for variable in requested_variables
1933+
if variable not in bindings or variable in force_variables
18811934
}
18821935

18831936

tests/pipelines/test_artifacts.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -176,19 +176,9 @@ def _create_policyengine_targets_db(path: Path) -> None:
176176
t.value,
177177
t.period,
178178
t.active,
179-
CASE
180-
WHEN t.variable = 'snap' THEN 'state'
181-
ELSE 'district'
182-
END AS geo_level,
183-
CASE
184-
WHEN t.variable = 'snap' THEN '06'
185-
ELSE '0601'
186-
END AS geographic_id,
187-
CASE
188-
WHEN t.variable = 'snap' THEN 'snap'
189-
WHEN t.variable = 'household_count' THEN 'snap'
190-
ELSE NULL
191-
END AS domain_variable
179+
'state' AS geo_level,
180+
'06' AS geographic_id,
181+
'household_count' AS domain_variable
192182
FROM targets AS t;
193183
"""
194184
)
@@ -216,7 +206,6 @@ def _create_policyengine_targets_db(path: Path) -> None:
216206
""",
217207
[
218208
(1, "household_count", 2024, 1, 0, 3.0, 1, None, "test", "count"),
219-
(2, "snap", 2024, 1, 0, 250.0, 1, None, "test", "snap"),
220209
],
221210
)
222211
conn.commit()
@@ -604,12 +593,11 @@ def test_writes_policyengine_harness_when_baseline_and_targets_are_provided(
604593
TargetSet(
605594
[
606595
TargetSpec(
607-
name="snap_total",
596+
name="household_count",
608597
entity=EntityType.HOUSEHOLD,
609-
value=250.0,
598+
value=3.0,
610599
period=2024,
611-
measure="snap",
612-
aggregation="sum",
600+
aggregation="count",
613601
),
614602
]
615603
)
@@ -622,9 +610,9 @@ def test_writes_policyengine_harness_when_baseline_and_targets_are_provided(
622610
policyengine_baseline_dataset=baseline_dataset,
623611
policyengine_harness_slices=(
624612
PolicyEngineUSHarnessSlice(
625-
name="snap",
626-
description="SNAP parity",
627-
query=TargetQuery(period=2024, names=("snap_total",)),
613+
name="household_count",
614+
description="Household count parity",
615+
query=TargetQuery(period=2024, names=("household_count",)),
628616
),
629617
),
630618
policyengine_harness_metadata={"baseline_dataset": baseline_dataset.name},
@@ -838,7 +826,7 @@ def test_writes_policyengine_harness_from_build_config_defaults(self, tmp_path):
838826
policyengine_dataset_year=2024,
839827
policyengine_targets_db=str(targets_db),
840828
policyengine_baseline_dataset=str(baseline_dataset),
841-
policyengine_target_variables=("snap", "household_count"),
829+
policyengine_target_variables=("household_count",),
842830
),
843831
seed_data=pd.DataFrame({"income": [10.0], "hh_weight": [1.0]}),
844832
synthetic_data=pd.DataFrame({"income": [10.0, 20.0], "weight": [1.0, 1.0]}),
@@ -921,10 +909,7 @@ def test_writes_policyengine_harness_from_build_config_defaults(self, tmp_path):
921909
assert harness_payload["metadata"]["targets_db"] == "policyengine_targets.db"
922910
assert harness_payload["metadata"]["harness_suite"] == "policyengine_us_all_targets"
923911
assert harness_payload["metadata"]["harness_slice_names"] == ["all_targets"]
924-
assert harness_payload["metadata"]["target_variables"] == [
925-
"snap",
926-
"household_count",
927-
]
912+
assert harness_payload["metadata"]["target_variables"] == ["household_count"]
928913
assert harness_payload["metadata"]["policyengine_us_runtime_version"] is not None
929914
assert [slice_payload["name"] for slice_payload in harness_payload["slices"]] == [
930915
"all_targets",

0 commit comments

Comments
 (0)