Skip to content

Commit 7507d29

Browse files
authored
Version economy and report output caches against runtime (#3384)
* Version economy and report output caches against runtime * Fix stale report output aliasing * Keep stale report updates isolated
1 parent b60fc53 commit 7507d29

9 files changed

Lines changed: 373 additions & 51 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Version economy caches and report output reuse against the full runtime, and strip stale congressional district payloads from legacy US reports so clients refresh district outcomes from live state summaries.

policyengine_api/constants.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
from importlib.metadata import distributions
33
from datetime import datetime
4+
import hashlib
45

56
REPO = Path(__file__).parents[1]
67
GET = "GET"
@@ -17,14 +18,85 @@
1718
"policyengine_ng",
1819
"policyengine_il",
1920
)
21+
22+
23+
def _normalize_distribution_name(name: str | None) -> str:
24+
if name is None:
25+
return ""
26+
return name.replace("_", "-").lower()
27+
28+
29+
def _resolve_distribution_version(
30+
dist_versions: dict[str, str], *package_names: str
31+
) -> str:
32+
for package_name in package_names:
33+
version = dist_versions.get(_normalize_distribution_name(package_name))
34+
if version is not None:
35+
return version
36+
return "0.0.0"
37+
38+
2039
try:
21-
_dist_versions = {d.metadata["Name"]: d.version for d in distributions()}
40+
_dist_versions = {
41+
_normalize_distribution_name(d.metadata["Name"]): d.version
42+
for d in distributions()
43+
}
2244
COUNTRY_PACKAGE_VERSIONS = {
23-
country: _dist_versions.get(package_name.replace("_", "-"), "0.0.0")
45+
country: _resolve_distribution_version(_dist_versions, package_name)
2446
for country, package_name in zip(COUNTRIES, COUNTRY_PACKAGE_NAMES)
2547
}
48+
POLICYENGINE_CORE_VERSION = _resolve_distribution_version(
49+
_dist_versions, "policyengine-core", "policyengine"
50+
)
2651
except Exception:
2752
COUNTRY_PACKAGE_VERSIONS = {country: "0.0.0" for country in COUNTRIES}
53+
POLICYENGINE_CORE_VERSION = "0.0.0"
54+
55+
RUNTIME_CACHE_SCHEMA_VERSIONS = {
56+
"economy_impact": 1,
57+
"report_output": 1,
58+
}
59+
60+
61+
def _build_runtime_cache_version(
62+
scope: str, country_id: str, caller_version: str | None = None
63+
) -> str:
64+
"""
65+
Build a compact version token for cache keys stored in legacy VARCHAR(10)
66+
columns. The token changes whenever the relevant runtime or payload schema
67+
changes, even if the country package version is unchanged.
68+
"""
69+
schema_version = str(RUNTIME_CACHE_SCHEMA_VERSIONS[scope])
70+
prefix = "e" if scope == "economy_impact" else "r"
71+
digest_length = 10 - len(prefix) - len(schema_version)
72+
if digest_length < 4:
73+
raise ValueError(
74+
f"Runtime cache version for {scope} does not fit in VARCHAR(10)"
75+
)
76+
77+
raw = "|".join(
78+
(
79+
scope,
80+
country_id,
81+
caller_version or COUNTRY_PACKAGE_VERSIONS.get(country_id, "0.0.0"),
82+
COUNTRY_PACKAGE_VERSIONS.get(country_id, "0.0.0"),
83+
POLICYENGINE_CORE_VERSION,
84+
schema_version,
85+
)
86+
)
87+
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:digest_length]
88+
return f"{prefix}{schema_version}{digest}"
89+
90+
91+
def get_economy_impact_cache_version(
92+
country_id: str, caller_version: str | None = None
93+
) -> str:
94+
return _build_runtime_cache_version("economy_impact", country_id, caller_version)
95+
96+
97+
def get_report_output_cache_version(country_id: str) -> str:
98+
return _build_runtime_cache_version("report_output", country_id)
99+
28100

29101
# Valid region types for each country
30102
# These define the geographic scope categories for regions

policyengine_api/routes/report_output_routes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def update_report_output(country_id: str) -> Response:
160160

161161
try:
162162
# First check if the report output exists
163-
existing_report = report_output_service.get_report_output(report_id)
163+
existing_report = report_output_service.get_stored_report_output(report_id)
164164
if existing_report is None:
165165
raise NotFound(f"Report #{report_id} not found.")
166166

@@ -176,8 +176,9 @@ def update_report_output(country_id: str) -> Response:
176176
if not success:
177177
raise BadRequest("No fields to update")
178178

179-
# Get the updated record
180-
updated_report = report_output_service.get_report_output(report_id)
179+
# Get the updated stored record so stale-runtime jobs do not appear to
180+
# complete the current runtime lineage in the PATCH response.
181+
updated_report = report_output_service.get_stored_report_output(report_id)
181182

182183
response_body = dict(
183184
status="ok",

policyengine_api/services/economy_service.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
EXECUTION_STATUSES_SUCCESS,
99
EXECUTION_STATUSES_FAILURE,
1010
EXECUTION_STATUSES_PENDING,
11+
get_economy_impact_cache_version,
1112
)
1213
from policyengine_api.gcp_logging import logger
1314
from policyengine_api.libs.simulation_api_modal import simulation_api_modal
@@ -164,6 +165,8 @@ def get_economic_impact(
164165
if country_id == "uk":
165166
country_package_version = None
166167

168+
cache_version = get_economy_impact_cache_version(country_id, api_version)
169+
167170
economic_impact_setup_options = EconomicImpactSetupOptions.model_validate(
168171
{
169172
"process_id": process_id,
@@ -174,7 +177,7 @@ def get_economic_impact(
174177
"dataset": dataset,
175178
"time_period": time_period,
176179
"options": options,
177-
"api_version": api_version,
180+
"api_version": cache_version,
178181
"target": target,
179182
"model_version": country_package_version,
180183
"data_version": get_dataset_version(country_id),

policyengine_api/services/report_output_service.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,52 @@
11
from sqlalchemy.engine.row import Row
22

33
from policyengine_api.data import database
4-
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
4+
from policyengine_api.constants import get_report_output_cache_version
55

66

77
class ReportOutputService:
8+
def _get_report_output_row(self, report_output_id: int) -> dict | None:
9+
row: Row | None = database.query(
10+
"SELECT * FROM report_outputs WHERE id = ?",
11+
(report_output_id,),
12+
).fetchone()
13+
return dict(row) if row is not None else None
14+
15+
def get_stored_report_output(self, report_output_id: int) -> dict | None:
16+
"""
17+
Get the raw stored report output row by ID without aliasing to the
18+
current runtime lineage. This is useful for mutation paths, which must
19+
update the originally addressed row rather than a resolved alias.
20+
"""
21+
return self._get_report_output_row(report_output_id)
22+
23+
def _is_current_report_output(self, report_output: dict) -> bool:
24+
return report_output.get("api_version") == get_report_output_cache_version(
25+
report_output["country_id"]
26+
)
27+
28+
def _get_or_create_current_report_output(self, report_output: dict) -> dict:
29+
current_report = self.find_existing_report_output(
30+
country_id=report_output["country_id"],
31+
simulation_1_id=report_output["simulation_1_id"],
32+
simulation_2_id=report_output["simulation_2_id"],
33+
year=report_output["year"],
34+
)
35+
if current_report is not None:
36+
return current_report
37+
38+
return self.create_report_output(
39+
country_id=report_output["country_id"],
40+
simulation_1_id=report_output["simulation_1_id"],
41+
simulation_2_id=report_output["simulation_2_id"],
42+
year=report_output["year"],
43+
)
44+
45+
def _alias_report_output(self, report_output_id: int, report_output: dict) -> dict:
46+
aliased_report = dict(report_output)
47+
aliased_report["id"] = report_output_id
48+
return aliased_report
49+
850
def find_existing_report_output(
951
self,
1052
country_id: str,
@@ -25,18 +67,20 @@ def find_existing_report_output(
2567
dict | None: The existing report output data or None if not found.
2668
"""
2769
print("Checking for existing report output")
70+
api_version = get_report_output_cache_version(country_id)
2871

2972
try:
30-
# Check for existing record with the same simulation IDs and year (excluding api_version)
31-
query = "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ?"
32-
params = [country_id, simulation_1_id, year]
73+
query = "SELECT * FROM report_outputs WHERE country_id = ? AND simulation_1_id = ? AND year = ? AND api_version = ?"
74+
params = [country_id, simulation_1_id, year, api_version]
3375

3476
if simulation_2_id is not None:
3577
query += " AND simulation_2_id = ?"
3678
params.append(simulation_2_id)
3779
else:
3880
query += " AND simulation_2_id IS NULL"
3981

82+
query += " ORDER BY id DESC"
83+
4084
row = database.query(query, tuple(params)).fetchone()
4185

4286
existing_report = None
@@ -71,9 +115,18 @@ def create_report_output(
71115
dict: The created report output record.
72116
"""
73117
print("Creating new report output")
74-
api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id)
118+
api_version = get_report_output_cache_version(country_id)
75119

76120
try:
121+
existing_report = self.find_existing_report_output(
122+
country_id, simulation_1_id, simulation_2_id, year
123+
)
124+
if existing_report is not None:
125+
print(
126+
f"Reusing existing report output with ID: {existing_report['id']}"
127+
)
128+
return existing_report
129+
77130
# Insert with default status 'pending'
78131
if simulation_2_id is not None:
79132
database.query(
@@ -132,18 +185,15 @@ def get_report_output(self, report_output_id: int) -> dict | None:
132185
f"Invalid report output ID: {report_output_id}. Must be a positive integer."
133186
)
134187

135-
row: Row | None = database.query(
136-
"SELECT * FROM report_outputs WHERE id = ?",
137-
(report_output_id,),
138-
).fetchone()
188+
report_output = self._get_report_output_row(report_output_id)
189+
if report_output is None:
190+
return None
139191

140-
report_output = None
141-
if row is not None:
142-
report_output = dict(row)
143-
# Keep output as JSON string - frontend expects string format
144-
# Frontend will parse it using JSON.parse()
192+
if self._is_current_report_output(report_output):
193+
return report_output
145194

146-
return report_output
195+
current_report = self._get_or_create_current_report_output(report_output)
196+
return self._alias_report_output(report_output_id, current_report)
147197

148198
except Exception as e:
149199
print(
@@ -172,10 +222,12 @@ def update_report_output(
172222
bool: True if update was successful.
173223
"""
174224
print(f"Updating report output {report_id}")
175-
# Automatically update api_version on every update to latest
176-
api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id)
177225

178226
try:
227+
requested_report = self._get_report_output_row(report_id)
228+
if requested_report is None:
229+
raise Exception(f"Report output #{report_id} not found")
230+
179231
# Build the update query dynamically based on provided fields
180232
update_fields = []
181233
update_values = []
@@ -193,16 +245,12 @@ def update_report_output(
193245
update_fields.append("error_message = ?")
194246
update_values.append(error_message)
195247

196-
# Always update API version
197-
update_fields.append("api_version = ?")
198-
update_values.append(api_version)
199-
200248
if not update_fields:
201249
print("No fields to update")
202250
return False
203251

204252
# Add report_id to the end of values for WHERE clause
205-
update_values.append(report_id)
253+
update_values.append(requested_report["id"])
206254

207255
query = f"UPDATE report_outputs SET {', '.join(update_fields)} WHERE id = ?"
208256

tests/fixtures/services/report_output_fixtures.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import pytest
22
import json
33

4+
from policyengine_api.constants import get_report_output_cache_version
5+
46
valid_report_data = {
57
"country_id": "us",
68
"simulation_1_id": 1,
79
"simulation_2_id": None,
8-
"api_version": "1.0.0",
10+
"api_version": get_report_output_cache_version("us"),
911
"status": "pending",
1012
"output": None,
1113
"error_message": None,

tests/unit/services/test_economy_service.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,40 @@ def test__given_no_previous_impact__includes_metadata_in_simulation_params(
212212
)
213213
assert sim_params["_metadata"]["process_id"] == MOCK_PROCESS_ID
214214

215+
def test__given_runtime_cache_version__uses_versioned_economy_cache_key(
216+
self,
217+
economy_service,
218+
base_params,
219+
mock_country_package_versions,
220+
mock_get_dataset_version,
221+
mock_policy_service,
222+
mock_reform_impacts_service,
223+
mock_simulation_api,
224+
mock_logger,
225+
mock_datetime,
226+
mock_numpy_random,
227+
monkeypatch,
228+
):
229+
cache_version = "e1cache01"
230+
monkeypatch.setattr(
231+
"policyengine_api.services.economy_service.get_economy_impact_cache_version",
232+
lambda country_id, api_version=None: cache_version,
233+
)
234+
mock_reform_impacts_service.get_all_reform_impacts.return_value = []
235+
236+
economy_service.get_economic_impact(**base_params)
237+
238+
mock_reform_impacts_service.get_all_reform_impacts.assert_called_once_with(
239+
MOCK_COUNTRY_ID,
240+
MOCK_POLICY_ID,
241+
MOCK_BASELINE_POLICY_ID,
242+
MOCK_REGION,
243+
MOCK_DATASET,
244+
MOCK_TIME_PERIOD,
245+
MOCK_OPTIONS_HASH,
246+
cache_version,
247+
)
248+
215249
def test__given_exception__raises_error(
216250
self,
217251
economy_service,

0 commit comments

Comments
 (0)