|
| 1 | +"""Tests for scripts/seed_utils.py bulk_insert serialization.""" |
| 2 | + |
| 3 | +import json |
| 4 | +import sys |
| 5 | +from unittest.mock import MagicMock |
| 6 | +from uuid import uuid4 |
| 7 | + |
| 8 | +import pytest |
| 9 | + |
| 10 | +# Mock the settings module before importing seed_utils, |
| 11 | +# since seed_utils has module-level imports that require env vars. |
| 12 | +mock_settings = MagicMock() |
| 13 | +mock_settings.logfire_token = None |
| 14 | +mock_settings.database_url = "sqlite://" |
| 15 | +sys.modules["policyengine_api.config.settings"] = MagicMock(settings=mock_settings) |
| 16 | +sys.modules["policyengine_api.config"] = MagicMock(settings=mock_settings) |
| 17 | + |
| 18 | +from scripts.seed_utils import bulk_insert # noqa: E402 |
| 19 | + |
| 20 | + |
| 21 | +@pytest.fixture |
| 22 | +def mock_session(): |
| 23 | + """Create a mock SQLModel session with a captured copy_from call.""" |
| 24 | + session = MagicMock() |
| 25 | + cursor = MagicMock() |
| 26 | + session.connection().connection.dbapi_connection.cursor.return_value = cursor |
| 27 | + return session, cursor |
| 28 | + |
| 29 | + |
| 30 | +def _capture_copy_buffer(mock_cursor) -> str: |
| 31 | + """Extract the StringIO content passed to copy_from.""" |
| 32 | + call_args = mock_cursor.copy_from.call_args |
| 33 | + buffer = call_args[0][0] |
| 34 | + buffer.seek(0) |
| 35 | + return buffer.read() |
| 36 | + |
| 37 | + |
| 38 | +class TestBulkInsertSerialization: |
| 39 | + def test_list_serialized_as_json(self, mock_session): |
| 40 | + """Lists should be serialized with json.dumps (double quotes), not str() (single quotes).""" |
| 41 | + session, cursor = mock_session |
| 42 | + rows = [{"adds": ["child_benefit", "housing_benefit"]}] |
| 43 | + |
| 44 | + bulk_insert(session, "variables", ["adds"], rows) |
| 45 | + |
| 46 | + content = _capture_copy_buffer(cursor) |
| 47 | + assert '["child_benefit", "housing_benefit"]' in content |
| 48 | + assert "['child_benefit'" not in content |
| 49 | + |
| 50 | + def test_dict_serialized_as_json(self, mock_session): |
| 51 | + """Dicts should be serialized with json.dumps.""" |
| 52 | + session, cursor = mock_session |
| 53 | + rows = [{"metadata": {"key": "value", "nested": True}}] |
| 54 | + |
| 55 | + bulk_insert(session, "test_table", ["metadata"], rows) |
| 56 | + |
| 57 | + content = _capture_copy_buffer(cursor) |
| 58 | + parsed = json.loads(content.strip()) |
| 59 | + assert parsed == {"key": "value", "nested": True} |
| 60 | + |
| 61 | + def test_empty_list_serialized_as_json(self, mock_session): |
| 62 | + """Empty lists should produce valid JSON.""" |
| 63 | + session, cursor = mock_session |
| 64 | + rows = [{"adds": []}] |
| 65 | + |
| 66 | + bulk_insert(session, "variables", ["adds"], rows) |
| 67 | + |
| 68 | + content = _capture_copy_buffer(cursor) |
| 69 | + assert "[]" in content |
| 70 | + |
| 71 | + def test_none_serialized_as_null(self, mock_session): |
| 72 | + """None values should produce \\N for Postgres COPY null.""" |
| 73 | + session, cursor = mock_session |
| 74 | + rows = [{"adds": None}] |
| 75 | + |
| 76 | + bulk_insert(session, "variables", ["adds"], rows) |
| 77 | + |
| 78 | + content = _capture_copy_buffer(cursor) |
| 79 | + assert "\\N" in content |
| 80 | + |
| 81 | + def test_string_with_special_chars_escaped(self, mock_session): |
| 82 | + """Strings with tabs and newlines should be escaped.""" |
| 83 | + session, cursor = mock_session |
| 84 | + rows = [{"description": "line1\tline2\nline3"}] |
| 85 | + |
| 86 | + bulk_insert(session, "variables", ["description"], rows) |
| 87 | + |
| 88 | + content = _capture_copy_buffer(cursor) |
| 89 | + assert "\\t" in content |
| 90 | + assert "\\n" in content |
| 91 | + |
| 92 | + def test_multiple_columns_with_mixed_types(self, mock_session): |
| 93 | + """Verify correct serialization across multiple column types in one row.""" |
| 94 | + session, cursor = mock_session |
| 95 | + row_id = uuid4() |
| 96 | + rows = [ |
| 97 | + { |
| 98 | + "id": row_id, |
| 99 | + "name": "income_tax", |
| 100 | + "adds": ["gross_income"], |
| 101 | + "subtracts": [], |
| 102 | + "description": None, |
| 103 | + } |
| 104 | + ] |
| 105 | + columns = ["id", "name", "adds", "subtracts", "description"] |
| 106 | + |
| 107 | + bulk_insert(session, "variables", columns, rows) |
| 108 | + |
| 109 | + content = _capture_copy_buffer(cursor) |
| 110 | + parts = content.strip().split("\t") |
| 111 | + assert parts[0] == str(row_id) |
| 112 | + assert parts[1] == "income_tax" |
| 113 | + assert parts[2] == '["gross_income"]' |
| 114 | + assert parts[3] == "[]" |
| 115 | + assert parts[4] == "\\N" |
| 116 | + |
| 117 | + def test_empty_rows_skips_copy(self, mock_session): |
| 118 | + """Empty row list should return without calling copy_from.""" |
| 119 | + session, cursor = mock_session |
| 120 | + |
| 121 | + bulk_insert(session, "variables", ["adds"], []) |
| 122 | + |
| 123 | + cursor.copy_from.assert_not_called() |
0 commit comments