Skip to content

Commit 6534151

Browse files
committed
remove redundant tests and fixtures
1 parent fe70932 commit 6534151

4 files changed

Lines changed: 172 additions & 446 deletions

File tree

policyengine_us_data/tests/test_local_area_calibration/conftest.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
"""Shared fixtures for local area calibration tests."""
1+
"""Shared fixtures for local area calibration tests.
2+
3+
Importantly, this file determines which variables will be included in the sparse matrix and calibrating routine.
4+
"""
25

36
import pytest
47
import numpy as np
@@ -16,6 +19,56 @@
1619
get_calculated_variables,
1720
)
1821

22+
# Variables to test for state-level value matching
23+
# Format: (variable_name, rtol)
24+
# variable_name as per the targets in policy_data.db
25+
# rtol is relative tolerance for comparison
26+
VARIABLES_TO_TEST = [
27+
("snap", 1e-2),
28+
("health_insurance_premiums_without_medicare_part_b", 1e-2),
29+
("medicaid", 1e-2),
30+
("medicare_part_b_premiums", 1e-2),
31+
("other_medical_expenses", 1e-2),
32+
("over_the_counter_health_expenses", 1e-2),
33+
("salt_deduction", 1e-2),
34+
("spm_unit_capped_work_childcare_expenses", 1e-2),
35+
("spm_unit_capped_housing_subsidy", 1e-2),
36+
("ssi", 1e-2),
37+
("tanf", 1e-2),
38+
("tip_income", 1e-2),
39+
("unemployment_compensation", 1e-2),
40+
]
41+
42+
# Combined filter config to build matrix with all variables at once
43+
COMBINED_FILTER_CONFIG = {
44+
"stratum_group_ids": [
45+
4, # SNAP targets
46+
5, # Medicaid targets
47+
112, # Unemployment compensation targets
48+
],
49+
"variables": [
50+
"snap",
51+
"health_insurance_premiums_without_medicare_part_b",
52+
"medicaid",
53+
"medicare_part_b_premiums",
54+
"other_medical_expenses",
55+
"over_the_counter_health_expenses",
56+
"salt_deduction",
57+
"spm_unit_capped_work_childcare_expenses",
58+
"spm_unit_capped_housing_subsidy",
59+
"ssi",
60+
"tanf",
61+
"tip_income",
62+
"unemployment_compensation",
63+
],
64+
}
65+
66+
# Maximum allowed mismatch rate for state-level value comparison
67+
MAX_MISMATCH_RATE = 0.02
68+
69+
# Number of samples for cell-level verification tests
70+
N_VERIFICATION_SAMPLES = 200
71+
1972

2073
@pytest.fixture(scope="module")
2174
def db_uri():
@@ -30,7 +83,7 @@ def dataset_path():
3083

3184
@pytest.fixture(scope="module")
3285
def test_cds(db_uri):
33-
"""CDs from NC, HI, MT, AK (manageable size, multiple same-state CDs)."""
86+
"""CDs from multiple states for comprehensive testing."""
3487
engine = create_engine(db_uri)
3588
query = """
3689
SELECT DISTINCT sc.value as cd_geoid
@@ -43,6 +96,10 @@ def test_cds(db_uri):
4396
OR sc.value LIKE '150_'
4497
OR sc.value LIKE '300_'
4598
OR sc.value = '200' OR sc.value = '201'
99+
OR sc.value IN ('101', '102')
100+
OR sc.value IN ('601', '602')
101+
OR sc.value IN ('3601', '3602')
102+
OR sc.value IN ('4801', '4802')
46103
)
47104
ORDER BY sc.value
48105
"""
@@ -58,15 +115,15 @@ def sim(dataset_path):
58115

59116
@pytest.fixture(scope="module")
60117
def matrix_data(db_uri, dataset_path, test_cds, sim):
61-
"""Build sparse matrix, return (targets_df, X_sparse, household_id_mapping)."""
118+
"""Build sparse matrix with all configured variables."""
62119
builder = SparseMatrixBuilder(
63120
db_uri,
64121
time_period=2023,
65122
cds_to_calibrate=test_cds,
66123
dataset_path=dataset_path,
67124
)
68125
targets_df, X_sparse, household_id_mapping = builder.build_matrix(
69-
sim, target_filter={"stratum_group_ids": [4], "variables": ["snap"]}
126+
sim, target_filter=COMBINED_FILTER_CONFIG
70127
)
71128
return targets_df, X_sparse, household_id_mapping
72129

policyengine_us_data/tests/test_local_area_calibration/test_cross_state.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
import pytest
44
import numpy as np
5+
from collections import defaultdict
56

67
from policyengine_us import Microsimulation
78
from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import (
89
get_calculated_variables,
910
)
1011

12+
from .conftest import VARIABLES_TO_TEST, N_VERIFICATION_SAMPLES
13+
1114

1215
def test_cross_state_matches_swapped_sim(
1316
X_sparse,
1417
targets_df,
15-
tracer,
1618
test_cds,
1719
dataset_path,
1820
n_households,
@@ -25,8 +27,10 @@ def test_cross_state_matches_swapped_sim(
2527
When household moves to different state, X_sparse should contain the
2628
value calculated from a fresh simulation with state_fips set to
2729
destination state.
30+
31+
Uses stratified sampling to ensure all variables in VARIABLES_TO_TEST
32+
are covered with approximately equal samples per variable.
2833
"""
29-
n_samples = 200
3034
seed = 42
3135
rng = np.random.default_rng(seed)
3236
n_hh = n_households
@@ -48,28 +52,46 @@ def get_state_sim(state):
4852

4953
nonzero_rows, nonzero_cols = X_sparse.nonzero()
5054

51-
cross_state_indices = []
55+
# Group cross-state cells by variable for stratified sampling
56+
variable_to_indices = defaultdict(list)
57+
variables_to_test = {v[0] for v in VARIABLES_TO_TEST}
58+
5259
for i in range(len(nonzero_rows)):
60+
row_idx = nonzero_rows[i]
5361
col_idx = nonzero_cols[i]
5462
cd_idx = col_idx // n_hh
5563
hh_idx = col_idx % n_hh
5664
cd = test_cds[cd_idx]
5765
dest_state = int(cd) // 100
5866
orig_state = int(hh_states[hh_idx])
59-
if dest_state != orig_state:
60-
cross_state_indices.append(i)
6167

62-
if not cross_state_indices:
63-
pytest.skip("No cross-state non-zero cells found")
68+
# Only include cross-state cells
69+
if dest_state == orig_state:
70+
continue
71+
72+
# Get variable for this row
73+
variable = targets_df.iloc[row_idx]["variable"]
74+
if variable in variables_to_test:
75+
variable_to_indices[variable].append(i)
76+
77+
if not variable_to_indices:
78+
pytest.skip("No cross-state non-zero cells found for test variables")
6479

65-
sample_idx = rng.choice(
66-
cross_state_indices,
67-
min(n_samples, len(cross_state_indices)),
68-
replace=False,
80+
# Stratified sampling: sample proportionally from each variable
81+
samples_per_var = max(
82+
1, N_VERIFICATION_SAMPLES // len(variable_to_indices)
6983
)
84+
sample_indices = []
85+
86+
for variable, indices in variable_to_indices.items():
87+
n_to_sample = min(samples_per_var, len(indices))
88+
sampled = rng.choice(indices, n_to_sample, replace=False)
89+
sample_indices.extend(sampled)
90+
7091
errors = []
92+
variables_tested = set()
7193

72-
for idx in sample_idx:
94+
for idx in sample_indices:
7395
row_idx = nonzero_rows[idx]
7496
col_idx = nonzero_cols[idx]
7597
cd_idx = col_idx // n_hh
@@ -83,6 +105,8 @@ def get_state_sim(state):
83105
state_sim.calculate(variable, map_to="household").values[hh_idx]
84106
)
85107

108+
variables_tested.add(variable)
109+
86110
if not np.isclose(actual, expected, atol=0.5):
87111
errors.append(
88112
{
@@ -95,7 +119,13 @@ def get_state_sim(state):
95119
}
96120
)
97121

122+
# Report which variables were tested
123+
missing_vars = variables_to_test - variables_tested
124+
if missing_vars:
125+
print(f"Warning: No cross-state cells found for: {missing_vars}")
126+
98127
assert not errors, (
99-
f"Cross-state verification failed: {len(errors)}/{len(sample_idx)} "
100-
f"mismatches. First 5: {errors[:5]}"
128+
f"Cross-state verification failed: {len(errors)}/{len(sample_indices)} "
129+
f"mismatches across {len(variables_tested)} variables. "
130+
f"First 5: {errors[:5]}"
101131
)

0 commit comments

Comments
 (0)