Skip to content

Commit 6421e34

Browse files
authored
Merge pull request #257 from PolicyEngine/fix/allow-multiple-group-entities
Allow multiple stored household group entities
2 parents bb710a2 + 4d1d087 commit 6421e34

11 files changed

Lines changed: 573 additions & 45 deletions

File tree

changelog.d/256.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Stored `/households` definitions now use the same plural entity-list contract as household calculation, including validation for multi-group person/entity linkage and compatibility coercion for legacy singular payloads.

src/policyengine_api/api/household.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ class HouseholdCalculateRequest(BaseModel):
9696
}
9797
"""
9898

99-
country_id: CountryId = Field(
100-
description="Which country model to use ('us' or 'uk')"
101-
)
10299
people: list[dict[str, Any]] = Field(
103100
description="List of people with flat variable values. Include person_id and person_{entity}_id fields to link to entities."
104101
)
@@ -130,6 +127,9 @@ class HouseholdCalculateRequest(BaseModel):
130127
default=None,
131128
description="Simulation year (default: 2024 for US, 2026 for UK). Specify this instead of embedding years in variable values.",
132129
)
130+
country_id: CountryId = Field(
131+
description="Which country model to use ('us' or 'uk')"
132+
)
133133
policy_id: UUID | None = Field(
134134
default=None, description="Optional policy reform ID"
135135
)
@@ -183,9 +183,6 @@ class HouseholdImpactRequest(BaseModel):
183183
}
184184
"""
185185

186-
country_id: CountryId = Field(
187-
description="Which country model to use ('us' or 'uk')"
188-
)
189186
people: list[dict[str, Any]] = Field(
190187
description="List of people with flat variable values. Include person_id and person_{entity}_id fields to link to entities."
191188
)
@@ -216,6 +213,9 @@ class HouseholdImpactRequest(BaseModel):
216213
year: int | None = Field(
217214
default=None, description="Simulation year (default: 2024 for US, 2026 for UK)"
218215
)
216+
country_id: CountryId = Field(
217+
description="Which country model to use ('us' or 'uk')"
218+
)
219219
policy_id: UUID | None = Field(
220220
default=None, description="Reform policy ID to compare against baseline"
221221
)
@@ -854,7 +854,7 @@ def calculate_household(
854854
Use flat values for all variables - do NOT use time-period format like {"2024": value}.
855855
The simulation year is specified via the `year` parameter.
856856
857-
US example: people=[{"employment_income": 70000, "age": 40}], tax_unit={"state_code": "CA"}, year=2024
857+
US example: people=[{"employment_income": 70000, "age": 40}], tax_unit=[{"state_code": "CA"}], year=2024
858858
UK example: people=[{"employment_income": 50000, "age": 30}], year=2026
859859
"""
860860
with logfire.span(

src/policyengine_api/api/household_analysis.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
3. Run analysis: POST /analysis/household-impact (returns report_id)
1010
4. Poll GET /analysis/household-impact/{report_id} until status="completed"
1111
5. Results include baseline_result, reform_result (if comparison), and impact diff
12+
13+
Stored households now use the same plural entity-list contract as the
14+
/household/calculate endpoint. Analysis should pass those lists through
15+
without collapsing multi-group households into a single entity row.
1216
"""
1317

1418
from dataclasses import dataclass
@@ -109,7 +113,7 @@ def calculate_uk_household(
109113
year: int,
110114
policy_data: dict | None,
111115
) -> dict:
112-
"""Calculate UK household using the existing implementation."""
116+
"""Calculate UK household using the stored-household plural entity contract."""
113117
from policyengine_api.api.household import _calculate_household_uk
114118

115119
return _calculate_household_uk(
@@ -126,7 +130,7 @@ def calculate_us_household(
126130
year: int,
127131
policy_data: dict | None,
128132
) -> dict:
129-
"""Calculate US household using the existing implementation."""
133+
"""Calculate US household using the stored-household plural entity contract."""
130134
from policyengine_api.api.household import _calculate_household_us
131135

132136
return _calculate_household_us(

src/policyengine_api/api/households.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def _pack_household_data(body: HouseholdCreate) -> dict[str, Any]:
3434
"""Pack the flat request fields into a single JSON blob for storage."""
3535
data: dict[str, Any] = {"people": body.people}
3636
for key in _ENTITY_GROUP_KEYS:
37-
val = getattr(body, key)
38-
if val is not None:
39-
data[key] = val
37+
data[key] = list(getattr(body, key))
4038
return data
4139

4240

@@ -49,12 +47,12 @@ def _to_read(record: Household) -> HouseholdRead:
4947
year=record.year,
5048
label=record.label,
5149
people=data["people"],
52-
tax_unit=data.get("tax_unit"),
53-
family=data.get("family"),
54-
spm_unit=data.get("spm_unit"),
55-
marital_unit=data.get("marital_unit"),
56-
household=data.get("household"),
57-
benunit=data.get("benunit"),
50+
tax_unit=data.get("tax_unit", []),
51+
family=data.get("family", []),
52+
spm_unit=data.get("spm_unit", []),
53+
marital_unit=data.get("marital_unit", []),
54+
household=data.get("household", []),
55+
benunit=data.get("benunit", []),
5856
created_at=record.created_at,
5957
updated_at=record.updated_at,
6058
)

src/policyengine_api/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
HouseholdJobRead,
3535
HouseholdJobStatus,
3636
)
37+
from .household_payload import HouseholdEntityCollections, StoredHouseholdPayload
3738
from .inequality import Inequality, InequalityCreate, InequalityRead
3839
from .intra_decile_impact import (
3940
DecileType,
@@ -152,6 +153,8 @@
152153
"DynamicRead",
153154
"Household",
154155
"HouseholdCreate",
156+
"HouseholdEntityCollections",
157+
"StoredHouseholdPayload",
155158
"HouseholdRead",
156159
"HouseholdJob",
157160
"HouseholdJobCreate",

src/policyengine_api/models/household.py

Lines changed: 83 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,32 @@
44
from typing import Any
55
from uuid import UUID, uuid4
66

7+
from pydantic import field_validator, model_validator
78
from sqlalchemy import JSON
89
from sqlmodel import Column, Field, SQLModel
910

10-
from policyengine_api.config.constants import CountryId
11+
from policyengine_api.models.household_payload import StoredHouseholdPayload
12+
13+
_ENTITY_ID_KEY_BY_GROUP = {
14+
"benunit": "benunit_id",
15+
"marital_unit": "marital_unit_id",
16+
"family": "family_id",
17+
"spm_unit": "spm_unit_id",
18+
"tax_unit": "tax_unit_id",
19+
"household": "household_id",
20+
}
21+
22+
_ENTITY_GROUP_FIELDS = tuple(_ENTITY_ID_KEY_BY_GROUP.keys())
23+
24+
25+
def _coerce_entity_group_collection(value: Any) -> Any:
26+
if value is None:
27+
return []
28+
if isinstance(value, list):
29+
return value
30+
if isinstance(value, dict):
31+
return [value]
32+
return value
1133

1234

1335
class HouseholdBase(SQLModel):
@@ -29,23 +51,70 @@ class Household(HouseholdBase, table=True):
2951
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
3052

3153

32-
class HouseholdCreate(SQLModel):
54+
class HouseholdCreate(StoredHouseholdPayload):
3355
"""Schema for creating a stored household.
3456
35-
Accepts the flat structure matching the frontend Household interface:
36-
people as an array, entity groups as optional dicts.
57+
Uses the same list-based relational entity shape as the household
58+
calculation APIs.
3759
"""
3860

39-
country_id: CountryId
40-
year: int
41-
label: str | None = None
42-
people: list[dict[str, Any]]
43-
tax_unit: dict[str, Any] | None = None
44-
family: dict[str, Any] | None = None
45-
spm_unit: dict[str, Any] | None = None
46-
marital_unit: dict[str, Any] | None = None
47-
household: dict[str, Any] | None = None
48-
benunit: dict[str, Any] | None = None
61+
@field_validator(*_ENTITY_GROUP_FIELDS, mode="before")
62+
@classmethod
63+
def coerce_legacy_singular_entity_groups(cls, value: Any) -> Any:
64+
return _coerce_entity_group_collection(value)
65+
66+
@model_validator(mode="after")
67+
def validate_relationships(self) -> "HouseholdCreate":
68+
person_ids = [
69+
person["person_id"]
70+
for person in self.people
71+
if person.get("person_id") is not None
72+
]
73+
if len(person_ids) != len(set(person_ids)):
74+
raise ValueError("people contains duplicate person_id values")
75+
76+
for group_key, entity_id_key in _ENTITY_ID_KEY_BY_GROUP.items():
77+
entity_records = getattr(self, group_key)
78+
person_link_key = f"person_{entity_id_key}"
79+
80+
entity_ids = [
81+
entity[entity_id_key]
82+
for entity in entity_records
83+
if entity.get(entity_id_key) is not None
84+
]
85+
86+
if len(entity_ids) != len(set(entity_ids)):
87+
raise ValueError(
88+
f"{group_key} contains duplicate {entity_id_key} values"
89+
)
90+
91+
requires_linkage = len(entity_records) > 1
92+
person_links = []
93+
for person in self.people:
94+
person_link = person.get(person_link_key)
95+
if person_link is None:
96+
if requires_linkage:
97+
raise ValueError(
98+
f"people must include {person_link_key} when {group_key} has multiple rows"
99+
)
100+
continue
101+
person_links.append(person_link)
102+
103+
if not person_links:
104+
continue
105+
106+
if len(entity_ids) != len(entity_records):
107+
raise ValueError(
108+
f"{group_key} rows must all include {entity_id_key} when people reference {person_link_key}"
109+
)
110+
111+
unknown_links = sorted(set(person_links) - set(entity_ids))
112+
if unknown_links:
113+
raise ValueError(
114+
f"{group_key} is missing rows for referenced {entity_id_key} values: {unknown_links}"
115+
)
116+
117+
return self
49118

50119

51120
class HouseholdRead(HouseholdCreate):
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Shared household payload models."""
2+
3+
from typing import Any
4+
5+
from sqlmodel import Field, SQLModel
6+
7+
from policyengine_api.config.constants import CountryId
8+
9+
10+
class HouseholdEntityCollections(SQLModel):
11+
"""Plural household entity collections used by stored and calculation APIs."""
12+
13+
benunit: list[dict[str, Any]] = Field(default_factory=list)
14+
marital_unit: list[dict[str, Any]] = Field(default_factory=list)
15+
family: list[dict[str, Any]] = Field(default_factory=list)
16+
spm_unit: list[dict[str, Any]] = Field(default_factory=list)
17+
tax_unit: list[dict[str, Any]] = Field(default_factory=list)
18+
household: list[dict[str, Any]] = Field(default_factory=list)
19+
20+
21+
class StoredHouseholdPayload(HouseholdEntityCollections):
22+
"""Core payload shared by stored household create/read flows."""
23+
24+
country_id: CountryId
25+
people: list[dict[str, Any]]
26+
year: int
27+
label: str | None = None

test_fixtures/fixtures_households.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
{"age": 30, "employment_income": 50000},
1515
{"age": 28, "employment_income": 30000},
1616
],
17-
"tax_unit": {},
18-
"family": {},
19-
"household": {"state_name": "CA"},
17+
"tax_unit": [{}],
18+
"family": [{}],
19+
"household": [{"state_name": "CA"}],
2020
}
2121

2222
MOCK_UK_HOUSEHOLD_CREATE = {
@@ -26,8 +26,88 @@
2626
"people": [
2727
{"age": 40, "employment_income": 35000},
2828
],
29-
"benunit": {"is_married": False},
30-
"household": {"region": "LONDON"},
29+
"benunit": [{"is_married": False}],
30+
"household": [{"region": "LONDON"}],
31+
}
32+
33+
MOCK_US_MULTI_GROUP_HOUSEHOLD_CREATE = {
34+
"country_id": "us",
35+
"year": 2024,
36+
"label": "US multi-group household",
37+
"people": [
38+
{
39+
"person_id": 0,
40+
"person_household_id": 0,
41+
"person_tax_unit_id": 0,
42+
"person_marital_unit_id": 0,
43+
"age": 30,
44+
"employment_income": 50000,
45+
},
46+
{
47+
"person_id": 1,
48+
"person_household_id": 0,
49+
"person_tax_unit_id": 0,
50+
"person_marital_unit_id": 1,
51+
"age": 28,
52+
"employment_income": 30000,
53+
},
54+
],
55+
"tax_unit": [{"tax_unit_id": 0, "state_name": "CA"}],
56+
"marital_unit": [
57+
{"marital_unit_id": 0},
58+
{"marital_unit_id": 1},
59+
],
60+
"family": [{"family_id": 0}],
61+
"spm_unit": [{"spm_unit_id": 0}],
62+
"household": [{"household_id": 0, "state_name": "CA"}],
63+
}
64+
65+
MOCK_US_FULL_MULTI_GROUP_HOUSEHOLD_CREATE = {
66+
"country_id": "us",
67+
"year": 2024,
68+
"label": "US fully multi-group household",
69+
"people": [
70+
{
71+
"person_id": 0,
72+
"person_household_id": 0,
73+
"person_tax_unit_id": 0,
74+
"person_marital_unit_id": 0,
75+
"person_family_id": 0,
76+
"person_spm_unit_id": 0,
77+
"age": 30,
78+
"employment_income": 50000,
79+
},
80+
{
81+
"person_id": 1,
82+
"person_household_id": 1,
83+
"person_tax_unit_id": 1,
84+
"person_marital_unit_id": 1,
85+
"person_family_id": 1,
86+
"person_spm_unit_id": 1,
87+
"age": 28,
88+
"employment_income": 30000,
89+
},
90+
],
91+
"tax_unit": [
92+
{"tax_unit_id": 0, "state_name": "CA"},
93+
{"tax_unit_id": 1, "state_name": "CA"},
94+
],
95+
"marital_unit": [
96+
{"marital_unit_id": 0},
97+
{"marital_unit_id": 1},
98+
],
99+
"family": [
100+
{"family_id": 0},
101+
{"family_id": 1},
102+
],
103+
"spm_unit": [
104+
{"spm_unit_id": 0},
105+
{"spm_unit_id": 1},
106+
],
107+
"household": [
108+
{"household_id": 0, "state_name": "CA"},
109+
{"household_id": 1, "state_name": "NY"},
110+
],
31111
}
32112

33113
MOCK_HOUSEHOLD_MINIMAL = {
@@ -36,6 +116,19 @@
36116
"people": [{"age": 25}],
37117
}
38118

119+
MOCK_US_HOUSEHOLD_CREATE_LEGACY = {
120+
"country_id": "us",
121+
"year": 2024,
122+
"label": "US legacy household",
123+
"people": [
124+
{"age": 30, "employment_income": 50000},
125+
{"age": 28, "employment_income": 30000},
126+
],
127+
"tax_unit": {},
128+
"family": {},
129+
"household": {"state_name": "CA"},
130+
}
131+
39132

40133
# -----------------------------------------------------------------------------
41134
# Factory functions

0 commit comments

Comments
 (0)