44import datetime
55import time
66import os
7+ import math
78from typing import Type , Any , Literal
89import pandas as pd
910import numpy as np
11+ from dotenv import load_dotenv
1012from google .cloud import workflows_v1
1113from google .cloud .workflows import executions_v1
1214
1921from policyengine_api .endpoints .economy .reform_impact import set_comment_on_job
2022from policyengine_api .constants import COUNTRY_PACKAGE_VERSIONS
2123from policyengine_api .country import COUNTRIES , create_policy_reform
24+ from policyengine_api .utils .v2_v1_comparison import (
25+ V2V1Comparison ,
26+ compute_difference ,
27+ )
2228from policyengine_core .simulations import Microsimulation
2329from policyengine_core .tools .hugging_face import download_huggingface_dataset
2430import h5py
2733from policyengine_uk import Microsimulation
2834import logging
2935
36+ load_dotenv ()
37+
3038reform_impacts_service = ReformImpactsService ()
3139
3240ENHANCED_FRS = "hf://policyengine/policyengine-uk-data/enhanced_frs_2022_23.h5"
3644CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
3745POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
3846
39- use_api_v2 = os .environ .get ("GOOGLE_APPLICATION_CREDENTIALS" ) is not None
47+ check_against_api_v2 = (
48+ os .environ .get ("GOOGLE_APPLICATION_CREDENTIALS" ) is not None
49+ )
4050
41- if not use_api_v2 :
51+ if not check_against_api_v2 :
4252 logging .warn (
43- "Didn't find any GOOGLE_APPLICATION_CREDENTIALS, so will not use APIv2."
53+ "Didn't find any GOOGLE_APPLICATION_CREDENTIALS, so will not check APIv1 results against APIv2."
4454 )
4555
4656
4757class CalculateEconomySimulationJob (BaseJob ):
4858 def __init__ (self ):
4959 super ().__init__ ()
50- if use_api_v2 :
60+ if check_against_api_v2 :
5161 self .api_v2 = SimulationAPIv2 ()
5262
5363 def run (
@@ -148,8 +158,21 @@ def run(
148158 comment = lambda x : set_comment_on_job (x , * identifiers )
149159 comment ("Computing baseline" )
150160
151- # Kick off APIv2 job
152- if use_api_v2 :
161+ # If comparing against API v2, start job
162+ if check_against_api_v2 :
163+
164+ # Populate v2/v1 comparison config data; we will pass this
165+ # to GCP logs either on error or success
166+ comparison_data = {
167+ "country_id" : country_id ,
168+ "region" : region ,
169+ "reform_policy" : reform_policy ,
170+ "baseline_policy" : baseline_policy ,
171+ "reform_policy_id" : policy_id ,
172+ "baseline_policy_id" : baseline_policy_id ,
173+ "time_period" : time_period ,
174+ "dataset" : dataset ,
175+ }
153176
154177 # Set up APIv2 job
155178 comment ("Setting up APIv2 job" )
@@ -163,37 +186,123 @@ def run(
163186 dataset = dataset ,
164187 )
165188
166- execution = self .api_v2 .run (sim_config )
189+ try :
190+ api_v2_execution = self .api_v2 .run (sim_config )
191+ execution_id = self .api_v2 .get_execution_id (
192+ api_v2_execution
193+ )
167194
168- impact = self .api_v2 .wait_for_completion (execution )
169- else :
170- # Compute baseline economy
171- baseline_economy = self ._compute_economy (
172- country_id = country_id ,
173- region = region ,
174- dataset = dataset ,
175- time_period = time_period ,
176- options = options ,
177- policy_json = baseline_policy ,
178- )
179- comment ("Computing reform" )
195+ comparison_data ["v2_id" ] = execution_id
196+
197+ # Pass name and status to logs
198+ progress_log : V2V1Comparison = (
199+ V2V1Comparison .model_validate (
200+ {
201+ ** comparison_data ,
202+ "v1_impact" : None ,
203+ "v2_impact" : None ,
204+ "v1_v2_diff" : None ,
205+ "message" : "APIv2 job started" ,
206+ }
207+ )
208+ )
209+ logging .info (progress_log .model_dump_json ())
210+ except Exception as e :
211+ # Send error log to GCP
212+ error_log : V2V1Comparison = V2V1Comparison .model_validate (
213+ {
214+ ** comparison_data ,
215+ "v2_error" : str (e ),
216+ }
217+ )
180218
181- # Compute reform economy
182- reform_economy = self ._compute_economy (
183- country_id = country_id ,
184- region = region ,
185- dataset = dataset ,
186- time_period = time_period ,
187- options = options ,
188- policy_json = reform_policy ,
189- )
219+ logging .error (error_log .model_dump_json ())
190220
191- baseline_economy = baseline_economy ["result" ]
192- reform_economy = reform_economy ["result" ]
193- comment ("Comparing baseline and reform" )
194- impact = compare_economic_outputs (
195- baseline_economy , reform_economy , country_id = country_id
196- )
221+ # Compute baseline economy
222+ baseline_economy = self ._compute_economy (
223+ country_id = country_id ,
224+ region = region ,
225+ dataset = dataset ,
226+ time_period = time_period ,
227+ options = options ,
228+ policy_json = baseline_policy ,
229+ )
230+ comment ("Computing reform" )
231+
232+ # Compute reform economy
233+ reform_economy = self ._compute_economy (
234+ country_id = country_id ,
235+ region = region ,
236+ dataset = dataset ,
237+ time_period = time_period ,
238+ options = options ,
239+ policy_json = reform_policy ,
240+ )
241+
242+ baseline_economy = baseline_economy ["result" ]
243+ reform_economy = reform_economy ["result" ]
244+ comment ("Comparing baseline and reform" )
245+ impact : dict [str , Any ] = compare_economic_outputs (
246+ baseline_economy , reform_economy , country_id = country_id
247+ )
248+
249+ # If comparing against API v2, wait for job to complete
250+ if check_against_api_v2 :
251+
252+ try :
253+ execution_id : str = self .api_v2 .get_execution_id (
254+ api_v2_execution
255+ )
256+ api_v2_output = self .api_v2 .wait_for_completion (
257+ api_v2_execution
258+ )
259+
260+ completion_log : V2V1Comparison = (
261+ V2V1Comparison .model_validate (
262+ {
263+ ** comparison_data ,
264+ "v1_impact" : impact ,
265+ "v2_impact" : api_v2_output ,
266+ "v1_v2_diff" : None ,
267+ "message" : "APIv2 job completed" ,
268+ }
269+ )
270+ )
271+
272+ logging .info (completion_log .model_dump_json ())
273+ # Run v2/v1 comparison
274+
275+ v1_v2_diff : dict [str , Any ] = compute_difference (
276+ x = impact ,
277+ y = api_v2_output ,
278+ )
279+ # Push relevant info into logging schema
280+ comparison_log : V2V1Comparison = (
281+ V2V1Comparison .model_validate (
282+ {
283+ ** comparison_data ,
284+ "v1_impact" : impact ,
285+ "v2_impact" : api_v2_output ,
286+ "v1_v2_diff" : v1_v2_diff ,
287+ "message" : "APIv2 job comparison with APIv1 completed" ,
288+ }
289+ )
290+ )
291+ logging .info (comparison_log .model_dump_json ())
292+
293+ except Exception as e :
294+ # If job fails, send error log to GCP
295+ error_log : V2V1Comparison = V2V1Comparison .model_validate (
296+ {
297+ ** comparison_data ,
298+ "v2_error" : str (e ),
299+ "v1_impact" : impact ,
300+ "v2_impact" : None ,
301+ "v1_v2_diff" : None ,
302+ "message" : "APIv2 job failed" ,
303+ }
304+ )
305+ logging .error (error_log .model_dump_json ())
197306
198307 # Finally, update all reform impact rows with the same baseline and reform policy IDs
199308 reform_impacts_service .set_complete_reform_impact (
@@ -491,6 +600,9 @@ def run(self, payload: dict) -> executions_v1.Execution:
491600 )
492601 return execution
493602
603+ def get_execution_id (self , execution : executions_v1 .Execution ) -> str :
604+ return execution .name
605+
494606 def get_execution_status (self , execution : executions_v1 .Execution ) -> str :
495607 """
496608 Get the status of an execution
0 commit comments