Skip to content

Commit 1333092

Browse files
committed
Add compact PolicyEngine dataset export
1 parent 9da1ab4 commit 1333092

3 files changed

Lines changed: 469 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ microplex-us-build-aca-ptc-multipliers = "microplex_us.targets.aca_ptc:main"
4646
microplex-us-backfill-pe-native-audit = "microplex_us.pipelines.backfill_pe_native_audit:main"
4747
microplex-us-backfill-pe-native-scores = "microplex_us.pipelines.backfill_pe_native_scores:main"
4848
microplex-us-check-site-snapshot = "microplex_us.pipelines.check_site_snapshot:main"
49+
microplex-us-compact-policyengine-dataset = "microplex_us.pipelines.compact_policyengine_dataset:main"
4950
microplex-us-mp300k-artifact-gates = "microplex_us.pipelines.mp300k_artifact_gates:main"
5051
microplex-us-package-mp300k-gate-inputs = "microplex_us.pipelines.mp300k_gate_inputs:main"
5152
microplex-us-pe-dataset-readiness = "microplex_us.pipelines.pe_us_dataset_readiness:main"
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
"""Compact PolicyEngine time-period H5 datasets by household weight."""
2+
3+
from __future__ import annotations
4+
5+
import argparse
6+
import json
7+
from datetime import UTC, datetime
8+
from pathlib import Path
9+
from typing import Any
10+
11+
import h5py
12+
import numpy as np
13+
14+
ENTITY_ID_VARIABLES = {
15+
"household": "household_id",
16+
"person": "person_id",
17+
"tax_unit": "tax_unit_id",
18+
"spm_unit": "spm_unit_id",
19+
"family": "family_id",
20+
"marital_unit": "marital_unit_id",
21+
}
22+
23+
PERSON_ENTITY_LINK_VARIABLES = {
24+
"household": "person_household_id",
25+
"tax_unit": "person_tax_unit_id",
26+
"spm_unit": "person_spm_unit_id",
27+
"family": "person_family_id",
28+
"marital_unit": "person_marital_unit_id",
29+
}
30+
31+
STRUCTURAL_VARIABLE_ENTITIES = {
32+
"household_id": "household",
33+
"household_weight": "household",
34+
"person_id": "person",
35+
"person_household_id": "person",
36+
"person_tax_unit_id": "person",
37+
"person_spm_unit_id": "person",
38+
"person_family_id": "person",
39+
"person_marital_unit_id": "person",
40+
"person_weight": "person",
41+
"tax_unit_id": "tax_unit",
42+
"spm_unit_id": "spm_unit",
43+
"family_id": "family",
44+
"marital_unit_id": "marital_unit",
45+
}
46+
47+
48+
def compact_policyengine_dataset_by_household_weight(
49+
*,
50+
input_dataset_path: str | Path,
51+
output_dataset_path: str | Path,
52+
households: int,
53+
period: int = 2024,
54+
weights_path: str | Path | None = None,
55+
rescale_to_total: bool = True,
56+
target_total_weight: float | None = None,
57+
) -> dict[str, Any]:
58+
"""Write a household-subset PE H5, keeping the largest household weights."""
59+
60+
input_path = Path(input_dataset_path).expanduser()
61+
output_path = Path(output_dataset_path).expanduser()
62+
if households <= 0:
63+
raise ValueError("households must be positive")
64+
65+
period_key = str(period)
66+
with h5py.File(input_path, "r") as source:
67+
household_ids = _period_array(source, "household_id", period_key)
68+
source_household_weights = np.asarray(
69+
_period_array(source, "household_weight", period_key),
70+
dtype=np.float64,
71+
)
72+
if household_ids.shape[0] != source_household_weights.shape[0]:
73+
raise ValueError("household_id and household_weight lengths differ")
74+
75+
selection_weights = (
76+
np.load(Path(weights_path).expanduser()).astype(np.float64)
77+
if weights_path is not None
78+
else source_household_weights
79+
)
80+
if selection_weights.ndim != 1:
81+
raise ValueError("selection weights must be a one-dimensional array")
82+
if selection_weights.shape[0] != household_ids.shape[0]:
83+
raise ValueError(
84+
"selection weights length does not match household_id length: "
85+
f"{selection_weights.shape[0]} vs {household_ids.shape[0]}"
86+
)
87+
if households > household_ids.shape[0]:
88+
raise ValueError(
89+
"households cannot exceed source household count: "
90+
f"{households} > {household_ids.shape[0]}"
91+
)
92+
93+
selected_by_weight = np.argsort(-selection_weights, kind="stable")[:households]
94+
selected_source_order = np.sort(selected_by_weight)
95+
selected_household_ids = household_ids[selected_source_order]
96+
selected_weights = source_household_weights[selected_source_order].astype(
97+
np.float64,
98+
copy=True,
99+
)
100+
original_selected_weight_sum = float(selected_weights.sum())
101+
resolved_target_total = (
102+
float(target_total_weight)
103+
if target_total_weight is not None
104+
else float(source_household_weights.sum())
105+
)
106+
if rescale_to_total:
107+
if original_selected_weight_sum <= 0:
108+
raise ValueError("selected household weights sum to zero")
109+
selected_weights *= resolved_target_total / original_selected_weight_sum
110+
111+
metadata = _build_metadata(source, period_key)
112+
masks = _build_entity_masks(metadata, selected_household_ids)
113+
_write_compacted_dataset(
114+
source,
115+
output_path,
116+
period_key=period_key,
117+
metadata=metadata,
118+
masks=masks,
119+
)
120+
121+
with h5py.File(output_path, "r+") as output:
122+
weight_dataset = output["household_weight"][period_key]
123+
weight_dataset[...] = selected_weights.astype(weight_dataset.dtype)
124+
entity_counts = {
125+
entity: int(len(output[variable][period_key]))
126+
for entity, variable in ENTITY_ID_VARIABLES.items()
127+
if variable in output and period_key in output[variable]
128+
}
129+
output_weight_sum = float(
130+
np.asarray(output["household_weight"][period_key], dtype=np.float64).sum()
131+
)
132+
133+
summary = {
134+
"schema_version": 1,
135+
"created_at": datetime.now(UTC).isoformat(),
136+
"input_dataset": str(input_path.resolve()),
137+
"output_dataset": str(output_path.resolve()),
138+
"period": int(period),
139+
"selection_method": "largest_household_weight",
140+
"source_households": int(household_ids.shape[0]),
141+
"selected_households": int(households),
142+
"source_weight_sum": float(source_household_weights.sum()),
143+
"selected_weight_sum_before_rescale": original_selected_weight_sum,
144+
"output_weight_sum": output_weight_sum,
145+
"target_total_weight": resolved_target_total if rescale_to_total else None,
146+
"rescale_to_total": bool(rescale_to_total),
147+
"selection_weight_min_kept": float(selection_weights[selected_by_weight[-1]]),
148+
"selection_weight_max_kept": float(selection_weights[selected_by_weight[0]]),
149+
"entity_counts": entity_counts,
150+
"source_size_bytes": int(input_path.stat().st_size),
151+
"output_size_bytes": int(output_path.stat().st_size),
152+
"source_size_ratio": float(
153+
output_path.stat().st_size / input_path.stat().st_size
154+
),
155+
}
156+
return summary
157+
158+
159+
def _period_array(source: h5py.File, variable: str, period_key: str) -> np.ndarray:
160+
if variable not in source or period_key not in source[variable]:
161+
raise ValueError(f"{source.filename} is missing {variable}/{period_key}")
162+
return np.asarray(source[variable][period_key])
163+
164+
165+
def _copy_attrs(
166+
source: h5py.Group | h5py.Dataset, destination: h5py.Group | h5py.Dataset
167+
) -> None:
168+
for key, value in source.attrs.items():
169+
destination.attrs[key] = value
170+
171+
172+
def _build_metadata(source: h5py.File, period_key: str) -> dict[str, Any]:
173+
entity_ids = {
174+
entity: _period_array(source, variable, period_key)
175+
for entity, variable in ENTITY_ID_VARIABLES.items()
176+
if variable in source and period_key in source[variable]
177+
}
178+
person_links = {
179+
entity: _period_array(source, variable, period_key)
180+
for entity, variable in PERSON_ENTITY_LINK_VARIABLES.items()
181+
if variable in source and period_key in source[variable]
182+
}
183+
if "household" not in entity_ids or "person" not in entity_ids:
184+
raise ValueError("input dataset must include household_id and person_id")
185+
if "household" not in person_links:
186+
raise ValueError("input dataset must include person_household_id")
187+
188+
entity_lengths = {entity: int(len(values)) for entity, values in entity_ids.items()}
189+
length_entities: dict[int, list[str]] = {}
190+
for entity, length in entity_lengths.items():
191+
length_entities.setdefault(length, []).append(entity)
192+
193+
policyengine_variable_entities = _load_policyengine_variable_entities()
194+
variable_entities: dict[str, str] = {}
195+
for variable in source.keys():
196+
if period_key not in source[variable]:
197+
continue
198+
dataset = source[variable][period_key]
199+
entity = _infer_variable_entity(
200+
variable,
201+
int(len(dataset)) if dataset.shape else 0,
202+
entity_lengths=entity_lengths,
203+
length_entities=length_entities,
204+
policyengine_variable_entities=policyengine_variable_entities,
205+
)
206+
variable_entities[variable] = entity
207+
208+
return {
209+
"entity_ids": entity_ids,
210+
"person_links": person_links,
211+
"variable_entities": variable_entities,
212+
}
213+
214+
215+
def _infer_variable_entity(
216+
variable: str,
217+
array_length: int,
218+
*,
219+
entity_lengths: dict[str, int],
220+
length_entities: dict[int, list[str]],
221+
policyengine_variable_entities: dict[str, str],
222+
) -> str:
223+
structural_entity = STRUCTURAL_VARIABLE_ENTITIES.get(variable)
224+
if structural_entity is not None:
225+
return structural_entity
226+
227+
policyengine_entity = policyengine_variable_entities.get(variable)
228+
if policyengine_entity in entity_lengths:
229+
return policyengine_entity
230+
231+
matching_entities = length_entities.get(array_length, [])
232+
if len(matching_entities) == 1:
233+
return matching_entities[0]
234+
235+
raise ValueError(
236+
f"Could not infer entity for {variable!r} with length {array_length}; "
237+
f"matches={matching_entities}"
238+
)
239+
240+
241+
def _load_policyengine_variable_entities() -> dict[str, str]:
242+
try:
243+
from policyengine_us import Microsimulation # noqa: PLC0415
244+
except Exception:
245+
return {}
246+
try:
247+
variables = Microsimulation().tax_benefit_system.variables
248+
except Exception:
249+
return {}
250+
return {name: str(definition.entity.key) for name, definition in variables.items()}
251+
252+
253+
def _build_entity_masks(
254+
metadata: dict[str, Any],
255+
selected_household_ids: np.ndarray,
256+
) -> dict[str, np.ndarray]:
257+
household_mask = np.isin(
258+
metadata["entity_ids"]["household"],
259+
selected_household_ids,
260+
)
261+
person_mask = np.isin(
262+
metadata["person_links"]["household"],
263+
selected_household_ids,
264+
)
265+
masks = {"household": household_mask, "person": person_mask}
266+
for entity in ("tax_unit", "spm_unit", "family", "marital_unit"):
267+
if entity not in metadata["entity_ids"]:
268+
continue
269+
if entity not in metadata["person_links"]:
270+
raise ValueError(
271+
f"input dataset includes {ENTITY_ID_VARIABLES[entity]} but lacks "
272+
f"{PERSON_ENTITY_LINK_VARIABLES[entity]}"
273+
)
274+
selected_entity_ids = np.unique(metadata["person_links"][entity][person_mask])
275+
masks[entity] = np.isin(metadata["entity_ids"][entity], selected_entity_ids)
276+
return masks
277+
278+
279+
def _write_compacted_dataset(
280+
source: h5py.File,
281+
output_path: Path,
282+
*,
283+
period_key: str,
284+
metadata: dict[str, Any],
285+
masks: dict[str, np.ndarray],
286+
) -> None:
287+
output_path.parent.mkdir(parents=True, exist_ok=True)
288+
with h5py.File(output_path, "w") as output:
289+
_copy_attrs(source, output)
290+
for variable in source.keys():
291+
if period_key not in source[variable]:
292+
continue
293+
entity = metadata["variable_entities"][variable]
294+
group = output.create_group(variable)
295+
_copy_attrs(source[variable], group)
296+
for source_period_key in source[variable].keys():
297+
dataset = source[variable][source_period_key]
298+
values = np.asarray(dataset)
299+
if values.shape:
300+
values = values[masks[entity]]
301+
output_dataset = group.create_dataset(source_period_key, data=values)
302+
_copy_attrs(dataset, output_dataset)
303+
304+
305+
def main(argv: list[str] | None = None) -> int:
306+
parser = argparse.ArgumentParser(
307+
description="Compact a PolicyEngine US H5 by keeping top household weights."
308+
)
309+
parser.add_argument("--input-dataset", required=True)
310+
parser.add_argument("--output-dataset", required=True)
311+
parser.add_argument("--households", type=int, required=True)
312+
parser.add_argument("--period", type=int, default=2024)
313+
parser.add_argument("--weights-npy")
314+
parser.add_argument("--target-total-weight", type=float)
315+
parser.add_argument("--no-rescale", action="store_true")
316+
parser.add_argument("--summary-json")
317+
args = parser.parse_args(argv)
318+
319+
summary = compact_policyengine_dataset_by_household_weight(
320+
input_dataset_path=args.input_dataset,
321+
output_dataset_path=args.output_dataset,
322+
households=args.households,
323+
period=args.period,
324+
weights_path=args.weights_npy,
325+
rescale_to_total=not args.no_rescale,
326+
target_total_weight=args.target_total_weight,
327+
)
328+
if args.summary_json:
329+
summary_path = Path(args.summary_json).expanduser()
330+
summary_path.parent.mkdir(parents=True, exist_ok=True)
331+
summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True))
332+
print(summary_path)
333+
else:
334+
print(json.dumps(summary, indent=2, sort_keys=True))
335+
return 0

0 commit comments

Comments
 (0)