Skip to content

Commit 0d7be7e

Browse files
committed
Harden stage 2 run and spec services
1 parent 1096031 commit 0d7be7e

10 files changed

Lines changed: 808 additions & 104 deletions

policyengine_api/services/report_output_alias_service.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55

66
class ReportOutputAliasService:
7+
def _report_output_exists(self, report_output_id: int) -> bool:
8+
row: Row | None = database.query(
9+
"SELECT id FROM report_outputs WHERE id = ?",
10+
(report_output_id,),
11+
).fetchone()
12+
return row is not None
13+
714
def get_alias(self, legacy_report_output_id: int) -> dict | None:
815
row: Row | None = database.query(
916
"""
@@ -19,7 +26,13 @@ def resolve_canonical_report_output_id(
1926
) -> int | None:
2027
alias = self.get_alias(requested_report_output_id)
2128
if alias is not None:
22-
return alias["canonical_report_output_id"]
29+
canonical_report_output_id = alias["canonical_report_output_id"]
30+
if not self._report_output_exists(canonical_report_output_id):
31+
raise ValueError(
32+
"Alias points to missing canonical report output "
33+
f"#{canonical_report_output_id}"
34+
)
35+
return canonical_report_output_id
2336

2437
row: Row | None = database.query(
2538
"SELECT id FROM report_outputs WHERE id = ?",
@@ -32,6 +45,11 @@ def set_alias(
3245
legacy_report_output_id: int,
3346
canonical_report_output_id: int,
3447
) -> bool:
48+
if not self._report_output_exists(canonical_report_output_id):
49+
raise ValueError(
50+
f"Canonical report output #{canonical_report_output_id} not found"
51+
)
52+
3553
existing_alias = self.get_alias(legacy_report_output_id)
3654
if existing_alias is None:
3755
database.query(
@@ -44,12 +62,12 @@ def set_alias(
4462
)
4563
return True
4664

47-
database.query(
48-
"""
49-
UPDATE legacy_report_output_aliases
50-
SET canonical_report_output_id = ?
51-
WHERE legacy_report_output_id = ?
52-
""",
53-
(canonical_report_output_id, legacy_report_output_id),
65+
if existing_alias["canonical_report_output_id"] == canonical_report_output_id:
66+
return True
67+
68+
raise ValueError(
69+
"Legacy report output alias already points to canonical report output "
70+
f"#{existing_alias['canonical_report_output_id']}"
5471
)
72+
5573
return True

policyengine_api/services/report_run_service.py

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import json
2+
import sqlite3
23
import uuid
34
from typing import Any
45

6+
import sqlalchemy.exc
57
from sqlalchemy.engine.row import Row
68

79
from policyengine_api.data import database
@@ -18,9 +20,17 @@
1820
"resolved_dataset",
1921
"resolved_options_hash",
2022
)
23+
MAX_CREATE_RUN_ATTEMPTS = 3
2124

2225

2326
class ReportRunService:
27+
def _report_output_exists(self, report_output_id: int) -> bool:
28+
row: Row | None = database.query(
29+
"SELECT id FROM report_outputs WHERE id = ?",
30+
(report_output_id,),
31+
).fetchone()
32+
return row is not None
33+
2434
def _next_run_sequence(self, report_output_id: int) -> int:
2535
row: Row | None = database.query(
2636
"""
@@ -50,6 +60,14 @@ def _parse_run_row(self, row: Row | dict | None) -> dict | None:
5060
)
5161
return run
5262

63+
def _is_sequence_conflict(self, error: Exception) -> bool:
64+
message = str(error)
65+
return (
66+
"report_output_run_sequence_idx" in message
67+
or "report_output_runs.report_output_id, report_output_runs.run_sequence"
68+
in message
69+
)
70+
5371
def create_report_output_run(
5472
self,
5573
report_output_id: int,
@@ -62,35 +80,53 @@ def create_report_output_run(
6280
version_manifest: dict[str, str | None] | None = None,
6381
run_id: str | None = None,
6482
) -> dict:
83+
if not self._report_output_exists(report_output_id):
84+
raise ValueError(f"Report output #{report_output_id} not found")
85+
6586
run_id = run_id or str(uuid.uuid4())
66-
run_sequence = self._next_run_sequence(report_output_id)
6787
version_manifest = version_manifest or {}
6888

69-
database.query(
70-
f"""
71-
INSERT INTO report_output_runs (
72-
id, report_output_id, run_sequence, status, output, error_message,
73-
trigger_type, requested_at, started_at, finished_at, source_run_id,
74-
report_spec_snapshot_json, {", ".join(REPORT_RUN_VERSION_FIELDS)}
75-
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
76-
""",
77-
(
78-
run_id,
79-
report_output_id,
80-
run_sequence,
81-
status,
82-
self._serialize_json(output),
83-
error_message,
84-
trigger_type,
85-
None,
86-
None,
87-
None,
88-
source_run_id,
89-
self._serialize_json(report_spec_snapshot),
90-
*[version_manifest.get(field) for field in REPORT_RUN_VERSION_FIELDS],
91-
),
89+
for attempt in range(MAX_CREATE_RUN_ATTEMPTS):
90+
run_sequence = self._next_run_sequence(report_output_id)
91+
try:
92+
database.query(
93+
f"""
94+
INSERT INTO report_output_runs (
95+
id, report_output_id, run_sequence, status, output, error_message,
96+
trigger_type, requested_at, started_at, finished_at, source_run_id,
97+
report_spec_snapshot_json, {", ".join(REPORT_RUN_VERSION_FIELDS)}
98+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
99+
""",
100+
(
101+
run_id,
102+
report_output_id,
103+
run_sequence,
104+
status,
105+
self._serialize_json(output),
106+
error_message,
107+
trigger_type,
108+
None,
109+
None,
110+
None,
111+
source_run_id,
112+
self._serialize_json(report_spec_snapshot),
113+
*[
114+
version_manifest.get(field)
115+
for field in REPORT_RUN_VERSION_FIELDS
116+
],
117+
),
118+
)
119+
return self.get_report_output_run(run_id)
120+
except (sqlite3.IntegrityError, sqlalchemy.exc.IntegrityError) as error:
121+
if (
122+
attempt == MAX_CREATE_RUN_ATTEMPTS - 1
123+
or not self._is_sequence_conflict(error)
124+
):
125+
raise
126+
127+
raise RuntimeError(
128+
f"Unable to allocate report output run sequence for #{report_output_id}"
92129
)
93-
return self.get_report_output_run(run_id)
94130

95131
def get_report_output_run(self, run_id: str) -> dict | None:
96132
row: Row | None = database.query(
@@ -124,7 +160,13 @@ def get_newest_report_output_run(self, report_output_id: int) -> dict | None:
124160

125161
def select_display_run(self, report_output: dict) -> dict | None:
126162
if report_output.get("active_run_id"):
127-
return self.get_report_output_run(report_output["active_run_id"])
163+
active_run = self.get_report_output_run(report_output["active_run_id"])
164+
if active_run is not None:
165+
return active_run
128166
if report_output.get("latest_successful_run_id"):
129-
return self.get_report_output_run(report_output["latest_successful_run_id"])
167+
latest_successful_run = self.get_report_output_run(
168+
report_output["latest_successful_run_id"]
169+
)
170+
if latest_successful_run is not None:
171+
return latest_successful_run
130172
return self.get_newest_report_output_run(report_output["id"])

policyengine_api/services/report_spec_service.py

Lines changed: 121 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,116 @@ class EconomyReportSpec(BaseModel):
4242

4343

4444
class ReportSpecService:
45+
def _validate_schema_version(self, schema_version: int | None) -> None:
46+
if schema_version != REPORT_SPEC_SCHEMA_VERSION:
47+
raise ValueError(
48+
f"Unsupported report spec schema version: {schema_version}"
49+
)
50+
4551
def _get_report_output_row(self, report_output_id: int) -> dict | None:
4652
row: Row | None = database.query(
4753
"SELECT * FROM report_outputs WHERE id = ?",
4854
(report_output_id,),
4955
).fetchone()
5056
return dict(row) if row is not None else None
5157

58+
def _validate_report_country(
59+
self,
60+
report_output: dict,
61+
simulation_1: dict,
62+
simulation_2: dict | None = None,
63+
) -> None:
64+
report_country_id = report_output["country_id"]
65+
if simulation_1["country_id"] != report_country_id:
66+
raise ValueError(
67+
"Simulation 1 country must match report output country to build a "
68+
"report spec"
69+
)
70+
if simulation_2 is not None and simulation_2["country_id"] != report_country_id:
71+
raise ValueError(
72+
"Simulation 2 country must match report output country to build a "
73+
"report spec"
74+
)
75+
76+
def _build_household_report_spec(
77+
self,
78+
report_output: dict,
79+
report_kind: str,
80+
simulation_1: dict,
81+
simulation_2: dict | None,
82+
time_period: str,
83+
) -> HouseholdReportSpec:
84+
if simulation_1["population_type"] != "household":
85+
raise ValueError("Household report specs require household simulations")
86+
if (
87+
simulation_2 is not None
88+
and simulation_2["population_id"] != simulation_1["population_id"]
89+
):
90+
raise ValueError(
91+
"Household comparison report specs require matching household IDs"
92+
)
93+
94+
return HouseholdReportSpec.model_validate(
95+
{
96+
"country_id": report_output["country_id"],
97+
"report_kind": report_kind,
98+
"time_period": time_period,
99+
"simulation_1": {
100+
"population_type": simulation_1["population_type"],
101+
"population_id": simulation_1["population_id"],
102+
"policy_id": simulation_1["policy_id"],
103+
},
104+
"simulation_2": (
105+
{
106+
"population_type": simulation_2["population_type"],
107+
"population_id": simulation_2["population_id"],
108+
"policy_id": simulation_2["policy_id"],
109+
}
110+
if simulation_2 is not None
111+
else None
112+
),
113+
}
114+
)
115+
116+
def _build_economy_report_spec(
117+
self,
118+
report_output: dict,
119+
report_kind: str,
120+
simulation_1: dict,
121+
simulation_2: dict | None,
122+
time_period: str,
123+
dataset: str,
124+
target: Literal["general", "cliff"],
125+
options: dict[str, Any] | None,
126+
) -> EconomyReportSpec:
127+
if simulation_1["population_type"] != "geography":
128+
raise ValueError("Economy report specs require geography simulations")
129+
if (
130+
simulation_2 is not None
131+
and simulation_2["population_id"] != simulation_1["population_id"]
132+
):
133+
raise ValueError(
134+
"Economy comparison report specs require matching geography IDs"
135+
)
136+
137+
return EconomyReportSpec.model_validate(
138+
{
139+
"country_id": report_output["country_id"],
140+
"report_kind": report_kind,
141+
"time_period": time_period,
142+
"region": simulation_1["population_id"],
143+
"baseline_policy_id": simulation_1["policy_id"],
144+
"reform_policy_id": (
145+
simulation_2["policy_id"]
146+
if simulation_2 is not None
147+
else simulation_1["policy_id"]
148+
),
149+
"dataset": dataset,
150+
"target": target,
151+
"options": options or {},
152+
}
153+
)
154+
52155
def infer_report_kind(
53156
self,
54157
simulation_1: dict,
@@ -88,46 +191,26 @@ def build_report_spec(
88191
) -> ReportSpec:
89192
report_kind = self.infer_report_kind(simulation_1, simulation_2)
90193
time_period = report_output["year"]
194+
self._validate_report_country(report_output, simulation_1, simulation_2)
91195

92196
if report_kind in HOUSEHOLD_REPORT_KINDS:
93-
return HouseholdReportSpec.model_validate(
94-
{
95-
"country_id": report_output["country_id"],
96-
"report_kind": report_kind,
97-
"time_period": time_period,
98-
"simulation_1": {
99-
"population_type": simulation_1["population_type"],
100-
"population_id": simulation_1["population_id"],
101-
"policy_id": simulation_1["policy_id"],
102-
},
103-
"simulation_2": (
104-
{
105-
"population_type": simulation_2["population_type"],
106-
"population_id": simulation_2["population_id"],
107-
"policy_id": simulation_2["policy_id"],
108-
}
109-
if simulation_2 is not None
110-
else None
111-
),
112-
}
197+
return self._build_household_report_spec(
198+
report_output=report_output,
199+
report_kind=report_kind,
200+
simulation_1=simulation_1,
201+
simulation_2=simulation_2,
202+
time_period=time_period,
113203
)
114204

115-
return EconomyReportSpec.model_validate(
116-
{
117-
"country_id": report_output["country_id"],
118-
"report_kind": report_kind,
119-
"time_period": time_period,
120-
"region": simulation_1["population_id"],
121-
"baseline_policy_id": simulation_1["policy_id"],
122-
"reform_policy_id": (
123-
simulation_2["policy_id"]
124-
if simulation_2 is not None
125-
else simulation_1["policy_id"]
126-
),
127-
"dataset": dataset,
128-
"target": target,
129-
"options": options or {},
130-
}
205+
return self._build_economy_report_spec(
206+
report_output=report_output,
207+
report_kind=report_kind,
208+
simulation_1=simulation_1,
209+
simulation_2=simulation_2,
210+
time_period=time_period,
211+
dataset=dataset,
212+
target=target,
213+
options=options,
131214
)
132215

133216
def _parse_json_field(self, value: str | dict | None) -> dict | None:
@@ -149,6 +232,7 @@ def get_report_spec(self, report_output_id: int) -> ReportSpec | None:
149232
if report_output is None or report_output["report_spec_json"] is None:
150233
return None
151234

235+
self._validate_schema_version(report_output["report_spec_schema_version"])
152236
raw_spec = self._parse_json_field(report_output["report_spec_json"])
153237
return self._parse_report_spec(report_output["report_kind"], raw_spec)
154238

@@ -161,6 +245,7 @@ def set_report_spec(
161245
) -> bool:
162246
if report_spec_status not in REPORT_SPEC_STATUSES:
163247
raise ValueError(f"Unsupported report spec status: {report_spec_status}")
248+
self._validate_schema_version(schema_version)
164249

165250
report_output = self._get_report_output_row(report_output_id)
166251
if report_output is None:

0 commit comments

Comments
 (0)