Skip to content

Commit 32b4b18

Browse files
committed
Enforce stage 2 lineage invariants
1 parent 0d7be7e commit 32b4b18

11 files changed

Lines changed: 709 additions & 221 deletions

policyengine_api/data/data.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sqlite3
2-
from policyengine_api.constants import REPO, VERSION, COUNTRY_PACKAGE_VERSIONS
2+
from policyengine_api.constants import REPO, COUNTRY_PACKAGE_VERSIONS
33
from policyengine_api.utils import hash_object
44
from pathlib import Path
55
from dotenv import load_dotenv
@@ -41,6 +41,29 @@ def fetchall(self):
4141
return remaining
4242

4343

44+
class _TransactionProxy:
45+
"""Execute queries against an existing connection inside a transaction."""
46+
47+
def __init__(self, connection, local: bool):
48+
self._connection = connection
49+
self._local = local
50+
51+
def query(self, *query):
52+
if self._local:
53+
cursor = self._connection.cursor()
54+
return cursor.execute(*query)
55+
56+
query = list(query)
57+
main_query = query[0].replace("?", "%s")
58+
query[0] = main_query
59+
params = query[1] if len(query) > 1 else None
60+
if params is not None:
61+
result = self._connection.exec_driver_sql(main_query, params)
62+
else:
63+
result = self._connection.exec_driver_sql(main_query)
64+
return _ResultProxy(result)
65+
66+
4467
class PolicyEngineDatabase:
4568
"""
4669
A wrapper around the database connection.
@@ -50,6 +73,13 @@ class PolicyEngineDatabase:
5073

5174
household_cache: dict = {}
5275

76+
@staticmethod
77+
def _dict_factory(cursor, row):
78+
d = {}
79+
for idx, col in enumerate(cursor.description):
80+
d[col[0]] = row[idx]
81+
return d
82+
5383
def __init__(
5484
self,
5585
local: bool = False,
@@ -91,7 +121,7 @@ def _close_pool(self):
91121
try:
92122
self.pool.dispose()
93123
self.connector.close()
94-
except:
124+
except Exception:
95125
pass
96126

97127
def _execute_remote(self, query_args):
@@ -110,17 +140,22 @@ def _execute_remote(self, query_args):
110140
# connection context closing
111141
return _ResultProxy(result)
112142

143+
def _execute_remote_transaction(self, callback):
144+
with self.pool.connect() as conn:
145+
transaction = conn.begin()
146+
proxy = _TransactionProxy(conn, local=False)
147+
try:
148+
result = callback(proxy)
149+
transaction.commit()
150+
return result
151+
except Exception:
152+
transaction.rollback()
153+
raise
154+
113155
def query(self, *query):
114156
if self.local:
115157
with sqlite3.connect(self.db_url) as conn:
116-
117-
def dict_factory(cursor, row):
118-
d = {}
119-
for idx, col in enumerate(cursor.description):
120-
d[col[0]] = row[idx]
121-
return d
122-
123-
conn.row_factory = dict_factory
158+
conn.row_factory = self._dict_factory
124159
cursor = conn.cursor()
125160
return cursor.execute(*query)
126161
else:
@@ -134,14 +169,44 @@ def dict_factory(cursor, row):
134169
except (
135170
sqlalchemy.exc.InterfaceError,
136171
sqlalchemy.exc.OperationalError,
137-
) as e:
172+
):
138173
try:
139174
self._close_pool()
140175
self._create_pool()
141176
return self._execute_remote(query)
142177
except Exception as e:
143178
raise e
144179

180+
def transaction(self, callback):
181+
if self.local:
182+
connection = getattr(self, "_connection", None)
183+
owns_connection = connection is None
184+
if owns_connection:
185+
connection = sqlite3.connect(self.db_url)
186+
connection.row_factory = self._dict_factory
187+
try:
188+
connection.execute("BEGIN IMMEDIATE")
189+
proxy = _TransactionProxy(connection, local=True)
190+
result = callback(proxy)
191+
connection.commit()
192+
return result
193+
except Exception:
194+
connection.rollback()
195+
raise
196+
finally:
197+
if owns_connection:
198+
connection.close()
199+
200+
try:
201+
return self._execute_remote_transaction(callback)
202+
except (
203+
sqlalchemy.exc.InterfaceError,
204+
sqlalchemy.exc.OperationalError,
205+
):
206+
self._close_pool()
207+
self._create_pool()
208+
return self._execute_remote_transaction(callback)
209+
145210
def initialize(self):
146211
"""
147212
Create the database tables.
@@ -175,7 +240,7 @@ def initialize(self):
175240
range(1, 1 + len(COUNTRY_PACKAGE_VERSIONS)),
176241
):
177242
self.query(
178-
f"INSERT INTO policy (id, country_id, label, api_version, policy_json, policy_hash) VALUES (?, ?, ?, ?, ?, ?)",
243+
"INSERT INTO policy (id, country_id, label, api_version, policy_json, policy_hash) VALUES (?, ?, ?, ?, ?, ?)",
179244
(
180245
policy_id,
181246
country_id,

policyengine_api/services/report_output_alias_service.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44

55

66
class ReportOutputAliasService:
7-
def _report_output_exists(self, report_output_id: int) -> bool:
7+
def _get_report_output_row(self, report_output_id: int) -> dict | None:
88
row: Row | None = database.query(
9-
"SELECT id FROM report_outputs WHERE id = ?",
9+
"""
10+
SELECT id, country_id, simulation_1_id, simulation_2_id, year
11+
FROM report_outputs
12+
WHERE id = ?
13+
""",
1014
(report_output_id,),
1115
).fetchone()
12-
return row is not None
16+
return dict(row) if row is not None else None
1317

1418
def get_alias(self, legacy_report_output_id: int) -> dict | None:
1519
row: Row | None = database.query(
@@ -27,7 +31,7 @@ def resolve_canonical_report_output_id(
2731
alias = self.get_alias(requested_report_output_id)
2832
if alias is not None:
2933
canonical_report_output_id = alias["canonical_report_output_id"]
30-
if not self._report_output_exists(canonical_report_output_id):
34+
if self._get_report_output_row(canonical_report_output_id) is None:
3135
raise ValueError(
3236
"Alias points to missing canonical report output "
3337
f"#{canonical_report_output_id}"
@@ -45,29 +49,49 @@ def set_alias(
4549
legacy_report_output_id: int,
4650
canonical_report_output_id: int,
4751
) -> bool:
48-
if not self._report_output_exists(canonical_report_output_id):
52+
legacy_report_output = self._get_report_output_row(legacy_report_output_id)
53+
if legacy_report_output is None:
54+
raise ValueError(
55+
f"Legacy report output #{legacy_report_output_id} not found"
56+
)
57+
58+
canonical_report_output = self._get_report_output_row(
59+
canonical_report_output_id
60+
)
61+
if canonical_report_output is None:
4962
raise ValueError(
5063
f"Canonical report output #{canonical_report_output_id} not found"
5164
)
65+
if legacy_report_output_id == canonical_report_output_id:
66+
raise ValueError("Legacy and canonical report outputs must be different")
5267

5368
existing_alias = self.get_alias(legacy_report_output_id)
54-
if existing_alias is None:
55-
database.query(
56-
"""
57-
INSERT INTO legacy_report_output_aliases
58-
(legacy_report_output_id, canonical_report_output_id)
59-
VALUES (?, ?)
60-
""",
61-
(legacy_report_output_id, canonical_report_output_id),
62-
)
63-
return True
69+
if existing_alias is not None:
70+
if (
71+
existing_alias["canonical_report_output_id"]
72+
== canonical_report_output_id
73+
):
74+
return True
6475

65-
if existing_alias["canonical_report_output_id"] == canonical_report_output_id:
66-
return True
76+
raise ValueError(
77+
"Legacy report output alias already points to canonical report output "
78+
f"#{existing_alias['canonical_report_output_id']}"
79+
)
6780

68-
raise ValueError(
69-
"Legacy report output alias already points to canonical report output "
70-
f"#{existing_alias['canonical_report_output_id']}"
81+
logical_key = ("country_id", "simulation_1_id", "simulation_2_id", "year")
82+
if any(
83+
legacy_report_output[field] != canonical_report_output[field]
84+
for field in logical_key
85+
):
86+
raise ValueError(
87+
"Legacy and canonical report outputs must describe the same report"
88+
)
89+
database.query(
90+
"""
91+
INSERT INTO legacy_report_output_aliases
92+
(legacy_report_output_id, canonical_report_output_id)
93+
VALUES (?, ?)
94+
""",
95+
(legacy_report_output_id, canonical_report_output_id),
7196
)
72-
7397
return True

0 commit comments

Comments
 (0)