Skip to content

Commit 5e232d8

Browse files
authored
Merge pull request #141 from PolicyEngine/fix/seed-json-serialization
fix: Serialize lists/dicts as JSON in seed bulk_insert
2 parents dc1dc50 + d4c4ec2 commit 5e232d8

4 files changed

Lines changed: 133 additions & 0 deletions

File tree

changelog.d/fix-seed-json.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Serialize lists and dicts as JSON in bulk_insert for Postgres COPY compatibility

scripts/seed_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Shared utilities for seed scripts."""
22

33
import io
4+
import json
45
import logging
56
import sys
67
import warnings
@@ -61,6 +62,8 @@ def bulk_insert(session: Session, table: str, columns: list[str], rows: list[dic
6162
val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n")
6263
)
6364
values.append(val)
65+
elif isinstance(val, (list, dict)):
66+
values.append(json.dumps(val))
6467
else:
6568
values.append(str(val))
6669
output.write("\t".join(values) + "\n")

tests/test_analysis_household_impact.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from datetime import date
44
from uuid import UUID, uuid4
55

6+
import pytest
7+
68
from policyengine_api.api.household_analysis import (
79
UK_CONFIG,
810
US_CONFIG,
@@ -314,6 +316,7 @@ def test_get_report_not_found(self, client):
314316
# ---------------------------------------------------------------------------
315317

316318

319+
@pytest.mark.integration
317320
class TestHouseholdImpactRecordCreation:
318321
"""Tests for correct record creation."""
319322

@@ -404,6 +407,7 @@ def test_report_links_simulations(self, client, session):
404407
# ---------------------------------------------------------------------------
405408

406409

410+
@pytest.mark.integration
407411
class TestHouseholdImpactDeduplication:
408412
"""Tests for simulation/report deduplication."""
409413

@@ -470,6 +474,7 @@ def test_different_policy_creates_different_simulation(self, client, session):
470474
# ---------------------------------------------------------------------------
471475

472476

477+
@pytest.mark.integration
473478
class TestGetHouseholdImpact:
474479
"""Tests for GET /analysis/household-impact/{report_id}."""
475480

@@ -500,6 +505,7 @@ def test_get_returns_report_data(self, client, session):
500505
# ---------------------------------------------------------------------------
501506

502507

508+
@pytest.mark.integration
503509
class TestUSHouseholdImpact:
504510
"""Tests specific to US households."""
505511

tests/test_seed_utils.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""Tests for scripts/seed_utils.py bulk_insert serialization."""
2+
3+
import json
4+
import sys
5+
from unittest.mock import MagicMock
6+
from uuid import uuid4
7+
8+
import pytest
9+
10+
# Mock the settings module before importing seed_utils,
11+
# since seed_utils has module-level imports that require env vars.
12+
mock_settings = MagicMock()
13+
mock_settings.logfire_token = None
14+
mock_settings.database_url = "sqlite://"
15+
sys.modules["policyengine_api.config.settings"] = MagicMock(settings=mock_settings)
16+
sys.modules["policyengine_api.config"] = MagicMock(settings=mock_settings)
17+
18+
from scripts.seed_utils import bulk_insert # noqa: E402
19+
20+
21+
@pytest.fixture
22+
def mock_session():
23+
"""Create a mock SQLModel session with a captured copy_from call."""
24+
session = MagicMock()
25+
cursor = MagicMock()
26+
session.connection().connection.dbapi_connection.cursor.return_value = cursor
27+
return session, cursor
28+
29+
30+
def _capture_copy_buffer(mock_cursor) -> str:
31+
"""Extract the StringIO content passed to copy_from."""
32+
call_args = mock_cursor.copy_from.call_args
33+
buffer = call_args[0][0]
34+
buffer.seek(0)
35+
return buffer.read()
36+
37+
38+
class TestBulkInsertSerialization:
39+
def test_list_serialized_as_json(self, mock_session):
40+
"""Lists should be serialized with json.dumps (double quotes), not str() (single quotes)."""
41+
session, cursor = mock_session
42+
rows = [{"adds": ["child_benefit", "housing_benefit"]}]
43+
44+
bulk_insert(session, "variables", ["adds"], rows)
45+
46+
content = _capture_copy_buffer(cursor)
47+
assert '["child_benefit", "housing_benefit"]' in content
48+
assert "['child_benefit'" not in content
49+
50+
def test_dict_serialized_as_json(self, mock_session):
51+
"""Dicts should be serialized with json.dumps."""
52+
session, cursor = mock_session
53+
rows = [{"metadata": {"key": "value", "nested": True}}]
54+
55+
bulk_insert(session, "test_table", ["metadata"], rows)
56+
57+
content = _capture_copy_buffer(cursor)
58+
parsed = json.loads(content.strip())
59+
assert parsed == {"key": "value", "nested": True}
60+
61+
def test_empty_list_serialized_as_json(self, mock_session):
62+
"""Empty lists should produce valid JSON."""
63+
session, cursor = mock_session
64+
rows = [{"adds": []}]
65+
66+
bulk_insert(session, "variables", ["adds"], rows)
67+
68+
content = _capture_copy_buffer(cursor)
69+
assert "[]" in content
70+
71+
def test_none_serialized_as_null(self, mock_session):
72+
"""None values should produce \\N for Postgres COPY null."""
73+
session, cursor = mock_session
74+
rows = [{"adds": None}]
75+
76+
bulk_insert(session, "variables", ["adds"], rows)
77+
78+
content = _capture_copy_buffer(cursor)
79+
assert "\\N" in content
80+
81+
def test_string_with_special_chars_escaped(self, mock_session):
82+
"""Strings with tabs and newlines should be escaped."""
83+
session, cursor = mock_session
84+
rows = [{"description": "line1\tline2\nline3"}]
85+
86+
bulk_insert(session, "variables", ["description"], rows)
87+
88+
content = _capture_copy_buffer(cursor)
89+
assert "\\t" in content
90+
assert "\\n" in content
91+
92+
def test_multiple_columns_with_mixed_types(self, mock_session):
93+
"""Verify correct serialization across multiple column types in one row."""
94+
session, cursor = mock_session
95+
row_id = uuid4()
96+
rows = [
97+
{
98+
"id": row_id,
99+
"name": "income_tax",
100+
"adds": ["gross_income"],
101+
"subtracts": [],
102+
"description": None,
103+
}
104+
]
105+
columns = ["id", "name", "adds", "subtracts", "description"]
106+
107+
bulk_insert(session, "variables", columns, rows)
108+
109+
content = _capture_copy_buffer(cursor)
110+
parts = content.strip().split("\t")
111+
assert parts[0] == str(row_id)
112+
assert parts[1] == "income_tax"
113+
assert parts[2] == '["gross_income"]'
114+
assert parts[3] == "[]"
115+
assert parts[4] == "\\N"
116+
117+
def test_empty_rows_skips_copy(self, mock_session):
118+
"""Empty row list should return without calling copy_from."""
119+
session, cursor = mock_session
120+
121+
bulk_insert(session, "variables", ["adds"], [])
122+
123+
cursor.copy_from.assert_not_called()

0 commit comments

Comments
 (0)