Skip to content

Commit bc5dcff

Browse files
anth-volkclaude
andcommitted
fix: Add PolicyReformAnalysis.model_rebuild() and update example script
- Add model_rebuild() for PolicyReformAnalysis in both US and UK __init__.py to resolve BudgetSummaryItem forward reference (TYPE_CHECKING import) - Fix test_aggregate to expect ValueError instead of StopIteration - Fix example script bp.metric → bp.poverty_type to match Poverty class Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent af51b6b commit bc5dcff

4 files changed

Lines changed: 163 additions & 2 deletions

File tree

examples/us_budgetary_impact.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""Example: US budgetary impact comparison between baseline and reform.
2+
3+
Demonstrates the canonical policyengine.py workflow:
4+
1. Ensure datasets exist (download + compute or load from cache)
5+
2. Define a parametric reform
6+
3. Run baseline and reform simulations
7+
4. Use economic_impact_analysis() for the full analysis
8+
5. Use ChangeAggregate for targeted single-metric queries
9+
10+
Run: python examples/us_budgetary_impact.py
11+
"""
12+
13+
import datetime
14+
15+
from policyengine.core import Parameter, ParameterValue, Policy, Simulation
16+
from policyengine.outputs.change_aggregate import (
17+
ChangeAggregate,
18+
ChangeAggregateType,
19+
)
20+
from policyengine.tax_benefit_models.us import (
21+
economic_impact_analysis,
22+
ensure_datasets,
23+
us_latest,
24+
)
25+
26+
27+
def main():
28+
year = 2026
29+
30+
# ── Step 1: Get dataset (downloads from HuggingFace on first run) ──
31+
print("Ensuring datasets are available...")
32+
datasets = ensure_datasets(
33+
datasets=["hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5"],
34+
years=[year],
35+
data_folder="./data",
36+
)
37+
dataset = datasets[f"enhanced_cps_2024_{year}"]
38+
print(f" Loaded: {dataset}")
39+
40+
# ── Step 2: Define a reform ──
41+
# Example: double the standard deduction for single filers
42+
param = Parameter(
43+
name="gov.irs.deductions.standard.amount.SINGLE",
44+
tax_benefit_model_version=us_latest,
45+
)
46+
reform = Policy(
47+
name="Double standard deduction (single)",
48+
parameter_values=[
49+
ParameterValue(
50+
parameter=param,
51+
start_date=datetime.date(year, 1, 1),
52+
end_date=datetime.date(year, 12, 31),
53+
value=30_950,
54+
),
55+
],
56+
)
57+
58+
# ── Step 3: Create simulations ──
59+
baseline_sim = Simulation(
60+
dataset=dataset,
61+
tax_benefit_model_version=us_latest,
62+
)
63+
reform_sim = Simulation(
64+
dataset=dataset,
65+
tax_benefit_model_version=us_latest,
66+
policy=reform,
67+
)
68+
69+
# ── Step 4a: Quick budgetary number via ChangeAggregate ──
70+
# This requires running the simulations first.
71+
print("\nRunning simulations...")
72+
baseline_sim.run()
73+
reform_sim.run()
74+
75+
tax_change = ChangeAggregate(
76+
baseline_simulation=baseline_sim,
77+
reform_simulation=reform_sim,
78+
variable="household_tax",
79+
aggregate_type=ChangeAggregateType.SUM,
80+
)
81+
tax_change.run()
82+
print("\nQuick budgetary result:")
83+
print(f" Tax revenue change: ${tax_change.result / 1e9:.2f}B")
84+
85+
# Count winners and losers
86+
winners = ChangeAggregate(
87+
baseline_simulation=baseline_sim,
88+
reform_simulation=reform_sim,
89+
variable="household_net_income",
90+
aggregate_type=ChangeAggregateType.COUNT,
91+
change_geq=1,
92+
)
93+
losers = ChangeAggregate(
94+
baseline_simulation=baseline_sim,
95+
reform_simulation=reform_sim,
96+
variable="household_net_income",
97+
aggregate_type=ChangeAggregateType.COUNT,
98+
change_leq=-1,
99+
)
100+
winners.run()
101+
losers.run()
102+
print(f" Winners: {winners.result / 1e6:.2f}M households")
103+
print(f" Losers: {losers.result / 1e6:.2f}M households")
104+
105+
# ── Step 4b: Full analysis via economic_impact_analysis ──
106+
# Note: this calls .ensure() internally, which is a no-op here since
107+
# we already ran the simulations above. If we hadn't called .run(),
108+
# ensure() would run + cache them automatically.
109+
print("\nRunning full economic impact analysis...")
110+
analysis = economic_impact_analysis(baseline_sim, reform_sim)
111+
112+
print("\n=== Program-by-Program Impact ===")
113+
for prog in analysis.program_statistics.outputs:
114+
print(
115+
f" {prog.program_name:30s} "
116+
f"baseline=${prog.baseline_total / 1e9:8.1f}B "
117+
f"reform=${prog.reform_total / 1e9:8.1f}B "
118+
f"change=${prog.change / 1e9:+8.1f}B"
119+
)
120+
121+
print("\n=== Decile Impacts ===")
122+
for d in analysis.decile_impacts.outputs:
123+
print(
124+
f" Decile {d.decile:2d}: "
125+
f"avg change=${d.absolute_change:+8.0f} "
126+
f"relative={d.relative_change:+.2%}"
127+
)
128+
129+
print("\n=== Poverty ===")
130+
for bp, rp in zip(
131+
analysis.baseline_poverty.outputs,
132+
analysis.reform_poverty.outputs,
133+
strict=True,
134+
):
135+
print(
136+
f" {bp.poverty_type:30s} "
137+
f"baseline={bp.rate:.4f} "
138+
f"reform={rp.rate:.4f} "
139+
f"change={rp.rate - bp.rate:+.4f}"
140+
)
141+
142+
print("\n=== Inequality ===")
143+
bi = analysis.baseline_inequality
144+
ri = analysis.reform_inequality
145+
print(f" Gini: baseline={bi.gini:.4f} reform={ri.gini:.4f}")
146+
print(
147+
f" Top 10% share: baseline={bi.top_10_share:.4f} reform={ri.top_10_share:.4f}"
148+
)
149+
print(
150+
f" Top 1% share: baseline={bi.top_1_share:.4f} reform={ri.top_1_share:.4f}"
151+
)
152+
153+
154+
if __name__ == "__main__":
155+
main()

src/policyengine/tax_benefit_models/uk/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
PolicyEngineUKLatest.model_rebuild()
3838
ProgrammeStatistics.model_rebuild(_types_namespace={"Simulation": Simulation})
3939
BudgetSummaryItem.model_rebuild(_types_namespace={"Simulation": Simulation})
40+
PolicyReformAnalysis.model_rebuild(
41+
_types_namespace={"BudgetSummaryItem": BudgetSummaryItem}
42+
)
4043

4144
__all__ = [
4245
"UKYearData",

src/policyengine/tax_benefit_models/us/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
PolicyEngineUSLatest.model_rebuild()
3838
ProgramStatistics.model_rebuild(_types_namespace={"Simulation": Simulation})
3939
BudgetSummaryItem.model_rebuild(_types_namespace={"Simulation": Simulation})
40+
PolicyReformAnalysis.model_rebuild(
41+
_types_namespace={"BudgetSummaryItem": BudgetSummaryItem}
42+
)
4043

4144
__all__ = [
4245
"USYearData",

tests/test_aggregate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def test_aggregate_invalid_variable():
478478
variable="nonexistent_variable",
479479
aggregate_type=AggregateType.SUM,
480480
)
481-
with pytest.raises(StopIteration):
481+
with pytest.raises(ValueError):
482482
agg.run()
483483

484484
# Invalid filter variable name should raise error on run()
@@ -488,5 +488,5 @@ def test_aggregate_invalid_variable():
488488
aggregate_type=AggregateType.SUM,
489489
filter_variable="nonexistent_filter",
490490
)
491-
with pytest.raises(StopIteration):
491+
with pytest.raises(ValueError):
492492
agg.run()

0 commit comments

Comments
 (0)