|
| 1 | +import importlib.util |
| 2 | +import sys |
| 3 | +import types |
| 4 | +from contextlib import contextmanager |
| 5 | +from pathlib import Path |
| 6 | + |
| 7 | +import pandas as pd |
| 8 | +import pytest |
| 9 | + |
| 10 | + |
| 11 | +REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent |
| 12 | +PACKAGE_ROOT = REPO_ROOT / "policyengine_us_data" |
| 13 | + |
| 14 | + |
| 15 | +@contextmanager |
| 16 | +def load_uprate_puf_module(storage_root: Path): |
| 17 | + module_names = [ |
| 18 | + "policyengine_us_data.datasets.puf.uprate_puf", |
| 19 | + "policyengine_us_data.datasets.puf", |
| 20 | + "policyengine_us_data.datasets", |
| 21 | + "policyengine_us_data.storage", |
| 22 | + "policyengine_us_data", |
| 23 | + ] |
| 24 | + original_modules = {name: sys.modules.get(name) for name in module_names} |
| 25 | + for name in module_names: |
| 26 | + sys.modules.pop(name, None) |
| 27 | + |
| 28 | + try: |
| 29 | + package = types.ModuleType("policyengine_us_data") |
| 30 | + package.__path__ = [str(PACKAGE_ROOT)] |
| 31 | + sys.modules["policyengine_us_data"] = package |
| 32 | + |
| 33 | + datasets_package = types.ModuleType("policyengine_us_data.datasets") |
| 34 | + datasets_package.__path__ = [str(PACKAGE_ROOT / "datasets")] |
| 35 | + sys.modules["policyengine_us_data.datasets"] = datasets_package |
| 36 | + |
| 37 | + puf_package = types.ModuleType("policyengine_us_data.datasets.puf") |
| 38 | + puf_package.__path__ = [str(PACKAGE_ROOT / "datasets" / "puf")] |
| 39 | + sys.modules["policyengine_us_data.datasets.puf"] = puf_package |
| 40 | + |
| 41 | + storage_spec = importlib.util.spec_from_file_location( |
| 42 | + "policyengine_us_data.storage", |
| 43 | + PACKAGE_ROOT / "storage" / "__init__.py", |
| 44 | + submodule_search_locations=[str(PACKAGE_ROOT / "storage")], |
| 45 | + ) |
| 46 | + storage_module = importlib.util.module_from_spec(storage_spec) |
| 47 | + assert storage_spec.loader is not None |
| 48 | + sys.modules["policyengine_us_data.storage"] = storage_module |
| 49 | + storage_spec.loader.exec_module(storage_module) |
| 50 | + storage_module.STORAGE_FOLDER = storage_root |
| 51 | + storage_module.CALIBRATION_FOLDER = storage_root / "calibration_targets" |
| 52 | + |
| 53 | + uprate_spec = importlib.util.spec_from_file_location( |
| 54 | + "policyengine_us_data.datasets.puf.uprate_puf", |
| 55 | + PACKAGE_ROOT / "datasets" / "puf" / "uprate_puf.py", |
| 56 | + ) |
| 57 | + uprate_module = importlib.util.module_from_spec(uprate_spec) |
| 58 | + assert uprate_spec.loader is not None |
| 59 | + sys.modules["policyengine_us_data.datasets.puf.uprate_puf"] = uprate_module |
| 60 | + uprate_spec.loader.exec_module(uprate_module) |
| 61 | + yield uprate_module |
| 62 | + finally: |
| 63 | + for name in module_names: |
| 64 | + sys.modules.pop(name, None) |
| 65 | + for name, module in original_modules.items(): |
| 66 | + if module is not None: |
| 67 | + sys.modules[name] = module |
| 68 | + |
| 69 | + |
| 70 | +def write_soi_targets(path: Path) -> None: |
| 71 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 72 | + pd.DataFrame( |
| 73 | + [ |
| 74 | + { |
| 75 | + "Year": 2021, |
| 76 | + "Variable": "employment_income", |
| 77 | + "Filing status": "All", |
| 78 | + "AGI lower bound": float("-inf"), |
| 79 | + "AGI upper bound": float("inf"), |
| 80 | + "Count": False, |
| 81 | + "Taxable only": False, |
| 82 | + "Full population": True, |
| 83 | + "Value": 200.0, |
| 84 | + }, |
| 85 | + { |
| 86 | + "Year": 2021, |
| 87 | + "Variable": "count", |
| 88 | + "Filing status": "All", |
| 89 | + "AGI lower bound": float("-inf"), |
| 90 | + "AGI upper bound": float("inf"), |
| 91 | + "Count": True, |
| 92 | + "Taxable only": False, |
| 93 | + "Full population": True, |
| 94 | + "Value": 100.0, |
| 95 | + }, |
| 96 | + ] |
| 97 | + ).to_csv(path, index=False) |
| 98 | + |
| 99 | + |
| 100 | +def test_get_soi_aggregate_reads_tracked_soi_targets(tmp_path: Path): |
| 101 | + write_soi_targets(tmp_path / "calibration_targets" / "soi_targets.csv") |
| 102 | + with load_uprate_puf_module(tmp_path) as module: |
| 103 | + assert module.get_soi_aggregate("employment_income", 2021, False) == 200.0 |
| 104 | + assert module.get_soi_aggregate("count", 2021, True) == 100.0 |
| 105 | + |
| 106 | + |
| 107 | +def test_get_soi_aggregate_raises_clear_error_when_missing(tmp_path: Path): |
| 108 | + with load_uprate_puf_module(tmp_path) as module: |
| 109 | + with pytest.raises(FileNotFoundError, match="No SOI aggregate file found at"): |
| 110 | + module.load_soi_aggregates() |
0 commit comments