Skip to content

Commit c85ed87

Browse files
authored
Merge pull request #3428 from PolicyEngine/feat/report-output-run-stage-2
Add internal services for report output runs
2 parents ec2e88d + 5e15e2f commit c85ed87

14 files changed

Lines changed: 2523 additions & 13 deletions
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add internal report and simulation spec, alias, and run services for the report-output run migration.

policyengine_api/data/data.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sqlite3
2-
from policyengine_api.constants import REPO, VERSION, COUNTRY_PACKAGE_VERSIONS
2+
from policyengine_api.constants import REPO, COUNTRY_PACKAGE_VERSIONS
33
from policyengine_api.utils import hash_object
44
from pathlib import Path
55
from dotenv import load_dotenv
@@ -41,6 +41,29 @@ def fetchall(self):
4141
return remaining
4242

4343

44+
class _TransactionProxy:
45+
"""Execute queries against an existing connection inside a transaction."""
46+
47+
def __init__(self, connection, local: bool):
48+
self._connection = connection
49+
self._local = local
50+
51+
def query(self, *query):
52+
if self._local:
53+
cursor = self._connection.cursor()
54+
return cursor.execute(*query)
55+
56+
query = list(query)
57+
main_query = query[0].replace("?", "%s")
58+
query[0] = main_query
59+
params = query[1] if len(query) > 1 else None
60+
if params is not None:
61+
result = self._connection.exec_driver_sql(main_query, params)
62+
else:
63+
result = self._connection.exec_driver_sql(main_query)
64+
return _ResultProxy(result)
65+
66+
4467
class PolicyEngineDatabase:
4568
"""
4669
A wrapper around the database connection.
@@ -50,6 +73,13 @@ class PolicyEngineDatabase:
5073

5174
household_cache: dict = {}
5275

76+
@staticmethod
77+
def _dict_factory(cursor, row):
78+
d = {}
79+
for idx, col in enumerate(cursor.description):
80+
d[col[0]] = row[idx]
81+
return d
82+
5383
def __init__(
5484
self,
5585
local: bool = False,
@@ -91,7 +121,7 @@ def _close_pool(self):
91121
try:
92122
self.pool.dispose()
93123
self.connector.close()
94-
except:
124+
except Exception:
95125
pass
96126

97127
def _execute_remote(self, query_args):
@@ -110,17 +140,22 @@ def _execute_remote(self, query_args):
110140
# connection context closing
111141
return _ResultProxy(result)
112142

143+
def _execute_remote_transaction(self, callback):
144+
with self.pool.connect() as conn:
145+
transaction = conn.begin()
146+
proxy = _TransactionProxy(conn, local=False)
147+
try:
148+
result = callback(proxy)
149+
transaction.commit()
150+
return result
151+
except Exception:
152+
transaction.rollback()
153+
raise
154+
113155
def query(self, *query):
114156
if self.local:
115157
with sqlite3.connect(self.db_url) as conn:
116-
117-
def dict_factory(cursor, row):
118-
d = {}
119-
for idx, col in enumerate(cursor.description):
120-
d[col[0]] = row[idx]
121-
return d
122-
123-
conn.row_factory = dict_factory
158+
conn.row_factory = self._dict_factory
124159
cursor = conn.cursor()
125160
return cursor.execute(*query)
126161
else:
@@ -134,14 +169,44 @@ def dict_factory(cursor, row):
134169
except (
135170
sqlalchemy.exc.InterfaceError,
136171
sqlalchemy.exc.OperationalError,
137-
) as e:
172+
):
138173
try:
139174
self._close_pool()
140175
self._create_pool()
141176
return self._execute_remote(query)
142177
except Exception as e:
143178
raise e
144179

180+
def transaction(self, callback):
181+
if self.local:
182+
connection = getattr(self, "_connection", None)
183+
owns_connection = connection is None
184+
if owns_connection:
185+
connection = sqlite3.connect(self.db_url)
186+
connection.row_factory = self._dict_factory
187+
try:
188+
connection.execute("BEGIN IMMEDIATE")
189+
proxy = _TransactionProxy(connection, local=True)
190+
result = callback(proxy)
191+
connection.commit()
192+
return result
193+
except Exception:
194+
connection.rollback()
195+
raise
196+
finally:
197+
if owns_connection:
198+
connection.close()
199+
200+
try:
201+
return self._execute_remote_transaction(callback)
202+
except (
203+
sqlalchemy.exc.InterfaceError,
204+
sqlalchemy.exc.OperationalError,
205+
):
206+
self._close_pool()
207+
self._create_pool()
208+
return self._execute_remote_transaction(callback)
209+
145210
def initialize(self):
146211
"""
147212
Create the database tables.
@@ -175,7 +240,7 @@ def initialize(self):
175240
range(1, 1 + len(COUNTRY_PACKAGE_VERSIONS)),
176241
):
177242
self.query(
178-
f"INSERT INTO policy (id, country_id, label, api_version, policy_json, policy_hash) VALUES (?, ?, ?, ?, ?, ?)",
243+
"INSERT INTO policy (id, country_id, label, api_version, policy_json, policy_hash) VALUES (?, ?, ?, ?, ?, ?)",
179244
(
180245
policy_id,
181246
country_id,
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from sqlalchemy.engine.row import Row
2+
3+
from policyengine_api.data import database
4+
5+
6+
class ReportOutputAliasService:
7+
def _get_report_output_row(self, report_output_id: int) -> dict | None:
8+
row: Row | None = database.query(
9+
"""
10+
SELECT id, country_id, simulation_1_id, simulation_2_id, year
11+
FROM report_outputs
12+
WHERE id = ?
13+
""",
14+
(report_output_id,),
15+
).fetchone()
16+
return dict(row) if row is not None else None
17+
18+
def get_alias(self, legacy_report_output_id: int) -> dict | None:
19+
row: Row | None = database.query(
20+
"""
21+
SELECT * FROM legacy_report_output_aliases
22+
WHERE legacy_report_output_id = ?
23+
""",
24+
(legacy_report_output_id,),
25+
).fetchone()
26+
return dict(row) if row is not None else None
27+
28+
def resolve_canonical_report_output_id(
29+
self, requested_report_output_id: int
30+
) -> int | None:
31+
alias = self.get_alias(requested_report_output_id)
32+
if alias is not None:
33+
canonical_report_output_id = alias["canonical_report_output_id"]
34+
if self._get_report_output_row(canonical_report_output_id) is None:
35+
raise ValueError(
36+
"Alias points to missing canonical report output "
37+
f"#{canonical_report_output_id}"
38+
)
39+
return canonical_report_output_id
40+
41+
row: Row | None = database.query(
42+
"SELECT id FROM report_outputs WHERE id = ?",
43+
(requested_report_output_id,),
44+
).fetchone()
45+
return row["id"] if row is not None else None
46+
47+
def set_alias(
48+
self,
49+
legacy_report_output_id: int,
50+
canonical_report_output_id: int,
51+
) -> bool:
52+
legacy_report_output = self._get_report_output_row(legacy_report_output_id)
53+
if legacy_report_output is None:
54+
raise ValueError(
55+
f"Legacy report output #{legacy_report_output_id} not found"
56+
)
57+
58+
canonical_report_output = self._get_report_output_row(
59+
canonical_report_output_id
60+
)
61+
if canonical_report_output is None:
62+
raise ValueError(
63+
f"Canonical report output #{canonical_report_output_id} not found"
64+
)
65+
if legacy_report_output_id == canonical_report_output_id:
66+
raise ValueError("Legacy and canonical report outputs must be different")
67+
68+
existing_alias = self.get_alias(legacy_report_output_id)
69+
if existing_alias is not None:
70+
if (
71+
existing_alias["canonical_report_output_id"]
72+
== canonical_report_output_id
73+
):
74+
return True
75+
76+
raise ValueError(
77+
"Legacy report output alias already points to canonical report output "
78+
f"#{existing_alias['canonical_report_output_id']}"
79+
)
80+
81+
logical_key = ("country_id", "simulation_1_id", "simulation_2_id", "year")
82+
if any(
83+
legacy_report_output[field] != canonical_report_output[field]
84+
for field in logical_key
85+
):
86+
raise ValueError(
87+
"Legacy and canonical report outputs must describe the same report"
88+
)
89+
database.query(
90+
"""
91+
INSERT INTO legacy_report_output_aliases
92+
(legacy_report_output_id, canonical_report_output_id)
93+
VALUES (?, ?)
94+
""",
95+
(legacy_report_output_id, canonical_report_output_id),
96+
)
97+
return True
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import json
2+
import uuid
3+
from typing import Any
4+
5+
from sqlalchemy.engine.row import Row
6+
7+
from policyengine_api.data import database
8+
9+
10+
REPORT_RUN_VERSION_FIELDS = (
11+
"country_package_version",
12+
"policyengine_version",
13+
"data_version",
14+
"runtime_app_name",
15+
"report_cache_version",
16+
"simulation_cache_version",
17+
"requested_version_override",
18+
"resolved_dataset",
19+
"resolved_options_hash",
20+
)
21+
22+
23+
class ReportRunService:
24+
def _serialize_json(
25+
self, value: dict[str, Any] | list[Any] | str | None
26+
) -> str | None:
27+
if value is None or isinstance(value, str):
28+
return value
29+
return json.dumps(value)
30+
31+
def _parse_run_row(self, row: Row | dict | None) -> dict | None:
32+
if row is None:
33+
return None
34+
35+
run = dict(row)
36+
if isinstance(run.get("report_spec_snapshot_json"), str):
37+
run["report_spec_snapshot_json"] = json.loads(
38+
run["report_spec_snapshot_json"]
39+
)
40+
return run
41+
42+
def create_report_output_run(
43+
self,
44+
report_output_id: int,
45+
status: str = "pending",
46+
trigger_type: str = "initial",
47+
output: dict[str, Any] | list[Any] | str | None = None,
48+
error_message: str | None = None,
49+
source_run_id: str | None = None,
50+
report_spec_snapshot: dict[str, Any] | str | None = None,
51+
version_manifest: dict[str, str | None] | None = None,
52+
run_id: str | None = None,
53+
) -> dict:
54+
run_id = run_id or str(uuid.uuid4())
55+
version_manifest = version_manifest or {}
56+
lock_clause = "" if database.local else " FOR UPDATE"
57+
58+
def create_run_transaction(tx) -> None:
59+
parent_row: Row | None = tx.query(
60+
f"SELECT id FROM report_outputs WHERE id = ?{lock_clause}",
61+
(report_output_id,),
62+
).fetchone()
63+
if parent_row is None:
64+
raise ValueError(f"Report output #{report_output_id} not found")
65+
66+
run_sequence_row: Row | None = tx.query(
67+
"""
68+
SELECT COALESCE(MAX(run_sequence), 0) AS max_run_sequence
69+
FROM report_output_runs
70+
WHERE report_output_id = ?
71+
""",
72+
(report_output_id,),
73+
).fetchone()
74+
run_sequence = (
75+
int(run_sequence_row["max_run_sequence"]) + 1
76+
if run_sequence_row is not None
77+
else 1
78+
)
79+
80+
tx.query(
81+
f"""
82+
INSERT INTO report_output_runs (
83+
id, report_output_id, run_sequence, status, output, error_message,
84+
trigger_type, requested_at, started_at, finished_at, source_run_id,
85+
report_spec_snapshot_json, {", ".join(REPORT_RUN_VERSION_FIELDS)}
86+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
87+
""",
88+
(
89+
run_id,
90+
report_output_id,
91+
run_sequence,
92+
status,
93+
self._serialize_json(output),
94+
error_message,
95+
trigger_type,
96+
None,
97+
None,
98+
None,
99+
source_run_id,
100+
self._serialize_json(report_spec_snapshot),
101+
*[
102+
version_manifest.get(field)
103+
for field in REPORT_RUN_VERSION_FIELDS
104+
],
105+
),
106+
)
107+
108+
database.transaction(create_run_transaction)
109+
return self.get_report_output_run(run_id)
110+
111+
def get_report_output_run(self, run_id: str) -> dict | None:
112+
row: Row | None = database.query(
113+
"SELECT * FROM report_output_runs WHERE id = ?",
114+
(run_id,),
115+
).fetchone()
116+
return self._parse_run_row(row)
117+
118+
def list_report_output_runs(self, report_output_id: int) -> list[dict]:
119+
rows = database.query(
120+
"""
121+
SELECT * FROM report_output_runs
122+
WHERE report_output_id = ?
123+
ORDER BY run_sequence ASC
124+
""",
125+
(report_output_id,),
126+
).fetchall()
127+
return [self._parse_run_row(row) for row in rows]
128+
129+
def get_newest_report_output_run(self, report_output_id: int) -> dict | None:
130+
row: Row | None = database.query(
131+
"""
132+
SELECT * FROM report_output_runs
133+
WHERE report_output_id = ?
134+
ORDER BY run_sequence DESC
135+
LIMIT 1
136+
""",
137+
(report_output_id,),
138+
).fetchone()
139+
return self._parse_run_row(row)
140+
141+
def select_display_run(self, report_output: dict) -> dict | None:
142+
if report_output.get("active_run_id"):
143+
active_run = self.get_report_output_run(report_output["active_run_id"])
144+
if active_run is not None:
145+
return active_run
146+
if report_output.get("latest_successful_run_id"):
147+
latest_successful_run = self.get_report_output_run(
148+
report_output["latest_successful_run_id"]
149+
)
150+
if latest_successful_run is not None:
151+
return latest_successful_run
152+
return self.get_newest_report_output_run(report_output["id"])

0 commit comments

Comments
 (0)