@@ -931,6 +931,7 @@ def __init__(
931931 self .time_period = time_period
932932 self .dataset_path = dataset_path
933933 self ._entity_rel_cache = None
934+ self ._target_overview_columns = None
934935
935936 # ---------------------------------------------------------------
936937 # Entity relationships
@@ -959,8 +960,8 @@ def _build_state_values(
959960 sim ,
960961 target_vars : set ,
961962 constraint_vars : set ,
962- reform_vars : set ,
963- geography ,
963+ reform_vars : set = None ,
964+ geography = None ,
964965 rerandomize_takeup : bool = True ,
965966 workers : int = 1 ,
966967 ) -> dict :
@@ -997,6 +998,9 @@ def _build_state_values(
997998 TAKEUP_AFFECTED_TARGETS ,
998999 )
9991000
1001+ if geography is None :
1002+ raise ValueError ("geography is required" )
1003+
10001004 unique_states = sorted (set (int (s ) for s in geography .state_fips ))
10011005 n_hh = geography .n_records
10021006
@@ -1022,7 +1026,7 @@ def _build_state_values(
10221026 # Convert sets to sorted lists for deterministic iteration
10231027 target_vars_list = sorted (target_vars )
10241028 constraint_vars_list = sorted (constraint_vars )
1025- reform_vars_list = sorted (reform_vars )
1029+ reform_vars_list = sorted (reform_vars or set () )
10261030
10271031 state_values = {}
10281032
@@ -1518,63 +1522,103 @@ def _get_stratum_constraints(self, stratum_id: int) -> List[dict]:
15181522 )
15191523 return df .to_dict ("records" )
15201524
1525+ def _get_target_overview_columns (self ) -> set :
1526+ if self ._target_overview_columns is None :
1527+ with self .engine .connect () as conn :
1528+ rows = conn .execute (
1529+ text ("PRAGMA table_info(target_overview)" )
1530+ ).fetchall ()
1531+ self ._target_overview_columns = {row [1 ] for row in rows }
1532+ return self ._target_overview_columns
1533+
15211534 def _query_targets (self , target_filter : dict ) -> pd .DataFrame :
15221535 """Query targets via target_overview view with
15231536 best-period selection."""
1524- or_conditions = []
1537+ and_conditions = []
15251538
15261539 if "domain_variables" in target_filter :
15271540 dvs = target_filter ["domain_variables" ]
15281541 ph = "," .join (f"'{ dv } '" for dv in dvs )
1529- or_conditions .append (f"tv.domain_variable IN ({ ph } )" )
1542+ and_conditions .append (f"tv.domain_variable IN ({ ph } )" )
15301543
15311544 if "variables" in target_filter :
15321545 vs = "," .join (f"'{ v } '" for v in target_filter ["variables" ])
1533- or_conditions .append (f"tv.variable IN ({ vs } )" )
1546+ and_conditions .append (f"tv.variable IN ({ vs } )" )
15341547
15351548 if "target_ids" in target_filter :
15361549 ids = "," .join (map (str , target_filter ["target_ids" ]))
1537- or_conditions .append (f"tv.target_id IN ({ ids } )" )
1550+ and_conditions .append (f"tv.target_id IN ({ ids } )" )
15381551
15391552 if "stratum_ids" in target_filter :
15401553 ids = "," .join (map (str , target_filter ["stratum_ids" ]))
1541- or_conditions .append (f"tv.stratum_id IN ({ ids } )" )
1554+ and_conditions .append (f"tv.stratum_id IN ({ ids } )" )
15421555
1543- if not or_conditions :
1556+ if not and_conditions :
15441557 where_clause = "1=1"
15451558 else :
1546- where_clause = " OR " .join (f"({ c } )" for c in or_conditions )
1547-
1548- query = f"""
1549- WITH filtered_targets AS (
1550- SELECT tv.target_id, tv.stratum_id, tv.variable, tv.reform_id,
1551- tv.value, tv.period, tv.geo_level,
1552- tv.geographic_id, tv.domain_variable
1553- FROM target_overview tv
1554- WHERE tv.active = 1
1555- AND ({ where_clause } )
1556- ),
1557- best_periods AS (
1558- SELECT stratum_id, variable, reform_id,
1559- CASE
1560- WHEN MAX(CASE WHEN period <= :time_period
1561- THEN period END) IS NOT NULL
1562- THEN MAX(CASE WHEN period <= :time_period
1563- THEN period END)
1564- ELSE MIN(period)
1565- END as best_period
1566- FROM filtered_targets
1567- GROUP BY stratum_id, variable, reform_id
1568- )
1569- SELECT ft.*
1570- FROM filtered_targets ft
1571- JOIN best_periods bp
1572- ON ft.stratum_id = bp.stratum_id
1573- AND ft.variable = bp.variable
1574- AND ft.reform_id = bp.reform_id
1575- AND ft.period = bp.best_period
1576- ORDER BY ft.target_id
1577- """
1559+ where_clause = " AND " .join (f"({ c } )" for c in and_conditions )
1560+
1561+ if "reform_id" in self ._get_target_overview_columns ():
1562+ query = f"""
1563+ WITH filtered_targets AS (
1564+ SELECT tv.target_id, tv.stratum_id, tv.variable, tv.reform_id,
1565+ tv.value, tv.period, tv.geo_level,
1566+ tv.geographic_id, tv.domain_variable
1567+ FROM target_overview tv
1568+ WHERE tv.active = 1
1569+ AND ({ where_clause } )
1570+ ),
1571+ best_periods AS (
1572+ SELECT stratum_id, variable, reform_id,
1573+ CASE
1574+ WHEN MAX(CASE WHEN period <= :time_period
1575+ THEN period END) IS NOT NULL
1576+ THEN MAX(CASE WHEN period <= :time_period
1577+ THEN period END)
1578+ ELSE MIN(period)
1579+ END as best_period
1580+ FROM filtered_targets
1581+ GROUP BY stratum_id, variable, reform_id
1582+ )
1583+ SELECT ft.*
1584+ FROM filtered_targets ft
1585+ JOIN best_periods bp
1586+ ON ft.stratum_id = bp.stratum_id
1587+ AND ft.variable = bp.variable
1588+ AND ft.reform_id = bp.reform_id
1589+ AND ft.period = bp.best_period
1590+ ORDER BY ft.target_id
1591+ """
1592+ else :
1593+ query = f"""
1594+ WITH filtered_targets AS (
1595+ SELECT tv.target_id, tv.stratum_id, tv.variable,
1596+ 0 AS reform_id, tv.value, tv.period, tv.geo_level,
1597+ tv.geographic_id, tv.domain_variable
1598+ FROM target_overview tv
1599+ WHERE tv.active = 1
1600+ AND ({ where_clause } )
1601+ ),
1602+ best_periods AS (
1603+ SELECT stratum_id, variable,
1604+ CASE
1605+ WHEN MAX(CASE WHEN period <= :time_period
1606+ THEN period END) IS NOT NULL
1607+ THEN MAX(CASE WHEN period <= :time_period
1608+ THEN period END)
1609+ ELSE MIN(period)
1610+ END as best_period
1611+ FROM filtered_targets
1612+ GROUP BY stratum_id, variable
1613+ )
1614+ SELECT ft.*
1615+ FROM filtered_targets ft
1616+ JOIN best_periods bp
1617+ ON ft.stratum_id = bp.stratum_id
1618+ AND ft.variable = bp.variable
1619+ AND ft.period = bp.best_period
1620+ ORDER BY ft.target_id
1621+ """
15781622
15791623 with self .engine .connect () as conn :
15801624 return pd .read_sql (
0 commit comments