Skip to content

Commit 0719ec1

Browse files
committed
Fix unified calibration test drift
1 parent 8a8bbef commit 0719ec1

4 files changed

Lines changed: 156 additions & 43 deletions

File tree

policyengine_us_data/calibration/unified_matrix_builder.py

Lines changed: 85 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,7 @@ def __init__(
931931
self.time_period = time_period
932932
self.dataset_path = dataset_path
933933
self._entity_rel_cache = None
934+
self._target_overview_columns = None
934935

935936
# ---------------------------------------------------------------
936937
# Entity relationships
@@ -959,8 +960,8 @@ def _build_state_values(
959960
sim,
960961
target_vars: set,
961962
constraint_vars: set,
962-
reform_vars: set,
963-
geography,
963+
reform_vars: set = None,
964+
geography=None,
964965
rerandomize_takeup: bool = True,
965966
workers: int = 1,
966967
) -> dict:
@@ -997,6 +998,9 @@ def _build_state_values(
997998
TAKEUP_AFFECTED_TARGETS,
998999
)
9991000

1001+
if geography is None:
1002+
raise ValueError("geography is required")
1003+
10001004
unique_states = sorted(set(int(s) for s in geography.state_fips))
10011005
n_hh = geography.n_records
10021006

@@ -1022,7 +1026,7 @@ def _build_state_values(
10221026
# Convert sets to sorted lists for deterministic iteration
10231027
target_vars_list = sorted(target_vars)
10241028
constraint_vars_list = sorted(constraint_vars)
1025-
reform_vars_list = sorted(reform_vars)
1029+
reform_vars_list = sorted(reform_vars or set())
10261030

10271031
state_values = {}
10281032

@@ -1518,63 +1522,103 @@ def _get_stratum_constraints(self, stratum_id: int) -> List[dict]:
15181522
)
15191523
return df.to_dict("records")
15201524

1525+
def _get_target_overview_columns(self) -> set:
1526+
if self._target_overview_columns is None:
1527+
with self.engine.connect() as conn:
1528+
rows = conn.execute(
1529+
text("PRAGMA table_info(target_overview)")
1530+
).fetchall()
1531+
self._target_overview_columns = {row[1] for row in rows}
1532+
return self._target_overview_columns
1533+
15211534
def _query_targets(self, target_filter: dict) -> pd.DataFrame:
15221535
"""Query targets via target_overview view with
15231536
best-period selection."""
1524-
or_conditions = []
1537+
and_conditions = []
15251538

15261539
if "domain_variables" in target_filter:
15271540
dvs = target_filter["domain_variables"]
15281541
ph = ",".join(f"'{dv}'" for dv in dvs)
1529-
or_conditions.append(f"tv.domain_variable IN ({ph})")
1542+
and_conditions.append(f"tv.domain_variable IN ({ph})")
15301543

15311544
if "variables" in target_filter:
15321545
vs = ",".join(f"'{v}'" for v in target_filter["variables"])
1533-
or_conditions.append(f"tv.variable IN ({vs})")
1546+
and_conditions.append(f"tv.variable IN ({vs})")
15341547

15351548
if "target_ids" in target_filter:
15361549
ids = ",".join(map(str, target_filter["target_ids"]))
1537-
or_conditions.append(f"tv.target_id IN ({ids})")
1550+
and_conditions.append(f"tv.target_id IN ({ids})")
15381551

15391552
if "stratum_ids" in target_filter:
15401553
ids = ",".join(map(str, target_filter["stratum_ids"]))
1541-
or_conditions.append(f"tv.stratum_id IN ({ids})")
1554+
and_conditions.append(f"tv.stratum_id IN ({ids})")
15421555

1543-
if not or_conditions:
1556+
if not and_conditions:
15441557
where_clause = "1=1"
15451558
else:
1546-
where_clause = " OR ".join(f"({c})" for c in or_conditions)
1547-
1548-
query = f"""
1549-
WITH filtered_targets AS (
1550-
SELECT tv.target_id, tv.stratum_id, tv.variable, tv.reform_id,
1551-
tv.value, tv.period, tv.geo_level,
1552-
tv.geographic_id, tv.domain_variable
1553-
FROM target_overview tv
1554-
WHERE tv.active = 1
1555-
AND ({where_clause})
1556-
),
1557-
best_periods AS (
1558-
SELECT stratum_id, variable, reform_id,
1559-
CASE
1560-
WHEN MAX(CASE WHEN period <= :time_period
1561-
THEN period END) IS NOT NULL
1562-
THEN MAX(CASE WHEN period <= :time_period
1563-
THEN period END)
1564-
ELSE MIN(period)
1565-
END as best_period
1566-
FROM filtered_targets
1567-
GROUP BY stratum_id, variable, reform_id
1568-
)
1569-
SELECT ft.*
1570-
FROM filtered_targets ft
1571-
JOIN best_periods bp
1572-
ON ft.stratum_id = bp.stratum_id
1573-
AND ft.variable = bp.variable
1574-
AND ft.reform_id = bp.reform_id
1575-
AND ft.period = bp.best_period
1576-
ORDER BY ft.target_id
1577-
"""
1559+
where_clause = " AND ".join(f"({c})" for c in and_conditions)
1560+
1561+
if "reform_id" in self._get_target_overview_columns():
1562+
query = f"""
1563+
WITH filtered_targets AS (
1564+
SELECT tv.target_id, tv.stratum_id, tv.variable, tv.reform_id,
1565+
tv.value, tv.period, tv.geo_level,
1566+
tv.geographic_id, tv.domain_variable
1567+
FROM target_overview tv
1568+
WHERE tv.active = 1
1569+
AND ({where_clause})
1570+
),
1571+
best_periods AS (
1572+
SELECT stratum_id, variable, reform_id,
1573+
CASE
1574+
WHEN MAX(CASE WHEN period <= :time_period
1575+
THEN period END) IS NOT NULL
1576+
THEN MAX(CASE WHEN period <= :time_period
1577+
THEN period END)
1578+
ELSE MIN(period)
1579+
END as best_period
1580+
FROM filtered_targets
1581+
GROUP BY stratum_id, variable, reform_id
1582+
)
1583+
SELECT ft.*
1584+
FROM filtered_targets ft
1585+
JOIN best_periods bp
1586+
ON ft.stratum_id = bp.stratum_id
1587+
AND ft.variable = bp.variable
1588+
AND ft.reform_id = bp.reform_id
1589+
AND ft.period = bp.best_period
1590+
ORDER BY ft.target_id
1591+
"""
1592+
else:
1593+
query = f"""
1594+
WITH filtered_targets AS (
1595+
SELECT tv.target_id, tv.stratum_id, tv.variable,
1596+
0 AS reform_id, tv.value, tv.period, tv.geo_level,
1597+
tv.geographic_id, tv.domain_variable
1598+
FROM target_overview tv
1599+
WHERE tv.active = 1
1600+
AND ({where_clause})
1601+
),
1602+
best_periods AS (
1603+
SELECT stratum_id, variable,
1604+
CASE
1605+
WHEN MAX(CASE WHEN period <= :time_period
1606+
THEN period END) IS NOT NULL
1607+
THEN MAX(CASE WHEN period <= :time_period
1608+
THEN period END)
1609+
ELSE MIN(period)
1610+
END as best_period
1611+
FROM filtered_targets
1612+
GROUP BY stratum_id, variable
1613+
)
1614+
SELECT ft.*
1615+
FROM filtered_targets ft
1616+
JOIN best_periods bp
1617+
ON ft.stratum_id = bp.stratum_id
1618+
AND ft.variable = bp.variable
1619+
AND ft.period = bp.best_period
1620+
ORDER BY ft.target_id
1621+
"""
15781622

15791623
with self.engine.connect() as conn:
15801624
return pd.read_sql(

policyengine_us_data/db/create_database_tables.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,8 @@ def create_database(
508508

509509
# Create SQL views
510510
with engine.connect() as conn:
511+
conn.execute(text("DROP VIEW IF EXISTS stratum_domain"))
512+
conn.execute(text("DROP VIEW IF EXISTS target_overview"))
511513
conn.execute(text(STRATUM_DOMAIN_VIEW))
512514
conn.execute(text(TARGET_OVERVIEW_VIEW))
513515
conn.commit()

policyengine_us_data/tests/test_calibration/test_unified_calibration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def test_county_var_uses_county_values(self):
459459
person_hh_idx = np.array([0, 1, 2, 3])
460460

461461
builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
462-
hh_vars, _ = builder._assemble_clone_values(
462+
hh_vars, _, _ = builder._assemble_clone_values(
463463
state_values,
464464
clone_states,
465465
person_hh_idx,
@@ -499,7 +499,7 @@ def test_non_county_var_uses_state_values(self):
499499
person_hh_idx = np.array([0, 1, 2, 3])
500500

501501
builder = UnifiedMatrixBuilder.__new__(UnifiedMatrixBuilder)
502-
hh_vars, _ = builder._assemble_clone_values(
502+
hh_vars, _, _ = builder._assemble_clone_values(
503503
state_values,
504504
clone_states,
505505
person_hh_idx,

policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,59 @@ def _create_test_db(db_path):
6161
return db_uri, engine
6262

6363

64+
def _create_legacy_target_overview(engine):
65+
legacy_view = """\
66+
CREATE VIEW target_overview AS
67+
SELECT
68+
t.target_id,
69+
t.stratum_id,
70+
t.variable,
71+
t.value,
72+
t.period,
73+
t.active,
74+
CASE
75+
WHEN MAX(CASE
76+
WHEN sc.constraint_variable = 'congressional_district_geoid'
77+
THEN 1
78+
WHEN sc.constraint_variable = 'ucgid_str'
79+
AND length(sc.value) = 13 THEN 1
80+
ELSE 0 END) = 1 THEN 'district'
81+
WHEN MAX(CASE
82+
WHEN sc.constraint_variable = 'state_fips' THEN 1
83+
WHEN sc.constraint_variable = 'ucgid_str'
84+
AND length(sc.value) = 11 THEN 1
85+
ELSE 0 END) = 1 THEN 'state'
86+
ELSE 'national'
87+
END AS geo_level,
88+
COALESCE(
89+
MAX(CASE
90+
WHEN sc.constraint_variable = 'congressional_district_geoid'
91+
THEN sc.value END),
92+
MAX(CASE
93+
WHEN sc.constraint_variable = 'state_fips'
94+
THEN sc.value END),
95+
MAX(CASE
96+
WHEN sc.constraint_variable = 'ucgid_str'
97+
THEN sc.value END),
98+
'US'
99+
) AS geographic_id,
100+
GROUP_CONCAT(DISTINCT CASE
101+
WHEN sc.constraint_variable NOT IN (
102+
'state_fips', 'congressional_district_geoid',
103+
'tax_unit_is_filer', 'ucgid_str'
104+
) THEN sc.constraint_variable
105+
END) AS domain_variable
106+
FROM targets t
107+
LEFT JOIN stratum_constraints sc ON t.stratum_id = sc.stratum_id
108+
GROUP BY t.target_id, t.stratum_id, t.variable,
109+
t.value, t.period, t.active;
110+
"""
111+
with engine.connect() as conn:
112+
conn.execute(text("DROP VIEW target_overview"))
113+
conn.execute(text(legacy_view))
114+
conn.commit()
115+
116+
64117
def _insert_aca_ptc_data(engine):
65118
with engine.connect() as conn:
66119
strata = [1, 2, 3, 4, 5, 6, 7, 8, 9]
@@ -217,6 +270,20 @@ def test_inactive_targets_are_excluded(self):
217270
self.assertEqual(int(baseline_rows.iloc[0]["period"]), 2022)
218271
self.assertEqual(float(baseline_rows.iloc[0]["value"]), 10000.0)
219272

273+
def test_legacy_target_overview_without_reform_id(self):
274+
_create_legacy_target_overview(self.engine)
275+
try:
276+
b = self._make_builder()
277+
df = b._query_targets({"domain_variables": ["aca_ptc"]})
278+
self.assertGreater(len(df), 0)
279+
self.assertIn("reform_id", df.columns)
280+
self.assertTrue((df["reform_id"] == 0).all())
281+
finally:
282+
with self.engine.connect() as conn:
283+
conn.execute(text("DROP VIEW target_overview"))
284+
conn.execute(text(TARGET_OVERVIEW_VIEW))
285+
conn.commit()
286+
220287
def test_target_name_adds_expenditure_suffix_for_reforms(self):
221288
name = UnifiedMatrixBuilder._make_target_name(
222289
"salt_deduction",

0 commit comments

Comments
 (0)