Skip to content

Commit 5a402f5

Browse files
authored
Merge pull request #475 from PolicyEngine/maria/sparse_matrix_tests
Adding test (and variables) for sparse matrix building logic
2 parents 854fa08 + b5b1f1d commit 5a402f5

23 files changed

Lines changed: 808 additions & 117 deletions

changelog_entry.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
- bump: patch
1+
- bump: minor
22
changes:
3-
fixed:
4-
- Versioning workflow checkout for push events
3+
added:
4+
- tests to verify SparseMatrixBuilder correctly calculates variables and constraints into the calibration matrix.

policyengine_us_data/datasets/cps/cps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from microimpute.models.qrf import QRF
1616
import logging
1717

18-
1918
test_lite = os.environ.get("TEST_LITE") == "true"
2019
print(f"TEST_LITE == {test_lite}")
2120

policyengine_us_data/datasets/cps/enhanced_cps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from pathlib import Path
2323
import logging
2424

25-
2625
try:
2726
import torch
2827
except ImportError:

policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
StateCode,
1818
)
1919

20-
2120
# State/Geographic Mappings
2221
STATE_CODES = {
2322
1: "AL",
@@ -193,15 +192,21 @@ def get_calculated_variables(sim) -> List[str]:
193192
"""
194193
Return variables that should be cleared for state-swap recalculation.
195194
196-
Includes variables with formulas, adds, or subtracts.
197-
198-
Excludes ID variables (person_id, household_id, etc.) because:
199-
1. They have formulas that generate sequential IDs (0, 1, 2, ...)
200-
2. We need the original H5 values, not regenerated sequences
201-
3. PolicyEngine's random() function uses entity IDs as seeds:
202-
seed = abs(entity_id * 100 + count_random_calls)
203-
If IDs change, random-dependent variables (SSI resource test,
204-
WIC nutritional risk, WIC takeup) produce different results.
195+
Includes variables with formulas, or adds/subtracts that are lists.
196+
197+
Excludes:
198+
1. ID variables (person_id, household_id, etc.) - needed for random seeds
199+
2. Variables with string adds/subtracts (parameter paths) - these are
200+
pseudo-inputs stored in H5 that would recalculate differently using
201+
parameter lookups. Examples: pre_tax_contributions.
202+
3. Variables in input_variables (have stored H5 values) even if they
203+
have formulas - the stored values represent original survey data
204+
that should be preserved. Examples: cdcc_relevant_expenses, rent.
205+
206+
The exclusions are critical because:
207+
- The H5 file stores pre-computed values from original CPS processing
208+
- If deleted, recalculation produces different values, corrupting
209+
downstream calculations like income_tax
205210
"""
206211
exclude_ids = {
207212
"person_id",
@@ -211,16 +216,36 @@ def get_calculated_variables(sim) -> List[str]:
211216
"family_id",
212217
"marital_unit_id",
213218
}
214-
return [
215-
name
216-
for name, var in sim.tax_benefit_system.variables.items()
217-
if (
218-
var.formulas
219-
or getattr(var, "adds", None)
220-
or getattr(var, "subtracts", None)
221-
)
222-
and name not in exclude_ids
223-
]
219+
220+
# Get stored input variables to exclude
221+
input_vars = set(sim.input_variables)
222+
223+
result = []
224+
for name, var in sim.tax_benefit_system.variables.items():
225+
if name in exclude_ids:
226+
continue
227+
228+
# Exclude variables that have stored values (input_variables)
229+
# These represent original survey data that should be preserved
230+
if name in input_vars:
231+
continue
232+
233+
# Include if has formulas
234+
if var.formulas:
235+
result.append(name)
236+
continue
237+
238+
# Include if adds/subtracts is a list (explicit component aggregation)
239+
# Exclude if adds/subtracts is a string (parameter path - pseudo-input)
240+
adds = getattr(var, "adds", None)
241+
subtracts = getattr(var, "subtracts", None)
242+
243+
if adds and isinstance(adds, list):
244+
result.append(name)
245+
elif subtracts and isinstance(subtracts, list):
246+
result.append(name)
247+
248+
return result
224249

225250

226251
def get_pseudo_input_variables(sim) -> set:

policyengine_us_data/datasets/cps/local_area_calibration/matrix_tracer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
create_target_groups,
4747
)
4848

49-
5049
logger = logging.getLogger(__name__)
5150

5251

policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py

Lines changed: 128 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,105 @@ def __init__(
3838
self.time_period = time_period
3939
self.cds_to_calibrate = cds_to_calibrate
4040
self.dataset_path = dataset_path
41+
self._entity_rel_cache = None
42+
43+
def _build_entity_relationship(self, sim) -> pd.DataFrame:
44+
"""
45+
Build entity relationship DataFrame mapping persons to all entity IDs.
46+
47+
This is used to evaluate constraints at the person level and then
48+
aggregate to household level, handling variables defined at different
49+
entity levels (person, tax_unit, household, spm_unit).
50+
51+
Returns:
52+
DataFrame with person_id, household_id, tax_unit_id, spm_unit_id
53+
"""
54+
if self._entity_rel_cache is not None:
55+
return self._entity_rel_cache
56+
57+
self._entity_rel_cache = pd.DataFrame(
58+
{
59+
"person_id": sim.calculate(
60+
"person_id", map_to="person"
61+
).values,
62+
"household_id": sim.calculate(
63+
"household_id", map_to="person"
64+
).values,
65+
"tax_unit_id": sim.calculate(
66+
"tax_unit_id", map_to="person"
67+
).values,
68+
"spm_unit_id": sim.calculate(
69+
"spm_unit_id", map_to="person"
70+
).values,
71+
}
72+
)
73+
return self._entity_rel_cache
74+
75+
def _evaluate_constraints_entity_aware(
76+
self, state_sim, constraints: List[dict], n_households: int
77+
) -> np.ndarray:
78+
"""
79+
Evaluate non-geographic constraints at person level, aggregate to
80+
household level using .any().
81+
82+
This properly handles constraints on variables defined at different
83+
entity levels (e.g., tax_unit_is_filer at tax_unit level). Instead of
84+
summing values at household level (which would give 2, 3, etc. for
85+
households with multiple tax units), we evaluate at person level and
86+
use .any() aggregation ("does this household have at least one person
87+
satisfying all constraints?").
88+
89+
Args:
90+
state_sim: Microsimulation with state_fips set
91+
constraints: List of constraint dicts with variable, operation,
92+
value keys (geographic constraints should be pre-filtered)
93+
n_households: Number of households
94+
95+
Returns:
96+
Boolean mask array of length n_households
97+
"""
98+
if not constraints:
99+
return np.ones(n_households, dtype=bool)
100+
101+
entity_rel = self._build_entity_relationship(state_sim)
102+
n_persons = len(entity_rel)
103+
104+
person_mask = np.ones(n_persons, dtype=bool)
105+
106+
for c in constraints:
107+
var = c["variable"]
108+
op = c["operation"]
109+
val = c["value"]
110+
111+
# Calculate constraint variable at person level
112+
constraint_values = state_sim.calculate(
113+
var, map_to="person"
114+
).values
115+
116+
# Apply operation at person level
117+
person_mask &= apply_op(constraint_values, op, val)
118+
119+
# Aggregate to household level using .any()
120+
# "At least one person in this household satisfies ALL constraints"
121+
entity_rel_with_mask = entity_rel.copy()
122+
entity_rel_with_mask["satisfies"] = person_mask
123+
124+
household_mask_series = entity_rel_with_mask.groupby("household_id")[
125+
"satisfies"
126+
].any()
127+
128+
# Ensure we return a mask aligned with household order
129+
household_ids = state_sim.calculate(
130+
"household_id", map_to="household"
131+
).values
132+
household_mask = np.array(
133+
[
134+
household_mask_series.get(hh_id, False)
135+
for hh_id in household_ids
136+
]
137+
)
138+
139+
return household_mask
41140

42141
def _query_targets(self, target_filter: dict) -> pd.DataFrame:
43142
"""Query targets based on filter criteria using OR logic."""
@@ -166,6 +265,9 @@ def build_matrix(
166265
cds_by_state[state].append((cd_idx, cd))
167266

168267
for state, cd_list in cds_by_state.items():
268+
# Clear entity relationship cache when creating new simulation
269+
self._entity_rel_cache = None
270+
169271
if self.dataset_path:
170272
state_sim = self._create_state_sim(state, n_households)
171273
else:
@@ -184,35 +286,43 @@ def build_matrix(
184286
for row_idx, (_, target) in enumerate(targets_df.iterrows()):
185287
constraints = self._get_constraints(target["stratum_id"])
186288

187-
mask = np.ones(n_households, dtype=bool)
289+
geo_constraints = []
290+
non_geo_constraints = []
188291
for c in constraints:
292+
if c["variable"] in (
293+
"state_fips",
294+
"congressional_district_geoid",
295+
):
296+
geo_constraints.append(c)
297+
else:
298+
non_geo_constraints.append(c)
299+
300+
# Check geographic constraints first (quick fail)
301+
geo_mask = np.ones(n_households, dtype=bool)
302+
for c in geo_constraints:
189303
if c["variable"] == "congressional_district_geoid":
190304
if (
191305
c["operation"] in ("==", "=")
192306
and c["value"] != cd
193307
):
194-
mask[:] = False
308+
geo_mask[:] = False
195309
elif c["variable"] == "state_fips":
196310
if (
197311
c["operation"] in ("==", "=")
198312
and int(c["value"]) != state
199313
):
200-
mask[:] = False
201-
else:
202-
try:
203-
values = state_sim.calculate(
204-
c["variable"], map_to="household"
205-
).values
206-
mask &= apply_op(
207-
values, c["operation"], c["value"]
208-
)
209-
except Exception as e:
210-
# Variable may not exist or may not be
211-
# calculable at household level - skip
212-
logger.debug(
213-
f"Could not evaluate constraint "
214-
f"{c['variable']}: {e}"
215-
)
314+
geo_mask[:] = False
315+
316+
if not geo_mask.any():
317+
continue
318+
319+
# Evaluate non-geographic constraints at entity level
320+
entity_mask = self._evaluate_constraints_entity_aware(
321+
state_sim, non_geo_constraints, n_households
322+
)
323+
324+
# Combine geographic and entity-aware masks
325+
mask = geo_mask & entity_mask
216326

217327
if not mask.any():
218328
continue

policyengine_us_data/datasets/puf/puf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
create_policyengine_uprating_factors_table,
1616
)
1717

18-
1918
rng = np.random.default_rng(seed=64)
2019

2120
# Get Qualified Business Income simulation parameters ---

policyengine_us_data/datasets/puf/uprate_puf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
from policyengine_us_data.storage import STORAGE_FOLDER
44

5-
65
ITMDED_GROW_RATE = 0.02 # annual growth rate in itemized deduction amounts
76

87
USE_VARIABLE_SPECIFIC_POPULATION_GROWTH_DIVISORS = False

policyengine_us_data/db/create_database_tables.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from policyengine_us_data.storage import STORAGE_FOLDER
1717

18-
1918
logging.basicConfig(
2019
level=logging.INFO,
2120
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",

policyengine_us_data/db/etl_age.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
)
1212
from policyengine_us_data.utils.census import get_census_docs, pull_acs_table
1313

14-
1514
LABEL_TO_SHORT = {
1615
"Estimate!!Total!!Total population!!AGE!!Under 5 years": "0-4",
1716
"Estimate!!Total!!Total population!!AGE!!5 to 9 years": "5-9",

0 commit comments

Comments
 (0)