|
3 | 3 | from scipy import optimize, sparse |
4 | 4 |
|
5 | 5 |
|
| 6 | +class GregCalibrator: |
| 7 | + """Small adapter around svy's GREG calibration workflow.""" |
| 8 | + |
| 9 | + _base_weight_column = "_policyengine_base_weight" |
| 10 | + _calibrated_weight_column = "_policyengine_greg_weight" |
| 11 | + |
| 12 | + def __init__(self): |
| 13 | + try: |
| 14 | + import polars as pl |
| 15 | + import svy |
| 16 | + except ImportError as e: # pragma: no cover - exercised without extra |
| 17 | + raise ImportError( |
| 18 | + "svy is required for GREG calibration. " |
| 19 | + "Install with: pip install policyengine-us-data[calibration]" |
| 20 | + ) from e |
| 21 | + |
| 22 | + self._pl = pl |
| 23 | + self._svy = svy |
| 24 | + |
| 25 | + def calibrate(self, *, samp_weight, aux_vars, control): |
| 26 | + control = {str(name): float(target) for name, target in control.items()} |
| 27 | + aux_df = self._auxiliary_dataframe(aux_vars, list(control)) |
| 28 | + aux_df[self._base_weight_column] = np.asarray(samp_weight, dtype=float) |
| 29 | + |
| 30 | + sample = self._svy.Sample( |
| 31 | + self._pl.from_pandas(aux_df), |
| 32 | + design=self._svy.Design(wgt=self._base_weight_column), |
| 33 | + ) |
| 34 | + sample.weighting.calibrate( |
| 35 | + controls=control, |
| 36 | + wgt_name=self._calibrated_weight_column, |
| 37 | + ) |
| 38 | + return ( |
| 39 | + sample.data.get_column(self._calibrated_weight_column) |
| 40 | + .to_numpy() |
| 41 | + .astype(float) |
| 42 | + ) |
| 43 | + |
| 44 | + def _auxiliary_dataframe(self, aux_vars, control_names): |
| 45 | + if isinstance(aux_vars, pd.DataFrame): |
| 46 | + aux_df = aux_vars.copy() |
| 47 | + aux_df.columns = [str(column) for column in aux_df.columns] |
| 48 | + return aux_df |
| 49 | + |
| 50 | + if sparse.issparse(aux_vars): |
| 51 | + aux_array = aux_vars.toarray() |
| 52 | + else: |
| 53 | + aux_array = np.asarray(aux_vars) |
| 54 | + |
| 55 | + if aux_array.ndim == 1: |
| 56 | + aux_array = aux_array.reshape(-1, 1) |
| 57 | + |
| 58 | + if aux_array.shape[1] != len(control_names): |
| 59 | + raise ValueError( |
| 60 | + "aux_vars column count must match the number of GREG controls" |
| 61 | + ) |
| 62 | + |
| 63 | + return pd.DataFrame(aux_array.astype(float), columns=control_names) |
| 64 | + |
| 65 | + |
6 | 66 | def _pct_error(achieved, target): |
7 | 67 | if target == 0: |
8 | 68 | return 0.0 if achieved == 0 else float("inf") |
@@ -106,10 +166,10 @@ def calibrate_greg( |
106 | 166 | n_ages=86, |
107 | 167 | ): |
108 | 168 | """ |
109 | | - Calibrate weights using GREG method via samplics. |
| 169 | + Calibrate weights using GREG method via svy. |
110 | 170 |
|
111 | 171 | Args: |
112 | | - calibrator: SampleWeight instance from samplics |
| 172 | + calibrator: GregCalibrator instance |
113 | 173 | X: Design matrix (n_households x n_ages) |
114 | 174 | y_target: Target age distribution |
115 | 175 | baseline_weights: Initial household weights |
|
0 commit comments