|
3 | 3 | import pandas as pd |
4 | 4 | import numpy as np |
5 | 5 | import logging |
| 6 | +import sqlite3 |
6 | 7 |
|
7 | | -from policyengine_us_data.storage import CALIBRATION_FOLDER |
| 8 | +from policyengine_us_data.storage import CALIBRATION_FOLDER, STORAGE_FOLDER |
8 | 9 | from policyengine_us_data.storage.calibration_targets.pull_soi_targets import ( |
9 | 10 | STATE_ABBR_TO_FIPS, |
10 | 11 | ) |
@@ -118,6 +119,133 @@ def fmt(x): |
118 | 119 | return f"{x / 1e9:.1f}bn" |
119 | 120 |
|
120 | 121 |
|
| 122 | +def _parse_constraint_value(value): |
| 123 | + if value == "True": |
| 124 | + return True |
| 125 | + if value == "False": |
| 126 | + return False |
| 127 | + try: |
| 128 | + return int(value) |
| 129 | + except (TypeError, ValueError): |
| 130 | + try: |
| 131 | + return float(value) |
| 132 | + except (TypeError, ValueError): |
| 133 | + return value |
| 134 | + |
| 135 | + |
| 136 | +def _apply_constraint(values, operation: str, raw_value: str): |
| 137 | + if operation == "in": |
| 138 | + allowed_values = [part.strip() for part in raw_value.split("|")] |
| 139 | + return np.isin(values, allowed_values) |
| 140 | + |
| 141 | + value = _parse_constraint_value(raw_value) |
| 142 | + if operation in ("equals", "==", "="): |
| 143 | + return values == value |
| 144 | + if operation in ("greater_than", ">"): |
| 145 | + return values > value |
| 146 | + if operation in ("greater_than_or_equal", ">="): |
| 147 | + return values >= value |
| 148 | + if operation in ("less_than", "<"): |
| 149 | + return values < value |
| 150 | + if operation in ("less_than_or_equal", "<="): |
| 151 | + return values <= value |
| 152 | + if operation in ("not_equals", "!=", "<>"): |
| 153 | + return values != value |
| 154 | + |
| 155 | + raise ValueError(f"Unsupported stratum constraint operation: {operation}") |
| 156 | + |
| 157 | + |
| 158 | +def _geo_label_from_ucgid(ucgid_str: str) -> str: |
| 159 | + if ucgid_str in (None, "", "0100000US"): |
| 160 | + return "nation" |
| 161 | + return f"geo/{ucgid_str}" |
| 162 | + |
| 163 | + |
| 164 | +def _add_liheap_targets_from_db(loss_matrix, targets_list, sim, time_period): |
| 165 | + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" |
| 166 | + if not db_path.exists(): |
| 167 | + return targets_list, loss_matrix |
| 168 | + |
| 169 | + query = """ |
| 170 | + SELECT |
| 171 | + t.target_id, |
| 172 | + t.variable, |
| 173 | + t.value AS target_value, |
| 174 | + s.notes, |
| 175 | + sc.constraint_variable, |
| 176 | + sc.operation, |
| 177 | + sc.value AS constraint_value |
| 178 | + FROM targets t |
| 179 | + JOIN strata s |
| 180 | + ON s.stratum_id = t.stratum_id |
| 181 | + JOIN stratum_constraints sc |
| 182 | + ON sc.stratum_id = s.stratum_id |
| 183 | + WHERE |
| 184 | + t.active = 1 |
| 185 | + AND t.reform_id = 0 |
| 186 | + AND t.period = ? |
| 187 | + AND s.notes LIKE '%LIHEAP%' |
| 188 | + ORDER BY t.target_id |
| 189 | + """ |
| 190 | + |
| 191 | + with sqlite3.connect(db_path) as conn: |
| 192 | + target_rows = pd.read_sql_query(query, conn, params=[time_period]) |
| 193 | + |
| 194 | + if target_rows.empty: |
| 195 | + return targets_list, loss_matrix |
| 196 | + |
| 197 | + household_values_cache = { |
| 198 | + "household_weight": sim.calculate("household_weight").values |
| 199 | + } |
| 200 | + |
| 201 | + def get_household_values(variable: str): |
| 202 | + if variable not in household_values_cache: |
| 203 | + household_values_cache[variable] = sim.calculate( |
| 204 | + variable, |
| 205 | + map_to="household", |
| 206 | + ).values |
| 207 | + return household_values_cache[variable] |
| 208 | + |
| 209 | + n_households = len(household_values_cache["household_weight"]) |
| 210 | + |
| 211 | + for _, target_df in target_rows.groupby("target_id", sort=False): |
| 212 | + mask = np.ones(n_households, dtype=bool) |
| 213 | + for row in target_df.itertuples(index=False): |
| 214 | + if ( |
| 215 | + row.constraint_variable == "ucgid_str" |
| 216 | + and row.constraint_value == "0100000US" |
| 217 | + ): |
| 218 | + continue |
| 219 | + values = get_household_values(row.constraint_variable) |
| 220 | + mask &= _apply_constraint( |
| 221 | + values, |
| 222 | + row.operation, |
| 223 | + row.constraint_value, |
| 224 | + ) |
| 225 | + |
| 226 | + variable = target_df["variable"].iat[0] |
| 227 | + if variable == "household_count": |
| 228 | + metric = mask.astype(float) |
| 229 | + else: |
| 230 | + metric = np.where(mask, get_household_values(variable), 0.0) |
| 231 | + |
| 232 | + ucgid_constraints = target_df.loc[ |
| 233 | + target_df.constraint_variable == "ucgid_str", "constraint_value" |
| 234 | + ] |
| 235 | + geo_label = _geo_label_from_ucgid( |
| 236 | + ucgid_constraints.iat[0] if not ucgid_constraints.empty else None |
| 237 | + ) |
| 238 | + label = f"{geo_label}/db/liheap/{variable}" |
| 239 | + loss_matrix[label] = metric |
| 240 | + targets_list.append(target_df["target_value"].iat[0]) |
| 241 | + |
| 242 | + logging.info( |
| 243 | + f"Loaded {target_rows['target_id'].nunique()} LIHEAP targets from the local targets DB" |
| 244 | + ) |
| 245 | + |
| 246 | + return targets_list, loss_matrix |
| 247 | + |
| 248 | + |
121 | 249 | def build_loss_matrix(dataset: type, time_period): |
122 | 250 | loss_matrix = pd.DataFrame() |
123 | 251 | df = pe_to_soi(dataset, time_period) |
@@ -667,6 +795,10 @@ def build_loss_matrix(dataset: type, time_period): |
667 | 795 | targets_array.extend(snap_state_targets) |
668 | 796 | loss_matrix = _add_snap_metric_columns(loss_matrix, sim) |
669 | 797 |
|
| 798 | + targets_array, loss_matrix = _add_liheap_targets_from_db( |
| 799 | + loss_matrix, targets_array, sim, time_period |
| 800 | + ) |
| 801 | + |
670 | 802 | del sim, df |
671 | 803 | gc.collect() |
672 | 804 |
|
|
0 commit comments