1010import logging
1111import os
1212import tempfile
13- import importlib
14- from typing import Any , Iterator
13+ from typing import Iterator
1514
16- # policyengine.core is imported for every simulation. Without this guard,
17- # importing the package pulls both country modules into the process; a US run
18- # can then fail before it starts if UK private-data credentials are absent.
19- os .environ .setdefault ("POLICYENGINE_SKIP_COUNTRY_IMPORTS" , "1" )
15+ # Module-level imports - these are SNAPSHOTTED at image build time
16+ from policyengine .simulation import Simulation , SimulationOptions
2017
21- try :
22- from src .modal .telemetry import split_internal_payload
23- except ModuleNotFoundError :
24- from modal .telemetry import split_internal_payload
18+ from src .modal .telemetry import split_internal_payload
2519
2620logger = logging .getLogger (__name__ )
2721
2822
29- DEFAULT_YEAR = 2026
30- DATASET_ALIASES = {
31- "us" : {
32- "enhanced_cps" : "enhanced_cps_2024" ,
33- "enhanced_cps_2024" : "enhanced_cps_2024" ,
34- "gs://policyengine-us-data/enhanced_cps_2024.h5" : "enhanced_cps_2024" ,
35- "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5" : "enhanced_cps_2024" ,
36- "cps_small" : "cps_small_2024" ,
37- "cps_small_2024" : "cps_small_2024" ,
38- },
39- "uk" : {
40- "enhanced_frs" : "enhanced_frs_2023_24" ,
41- "enhanced_frs_2023_24" : "enhanced_frs_2023_24" ,
42- "frs" : "frs_2023_24" ,
43- "frs_2023_24" : "frs_2023_24" ,
44- },
45- }
46-
47-
4823def _normalize_credentials_blob (creds_json : str ) -> str :
4924 """Return the raw JSON blob, decoding the outer escape if present.
5025
@@ -140,237 +115,6 @@ def run_simulation_impl(params: dict) -> dict:
140115 return _run_simulation_impl_core (params )
141116
142117
143- def _parse_year (params : dict [str , Any ]) -> int :
144- value = params .get ("time_period" ) or params .get ("year" ) or DEFAULT_YEAR
145- return int (value )
146-
147-
148- def _normalise_period_key (period_key : Any ) -> str :
149- """Convert legacy ``start.stop`` period keys to v4 effective dates."""
150- text = str (period_key )
151- parts = text .split ("." )
152- if len (parts ) > 1 and len (parts [0 ]) == 10 :
153- return parts [0 ]
154- return text
155-
156-
157- def _normalise_reform (reform : dict [str , Any ] | None ) -> dict [str , Any ] | None :
158- if not reform :
159- return None
160- normalised : dict [str , Any ] = {}
161- for parameter , value in reform .items ():
162- if isinstance (value , dict ):
163- normalised [parameter ] = {
164- _normalise_period_key (period ): period_value
165- for period , period_value in value .items ()
166- }
167- else :
168- normalised [parameter ] = value
169- return normalised
170-
171-
172- def _resolve_dataset_name (
173- country : str , requested_data : str | None , subsample : int | None
174- ) -> str :
175- if requested_data is None :
176- return "enhanced_cps_2024" if country == "us" else "enhanced_frs_2023_24"
177-
178- requested = requested_data .split ("@" , maxsplit = 1 )[0 ]
179- return DATASET_ALIASES .get (country , {}).get (requested , requested_data )
180-
181-
182- def _microframe_like (frame , weights : str ):
183- from microdf import MicroDataFrame
184-
185- return MicroDataFrame (frame .copy (), weights = weights )
186-
187-
188- def _person_group_column (person , entity : str ) -> str :
189- prefixed = f"person_{ entity } _id"
190- if prefixed in person .columns :
191- return prefixed
192- return f"{ entity } _id"
193-
194-
195- def _subsample_us_dataset (dataset , subsample : int | None ):
196- if not subsample :
197- return dataset
198-
199- from policyengine .tax_benefit_models .us .datasets import (
200- PolicyEngineUSDataset ,
201- USYearData ,
202- )
203-
204- dataset .load ()
205- data = dataset .data
206- household = data .household .head (int (subsample )).copy ()
207- household_ids = set (household ["household_id" ])
208-
209- person_household_col = _person_group_column (data .person , "household" )
210- person = data .person [data .person [person_household_col ].isin (household_ids )].copy ()
211-
212- def group_subset (entity : str ):
213- person_col = _person_group_column (person , entity )
214- entity_id_col = f"{ entity } _id"
215- ids = set (person [person_col ])
216- frame = getattr (data , entity )
217- return frame [frame [entity_id_col ].isin (ids )].copy ()
218-
219- subset_data = USYearData (
220- person = _microframe_like (person , "person_weight" ),
221- marital_unit = _microframe_like (
222- group_subset ("marital_unit" ), "marital_unit_weight"
223- ),
224- family = _microframe_like (group_subset ("family" ), "family_weight" ),
225- spm_unit = _microframe_like (group_subset ("spm_unit" ), "spm_unit_weight" ),
226- tax_unit = _microframe_like (group_subset ("tax_unit" ), "tax_unit_weight" ),
227- household = _microframe_like (household , "household_weight" ),
228- )
229- subset_path = os .path .join (
230- os .environ .get ("POLICYENGINE_DATA_FOLDER" , "/tmp/policyengine-data" ),
231- f"{ dataset .id } _subsample_{ subsample } .h5" ,
232- )
233- return PolicyEngineUSDataset (
234- id = f"{ dataset .id } _subsample_{ subsample } " ,
235- name = f"{ dataset .name } subsample { subsample } " ,
236- description = dataset .description ,
237- filepath = subset_path ,
238- year = dataset .year ,
239- is_output_dataset = dataset .is_output_dataset ,
240- data = subset_data ,
241- )
242-
243-
244- def _country_module (country : str ):
245- country = country .lower ()
246- if country == "us" :
247- return importlib .import_module ("policyengine.tax_benefit_models.us" )
248- if country == "uk" :
249- return importlib .import_module ("policyengine.tax_benefit_models.uk" )
250- raise ValueError (f"Unsupported country: { country } " )
251-
252-
253- def _load_dataset (params : dict [str , Any ]):
254- country = params .get ("country" , "us" ).lower ()
255- year = _parse_year (params )
256- country_module = _country_module (country )
257- dataset_name = _resolve_dataset_name (
258- country , params .get ("data" ), params .get ("subsample" )
259- )
260- datasets = country_module .ensure_datasets (
261- datasets = [dataset_name ],
262- years = [year ],
263- data_folder = os .environ .get (
264- "POLICYENGINE_DATA_FOLDER" , "/tmp/policyengine-data"
265- ),
266- )
267- dataset = next (iter (datasets .values ()))
268- if country == "us" :
269- return _subsample_us_dataset (dataset , params .get ("subsample" ))
270- return dataset
271-
272-
273- def _build_simulation (params : dict [str , Any ], policy : dict [str , Any ] | None ):
274- from policyengine .core import Simulation
275-
276- country = params .get ("country" , "us" ).lower ()
277- country_module = _country_module (country )
278- dataset = _load_dataset (params )
279- return Simulation (
280- dataset = dataset ,
281- tax_benefit_model_version = country_module .model ,
282- policy = policy ,
283- )
284-
285-
286- def _change_sum (baseline , reform , variable : str , entity : str | None = None ) -> float :
287- from policyengine .outputs import ChangeAggregate , ChangeAggregateType
288-
289- output = ChangeAggregate (
290- baseline_simulation = baseline ,
291- reform_simulation = reform ,
292- variable = variable ,
293- entity = entity ,
294- aggregate_type = ChangeAggregateType .SUM ,
295- )
296- output .run ()
297- return float (output .result )
298-
299-
300- def _try_change_sum (
301- baseline , reform , variable : str , entity : str | None = None
302- ) -> float :
303- try :
304- return _change_sum (baseline , reform , variable , entity )
305- except Exception :
306- logger .warning ("Unable to calculate change for %s" , variable , exc_info = True )
307- return 0.0
308-
309-
310- def _budget_result (country : str , baseline , reform ) -> dict [str , float ]:
311- tax_revenue_impact = _try_change_sum (
312- baseline , reform , "household_tax" , entity = "household"
313- )
314- benefit_spending_impact = _try_change_sum (
315- baseline , reform , "household_benefits" , entity = "household"
316- )
317- budgetary_impact = tax_revenue_impact - benefit_spending_impact
318- result = {
319- "tax_revenue_impact" : tax_revenue_impact ,
320- "benefit_spending_impact" : benefit_spending_impact ,
321- "budgetary_impact" : budgetary_impact ,
322- }
323- if country == "us" :
324- result ["state_tax_revenue_impact" ] = _try_change_sum (
325- baseline ,
326- reform ,
327- "household_state_income_tax" ,
328- entity = "tax_unit" ,
329- )
330- return result
331-
332-
333- def _poverty_result (country : str , baseline , reform ) -> dict [str , list [dict [str , Any ]]]:
334- country_module = _country_module (country )
335- impact = country_module .economic_impact_analysis (baseline , reform )
336- baseline_poverty = impact .baseline_poverty
337- reform_poverty = impact .reform_poverty
338-
339- return {
340- "baseline" : baseline_poverty .dataframe .to_dict ("records" ),
341- "reform" : reform_poverty .dataframe .to_dict ("records" ),
342- }
343-
344-
345- def _analysis_result (country : str , baseline , reform ) -> dict [str , Any ]:
346- country_module = _country_module (country )
347- analysis = country_module .economic_impact_analysis (baseline , reform )
348-
349- return {
350- "decile_impacts" : analysis .decile_impacts .dataframe .to_dict ("records" ),
351- "program_statistics" : analysis .program_statistics .dataframe .to_dict ("records" ),
352- "poverty" : {
353- "baseline" : analysis .baseline_poverty .dataframe .to_dict ("records" ),
354- "reform" : analysis .reform_poverty .dataframe .to_dict ("records" ),
355- },
356- "inequality" : {
357- "baseline" : _inequality_summary (analysis .baseline_inequality ),
358- "reform" : _inequality_summary (analysis .reform_inequality ),
359- },
360- }
361-
362-
363- def _inequality_summary (inequality ) -> dict [str , Any ]:
364- return {
365- "income_variable" : inequality .income_variable ,
366- "entity" : inequality .entity ,
367- "gini" : inequality .gini ,
368- "top_10_share" : inequality .top_10_share ,
369- "top_1_share" : inequality .top_1_share ,
370- "bottom_50_share" : inequality .bottom_50_share ,
371- }
372-
373-
374118def _run_simulation_impl_core (params : dict ) -> dict :
375119 simulation_params , telemetry , metadata = split_internal_payload (params )
376120
@@ -383,21 +127,17 @@ def _run_simulation_impl_core(params: dict) -> dict:
383127 if metadata :
384128 logger .info ("Received simulation metadata keys: %s" , sorted (metadata ))
385129
386- country = simulation_params . get ( "country" , "us" ). lower ()
387- baseline_policy = _normalise_reform ( simulation_params . get ( "baseline" ) )
388- reform_policy = _normalise_reform ( simulation_params . get ( "reform" ) )
130+ # Validate and create simulation options
131+ options = SimulationOptions . model_validate ( simulation_params )
132+ logger . info ( "Initialising simulation from input" )
389133
390- logger . info ( "Initialising baseline and reform simulations" )
391- baseline = _build_simulation ( simulation_params , baseline_policy )
392- reform = _build_simulation ( simulation_params , reform_policy )
134+ # Create simulation instance
135+ simulation = Simulation ( ** options . model_dump () )
136+ logger . info ( "Calculating comparison" )
393137
394- logger .info ("Calculating economic impact" )
395- analysis = _analysis_result (country , baseline , reform )
396- analysis ["budget" ] = _budget_result (country , baseline , reform )
397- analysis ["metadata" ] = {
398- "country" : country ,
399- "year" : _parse_year (simulation_params ),
400- "dataset" : getattr (baseline .dataset , "filepath" , None ),
401- }
138+ # Run the economy comparison calculation
139+ result = simulation .calculate_economy_comparison ()
402140 logger .info ("Comparison complete" )
403- return analysis
141+
142+ # Use mode='json' to ensure numpy arrays are converted to lists
143+ return result .model_dump (mode = "json" )
0 commit comments