1010import logging
1111import os
1212import tempfile
13- from typing import Iterator
13+ from typing import Any , Iterator
1414
15- # Module-level imports - these are SNAPSHOTTED at image build time
16- from policyengine . simulation import Simulation , SimulationOptions
17-
18- from src . modal .telemetry import split_internal_payload
15+ try :
16+ from src . modal . telemetry import split_internal_payload
17+ except ModuleNotFoundError :
18+ from modal .telemetry import split_internal_payload
1919
2020logger = logging .getLogger (__name__ )
2121
2222
23+ DEFAULT_YEAR = 2026
24+ DATASET_ALIASES = {
25+ "us" : {
26+ "enhanced_cps" : "enhanced_cps_2024" ,
27+ "enhanced_cps_2024" : "enhanced_cps_2024" ,
28+ "gs://policyengine-us-data/enhanced_cps_2024.h5" : "enhanced_cps_2024" ,
29+ "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5" : "enhanced_cps_2024" ,
30+ "cps_small" : "cps_small_2024" ,
31+ "cps_small_2024" : "cps_small_2024" ,
32+ },
33+ "uk" : {
34+ "enhanced_frs" : "enhanced_frs_2023_24" ,
35+ "enhanced_frs_2023_24" : "enhanced_frs_2023_24" ,
36+ "frs" : "frs_2023_24" ,
37+ "frs_2023_24" : "frs_2023_24" ,
38+ },
39+ }
40+
41+
2342def _normalize_credentials_blob (creds_json : str ) -> str :
2443 """Return the raw JSON blob, decoding the outer escape if present.
2544
@@ -115,6 +134,251 @@ def run_simulation_impl(params: dict) -> dict:
115134 return _run_simulation_impl_core (params )
116135
117136
137+ def _parse_year (params : dict [str , Any ]) -> int :
138+ value = params .get ("time_period" ) or params .get ("year" ) or DEFAULT_YEAR
139+ return int (value )
140+
141+
142+ def _normalise_period_key (period_key : Any ) -> str :
143+ """Convert legacy ``start.stop`` period keys to v4 effective dates."""
144+ text = str (period_key )
145+ parts = text .split ("." )
146+ if len (parts ) > 1 and len (parts [0 ]) == 10 :
147+ return parts [0 ]
148+ return text
149+
150+
151+ def _normalise_reform (reform : dict [str , Any ] | None ) -> dict [str , Any ] | None :
152+ if not reform :
153+ return None
154+ normalised : dict [str , Any ] = {}
155+ for parameter , value in reform .items ():
156+ if isinstance (value , dict ):
157+ normalised [parameter ] = {
158+ _normalise_period_key (period ): period_value
159+ for period , period_value in value .items ()
160+ }
161+ else :
162+ normalised [parameter ] = value
163+ return normalised
164+
165+
166+ def _resolve_dataset_name (
167+ country : str , requested_data : str | None , subsample : int | None
168+ ) -> str :
169+ if requested_data is None :
170+ return "enhanced_cps_2024" if country == "us" else "enhanced_frs_2023_24"
171+
172+ requested = requested_data .split ("@" , maxsplit = 1 )[0 ]
173+ return DATASET_ALIASES .get (country , {}).get (requested , requested_data )
174+
175+
176+ def _microframe_like (frame , weights : str ):
177+ from microdf import MicroDataFrame
178+
179+ return MicroDataFrame (frame .copy (), weights = weights )
180+
181+
182+ def _person_group_column (person , entity : str ) -> str :
183+ prefixed = f"person_{ entity } _id"
184+ if prefixed in person .columns :
185+ return prefixed
186+ return f"{ entity } _id"
187+
188+
189+ def _subsample_us_dataset (dataset , subsample : int | None ):
190+ if not subsample :
191+ return dataset
192+
193+ from policyengine .tax_benefit_models .us .datasets import (
194+ PolicyEngineUSDataset ,
195+ USYearData ,
196+ )
197+
198+ dataset .load ()
199+ data = dataset .data
200+ household = data .household .head (int (subsample )).copy ()
201+ household_ids = set (household ["household_id" ])
202+
203+ person_household_col = _person_group_column (data .person , "household" )
204+ person = data .person [data .person [person_household_col ].isin (household_ids )].copy ()
205+
206+ def group_subset (entity : str ):
207+ person_col = _person_group_column (person , entity )
208+ entity_id_col = f"{ entity } _id"
209+ ids = set (person [person_col ])
210+ frame = getattr (data , entity )
211+ return frame [frame [entity_id_col ].isin (ids )].copy ()
212+
213+ subset_data = USYearData (
214+ person = _microframe_like (person , "person_weight" ),
215+ marital_unit = _microframe_like (
216+ group_subset ("marital_unit" ), "marital_unit_weight"
217+ ),
218+ family = _microframe_like (group_subset ("family" ), "family_weight" ),
219+ spm_unit = _microframe_like (group_subset ("spm_unit" ), "spm_unit_weight" ),
220+ tax_unit = _microframe_like (group_subset ("tax_unit" ), "tax_unit_weight" ),
221+ household = _microframe_like (household , "household_weight" ),
222+ )
223+ subset_path = os .path .join (
224+ os .environ .get ("POLICYENGINE_DATA_FOLDER" , "/tmp/policyengine-data" ),
225+ f"{ dataset .id } _subsample_{ subsample } .h5" ,
226+ )
227+ return PolicyEngineUSDataset (
228+ id = f"{ dataset .id } _subsample_{ subsample } " ,
229+ name = f"{ dataset .name } subsample { subsample } " ,
230+ description = dataset .description ,
231+ filepath = subset_path ,
232+ year = dataset .year ,
233+ is_output_dataset = dataset .is_output_dataset ,
234+ data = subset_data ,
235+ )
236+
237+
238+ def _country_module (country : str ):
239+ import policyengine as pe
240+
241+ country = country .lower ()
242+ if country == "us" :
243+ return pe .us
244+ if country == "uk" :
245+ return pe .uk
246+ raise ValueError (f"Unsupported country: { country } " )
247+
248+
249+ def _load_dataset (params : dict [str , Any ]):
250+ country = params .get ("country" , "us" ).lower ()
251+ year = _parse_year (params )
252+ country_module = _country_module (country )
253+ dataset_name = _resolve_dataset_name (
254+ country , params .get ("data" ), params .get ("subsample" )
255+ )
256+ datasets = country_module .ensure_datasets (
257+ datasets = [dataset_name ],
258+ years = [year ],
259+ data_folder = os .environ .get (
260+ "POLICYENGINE_DATA_FOLDER" , "/tmp/policyengine-data"
261+ ),
262+ )
263+ dataset = next (iter (datasets .values ()))
264+ if country == "us" :
265+ return _subsample_us_dataset (dataset , params .get ("subsample" ))
266+ return dataset
267+
268+
269+ def _build_simulation (params : dict [str , Any ], policy : dict [str , Any ] | None ):
270+ from policyengine .core import Simulation
271+
272+ country = params .get ("country" , "us" ).lower ()
273+ country_module = _country_module (country )
274+ dataset = _load_dataset (params )
275+ return Simulation (
276+ dataset = dataset ,
277+ tax_benefit_model_version = country_module .model ,
278+ policy = policy ,
279+ )
280+
281+
282+ def _change_sum (baseline , reform , variable : str , entity : str | None = None ) -> float :
283+ from policyengine .outputs import ChangeAggregate , ChangeAggregateType
284+
285+ output = ChangeAggregate (
286+ baseline_simulation = baseline ,
287+ reform_simulation = reform ,
288+ variable = variable ,
289+ entity = entity ,
290+ aggregate_type = ChangeAggregateType .SUM ,
291+ )
292+ output .run ()
293+ return float (output .result )
294+
295+
296+ def _try_change_sum (
297+ baseline , reform , variable : str , entity : str | None = None
298+ ) -> float :
299+ try :
300+ return _change_sum (baseline , reform , variable , entity )
301+ except Exception :
302+ logger .warning ("Unable to calculate change for %s" , variable , exc_info = True )
303+ return 0.0
304+
305+
306+ def _budget_result (country : str , baseline , reform ) -> dict [str , float ]:
307+ tax_revenue_impact = _try_change_sum (
308+ baseline , reform , "household_tax" , entity = "household"
309+ )
310+ benefit_spending_impact = _try_change_sum (
311+ baseline , reform , "household_benefits" , entity = "household"
312+ )
313+ budgetary_impact = tax_revenue_impact - benefit_spending_impact
314+ result = {
315+ "tax_revenue_impact" : tax_revenue_impact ,
316+ "benefit_spending_impact" : benefit_spending_impact ,
317+ "budgetary_impact" : budgetary_impact ,
318+ }
319+ if country == "us" :
320+ result ["state_tax_revenue_impact" ] = _try_change_sum (
321+ baseline ,
322+ reform ,
323+ "household_state_income_tax" ,
324+ entity = "tax_unit" ,
325+ )
326+ return result
327+
328+
329+ def _poverty_result (country : str , baseline , reform ) -> dict [str , list [dict [str , Any ]]]:
330+ import policyengine as pe
331+
332+ if country == "us" :
333+ baseline_poverty = pe .us .economic_impact_analysis (
334+ baseline , reform
335+ ).baseline_poverty
336+ reform_poverty = pe .us .economic_impact_analysis (baseline , reform ).reform_poverty
337+ else :
338+ baseline_poverty = pe .uk .economic_impact_analysis (
339+ baseline , reform
340+ ).baseline_poverty
341+ reform_poverty = pe .uk .economic_impact_analysis (baseline , reform ).reform_poverty
342+
343+ return {
344+ "baseline" : baseline_poverty .dataframe .to_dict ("records" ),
345+ "reform" : reform_poverty .dataframe .to_dict ("records" ),
346+ }
347+
348+
349+ def _analysis_result (country : str , baseline , reform ) -> dict [str , Any ]:
350+ import policyengine as pe
351+
352+ if country == "us" :
353+ analysis = pe .us .economic_impact_analysis (baseline , reform )
354+ else :
355+ analysis = pe .uk .economic_impact_analysis (baseline , reform )
356+
357+ return {
358+ "decile_impacts" : analysis .decile_impacts .dataframe .to_dict ("records" ),
359+ "program_statistics" : analysis .program_statistics .dataframe .to_dict ("records" ),
360+ "poverty" : {
361+ "baseline" : analysis .baseline_poverty .dataframe .to_dict ("records" ),
362+ "reform" : analysis .reform_poverty .dataframe .to_dict ("records" ),
363+ },
364+ "inequality" : {
365+ "baseline" : _inequality_summary (analysis .baseline_inequality ),
366+ "reform" : _inequality_summary (analysis .reform_inequality ),
367+ },
368+ }
369+
370+
371+ def _inequality_summary (inequality ) -> dict [str , Any ]:
372+ return {
373+ "income_variable" : inequality .income_variable ,
374+ "entity" : inequality .entity ,
375+ "gini" : inequality .gini ,
376+ "top_10_share" : inequality .top_10_share ,
377+ "top_1_share" : inequality .top_1_share ,
378+ "bottom_50_share" : inequality .bottom_50_share ,
379+ }
380+
381+
118382def _run_simulation_impl_core (params : dict ) -> dict :
119383 simulation_params , telemetry , metadata = split_internal_payload (params )
120384
@@ -127,17 +391,21 @@ def _run_simulation_impl_core(params: dict) -> dict:
127391 if metadata :
128392 logger .info ("Received simulation metadata keys: %s" , sorted (metadata ))
129393
130- # Validate and create simulation options
131- options = SimulationOptions . model_validate (simulation_params )
132- logger . info ( "Initialising simulation from input" )
394+ country = simulation_params . get ( "country" , "us" ). lower ()
395+ baseline_policy = _normalise_reform (simulation_params . get ( "baseline" ) )
396+ reform_policy = _normalise_reform ( simulation_params . get ( "reform" ) )
133397
134- # Create simulation instance
135- simulation = Simulation ( ** options . model_dump () )
136- logger . info ( "Calculating comparison" )
398+ logger . info ( "Initialising baseline and reform simulations" )
399+ baseline = _build_simulation ( simulation_params , baseline_policy )
400+ reform = _build_simulation ( simulation_params , reform_policy )
137401
138- # Run the economy comparison calculation
139- result = simulation .calculate_economy_comparison ()
402+ logger .info ("Calculating economic impact" )
403+ analysis = _analysis_result (country , baseline , reform )
404+ analysis ["budget" ] = _budget_result (country , baseline , reform )
405+ analysis ["metadata" ] = {
406+ "country" : country ,
407+ "year" : _parse_year (simulation_params ),
408+ "dataset" : getattr (baseline .dataset , "filepath" , None ),
409+ }
140410 logger .info ("Comparison complete" )
141-
142- # Use mode='json' to ensure numpy arrays are converted to lists
143- return result .model_dump (mode = "json" )
411+ return analysis
0 commit comments