Skip to content

Commit 8cad222

Browse files
MaxGhenisclaude
andcommitted
Upgrade SQLAlchemy from v1 to v2 for Python 3.14 compatibility
SQLAlchemy v1 does not support Python 3.14, making this upgrade a blocker. Key changes: - Update setup.py pin from sqlalchemy>=1.4,<2 to sqlalchemy>=2,<3 - Replace removed engine.execute() with connection-based execution using conn.exec_driver_sql() inside a connection context manager - Add _ResultProxy wrapper to eagerly fetch results so they survive connection closure (maintains existing fetchone()/fetchall() API) - Replace LegacyRow (removed in v2) with Row in type annotations - Update test mocks that used spec=LegacyRow to use plain MagicMock() since Row in v2 has a different interface - Add 10 dedicated SQLAlchemy v2 compatibility tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b58fc15 commit 8cad222

10 files changed

Lines changed: 258 additions & 17 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
changed:
4+
- Upgraded SQLAlchemy from v1 (>=1.4,<2) to v2 (>=2,<3) for Python 3.14 compatibility. Replaced removed engine.execute() with connection-based execution, updated LegacyRow to Row, and added _ResultProxy wrapper for eager result fetching.

policyengine_api/data/data.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@
1313
load_dotenv()
1414

1515

16+
class _ResultProxy:
17+
"""Lightweight wrapper that eagerly fetches results from a
18+
SQLAlchemy CursorResult so they survive connection closure.
19+
Provides fetchone()/fetchall() with dict-like row access."""
20+
21+
def __init__(self, cursor_result):
22+
try:
23+
# Use .mappings() so rows behave like dicts
24+
self._rows = list(cursor_result.mappings())
25+
except Exception:
26+
# For non-SELECT statements (INSERT/UPDATE/DELETE)
27+
# there are no rows to fetch
28+
self._rows = []
29+
self._index = 0
30+
31+
def fetchone(self):
32+
if self._index < len(self._rows):
33+
row = self._rows[self._index]
34+
self._index += 1
35+
return row
36+
return None
37+
38+
def fetchall(self):
39+
remaining = self._rows[self._index :]
40+
self._index = len(self._rows)
41+
return remaining
42+
43+
1644
class PolicyEngineDatabase:
1745
"""
1846
A wrapper around the database connection.
@@ -70,6 +98,22 @@ def _close_pool(self):
7098
except:
7199
pass
72100

101+
def _execute_remote(self, query_args):
102+
"""Execute a query against the remote database using
103+
SQLAlchemy v2 connection-based execution."""
104+
main_query = query_args[0]
105+
params = query_args[1] if len(query_args) > 1 else None
106+
with self.pool.connect() as conn:
107+
if params is not None:
108+
result = conn.exec_driver_sql(main_query, params)
109+
else:
110+
result = conn.exec_driver_sql(main_query)
111+
conn.commit()
112+
# Return a lightweight wrapper that holds
113+
# the fetched results so they survive the
114+
# connection context closing
115+
return _ResultProxy(result)
116+
73117
def query(self, *query):
74118
if self.local:
75119
with sqlite3.connect(self.db_url) as conn:
@@ -89,7 +133,7 @@ def dict_factory(cursor, row):
89133
main_query = main_query.replace("?", "%s")
90134
query[0] = main_query
91135
try:
92-
return self.pool.execute(*query)
136+
return self._execute_remote(query)
93137
# Except InterfaceError and OperationalError, which are thrown when the connection is lost.
94138
except (
95139
sqlalchemy.exc.InterfaceError,
@@ -98,7 +142,7 @@ def dict_factory(cursor, row):
98142
try:
99143
self._close_pool()
100144
self._create_pool()
101-
return self.pool.execute(*query)
145+
return self._execute_remote(query)
102146
except Exception as e:
103147
raise e
104148

policyengine_api/services/household_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from sqlalchemy.engine.row import LegacyRow
2+
from sqlalchemy.engine.row import Row
33

44
from policyengine_api.data import database
55
from policyengine_api.utils import hash_object
@@ -24,7 +24,7 @@ def get_household(self, country_id: str, household_id: int) -> dict | None:
2424
f"Invalid household ID: {household_id}. Must be a positive integer."
2525
)
2626

27-
row: LegacyRow | None = database.query(
27+
row: Row | None = database.query(
2828
f"SELECT * FROM household WHERE id = ? AND country_id = ?",
2929
(household_id, country_id),
3030
).fetchone()

policyengine_api/services/policy_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from sqlalchemy.engine.row import LegacyRow
2+
from sqlalchemy.engine.row import Row
33

44
from policyengine_api.data import database
55
from policyengine_api.utils import hash_object
@@ -37,7 +37,7 @@ def get_policy(self, country_id: str, policy_id: int) -> dict | None:
3737
raise ValueError("country_id cannot be empty or None")
3838

3939
# If no policy found, this will return None
40-
row: LegacyRow | None = database.query(
40+
row: Row | None = database.query(
4141
"SELECT * FROM policy WHERE country_id = ? AND id = ?",
4242
(country_id, policy_id),
4343
).fetchone()

policyengine_api/services/report_output_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy.engine.row import LegacyRow
1+
from sqlalchemy.engine.row import Row
22

33
from policyengine_api.data import database
44
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
@@ -137,7 +137,7 @@ def get_report_output(self, report_output_id: int) -> dict | None:
137137
f"Invalid report output ID: {report_output_id}. Must be a positive integer."
138138
)
139139

140-
row: LegacyRow | None = database.query(
140+
row: Row | None = database.query(
141141
"SELECT * FROM report_outputs WHERE id = ?",
142142
(report_output_id,),
143143
).fetchone()

policyengine_api/services/simulation_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from sqlalchemy.engine.row import LegacyRow
2+
from sqlalchemy.engine.row import Row
33

44
from policyengine_api.data import database
55
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
@@ -119,7 +119,7 @@ def get_simulation(
119119
f"Invalid simulation ID: {simulation_id}. Must be a positive integer."
120120
)
121121

122-
row: LegacyRow | None = database.query(
122+
row: Row | None = database.query(
123123
"SELECT * FROM simulations WHERE id = ? AND country_id = ?",
124124
(simulation_id, country_id),
125125
).fetchone()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"python-dotenv",
4343
"redis",
4444
"rq",
45-
"sqlalchemy>=1.4,<2",
45+
"sqlalchemy>=2,<3",
4646
"streamlit",
4747
"werkzeug",
4848
"Flask-Caching>=2,<3",

tests/to_refactor/python/test_household_routes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import json
33
from unittest.mock import MagicMock, patch
4-
from sqlalchemy.engine.row import LegacyRow
54

65
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
76

@@ -16,8 +15,9 @@
1615
class TestGetHousehold:
1716
def test_get_existing_household(self, rest_client, mock_database):
1817
"""Test getting an existing household."""
19-
# Mock database response
20-
mock_row = MagicMock(spec=LegacyRow)
18+
# Mock database response as a dict-like object
19+
# (SQLAlchemy v2 Row objects support dict() via ._mapping)
20+
mock_row = MagicMock()
2121
mock_row.__getitem__.side_effect = lambda x: valid_db_row[x]
2222
mock_row.keys.return_value = valid_db_row.keys()
2323
mock_database.query().fetchone.return_value = mock_row
@@ -57,7 +57,7 @@ def test_create_household_success(
5757
):
5858
"""Test successfully creating a new household."""
5959
# Mock database responses
60-
mock_row = MagicMock(spec=LegacyRow)
60+
mock_row = MagicMock()
6161
mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x]
6262
mock_database.query().fetchone.return_value = mock_row
6363

@@ -111,7 +111,7 @@ def test_update_household_success(
111111
):
112112
"""Test successfully updating an existing household."""
113113
# Mock getting existing household
114-
mock_row = MagicMock(spec=LegacyRow)
114+
mock_row = MagicMock()
115115
mock_row.__getitem__.side_effect = lambda x: valid_db_row[x]
116116
mock_row.keys.return_value = valid_db_row.keys()
117117
mock_database.query().fetchone.return_value = mock_row
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""Tests for SQLAlchemy v2 compatibility.
2+
3+
These tests verify that the database layer works correctly with
4+
SQLAlchemy v2, specifically:
5+
- The _ResultProxy wrapper provides fetchone()/fetchall() on eagerly
6+
fetched results.
7+
- The remote (non-local) query path uses connection-based execution
8+
instead of the removed engine.execute().
9+
- Row objects returned from the remote path support dict-like access
10+
(dict(row) and row["key"]).
11+
"""
12+
13+
import pytest
14+
import sqlalchemy
15+
16+
from policyengine_api.data.data import _ResultProxy, PolicyEngineDatabase
17+
18+
19+
class TestSQLAlchemyVersion:
20+
"""Verify that SQLAlchemy v2 is installed."""
21+
22+
def test_sqlalchemy_version_is_v2(self):
23+
major = int(sqlalchemy.__version__.split(".")[0])
24+
assert (
25+
major >= 2
26+
), f"Expected SQLAlchemy v2+, got {sqlalchemy.__version__}"
27+
28+
29+
class TestResultProxy:
30+
"""Test the _ResultProxy wrapper that bridges SQLAlchemy v2
31+
connection-scoped results with the existing query() API."""
32+
33+
def test_fetchone_returns_dict_like_rows(self):
34+
"""Rows returned by fetchone() should support dict() and
35+
key-based access."""
36+
engine = sqlalchemy.create_engine("sqlite://")
37+
with engine.connect() as conn:
38+
conn.exec_driver_sql(
39+
"CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)"
40+
)
41+
conn.exec_driver_sql("INSERT INTO test VALUES (1, 'hello')")
42+
result = conn.exec_driver_sql("SELECT * FROM test")
43+
proxy = _ResultProxy(result)
44+
45+
row = proxy.fetchone()
46+
assert row is not None
47+
assert dict(row) == {"id": 1, "name": "hello"}
48+
assert row["id"] == 1
49+
assert row["name"] == "hello"
50+
51+
def test_fetchone_returns_none_when_exhausted(self):
52+
engine = sqlalchemy.create_engine("sqlite://")
53+
with engine.connect() as conn:
54+
conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)")
55+
result = conn.exec_driver_sql("SELECT * FROM test")
56+
proxy = _ResultProxy(result)
57+
58+
assert proxy.fetchone() is None
59+
60+
def test_fetchall_returns_all_rows(self):
61+
engine = sqlalchemy.create_engine("sqlite://")
62+
with engine.connect() as conn:
63+
conn.exec_driver_sql(
64+
"CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT)"
65+
)
66+
conn.exec_driver_sql("INSERT INTO test VALUES (1, 'a')")
67+
conn.exec_driver_sql("INSERT INTO test VALUES (2, 'b')")
68+
conn.exec_driver_sql("INSERT INTO test VALUES (3, 'c')")
69+
result = conn.exec_driver_sql("SELECT * FROM test")
70+
proxy = _ResultProxy(result)
71+
72+
rows = proxy.fetchall()
73+
assert len(rows) == 3
74+
assert dict(rows[0]) == {"id": 1, "val": "a"}
75+
assert dict(rows[2]) == {"id": 3, "val": "c"}
76+
77+
def test_fetchone_then_fetchall_respects_cursor_position(self):
78+
engine = sqlalchemy.create_engine("sqlite://")
79+
with engine.connect() as conn:
80+
conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)")
81+
conn.exec_driver_sql("INSERT INTO test VALUES (1)")
82+
conn.exec_driver_sql("INSERT INTO test VALUES (2)")
83+
conn.exec_driver_sql("INSERT INTO test VALUES (3)")
84+
result = conn.exec_driver_sql("SELECT * FROM test")
85+
proxy = _ResultProxy(result)
86+
87+
first = proxy.fetchone()
88+
assert dict(first) == {"id": 1}
89+
remaining = proxy.fetchall()
90+
assert len(remaining) == 2
91+
assert dict(remaining[0]) == {"id": 2}
92+
93+
def test_result_proxy_for_insert_statement(self):
94+
"""INSERT statements produce no rows; _ResultProxy should
95+
handle this gracefully."""
96+
engine = sqlalchemy.create_engine("sqlite://")
97+
with engine.connect() as conn:
98+
conn.exec_driver_sql("CREATE TABLE test (id INTEGER PRIMARY KEY)")
99+
result = conn.exec_driver_sql("INSERT INTO test VALUES (1)")
100+
proxy = _ResultProxy(result)
101+
102+
assert proxy.fetchone() is None
103+
assert proxy.fetchall() == []
104+
105+
106+
class TestRemoteQueryPath:
107+
"""Test the non-local query path that uses SQLAlchemy engine
108+
with connection-based execution (v2 pattern)."""
109+
110+
def _make_remote_db(self):
111+
"""Create a PolicyEngineDatabase-like object that uses
112+
a SQLAlchemy engine (the 'remote' path) but backed by
113+
in-memory SQLite for testing."""
114+
db = PolicyEngineDatabase.__new__(PolicyEngineDatabase)
115+
db.local = False
116+
db.pool = sqlalchemy.create_engine("sqlite://")
117+
# Initialize schema using the remote path
118+
with db.pool.connect() as conn:
119+
conn.exec_driver_sql(
120+
"CREATE TABLE test_table "
121+
"(id INTEGER PRIMARY KEY, name TEXT, value REAL)"
122+
)
123+
conn.commit()
124+
return db
125+
126+
def test_remote_insert_and_select(self):
127+
"""Test INSERT then SELECT through the remote query path."""
128+
db = self._make_remote_db()
129+
130+
# Note: remote path converts ? to %s for MySQL, but SQLite
131+
# uses ? natively. Since exec_driver_sql passes to the DBAPI
132+
# driver directly and SQLite's driver uses ?, we need to
133+
# test with the actual query() method which does the conversion.
134+
# For SQLite DBAPI, ? is the native marker.
135+
136+
# Use exec_driver_sql directly to bypass ?->%s conversion
137+
# (which would break SQLite)
138+
db._execute_remote(
139+
[
140+
"INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)",
141+
(1, "test", 3.14),
142+
]
143+
)
144+
145+
result = db._execute_remote(
146+
["SELECT * FROM test_table WHERE id = ?", (1,)]
147+
)
148+
row = result.fetchone()
149+
assert row is not None
150+
assert row["id"] == 1
151+
assert row["name"] == "test"
152+
assert row["value"] == 3.14
153+
assert dict(row) == {"id": 1, "name": "test", "value": 3.14}
154+
155+
def test_remote_select_no_results(self):
156+
db = self._make_remote_db()
157+
result = db._execute_remote(
158+
["SELECT * FROM test_table WHERE id = ?", (999,)]
159+
)
160+
assert result.fetchone() is None
161+
162+
def test_remote_update(self):
163+
db = self._make_remote_db()
164+
db._execute_remote(
165+
[
166+
"INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)",
167+
(1, "original", 1.0),
168+
]
169+
)
170+
db._execute_remote(
171+
[
172+
"UPDATE test_table SET name = ? WHERE id = ?",
173+
("updated", 1),
174+
]
175+
)
176+
result = db._execute_remote(
177+
["SELECT * FROM test_table WHERE id = ?", (1,)]
178+
)
179+
row = result.fetchone()
180+
assert row["name"] == "updated"
181+
182+
def test_remote_delete(self):
183+
db = self._make_remote_db()
184+
db._execute_remote(
185+
[
186+
"INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)",
187+
(1, "to_delete", 0.0),
188+
]
189+
)
190+
db._execute_remote(["DELETE FROM test_table WHERE id = ?", (1,)])
191+
result = db._execute_remote(
192+
["SELECT * FROM test_table WHERE id = ?", (1,)]
193+
)
194+
assert result.fetchone() is None

tests/unit/services/test_household_service.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import json
33
from unittest.mock import MagicMock
4-
from sqlalchemy.engine.row import LegacyRow
54
import re
65

76
from policyengine_api.services.household_service import HouseholdService

0 commit comments

Comments
 (0)