Skip to content

Commit d159567

Browse files
committed
Fix capital gains basis read backfill
1 parent a581a14 commit d159567

3 files changed

Lines changed: 223 additions & 2 deletions

File tree

policyengine_us_data/datasets/puf/puf.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,88 @@ def _ensure_sstb_split_inputs(self) -> dict[str, np.ndarray]:
13401340

13411341
return overrides
13421342

1343+
def _capital_gains_basis_overrides(
1344+
self,
1345+
existing_overrides: dict[str, np.ndarray] | None = None,
1346+
) -> dict[str, np.ndarray]:
1347+
if not has_policyengine_us_variables(*CAPITAL_GAINS_BASIS_VARIABLES):
1348+
return {}
1349+
if not self.file_path.exists():
1350+
return {}
1351+
1352+
existing_overrides = existing_overrides or {}
1353+
with h5py.File(self.file_path, "r") as file_handle:
1354+
keys = set(file_handle.keys()) | set(existing_overrides)
1355+
if all(variable in keys for variable in CAPITAL_GAINS_BASIS_VARIABLES):
1356+
return {}
1357+
if (
1358+
"long_term_capital_gains" not in keys
1359+
or "person_tax_unit_id" not in keys
1360+
):
1361+
return {}
1362+
1363+
gains = self._values_from_file_or_overrides(
1364+
file_handle,
1365+
"long_term_capital_gains",
1366+
existing_overrides,
1367+
0,
1368+
)
1369+
length = len(gains)
1370+
arrays = {
1371+
"long_term_capital_gains": gains,
1372+
"person_tax_unit_id": self._values_from_file_or_overrides(
1373+
file_handle,
1374+
"person_tax_unit_id",
1375+
existing_overrides,
1376+
length,
1377+
),
1378+
}
1379+
for variable in (
1380+
"person_id",
1381+
"household_weight",
1382+
"person_household_id",
1383+
"household_id",
1384+
*CAPITAL_GAINS_BASIS_VARIABLES,
1385+
):
1386+
if variable in keys:
1387+
arrays[variable] = self._values_from_file_or_overrides(
1388+
file_handle,
1389+
variable,
1390+
existing_overrides,
1391+
length,
1392+
)
1393+
1394+
arrays = _with_capital_gains_basis_inputs(arrays, self.time_period)
1395+
return {
1396+
variable: np.asarray(arrays[variable])
1397+
for variable in CAPITAL_GAINS_BASIS_VARIABLES
1398+
if variable not in keys and variable in arrays
1399+
}
1400+
1401+
def _ensure_capital_gains_basis_inputs(
1402+
self,
1403+
existing_overrides: dict[str, np.ndarray] | None = None,
1404+
) -> dict[str, np.ndarray]:
1405+
overrides = self._capital_gains_basis_overrides(existing_overrides)
1406+
if not overrides:
1407+
return {}
1408+
1409+
try:
1410+
with h5py.File(self.file_path, "r+") as file_handle:
1411+
for key, values in overrides.items():
1412+
self._replace_array(file_handle, key, values)
1413+
except OSError:
1414+
pass
1415+
1416+
return overrides
1417+
1418+
def _ensure_read_overrides(self) -> dict[str, np.ndarray]:
1419+
sstb_overrides = self._ensure_sstb_split_inputs()
1420+
capital_gains_overrides = self._ensure_capital_gains_basis_inputs(
1421+
sstb_overrides
1422+
)
1423+
return {**sstb_overrides, **capital_gains_overrides}
1424+
13431425
class _OverrideView:
13441426
def __init__(self, backing, overrides: dict[str, np.ndarray]):
13451427
self._backing = backing
@@ -1393,15 +1475,15 @@ def __getattr__(self, name):
13931475

13941476
def load(self, key=None, mode="r"):
13951477
if mode == "r":
1396-
overrides = self._ensure_sstb_split_inputs()
1478+
overrides = self._ensure_read_overrides()
13971479
if key in overrides:
13981480
return overrides[key]
13991481
if key is None and overrides:
14001482
return self._OverrideView(super().load(key=key, mode=mode), overrides)
14011483
return super().load(key=key, mode=mode)
14021484

14031485
def load_dataset(self):
1404-
overrides = self._ensure_sstb_split_inputs()
1486+
overrides = self._ensure_read_overrides()
14051487
arrays = super().load_dataset()
14061488
arrays.update(overrides)
14071489
return arrays

tests/unit/calibration/test_calibration_puf_impute.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,59 @@ def test_capital_gains_basis_fields_are_stage_one_outputs(self):
216216
assert expected <= set(IMPUTED_VARIABLES)
217217
assert expected <= set(DETERMINISTIC_IMPUTED_VARIABLES)
218218

219+
def test_qrf_excludes_deterministic_capital_gains_basis_outputs(
220+
self,
221+
monkeypatch,
222+
):
223+
import policyengine_us
224+
225+
data = _make_mock_data(n_persons=4, n_households=2)
226+
227+
class FakeCalculation:
228+
values = np.array([100.0, 200.0, 300.0, 400.0], dtype=np.float32)
229+
230+
class FakeMicrosimulation:
231+
def __init__(self, dataset):
232+
self.dataset = dataset
233+
234+
def calculate(self, variable, map_to=None):
235+
return FakeCalculation()
236+
237+
def calculate_dataframe(self, variables):
238+
return pd.DataFrame(
239+
{variable: np.arange(4, dtype=np.float32) for variable in variables}
240+
)
241+
242+
captured_output_vars = []
243+
244+
def fake_sequential_qrf(X_train, X_test, predictors, output_vars):
245+
captured_output_vars.append(tuple(output_vars))
246+
return {
247+
variable: np.zeros(len(X_test), dtype=np.float32)
248+
for variable in output_vars
249+
}
250+
251+
monkeypatch.setattr(policyengine_us, "Microsimulation", FakeMicrosimulation)
252+
monkeypatch.setattr(
253+
puf_impute_module,
254+
"_sequential_qrf",
255+
fake_sequential_qrf,
256+
)
257+
258+
puf_impute_module._run_qrf_imputation(
259+
data=data,
260+
time_period=2024,
261+
puf_dataset=object(),
262+
)
263+
264+
deterministic_outputs = set(DETERMINISTIC_IMPUTED_VARIABLES)
265+
assert captured_output_vars
266+
for output_vars in captured_output_vars:
267+
assert deterministic_outputs.isdisjoint(output_vars)
268+
assert set(captured_output_vars[0]) == (
269+
set(IMPUTED_VARIABLES) - deterministic_outputs
270+
)
271+
219272
def test_overridden_subset_of_imputed(self):
220273
for var in OVERRIDDEN_IMPUTED_VARIABLES:
221274
assert var in IMPUTED_VARIABLES

tests/unit/datasets/test_irs_puf.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import pytest
44

5+
from policyengine_us_data.datasets.puf import puf as puf_module
56
from policyengine_us_data.datasets.puf.puf import (
67
PUF,
78
QBI_SIMULATION_VERSION,
@@ -14,6 +15,19 @@ def _mark_current_qbi_simulation(file_handle):
1415
file_handle.attrs[QBI_SIMULATION_VERSION_ATTR] = QBI_SIMULATION_VERSION
1516

1617

18+
def _write_capital_gains_basis_source_file(path):
19+
with h5py.File(path, "w") as file_handle:
20+
file_handle.create_dataset("person_id", data=np.array([1, 2, 3, 4]))
21+
file_handle.create_dataset("person_tax_unit_id", data=np.array([1, 1, 2, 2]))
22+
file_handle.create_dataset("person_household_id", data=np.array([1, 1, 2, 2]))
23+
file_handle.create_dataset("household_id", data=np.array([1, 2]))
24+
file_handle.create_dataset("household_weight", data=np.array([100.0, 200.0]))
25+
file_handle.create_dataset(
26+
"long_term_capital_gains",
27+
data=np.array([100.0, -40.0, 0.0, 200.0]),
28+
)
29+
30+
1731
@pytest.mark.skip(reason="This test requires private data.")
1832
@pytest.mark.parametrize("year", [2015])
1933
def test_irs_puf_generates(year: int):
@@ -50,6 +64,78 @@ def test_puf_person_split_keeps_capital_gains_holding_period_collapsed():
5064
)
5165

5266

67+
def test_puf_load_dataset_backfills_capital_gains_basis_inputs(
68+
tmp_path,
69+
monkeypatch,
70+
):
71+
monkeypatch.setattr(
72+
puf_module,
73+
"has_policyengine_us_variables",
74+
lambda *variables: True,
75+
)
76+
77+
class DummyPUF(PUF):
78+
label = "Dummy PUF"
79+
name = "dummy_puf"
80+
time_period = 2024
81+
file_path = tmp_path / "dummy_puf.h5"
82+
83+
_write_capital_gains_basis_source_file(DummyPUF.file_path)
84+
85+
arrays = DummyPUF().load_dataset()
86+
87+
basis = arrays["long_term_capital_gains_basis"]
88+
years = arrays["long_term_capital_gains_years_held"]
89+
gains = arrays["long_term_capital_gains"]
90+
91+
assert np.all(basis[gains != 0] > 0)
92+
assert np.all(years[gains != 0] > 0)
93+
assert np.all(basis[gains == 0] == 0)
94+
assert np.all(years[gains == 0] == 0)
95+
96+
with h5py.File(DummyPUF.file_path, "r") as file_handle:
97+
assert "long_term_capital_gains_basis" in file_handle
98+
assert "long_term_capital_gains_years_held" in file_handle
99+
100+
101+
def test_puf_load_key_backfills_read_only_capital_gains_basis_inputs(
102+
tmp_path,
103+
monkeypatch,
104+
):
105+
monkeypatch.setattr(
106+
puf_module,
107+
"has_policyengine_us_variables",
108+
lambda *variables: True,
109+
)
110+
111+
class DummyPUF(PUF):
112+
label = "Dummy PUF"
113+
name = "dummy_puf"
114+
time_period = 2024
115+
file_path = tmp_path / "dummy_puf.h5"
116+
117+
_write_capital_gains_basis_source_file(DummyPUF.file_path)
118+
DummyPUF.file_path.chmod(0o444)
119+
120+
dataset = DummyPUF()
121+
try:
122+
basis = dataset.load("long_term_capital_gains_basis")
123+
years = dataset.load("long_term_capital_gains_years_held")
124+
reader = dataset.load()
125+
np.testing.assert_array_equal(
126+
reader["long_term_capital_gains_basis"],
127+
basis,
128+
)
129+
reader.close()
130+
finally:
131+
DummyPUF.file_path.chmod(0o644)
132+
133+
assert np.all(basis[[0, 1, 3]] > 0)
134+
assert basis[2] == 0
135+
assert np.all(years[[0, 1, 3]] > 0)
136+
assert years[2] == 0
137+
138+
53139
def test_puf_load_dataset_backfills_sstb_split_inputs(tmp_path):
54140
class DummyPUF(PUF):
55141
label = "Dummy PUF"

0 commit comments

Comments
 (0)