Skip to content

Commit f6bb1be

Browse files
feat: add poverty and inequality outputs to economy comparison (#65)
* feat: add poverty and inequality outputs to economy comparison Adds poverty rates and inequality metrics to the economy comparison analysis, matching policyengine.py PR #207. Poverty outputs: - UK: absolute BHC/AHC, relative BHC/AHC - US: SPM, deep SPM Inequality outputs: - Gini coefficient - Top 10%, top 1%, bottom 50% income shares New models: Poverty, Inequality Migration: 20260103000000_add_poverty_inequality.sql Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix: sanitize NaN/Inf values in household results for JSON serialization Also updated tests to use the async job polling pattern. --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent da06d77 commit f6bb1be

7 files changed

Lines changed: 300 additions & 39 deletions

File tree

src/policyengine_api/api/household.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Poll the status endpoint until the job is complete.
55
"""
66

7+
import math
78
from typing import Any, Literal
89
from uuid import UUID
910

@@ -22,12 +23,26 @@
2223
from policyengine_api.services.database import get_session
2324

2425

26+
def _sanitize_for_json(obj: Any) -> Any:
27+
"""Replace NaN/Inf values with None for JSON serialization."""
28+
if isinstance(obj, float):
29+
if math.isnan(obj) or math.isinf(obj):
30+
return None
31+
return obj
32+
elif isinstance(obj, dict):
33+
return {k: _sanitize_for_json(v) for k, v in obj.items()}
34+
elif isinstance(obj, list):
35+
return [_sanitize_for_json(v) for v in obj]
36+
return obj
37+
38+
2539
def get_traceparent() -> str | None:
2640
"""Get the current W3C traceparent header for distributed tracing."""
2741
carrier: dict[str, str] = {}
2842
TraceContextTextMapPropagator().inject(carrier)
2943
return carrier.get("traceparent")
3044

45+
3146
router = APIRouter(prefix="/household", tags=["household"])
3247

3348

@@ -254,11 +269,13 @@ def _run_local_household_uk(
254269
job = session.get(HouseholdJob, job_id)
255270
if job:
256271
job.status = HouseholdJobStatus.COMPLETED
257-
job.result = {
258-
"person": result.person,
259-
"benunit": result.benunit,
260-
"household": result.household,
261-
}
272+
job.result = _sanitize_for_json(
273+
{
274+
"person": result.person,
275+
"benunit": result.benunit,
276+
"household": result.household,
277+
}
278+
)
262279
job.completed_at = datetime.now(timezone.utc)
263280
session.add(job)
264281
session.commit()
@@ -343,14 +360,16 @@ def _run_local_household_us(
343360
job = session.get(HouseholdJob, job_id)
344361
if job:
345362
job.status = HouseholdJobStatus.COMPLETED
346-
job.result = {
347-
"person": result.person,
348-
"marital_unit": result.marital_unit,
349-
"family": result.family,
350-
"spm_unit": result.spm_unit,
351-
"tax_unit": result.tax_unit,
352-
"household": result.household,
353-
}
363+
job.result = _sanitize_for_json(
364+
{
365+
"person": result.person,
366+
"marital_unit": result.marital_unit,
367+
"family": result.family,
368+
"spm_unit": result.spm_unit,
369+
"tax_unit": result.tax_unit,
370+
"household": result.household,
371+
}
372+
)
354373
job.completed_at = datetime.now(timezone.utc)
355374
session.add(job)
356375
session.commit()

src/policyengine_api/modal_app.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,8 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
707707
from policyengine_api.models import (
708708
Dataset,
709709
DecileImpact,
710+
Inequality,
711+
Poverty,
710712
ProgramStatistics,
711713
Report,
712714
ReportStatus,
@@ -748,6 +750,12 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
748750
# Import policyengine
749751
from policyengine.core import Simulation as PESimulation
750752
from policyengine.outputs import DecileImpact as PEDecileImpact
753+
from policyengine.outputs.inequality import (
754+
calculate_uk_inequality,
755+
)
756+
from policyengine.outputs.poverty import (
757+
calculate_uk_poverty_rates,
758+
)
751759
from policyengine.tax_benefit_models.uk import uk_latest
752760
from policyengine.tax_benefit_models.uk.datasets import (
753761
PolicyEngineUKDataset,
@@ -881,6 +889,45 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
881889
except KeyError:
882890
pass # Variable not in model, skip silently
883891

892+
# Calculate poverty rates
893+
with logfire.span("calculate_poverty"):
894+
for sim, sim_id in [
895+
(pe_baseline_sim, baseline_sim.id),
896+
(pe_reform_sim, reform_sim.id),
897+
]:
898+
poverty_collection = calculate_uk_poverty_rates(sim)
899+
for pov in poverty_collection.outputs:
900+
poverty_record = Poverty(
901+
simulation_id=sim_id,
902+
report_id=report.id,
903+
poverty_type=pov.poverty_type,
904+
entity=pov.entity,
905+
filter_variable=pov.filter_variable,
906+
headcount=pov.headcount,
907+
total_population=pov.total_population,
908+
rate=pov.rate,
909+
)
910+
session.add(poverty_record)
911+
912+
# Calculate inequality
913+
with logfire.span("calculate_inequality"):
914+
for sim, sim_id in [
915+
(pe_baseline_sim, baseline_sim.id),
916+
(pe_reform_sim, reform_sim.id),
917+
]:
918+
ineq = calculate_uk_inequality(sim)
919+
inequality_record = Inequality(
920+
simulation_id=sim_id,
921+
report_id=report.id,
922+
income_variable=ineq.income_variable,
923+
entity=ineq.entity,
924+
gini=ineq.gini,
925+
top_10_share=ineq.top_10_share,
926+
top_1_share=ineq.top_1_share,
927+
bottom_50_share=ineq.bottom_50_share,
928+
)
929+
session.add(inequality_record)
930+
884931
# Mark simulations and report as completed
885932
baseline_sim.status = SimulationStatus.COMPLETED
886933
baseline_sim.completed_at = datetime.now(timezone.utc)
@@ -949,6 +996,8 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
949996
from policyengine_api.models import (
950997
Dataset,
951998
DecileImpact,
999+
Inequality,
1000+
Poverty,
9521001
ProgramStatistics,
9531002
Report,
9541003
ReportStatus,
@@ -983,6 +1032,12 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
9831032
# Import policyengine
9841033
from policyengine.core import Simulation as PESimulation
9851034
from policyengine.outputs import DecileImpact as PEDecileImpact
1035+
from policyengine.outputs.inequality import (
1036+
calculate_us_inequality,
1037+
)
1038+
from policyengine.outputs.poverty import (
1039+
calculate_us_poverty_rates,
1040+
)
9861041
from policyengine.tax_benefit_models.us import us_latest
9871042
from policyengine.tax_benefit_models.us.datasets import (
9881043
PolicyEngineUSDataset,
@@ -1113,6 +1168,45 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
11131168
except KeyError:
11141169
pass # Variable not in model, skip silently
11151170

1171+
# Calculate poverty rates
1172+
with logfire.span("calculate_poverty"):
1173+
for sim, sim_id in [
1174+
(pe_baseline_sim, baseline_sim.id),
1175+
(pe_reform_sim, reform_sim.id),
1176+
]:
1177+
poverty_collection = calculate_us_poverty_rates(sim)
1178+
for pov in poverty_collection.outputs:
1179+
poverty_record = Poverty(
1180+
simulation_id=sim_id,
1181+
report_id=report.id,
1182+
poverty_type=pov.poverty_type,
1183+
entity=pov.entity,
1184+
filter_variable=pov.filter_variable,
1185+
headcount=pov.headcount,
1186+
total_population=pov.total_population,
1187+
rate=pov.rate,
1188+
)
1189+
session.add(poverty_record)
1190+
1191+
# Calculate inequality
1192+
with logfire.span("calculate_inequality"):
1193+
for sim, sim_id in [
1194+
(pe_baseline_sim, baseline_sim.id),
1195+
(pe_reform_sim, reform_sim.id),
1196+
]:
1197+
ineq = calculate_us_inequality(sim)
1198+
inequality_record = Inequality(
1199+
simulation_id=sim_id,
1200+
report_id=report.id,
1201+
income_variable=ineq.income_variable,
1202+
entity=ineq.entity,
1203+
gini=ineq.gini,
1204+
top_10_share=ineq.top_10_share,
1205+
top_1_share=ineq.top_1_share,
1206+
bottom_50_share=ineq.bottom_50_share,
1207+
)
1208+
session.add(inequality_record)
1209+
11161210
# Mark simulations and report as completed
11171211
baseline_sim.status = SimulationStatus.COMPLETED
11181212
baseline_sim.completed_at = datetime.now(timezone.utc)

src/policyengine_api/models/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
HouseholdJobRead,
1717
HouseholdJobStatus,
1818
)
19+
from .inequality import Inequality, InequalityCreate, InequalityRead
1920
from .output import (
2021
AggregateOutput,
2122
AggregateOutputCreate,
@@ -25,6 +26,7 @@
2526
from .parameter import Parameter, ParameterCreate, ParameterRead
2627
from .parameter_value import ParameterValue, ParameterValueCreate, ParameterValueRead
2728
from .policy import Policy, PolicyCreate, PolicyRead
29+
from .poverty import Poverty, PovertyCreate, PovertyRead
2830
from .program_statistics import (
2931
ProgramStatistics,
3032
ProgramStatisticsCreate,
@@ -70,6 +72,9 @@
7072
"HouseholdJobCreate",
7173
"HouseholdJobRead",
7274
"HouseholdJobStatus",
75+
"Inequality",
76+
"InequalityCreate",
77+
"InequalityRead",
7378
"Parameter",
7479
"ParameterCreate",
7580
"ParameterRead",
@@ -79,6 +84,9 @@
7984
"Policy",
8085
"PolicyCreate",
8186
"PolicyRead",
87+
"Poverty",
88+
"PovertyCreate",
89+
"PovertyRead",
8290
"ProgramStatistics",
8391
"ProgramStatisticsCreate",
8492
"ProgramStatisticsRead",
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Inequality output model."""
2+
3+
from datetime import datetime, timezone
4+
from uuid import UUID, uuid4
5+
6+
from sqlmodel import Field, SQLModel
7+
8+
9+
class InequalityBase(SQLModel):
10+
"""Base inequality fields."""
11+
12+
simulation_id: UUID = Field(foreign_key="simulations.id")
13+
report_id: UUID | None = Field(default=None, foreign_key="reports.id")
14+
income_variable: str
15+
entity: str = "household"
16+
gini: float | None = None
17+
top_10_share: float | None = None
18+
top_1_share: float | None = None
19+
bottom_50_share: float | None = None
20+
21+
22+
class Inequality(InequalityBase, table=True):
23+
"""Inequality database model."""
24+
25+
__tablename__ = "inequality"
26+
27+
id: UUID = Field(default_factory=uuid4, primary_key=True)
28+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
29+
30+
31+
class InequalityCreate(InequalityBase):
32+
"""Schema for creating inequality records."""
33+
34+
pass
35+
36+
37+
class InequalityRead(InequalityBase):
38+
"""Schema for reading inequality records."""
39+
40+
id: UUID
41+
created_at: datetime
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Poverty output model."""
2+
3+
from datetime import datetime, timezone
4+
from uuid import UUID, uuid4
5+
6+
from sqlmodel import Field, SQLModel
7+
8+
9+
class PovertyBase(SQLModel):
10+
"""Base poverty fields."""
11+
12+
simulation_id: UUID = Field(foreign_key="simulations.id")
13+
report_id: UUID | None = Field(default=None, foreign_key="reports.id")
14+
poverty_type: str # e.g. "absolute_bhc", "spm", etc.
15+
entity: str = "person"
16+
filter_variable: str | None = None
17+
headcount: float | None = None
18+
total_population: float | None = None
19+
rate: float | None = None
20+
21+
22+
class Poverty(PovertyBase, table=True):
23+
"""Poverty database model."""
24+
25+
__tablename__ = "poverty"
26+
27+
id: UUID = Field(default_factory=uuid4, primary_key=True)
28+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
29+
30+
31+
class PovertyCreate(PovertyBase):
32+
"""Schema for creating poverty records."""
33+
34+
pass
35+
36+
37+
class PovertyRead(PovertyBase):
38+
"""Schema for reading poverty records."""
39+
40+
id: UUID
41+
created_at: datetime
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
-- Add poverty and inequality tables for economic analysis
2+
3+
CREATE TABLE IF NOT EXISTS poverty (
4+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
5+
simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE,
6+
report_id UUID REFERENCES reports(id) ON DELETE CASCADE,
7+
poverty_type VARCHAR NOT NULL,
8+
entity VARCHAR NOT NULL DEFAULT 'person',
9+
filter_variable VARCHAR,
10+
headcount FLOAT,
11+
total_population FLOAT,
12+
rate FLOAT,
13+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
14+
);
15+
16+
CREATE TABLE IF NOT EXISTS inequality (
17+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
18+
simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE,
19+
report_id UUID REFERENCES reports(id) ON DELETE CASCADE,
20+
income_variable VARCHAR NOT NULL,
21+
entity VARCHAR NOT NULL DEFAULT 'household',
22+
gini FLOAT,
23+
top_10_share FLOAT,
24+
top_1_share FLOAT,
25+
bottom_50_share FLOAT,
26+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
27+
);
28+
29+
-- Indexes for efficient querying
30+
CREATE INDEX IF NOT EXISTS idx_poverty_simulation_id ON poverty(simulation_id);
31+
CREATE INDEX IF NOT EXISTS idx_poverty_report_id ON poverty(report_id);
32+
CREATE INDEX IF NOT EXISTS idx_inequality_simulation_id ON inequality(simulation_id);
33+
CREATE INDEX IF NOT EXISTS idx_inequality_report_id ON inequality(report_id);

0 commit comments

Comments
 (0)