Skip to content

Commit 76bc19b

Browse files
authored
Merge pull request #3293 from PolicyEngine/upgrade-sqlalchemy-v2
Upgrade SQLAlchemy v1 to v2 (Python 3.14 blocker)
2 parents 55732b6 + f3a7d44 commit 76bc19b

12 files changed

Lines changed: 383 additions & 18 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
changed:
4+
- Upgraded SQLAlchemy from v1 (>=1.4,<2) to v2 (>=2,<3) for Python 3.14 compatibility. Replaced removed engine.execute() with connection-based execution, updated LegacyRow to Row, and added _ResultProxy wrapper for eager result fetching.

policyengine_api/data/data.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@
1313
load_dotenv()
1414

1515

16+
class _ResultProxy:
17+
"""Lightweight wrapper that eagerly fetches results from a
18+
SQLAlchemy CursorResult so they survive connection closure.
19+
Provides fetchone()/fetchall() with dict-like row access."""
20+
21+
def __init__(self, cursor_result):
22+
try:
23+
# Use .mappings() so rows behave like dicts
24+
self._rows = list(cursor_result.mappings())
25+
except Exception:
26+
# For non-SELECT statements (INSERT/UPDATE/DELETE)
27+
# there are no rows to fetch
28+
self._rows = []
29+
self._index = 0
30+
31+
def fetchone(self):
32+
if self._index < len(self._rows):
33+
row = self._rows[self._index]
34+
self._index += 1
35+
return row
36+
return None
37+
38+
def fetchall(self):
39+
remaining = self._rows[self._index :]
40+
self._index = len(self._rows)
41+
return remaining
42+
43+
1644
class PolicyEngineDatabase:
1745
"""
1846
A wrapper around the database connection.
@@ -70,6 +98,22 @@ def _close_pool(self):
7098
except:
7199
pass
72100

101+
def _execute_remote(self, query_args):
102+
"""Execute a query against the remote database using
103+
SQLAlchemy v2 connection-based execution."""
104+
main_query = query_args[0]
105+
params = query_args[1] if len(query_args) > 1 else None
106+
with self.pool.connect() as conn:
107+
if params is not None:
108+
result = conn.exec_driver_sql(main_query, params)
109+
else:
110+
result = conn.exec_driver_sql(main_query)
111+
conn.commit()
112+
# Return a lightweight wrapper that holds
113+
# the fetched results so they survive the
114+
# connection context closing
115+
return _ResultProxy(result)
116+
73117
def query(self, *query):
74118
if self.local:
75119
with sqlite3.connect(self.db_url) as conn:
@@ -89,7 +133,7 @@ def dict_factory(cursor, row):
89133
main_query = main_query.replace("?", "%s")
90134
query[0] = main_query
91135
try:
92-
return self.pool.execute(*query)
136+
return self._execute_remote(query)
93137
# Except InterfaceError and OperationalError, which are thrown when the connection is lost.
94138
except (
95139
sqlalchemy.exc.InterfaceError,
@@ -98,7 +142,7 @@ def dict_factory(cursor, row):
98142
try:
99143
self._close_pool()
100144
self._create_pool()
101-
return self.pool.execute(*query)
145+
return self._execute_remote(query)
102146
except Exception as e:
103147
raise e
104148

policyengine_api/endpoints/policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,13 @@ def set_user_policy(country_id: str) -> dict:
228228
f"AND dataset {dataset_select_str}"
229229
)
230230

231+
params = [country_id, reform_id, baseline_id, user_id, year, geography]
232+
if dataset:
233+
params.append(dataset)
234+
231235
row = database.query(
232236
query,
233-
(country_id, reform_id, baseline_id, user_id, year, geography),
237+
tuple(params),
234238
).fetchone()
235239

236240
except Exception as e:

policyengine_api/services/household_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from sqlalchemy.engine.row import LegacyRow
2+
from sqlalchemy.engine.row import Row
33

44
from policyengine_api.data import database
55
from policyengine_api.utils import hash_object
@@ -24,7 +24,7 @@ def get_household(self, country_id: str, household_id: int) -> dict | None:
2424
f"Invalid household ID: {household_id}. Must be a positive integer."
2525
)
2626

27-
row: LegacyRow | None = database.query(
27+
row: Row | None = database.query(
2828
f"SELECT * FROM household WHERE id = ? AND country_id = ?",
2929
(household_id, country_id),
3030
).fetchone()

policyengine_api/services/policy_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from sqlalchemy.engine.row import LegacyRow
2+
from sqlalchemy.engine.row import Row
33

44
from policyengine_api.data import database
55
from policyengine_api.utils import hash_object
@@ -37,7 +37,7 @@ def get_policy(self, country_id: str, policy_id: int) -> dict | None:
3737
raise ValueError("country_id cannot be empty or None")
3838

3939
# If no policy found, this will return None
40-
row: LegacyRow | None = database.query(
40+
row: Row | None = database.query(
4141
"SELECT * FROM policy WHERE country_id = ? AND id = ?",
4242
(country_id, policy_id),
4343
).fetchone()

policyengine_api/services/report_output_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy.engine.row import LegacyRow
1+
from sqlalchemy.engine.row import Row
22

33
from policyengine_api.data import database
44
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
@@ -137,7 +137,7 @@ def get_report_output(self, report_output_id: int) -> dict | None:
137137
f"Invalid report output ID: {report_output_id}. Must be a positive integer."
138138
)
139139

140-
row: LegacyRow | None = database.query(
140+
row: Row | None = database.query(
141141
"SELECT * FROM report_outputs WHERE id = ?",
142142
(report_output_id,),
143143
).fetchone()

policyengine_api/services/simulation_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from sqlalchemy.engine.row import LegacyRow
2+
from sqlalchemy.engine.row import Row
33

44
from policyengine_api.data import database
55
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
@@ -119,7 +119,7 @@ def get_simulation(
119119
f"Invalid simulation ID: {simulation_id}. Must be a positive integer."
120120
)
121121

122-
row: LegacyRow | None = database.query(
122+
row: Row | None = database.query(
123123
"SELECT * FROM simulations WHERE id = ? AND country_id = ?",
124124
(simulation_id, country_id),
125125
).fetchone()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"python-dotenv",
4343
"redis",
4444
"rq",
45-
"sqlalchemy>=1.4,<2",
45+
"sqlalchemy>=2,<3",
4646
"streamlit",
4747
"werkzeug",
4848
"Flask-Caching>=2,<3",

tests/to_refactor/python/test_household_routes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import json
33
from unittest.mock import MagicMock, patch
4-
from sqlalchemy.engine.row import LegacyRow
54

65
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
76

@@ -16,8 +15,9 @@
1615
class TestGetHousehold:
1716
def test_get_existing_household(self, rest_client, mock_database):
1817
"""Test getting an existing household."""
19-
# Mock database response
20-
mock_row = MagicMock(spec=LegacyRow)
18+
# Mock database response as a dict-like object
19+
# (SQLAlchemy v2 Row objects support dict() via ._mapping)
20+
mock_row = MagicMock()
2121
mock_row.__getitem__.side_effect = lambda x: valid_db_row[x]
2222
mock_row.keys.return_value = valid_db_row.keys()
2323
mock_database.query().fetchone.return_value = mock_row
@@ -57,7 +57,7 @@ def test_create_household_success(
5757
):
5858
"""Test successfully creating a new household."""
5959
# Mock database responses
60-
mock_row = MagicMock(spec=LegacyRow)
60+
mock_row = MagicMock()
6161
mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x]
6262
mock_database.query().fetchone.return_value = mock_row
6363

@@ -111,7 +111,7 @@ def test_update_household_success(
111111
):
112112
"""Test successfully updating an existing household."""
113113
# Mock getting existing household
114-
mock_row = MagicMock(spec=LegacyRow)
114+
mock_row = MagicMock()
115115
mock_row.__getitem__.side_effect = lambda x: valid_db_row[x]
116116
mock_row.keys.return_value = valid_db_row.keys()
117117
mock_database.query().fetchone.return_value = mock_row

0 commit comments

Comments
 (0)