From ce5a065929fbd5e9c5ab3adaebb58ca8ea5df3bd Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Fri, 17 Apr 2026 08:37:29 -0400 Subject: [PATCH] Add defensive load_weights helper to normalise calibration weight shape --- changelog.d/weight-shape-load-helper.added.md | 1 + .../tests/test_load_weights.py | 87 +++++++++++++++++++ policyengine_uk_data/utils/calibrate.py | 74 ++++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 changelog.d/weight-shape-load-helper.added.md create mode 100644 policyengine_uk_data/tests/test_load_weights.py diff --git a/changelog.d/weight-shape-load-helper.added.md b/changelog.d/weight-shape-load-helper.added.md new file mode 100644 index 000000000..d3e4ca1e2 --- /dev/null +++ b/changelog.d/weight-shape-load-helper.added.md @@ -0,0 +1 @@ +Add `policyengine_uk_data.utils.calibrate.load_weights`, a defensive loader that normalises calibration weights to 2D `(n_areas, n_records)` and validates expected shapes so consumers can't silently read the wrong axis layout across the L2 and L0 calibrators. diff --git a/policyengine_uk_data/tests/test_load_weights.py b/policyengine_uk_data/tests/test_load_weights.py new file mode 100644 index 000000000..e8218a9b1 --- /dev/null +++ b/policyengine_uk_data/tests/test_load_weights.py @@ -0,0 +1,87 @@ +"""Tests for `policyengine_uk_data.utils.calibrate.load_weights`. + +Adds a defensive loader that normalises shape across the two calibrator +back-ends that have lived in this module (2D L2 and flat L0), so downstream +consumers cannot silently read the wrong axis layout (bug-hunt finding U4). +""" + +from __future__ import annotations + +import importlib.util + +import numpy as np +import pytest + +if importlib.util.find_spec("h5py") is None: + pytest.skip("h5py not installed", allow_module_level=True) + +import h5py # noqa: E402 + + +def _write_h5(path, key: str, data: np.ndarray): + with h5py.File(path, "w") as f: + f.create_dataset(key, data=data) + + +def test_load_weights_returns_2d_for_2d_input(tmp_path): + from policyengine_uk_data.utils.calibrate import load_weights + + weights = np.arange(6, dtype=float).reshape(2, 3) + path = tmp_path / "w.h5" + _write_h5(path, "2025", weights) + + out = load_weights(path, dataset_key="2025") + assert out.shape == (2, 3) + np.testing.assert_allclose(out, weights) + + +def test_load_weights_promotes_1d_input_to_2d(tmp_path): + from policyengine_uk_data.utils.calibrate import load_weights + + flat = np.arange(4, dtype=float) + path = tmp_path / "w.h5" + _write_h5(path, "2025", flat) + + out = load_weights(path, dataset_key="2025") + # Flat inputs become (1, n_records) so .sum(axis=0) still yields the + # same vector and downstream matrix ops stay consistent. + assert out.shape == (1, 4) + np.testing.assert_allclose(out[0], flat) + + +def test_load_weights_checks_expected_shapes(tmp_path): + from policyengine_uk_data.utils.calibrate import load_weights + + weights = np.ones((3, 5), dtype=float) + path = tmp_path / "w.h5" + _write_h5(path, "2025", weights) + + # Correct expected dims → no exception. + load_weights(path, dataset_key="2025", n_areas=3, n_records=5) + + with pytest.raises(ValueError, match="areas"): + load_weights(path, dataset_key="2025", n_areas=4, n_records=5) + with pytest.raises(ValueError, match="records"): + load_weights(path, dataset_key="2025", n_areas=3, n_records=999) + + +def test_load_weights_missing_key_raises(tmp_path): + from policyengine_uk_data.utils.calibrate import load_weights + + weights = np.ones((2, 2), dtype=float) + path = tmp_path / "w.h5" + _write_h5(path, "2025", weights) + + with pytest.raises(KeyError, match="not found"): + load_weights(path, dataset_key="2099") + + +def test_load_weights_rejects_higher_dim_input(tmp_path): + from policyengine_uk_data.utils.calibrate import load_weights + + weights = np.ones((2, 2, 2), dtype=float) + path = tmp_path / "w.h5" + _write_h5(path, "2025", weights) + + with pytest.raises(ValueError, match="1D or 2D"): + load_weights(path, dataset_key="2025") diff --git a/policyengine_uk_data/utils/calibrate.py b/policyengine_uk_data/utils/calibrate.py index ec4c807e0..c48a6a778 100644 --- a/policyengine_uk_data/utils/calibrate.py +++ b/policyengine_uk_data/utils/calibrate.py @@ -1,4 +1,6 @@ from contextlib import nullcontext +from pathlib import Path +from typing import Optional, Union import torch import pandas as pd @@ -9,6 +11,78 @@ from policyengine_uk_data.utils.progress import ProcessingProgress +def load_weights( + weight_file: Union[str, Path], + dataset_key: str = "2025", + n_areas: Optional[int] = None, + n_records: Optional[int] = None, +) -> np.ndarray: + """Load calibration weights from an h5 file and normalise their shape. + + Two calibration back-ends exist in this repo's history: the L2 + calibrator in `calibrate_local_areas` (this module) saves weights as a + 2D ``(n_areas, n_records)`` array, while the L0-regularised variant + (when present) sometimes saves a flat 1D ``(n_records,)`` array under + the same dataset key. Consumers that are not careful about axes can + therefore silently read the wrong shape. + + This helper centralises loading and always returns a 2D + ``(n_areas, n_records)`` array. A 1D input is reshaped to + ``(1, n_records)`` so downstream ``.sum(axis=0)`` and matrix-multiply + operations behave consistently. Optional ``n_areas`` / ``n_records`` + arguments raise a clear ``ValueError`` on shape mismatch instead of + silently producing wrong answers. + + Args: + weight_file: Path to the h5 file written by a calibrator. If the + path is not absolute it is resolved relative to the package + ``STORAGE_FOLDER``. + dataset_key: H5 dataset key to read. + n_areas: Optional expected number of areas (first axis). When + provided, a 1D input is reshaped and its length checked; a 2D + input has its first axis checked. + n_records: Optional expected number of records (second axis). + Checked against the final axis of the loaded array. + + Returns: + A 2D ``(n_areas, n_records)`` numpy array. + """ + path = Path(weight_file) + if not path.is_absolute(): + path = STORAGE_FOLDER / path + + with h5py.File(path, "r") as f: + if dataset_key not in f: + available = ", ".join(sorted(f.keys())) + raise KeyError( + f"Dataset key {dataset_key!r} not found in {path}; " + f"available keys: {available}" + ) + arr = f[dataset_key][:] + + if arr.ndim == 1: + # Flat (n_records,) layout — promote to (1, n_records) so callers + # can treat all weights as a 2D matrix. + arr = arr.reshape(1, -1) + elif arr.ndim != 2: + raise ValueError( + f"Expected weights at {dataset_key!r} in {path} to be 1D or 2D; " + f"got shape {arr.shape}" + ) + + if n_areas is not None and arr.shape[0] != n_areas: + raise ValueError( + f"Weights at {dataset_key!r} in {path} have {arr.shape[0]} areas, " + f"expected {n_areas}" + ) + if n_records is not None and arr.shape[-1] != n_records: + raise ValueError( + f"Weights at {dataset_key!r} in {path} have {arr.shape[-1]} " + f"records, expected {n_records}" + ) + return arr + + def calibrate_local_areas( dataset: UKSingleYearDataset, matrix_fn,