Skip to content

Commit 4d1d087

Browse files
committed
Allow fully plural stored household groups
1 parent a16feba commit 4d1d087

4 files changed

Lines changed: 75 additions & 22 deletions

File tree

src/policyengine_api/models/household.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
}
2121

2222
_ENTITY_GROUP_FIELDS = tuple(_ENTITY_ID_KEY_BY_GROUP.keys())
23-
_MULTI_ROW_ALLOWED = {"marital_unit"}
2423

2524

2625
def _coerce_entity_group_collection(value: Any) -> Any:
@@ -55,8 +54,8 @@ class Household(HouseholdBase, table=True):
5554
class HouseholdCreate(StoredHouseholdPayload):
5655
"""Schema for creating a stored household.
5756
58-
Uses list-based entity groups for storage, with multiple rows allowed only
59-
for marital_unit.
57+
Uses the same list-based relational entity shape as the household
58+
calculation APIs.
6059
"""
6160

6261
@field_validator(*_ENTITY_GROUP_FIELDS, mode="before")
@@ -78,11 +77,6 @@ def validate_relationships(self) -> "HouseholdCreate":
7877
entity_records = getattr(self, group_key)
7978
person_link_key = f"person_{entity_id_key}"
8079

81-
if len(entity_records) > 1 and group_key not in _MULTI_ROW_ALLOWED:
82-
raise ValueError(
83-
f"{group_key} supports at most one row in stored households"
84-
)
85-
8680
entity_ids = [
8781
entity[entity_id_key]
8882
for entity in entity_records

src/policyengine_api/models/household_payload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
class HouseholdEntityCollections(SQLModel):
11-
"""Plural household entity collections used by calculation and storage APIs."""
11+
"""Plural household entity collections used by stored and calculation APIs."""
1212

1313
benunit: list[dict[str, Any]] = Field(default_factory=list)
1414
marital_unit: list[dict[str, Any]] = Field(default_factory=list)

test_fixtures/fixtures_households.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,54 @@
6262
"household": [{"household_id": 0, "state_name": "CA"}],
6363
}
6464

65+
MOCK_US_FULL_MULTI_GROUP_HOUSEHOLD_CREATE = {
66+
"country_id": "us",
67+
"year": 2024,
68+
"label": "US fully multi-group household",
69+
"people": [
70+
{
71+
"person_id": 0,
72+
"person_household_id": 0,
73+
"person_tax_unit_id": 0,
74+
"person_marital_unit_id": 0,
75+
"person_family_id": 0,
76+
"person_spm_unit_id": 0,
77+
"age": 30,
78+
"employment_income": 50000,
79+
},
80+
{
81+
"person_id": 1,
82+
"person_household_id": 1,
83+
"person_tax_unit_id": 1,
84+
"person_marital_unit_id": 1,
85+
"person_family_id": 1,
86+
"person_spm_unit_id": 1,
87+
"age": 28,
88+
"employment_income": 30000,
89+
},
90+
],
91+
"tax_unit": [
92+
{"tax_unit_id": 0, "state_name": "CA"},
93+
{"tax_unit_id": 1, "state_name": "CA"},
94+
],
95+
"marital_unit": [
96+
{"marital_unit_id": 0},
97+
{"marital_unit_id": 1},
98+
],
99+
"family": [
100+
{"family_id": 0},
101+
{"family_id": 1},
102+
],
103+
"spm_unit": [
104+
{"spm_unit_id": 0},
105+
{"spm_unit_id": 1},
106+
],
107+
"household": [
108+
{"household_id": 0, "state_name": "CA"},
109+
{"household_id": 1, "state_name": "NY"},
110+
],
111+
}
112+
65113
MOCK_HOUSEHOLD_MINIMAL = {
66114
"country_id": "us",
67115
"year": 2024,

tests/test_households.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from test_fixtures.fixtures_households import (
66
MOCK_HOUSEHOLD_MINIMAL,
77
MOCK_UK_HOUSEHOLD_CREATE,
8+
MOCK_US_FULL_MULTI_GROUP_HOUSEHOLD_CREATE,
89
MOCK_US_HOUSEHOLD_CREATE,
910
MOCK_US_HOUSEHOLD_CREATE_LEGACY,
1011
MOCK_US_MULTI_GROUP_HOUSEHOLD_CREATE,
@@ -62,14 +63,20 @@ def test_create_household_minimal(client):
6263

6364

6465
def test_create_household_round_trips_multiple_entity_groups(client):
65-
"""Stored household CRUD preserves multiple marital units."""
66-
response = client.post("/households", json=MOCK_US_MULTI_GROUP_HOUSEHOLD_CREATE)
66+
"""Stored household CRUD preserves multiple rows across entity groups."""
67+
response = client.post(
68+
"/households", json=MOCK_US_FULL_MULTI_GROUP_HOUSEHOLD_CREATE
69+
)
6770
assert response.status_code == 201
6871
data = response.json()
69-
assert len(data["tax_unit"]) == 1
70-
assert data["tax_unit"][0]["tax_unit_id"] == 0
72+
assert len(data["tax_unit"]) == 2
7173
assert len(data["marital_unit"]) == 2
72-
assert data["people"][1]["person_marital_unit_id"] == 1
74+
assert len(data["family"]) == 2
75+
assert len(data["spm_unit"]) == 2
76+
assert len(data["household"]) == 2
77+
assert data["people"][0]["person_tax_unit_id"] == 0
78+
assert data["people"][1]["person_tax_unit_id"] == 1
79+
assert data["people"][1]["person_household_id"] == 1
7380

7481

7582
def test_create_household_invalid_country_id(client):
@@ -154,8 +161,8 @@ def test_create_household_requires_person_links_for_multi_group_rows(client):
154161
assert "when " in response.text
155162

156163

157-
def test_create_household_rejects_multiple_tax_units(client):
158-
"""Stored households support at most one tax unit."""
164+
def test_create_household_accepts_multiple_tax_units(client):
165+
"""Stored households preserve multiple tax units when person links are present."""
159166
payload = {
160167
**MOCK_US_MULTI_GROUP_HOUSEHOLD_CREATE,
161168
"tax_unit": [
@@ -182,12 +189,14 @@ def test_create_household_rejects_multiple_tax_units(client):
182189

183190
response = client.post("/households", json=payload)
184191

185-
assert response.status_code == 422
186-
assert "tax_unit supports at most one row" in response.text
192+
assert response.status_code == 201
193+
data = response.json()
194+
assert len(data["tax_unit"]) == 2
195+
assert data["people"][1]["person_tax_unit_id"] == 1
187196

188197

189-
def test_create_household_rejects_multiple_households(client):
190-
"""Stored households support at most one household row."""
198+
def test_create_household_accepts_multiple_households(client):
199+
"""Stored households preserve multiple household rows when person links are present."""
191200
payload = {
192201
**MOCK_US_MULTI_GROUP_HOUSEHOLD_CREATE,
193202
"household": [
@@ -214,8 +223,10 @@ def test_create_household_rejects_multiple_households(client):
214223

215224
response = client.post("/households", json=payload)
216225

217-
assert response.status_code == 422
218-
assert "household supports at most one row" in response.text
226+
assert response.status_code == 201
227+
data = response.json()
228+
assert len(data["household"]) == 2
229+
assert data["people"][1]["person_household_id"] == 1
219230

220231

221232
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)