1- """Test same-state values match fresh simulations ."""
1+ """Test same-state values match original simulation values ."""
22
33import pytest
44import numpy as np
55from collections import defaultdict
66
7- from policyengine_us import Microsimulation
8- from policyengine_us_data .datasets .cps .local_area_calibration .calibration_utils import (
9- get_calculated_variables ,
10- )
11-
127from .conftest import VARIABLES_TO_TEST , N_VERIFICATION_SAMPLES
138
149
1510def test_same_state_matches_original (
11+ sim ,
1612 X_sparse ,
1713 targets_df ,
1814 test_cds ,
19- dataset_path ,
2015 n_households ,
2116 household_ids ,
2217 household_states ,
2318):
2419 """
25- Same-state non-zero cells must match fresh same-state simulation.
20+ Same-state non-zero cells must match ORIGINAL simulation values .
2621
2722 When household stays in same state, X_sparse should contain the value
28- calculated from a fresh simulation with state_fips set to that state .
23+ from the original simulation (ground truth from H5 dataset) .
2924
3025 Uses stratified sampling to ensure all variables in VARIABLES_TO_TEST
3126 are covered with approximately equal samples per variable.
@@ -36,19 +31,6 @@ def test_same_state_matches_original(
3631 hh_ids = household_ids
3732 hh_states = household_states
3833
39- state_sims = {}
40-
41- def get_state_sim (state ):
42- if state not in state_sims :
43- s = Microsimulation (dataset = dataset_path )
44- s .set_input (
45- "state_fips" , 2023 , np .full (n_hh , state , dtype = np .int32 )
46- )
47- for var in get_calculated_variables (s ):
48- s .delete_arrays (var )
49- state_sims [state ] = s
50- return state_sims [state ]
51-
5234 nonzero_rows , nonzero_cols = X_sparse .nonzero ()
5335
5436 # Group same-state cells by variable for stratified sampling
@@ -68,7 +50,6 @@ def get_state_sim(state):
6850 if dest_state != orig_state :
6951 continue
7052
71- # Get variable for this row
7253 variable = targets_df .iloc [row_idx ]["variable" ]
7354 if variable in variables_to_test :
7455 variable_to_indices [variable ].append (i )
@@ -87,6 +68,16 @@ def get_state_sim(state):
8768 sampled = rng .choice (indices , n_to_sample , replace = False )
8869 sample_indices .extend (sampled )
8970
71+ # Cache original values per variable to avoid repeated calculations
72+ original_values_cache = {}
73+
74+ def get_original_values (variable ):
75+ if variable not in original_values_cache :
76+ original_values_cache [variable ] = sim .calculate (
77+ variable , map_to = "household"
78+ ).values
79+ return original_values_cache [variable ]
80+
9081 errors = []
9182 variables_tested = set ()
9283
@@ -95,28 +86,32 @@ def get_state_sim(state):
9586 col_idx = nonzero_cols [idx ]
9687 cd_idx = col_idx // n_hh
9788 hh_idx = col_idx % n_hh
98- cd = test_cds [cd_idx ]
99- dest_state = int (cd ) // 100
10089 variable = targets_df .iloc [row_idx ]["variable" ]
10190 actual = float (X_sparse [row_idx , col_idx ])
102- state_sim = get_state_sim ( dest_state )
103- expected = float (
104- state_sim . calculate (variable , map_to = "household" ). values [ hh_idx ]
105- )
91+
92+ # Compare to ORIGINAL simulation values (ground truth)
93+ original_values = get_original_values (variable )
94+ expected = float ( original_values [ hh_idx ] )
10695
10796 variables_tested .add (variable )
10897
10998 if not np .isclose (actual , expected , atol = 0.5 ):
11099 errors .append (
111100 {
112101 "hh_id" : hh_ids [hh_idx ],
102+ "hh_idx" : hh_idx ,
113103 "variable" : variable ,
114104 "actual" : actual ,
115105 "expected" : expected ,
106+ "diff" : actual - expected ,
107+ "rel_diff" : (
108+ (actual - expected ) / expected
109+ if expected != 0
110+ else np .inf
111+ ),
116112 }
117113 )
118114
119- # Report which variables were tested
120115 missing_vars = variables_to_test - variables_tested
121116 if missing_vars :
122117 print (f"Warning: No same-state cells found for: { missing_vars } " )
0 commit comments