Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/weight-shape-load-helper.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `policyengine_uk_data.utils.calibrate.load_weights`, a defensive loader that normalises calibration weights to 2D `(n_areas, n_records)` and validates expected shapes so consumers can't silently read the wrong axis layout across the L2 and L0 calibrators.
87 changes: 87 additions & 0 deletions policyengine_uk_data/tests/test_load_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Tests for `policyengine_uk_data.utils.calibrate.load_weights`.

Adds a defensive loader that normalises shape across the two calibrator
back-ends that have lived in this module (2D L2 and flat L0), so downstream
consumers cannot silently read the wrong axis layout (bug-hunt finding U4).
"""

from __future__ import annotations

import importlib.util

import numpy as np
import pytest

if importlib.util.find_spec("h5py") is None:
pytest.skip("h5py not installed", allow_module_level=True)

import h5py # noqa: E402


def _write_h5(path, key: str, data: np.ndarray):
with h5py.File(path, "w") as f:
f.create_dataset(key, data=data)


def test_load_weights_returns_2d_for_2d_input(tmp_path):
from policyengine_uk_data.utils.calibrate import load_weights

weights = np.arange(6, dtype=float).reshape(2, 3)
path = tmp_path / "w.h5"
_write_h5(path, "2025", weights)

out = load_weights(path, dataset_key="2025")
assert out.shape == (2, 3)
np.testing.assert_allclose(out, weights)


def test_load_weights_promotes_1d_input_to_2d(tmp_path):
from policyengine_uk_data.utils.calibrate import load_weights

flat = np.arange(4, dtype=float)
path = tmp_path / "w.h5"
_write_h5(path, "2025", flat)

out = load_weights(path, dataset_key="2025")
# Flat inputs become (1, n_records) so .sum(axis=0) still yields the
# same vector and downstream matrix ops stay consistent.
assert out.shape == (1, 4)
np.testing.assert_allclose(out[0], flat)


def test_load_weights_checks_expected_shapes(tmp_path):
from policyengine_uk_data.utils.calibrate import load_weights

weights = np.ones((3, 5), dtype=float)
path = tmp_path / "w.h5"
_write_h5(path, "2025", weights)

# Correct expected dims → no exception.
load_weights(path, dataset_key="2025", n_areas=3, n_records=5)

with pytest.raises(ValueError, match="areas"):
load_weights(path, dataset_key="2025", n_areas=4, n_records=5)
with pytest.raises(ValueError, match="records"):
load_weights(path, dataset_key="2025", n_areas=3, n_records=999)


def test_load_weights_missing_key_raises(tmp_path):
from policyengine_uk_data.utils.calibrate import load_weights

weights = np.ones((2, 2), dtype=float)
path = tmp_path / "w.h5"
_write_h5(path, "2025", weights)

with pytest.raises(KeyError, match="not found"):
load_weights(path, dataset_key="2099")


def test_load_weights_rejects_higher_dim_input(tmp_path):
from policyengine_uk_data.utils.calibrate import load_weights

weights = np.ones((2, 2, 2), dtype=float)
path = tmp_path / "w.h5"
_write_h5(path, "2025", weights)

with pytest.raises(ValueError, match="1D or 2D"):
load_weights(path, dataset_key="2025")
74 changes: 74 additions & 0 deletions policyengine_uk_data/utils/calibrate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from contextlib import nullcontext
from pathlib import Path
from typing import Optional, Union

import torch
import pandas as pd
Expand All @@ -9,6 +11,78 @@
from policyengine_uk_data.utils.progress import ProcessingProgress


def load_weights(
weight_file: Union[str, Path],
dataset_key: str = "2025",
n_areas: Optional[int] = None,
n_records: Optional[int] = None,
) -> np.ndarray:
"""Load calibration weights from an h5 file and normalise their shape.

Two calibration back-ends exist in this repo's history: the L2
calibrator in `calibrate_local_areas` (this module) saves weights as a
2D ``(n_areas, n_records)`` array, while the L0-regularised variant
(when present) sometimes saves a flat 1D ``(n_records,)`` array under
the same dataset key. Consumers that are not careful about axes can
therefore silently read the wrong shape.

This helper centralises loading and always returns a 2D
``(n_areas, n_records)`` array. A 1D input is reshaped to
``(1, n_records)`` so downstream ``.sum(axis=0)`` and matrix-multiply
operations behave consistently. Optional ``n_areas`` / ``n_records``
arguments raise a clear ``ValueError`` on shape mismatch instead of
silently producing wrong answers.

Args:
weight_file: Path to the h5 file written by a calibrator. If the
path is not absolute it is resolved relative to the package
``STORAGE_FOLDER``.
dataset_key: H5 dataset key to read.
n_areas: Optional expected number of areas (first axis). When
provided, a 1D input is reshaped and its length checked; a 2D
input has its first axis checked.
n_records: Optional expected number of records (second axis).
Checked against the final axis of the loaded array.

Returns:
A 2D ``(n_areas, n_records)`` numpy array.
"""
path = Path(weight_file)
if not path.is_absolute():
path = STORAGE_FOLDER / path

with h5py.File(path, "r") as f:
if dataset_key not in f:
available = ", ".join(sorted(f.keys()))
raise KeyError(
f"Dataset key {dataset_key!r} not found in {path}; "
f"available keys: {available}"
)
arr = f[dataset_key][:]

if arr.ndim == 1:
# Flat (n_records,) layout — promote to (1, n_records) so callers
# can treat all weights as a 2D matrix.
arr = arr.reshape(1, -1)
elif arr.ndim != 2:
raise ValueError(
f"Expected weights at {dataset_key!r} in {path} to be 1D or 2D; "
f"got shape {arr.shape}"
)

if n_areas is not None and arr.shape[0] != n_areas:
raise ValueError(
f"Weights at {dataset_key!r} in {path} have {arr.shape[0]} areas, "
f"expected {n_areas}"
)
if n_records is not None and arr.shape[-1] != n_records:
raise ValueError(
f"Weights at {dataset_key!r} in {path} have {arr.shape[-1]} "
f"records, expected {n_records}"
)
return arr


def calibrate_local_areas(
dataset: UKSingleYearDataset,
matrix_fn,
Expand Down
Loading