Skip to content

Commit 89ce2c8

Browse files
committed
update test_same_test so it compares to original values instead of fresh calculations
1 parent 6534151 commit 89ce2c8

1 file changed

Lines changed: 25 additions & 30 deletions

File tree

policyengine_us_data/tests/test_local_area_calibration/test_same_state.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,26 @@
1-
"""Test same-state values match fresh simulations."""
1+
"""Test same-state values match original simulation values."""
22

33
import pytest
44
import numpy as np
55
from collections import defaultdict
66

7-
from policyengine_us import Microsimulation
8-
from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
9-
get_calculated_variables,
10-
)
11-
127
from .conftest import VARIABLES_TO_TEST, N_VERIFICATION_SAMPLES
138

149

1510
def test_same_state_matches_original(
11+
sim,
1612
X_sparse,
1713
targets_df,
1814
test_cds,
19-
dataset_path,
2015
n_households,
2116
household_ids,
2217
household_states,
2318
):
2419
"""
25-
Same-state non-zero cells must match fresh same-state simulation.
20+
Same-state non-zero cells must match ORIGINAL simulation values.
2621
2722
When household stays in same state, X_sparse should contain the value
28-
calculated from a fresh simulation with state_fips set to that state.
23+
from the original simulation (ground truth from H5 dataset).
2924
3025
Uses stratified sampling to ensure all variables in VARIABLES_TO_TEST
3126
are covered with approximately equal samples per variable.
@@ -36,19 +31,6 @@ def test_same_state_matches_original(
3631
hh_ids = household_ids
3732
hh_states = household_states
3833

39-
state_sims = {}
40-
41-
def get_state_sim(state):
42-
if state not in state_sims:
43-
s = Microsimulation(dataset=dataset_path)
44-
s.set_input(
45-
"state_fips", 2023, np.full(n_hh, state, dtype=np.int32)
46-
)
47-
for var in get_calculated_variables(s):
48-
s.delete_arrays(var)
49-
state_sims[state] = s
50-
return state_sims[state]
51-
5234
nonzero_rows, nonzero_cols = X_sparse.nonzero()
5335

5436
# Group same-state cells by variable for stratified sampling
@@ -68,7 +50,6 @@ def get_state_sim(state):
6850
if dest_state != orig_state:
6951
continue
7052

71-
# Get variable for this row
7253
variable = targets_df.iloc[row_idx]["variable"]
7354
if variable in variables_to_test:
7455
variable_to_indices[variable].append(i)
@@ -87,6 +68,16 @@ def get_state_sim(state):
8768
sampled = rng.choice(indices, n_to_sample, replace=False)
8869
sample_indices.extend(sampled)
8970

71+
# Cache original values per variable to avoid repeated calculations
72+
original_values_cache = {}
73+
74+
def get_original_values(variable):
75+
if variable not in original_values_cache:
76+
original_values_cache[variable] = sim.calculate(
77+
variable, map_to="household"
78+
).values
79+
return original_values_cache[variable]
80+
9081
errors = []
9182
variables_tested = set()
9283

@@ -95,28 +86,32 @@ def get_state_sim(state):
9586
col_idx = nonzero_cols[idx]
9687
cd_idx = col_idx // n_hh
9788
hh_idx = col_idx % n_hh
98-
cd = test_cds[cd_idx]
99-
dest_state = int(cd) // 100
10089
variable = targets_df.iloc[row_idx]["variable"]
10190
actual = float(X_sparse[row_idx, col_idx])
102-
state_sim = get_state_sim(dest_state)
103-
expected = float(
104-
state_sim.calculate(variable, map_to="household").values[hh_idx]
105-
)
91+
92+
# Compare to ORIGINAL simulation values (ground truth)
93+
original_values = get_original_values(variable)
94+
expected = float(original_values[hh_idx])
10695

10796
variables_tested.add(variable)
10897

10998
if not np.isclose(actual, expected, atol=0.5):
11099
errors.append(
111100
{
112101
"hh_id": hh_ids[hh_idx],
102+
"hh_idx": hh_idx,
113103
"variable": variable,
114104
"actual": actual,
115105
"expected": expected,
106+
"diff": actual - expected,
107+
"rel_diff": (
108+
(actual - expected) / expected
109+
if expected != 0
110+
else np.inf
111+
),
116112
}
117113
)
118114

119-
# Report which variables were tested
120115
missing_vars = variables_to_test - variables_tested
121116
if missing_vars:
122117
print(f"Warning: No same-state cells found for: {missing_vars}")

0 commit comments

Comments
 (0)