Skip to content

Commit b921172

Browse files
authored
Roll back simulation API to v1 (#2448)
* fix: Roll back to sim API v1, re-add logging * chore: Changelog * fix: Add python-dotenv to setup.py * fix: Properly handle decile keys
1 parent 882e4e1 commit b921172

8 files changed

Lines changed: 594 additions & 37 deletions

File tree

changelog_entry.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
- bump: minor
2+
changes:
3+
changed:
4+
- Rolled back to simulation API v1
5+
added:
6+
- New logging structure
7+
- New logging outputs for simulation API v1 and v2 runs

gcp/policyengine_api/app.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
runtime: custom
22
env: flex
33
resources:
4-
cpu: 4
5-
memory_gb: 16
6-
disk_size_gb: 64
4+
cpu: 24
5+
memory_gb: 128
6+
disk_size_gb: 128
77
automatic_scaling:
88
min_num_instances: 1
99
max_num_instances: 1

policyengine_api/data/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
from policyengine_api.constants import REPO, VERSION, COUNTRY_PACKAGE_VERSIONS
33
from policyengine_api.utils import hash_object
44
from pathlib import Path
5+
from dotenv import load_dotenv
56
import json
67
from google.cloud.sql.connector import Connector
78
import sqlalchemy
89
import sqlalchemy.exc
910
import os
1011
import sys
1112

13+
load_dotenv()
14+
1215

1316
class PolicyEngineDatabase:
1417
"""

policyengine_api/jobs/calculate_economy_simulation_job.py

Lines changed: 146 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import datetime
55
import time
66
import os
7+
import math
78
from typing import Type, Any, Literal
89
import pandas as pd
910
import numpy as np
11+
from dotenv import load_dotenv
1012
from google.cloud import workflows_v1
1113
from google.cloud.workflows import executions_v1
1214

@@ -19,6 +21,10 @@
1921
from policyengine_api.endpoints.economy.reform_impact import set_comment_on_job
2022
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
2123
from policyengine_api.country import COUNTRIES, create_policy_reform
24+
from policyengine_api.utils.v2_v1_comparison import (
25+
V2V1Comparison,
26+
compute_difference,
27+
)
2228
from policyengine_core.simulations import Microsimulation
2329
from policyengine_core.tools.hugging_face import download_huggingface_dataset
2430
import h5py
@@ -27,6 +33,8 @@
2733
from policyengine_uk import Microsimulation
2834
import logging
2935

36+
load_dotenv()
37+
3038
reform_impacts_service = ReformImpactsService()
3139

3240
ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5"
@@ -36,18 +44,20 @@
3644
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
3745
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
3846

39-
use_api_v2 = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None
47+
check_against_api_v2 = (
48+
os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None
49+
)
4050

41-
if not use_api_v2:
51+
if not check_against_api_v2:
4252
logging.warn(
43-
"Didn't find any GOOGLE_APPLICATION_CREDENTIALS, so will not use APIv2."
53+
"Didn't find any GOOGLE_APPLICATION_CREDENTIALS, so will not check APIv1 results against APIv2."
4454
)
4555

4656

4757
class CalculateEconomySimulationJob(BaseJob):
4858
def __init__(self):
4959
super().__init__()
50-
if use_api_v2:
60+
if check_against_api_v2:
5161
self.api_v2 = SimulationAPIv2()
5262

5363
def run(
@@ -148,8 +158,21 @@ def run(
148158
comment = lambda x: set_comment_on_job(x, *identifiers)
149159
comment("Computing baseline")
150160

151-
# Kick off APIv2 job
152-
if use_api_v2:
161+
# If comparing against API v2, start job
162+
if check_against_api_v2:
163+
164+
# Populate v2/v1 comparison config data; we will pass this
165+
# to GCP logs either on error or success
166+
comparison_data = {
167+
"country_id": country_id,
168+
"region": region,
169+
"reform_policy": reform_policy,
170+
"baseline_policy": baseline_policy,
171+
"reform_policy_id": policy_id,
172+
"baseline_policy_id": baseline_policy_id,
173+
"time_period": time_period,
174+
"dataset": dataset,
175+
}
153176

154177
# Set up APIv2 job
155178
comment("Setting up APIv2 job")
@@ -163,37 +186,123 @@ def run(
163186
dataset=dataset,
164187
)
165188

166-
execution = self.api_v2.run(sim_config)
189+
try:
190+
api_v2_execution = self.api_v2.run(sim_config)
191+
execution_id = self.api_v2.get_execution_id(
192+
api_v2_execution
193+
)
167194

168-
impact = self.api_v2.wait_for_completion(execution)
169-
else:
170-
# Compute baseline economy
171-
baseline_economy = self._compute_economy(
172-
country_id=country_id,
173-
region=region,
174-
dataset=dataset,
175-
time_period=time_period,
176-
options=options,
177-
policy_json=baseline_policy,
178-
)
179-
comment("Computing reform")
195+
comparison_data["v2_id"] = execution_id
196+
197+
# Pass name and status to logs
198+
progress_log: V2V1Comparison = (
199+
V2V1Comparison.model_validate(
200+
{
201+
**comparison_data,
202+
"v1_impact": None,
203+
"v2_impact": None,
204+
"v1_v2_diff": None,
205+
"message": "APIv2 job started",
206+
}
207+
)
208+
)
209+
logging.info(progress_log.model_dump_json())
210+
except Exception as e:
211+
# Send error log to GCP
212+
error_log: V2V1Comparison = V2V1Comparison.model_validate(
213+
{
214+
**comparison_data,
215+
"v2_error": str(e),
216+
}
217+
)
180218

181-
# Compute reform economy
182-
reform_economy = self._compute_economy(
183-
country_id=country_id,
184-
region=region,
185-
dataset=dataset,
186-
time_period=time_period,
187-
options=options,
188-
policy_json=reform_policy,
189-
)
219+
logging.error(error_log.model_dump_json())
190220

191-
baseline_economy = baseline_economy["result"]
192-
reform_economy = reform_economy["result"]
193-
comment("Comparing baseline and reform")
194-
impact = compare_economic_outputs(
195-
baseline_economy, reform_economy, country_id=country_id
196-
)
221+
# Compute baseline economy
222+
baseline_economy = self._compute_economy(
223+
country_id=country_id,
224+
region=region,
225+
dataset=dataset,
226+
time_period=time_period,
227+
options=options,
228+
policy_json=baseline_policy,
229+
)
230+
comment("Computing reform")
231+
232+
# Compute reform economy
233+
reform_economy = self._compute_economy(
234+
country_id=country_id,
235+
region=region,
236+
dataset=dataset,
237+
time_period=time_period,
238+
options=options,
239+
policy_json=reform_policy,
240+
)
241+
242+
baseline_economy = baseline_economy["result"]
243+
reform_economy = reform_economy["result"]
244+
comment("Comparing baseline and reform")
245+
impact: dict[str, Any] = compare_economic_outputs(
246+
baseline_economy, reform_economy, country_id=country_id
247+
)
248+
249+
# If comparing against API v2, wait for job to complete
250+
if check_against_api_v2:
251+
252+
try:
253+
execution_id: str = self.api_v2.get_execution_id(
254+
api_v2_execution
255+
)
256+
api_v2_output = self.api_v2.wait_for_completion(
257+
api_v2_execution
258+
)
259+
260+
completion_log: V2V1Comparison = (
261+
V2V1Comparison.model_validate(
262+
{
263+
**comparison_data,
264+
"v1_impact": impact,
265+
"v2_impact": api_v2_output,
266+
"v1_v2_diff": None,
267+
"message": "APIv2 job completed",
268+
}
269+
)
270+
)
271+
272+
logging.info(completion_log.model_dump_json())
273+
# Run v2/v1 comparison
274+
275+
v1_v2_diff: dict[str, Any] = compute_difference(
276+
x=impact,
277+
y=api_v2_output,
278+
)
279+
# Push relevant info into logging schema
280+
comparison_log: V2V1Comparison = (
281+
V2V1Comparison.model_validate(
282+
{
283+
**comparison_data,
284+
"v1_impact": impact,
285+
"v2_impact": api_v2_output,
286+
"v1_v2_diff": v1_v2_diff,
287+
"message": "APIv2 job comparison with APIv1 completed",
288+
}
289+
)
290+
)
291+
logging.info(comparison_log.model_dump_json())
292+
293+
except Exception as e:
294+
# If job fails, send error log to GCP
295+
error_log: V2V1Comparison = V2V1Comparison.model_validate(
296+
{
297+
**comparison_data,
298+
"v2_error": str(e),
299+
"v1_impact": impact,
300+
"v2_impact": None,
301+
"v1_v2_diff": None,
302+
"message": "APIv2 job failed",
303+
}
304+
)
305+
logging.error(error_log.model_dump_json())
197306

198307
# Finally, update all reform impact rows with the same baseline and reform policy IDs
199308
reform_impacts_service.set_complete_reform_impact(
@@ -491,6 +600,9 @@ def run(self, payload: dict) -> executions_v1.Execution:
491600
)
492601
return execution
493602

603+
def get_execution_id(self, execution: executions_v1.Execution) -> str:
604+
return execution.name
605+
494606
def get_execution_status(self, execution: executions_v1.Execution) -> str:
495607
"""
496608
Get the status of an execution

0 commit comments

Comments
 (0)