|
3 | 3 |
|
4 | 4 | from policyengine_us_data.utils.source_quality import ( |
5 | 5 | cap_training_sample, |
| 6 | + filter_positive_finite_weight_rows, |
6 | 7 | observed_source_mask, |
7 | 8 | require_columns_present, |
8 | 9 | sipp_allocation_flag_for, |
@@ -201,3 +202,54 @@ def test_cap_training_sample_rejects_misaligned_filters(): |
201 | 202 | raise AssertionError("Expected misaligned target filters to fail") |
202 | 203 |
|
203 | 204 | assert "target_filters['value']" in message |
| 205 | + |
| 206 | + |
| 207 | +def test_filter_positive_finite_weight_rows_reindexes_target_filters(): |
| 208 | + df = pd.DataFrame( |
| 209 | + { |
| 210 | + "value": [10, 20, 30, 40, 50], |
| 211 | + "household_weight": [1.0, 0.0, np.nan, np.inf, 5.0], |
| 212 | + }, |
| 213 | + index=[10, 11, 12, 13, 14], |
| 214 | + ) |
| 215 | + filters = { |
| 216 | + "value": pd.Series( |
| 217 | + [True, True, False, True, True], |
| 218 | + index=df.index, |
| 219 | + ) |
| 220 | + } |
| 221 | + |
| 222 | + filtered, filtered_filters = filter_positive_finite_weight_rows( |
| 223 | + df, |
| 224 | + weight_col="household_weight", |
| 225 | + target_filters=filters, |
| 226 | + context_name="unit-test donor", |
| 227 | + ) |
| 228 | + |
| 229 | + assert filtered["value"].tolist() == [10, 50] |
| 230 | + assert filtered.index.tolist() == [0, 1] |
| 231 | + np.testing.assert_array_equal(filtered_filters["value"].values, [True, True]) |
| 232 | + assert filtered_filters["value"].index.tolist() == [0, 1] |
| 233 | + |
| 234 | + |
| 235 | +def test_filter_positive_finite_weight_rows_requires_observed_target_rows(): |
| 236 | + df = pd.DataFrame( |
| 237 | + { |
| 238 | + "value": [10, 20], |
| 239 | + "household_weight": [0.0, 1.0], |
| 240 | + } |
| 241 | + ) |
| 242 | + filters = {"value": pd.Series([True, False], index=df.index)} |
| 243 | + |
| 244 | + try: |
| 245 | + filter_positive_finite_weight_rows( |
| 246 | + df, |
| 247 | + weight_col="household_weight", |
| 248 | + target_filters=filters, |
| 249 | + ) |
| 250 | + except ValueError as error: |
| 251 | + message = str(error) |
| 252 | + else: |
| 253 | + raise AssertionError("Expected all invalid observed weights to fail") |
| 254 | + |
| 255 | + assert "No observed donor rows with positive finite household_weight" in message |
0 commit comments