Skip to content
This repository was archived by the owner on Jun 19, 2026. It is now read-only.

Commit b99d11d

Browse files
vahid-ahmadiclaude
andcommitted
Add cross-year smoothness penalty to calibrate_local_areas (#345 step 5)
Adds an opt-in log-space L2 penalty to the training loss in `calibrate_local_areas` that pulls the optimised weights towards a prior year's weights. This is the regulariser that makes a sequence of per-year calibrations statistically coherent as a panel — without it, the same household can represent, say, 500 units in 2024 and 50 in 2025. Design choices: - The penalty is factored out into a pure helper `compute_log_weight_smoothness_penalty(log_weights, prior_weights)` so it can be unit-tested thoroughly. Entries where the prior is zero (households outside an area's country) are excluded from the mean so they neither pull nor inflate the penalty. - `calibrate_local_areas` gains two keyword-only kwargs, `prior_weights` and `smoothness_penalty`, both defaulting to values that reproduce the pre-step-5 training loop exactly. - Shape mismatches raise a clear `ValueError` rather than failing deep inside the optimiser. - The penalty is computed from the underlying log-space weights (not the dropout-augmented tensor fed into the fit loss) so the regulariser does not double-count the dropout noise. Tests (15 new, all in two files): - 10 unit tests on the helper covering zero-when-equal, quadratic scaling, masking of zero-prior entries, gradient masking, shape validation, symmetric log deviation, differentiability, dtype round-trip and a hand-computed heterogeneous case. - 5 integration tests on `calibrate_local_areas` with a three-household fake dataset: default kwargs reproduce pre-step-5 behaviour, shape mismatch raises, `None` prior + penalty is a no-op, zero penalty + prior is a no-op, and a large penalty measurably pulls weights towards the prior versus a no-smoothness run. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8c3c587 commit b99d11d

4 files changed

Lines changed: 386 additions & 1 deletion

File tree

changelog.d/345.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Add panel ID contract, `create_yearly_snapshots` helper, `age_dataset` demographic ageing module and year-aware loss matrices with a documented `resolve_target_value` fallback policy as the first four steps towards per-year snapshots (#345).
1+
Add panel ID contract, `create_yearly_snapshots` helper, `age_dataset` demographic ageing module, year-aware loss matrices with a documented `resolve_target_value` fallback policy, and a cross-year smoothness penalty on `calibrate_local_areas` as the first five steps towards per-year snapshots (#345).
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""Integration tests for the smoothness-penalty wiring in calibrate_local_areas.
2+
3+
The unit tests for ``compute_log_weight_smoothness_penalty`` live in
4+
``test_smoothness_penalty.py``. The tests here exercise the surrounding
5+
plumbing: validation of the new kwargs, that default behaviour is
6+
unchanged, and that a large penalty actually pulls the optimised weights
7+
towards the prior.
8+
"""
9+
10+
from pathlib import Path
11+
12+
import numpy as np
13+
import pandas as pd
14+
import pytest
15+
from policyengine_uk.data import UKSingleYearDataset
16+
17+
from policyengine_uk_data.utils import calibrate as calibrate_module
18+
from policyengine_uk_data.utils.calibrate import calibrate_local_areas
19+
20+
21+
# ---------------------------------------------------------------------------
22+
# Fixtures
23+
# ---------------------------------------------------------------------------
24+
25+
26+
def _tiny_dataset() -> UKSingleYearDataset:
27+
"""Three-household dataset just big enough for calibration shapes."""
28+
household = pd.DataFrame(
29+
{
30+
"household_id": [1, 2, 3],
31+
"household_weight": [1000.0, 1000.0, 1000.0],
32+
}
33+
)
34+
benunit = pd.DataFrame({"benunit_id": [101, 201, 301]})
35+
person = pd.DataFrame(
36+
{
37+
"person_id": [1001, 2001, 3001],
38+
"person_benunit_id": [101, 201, 301],
39+
"person_household_id": [1, 2, 3],
40+
"age": [30, 40, 50],
41+
}
42+
)
43+
return UKSingleYearDataset(
44+
person=person, benunit=benunit, household=household, fiscal_year=2025
45+
)
46+
47+
48+
AREA_COUNT = 2
49+
50+
51+
def _fake_local_matrix(dataset):
52+
"""Two areas, three households, one target per area.
53+
54+
Each target is the sum of household_weight over the households in
55+
that area. With default initial weights the target is easy to learn.
56+
"""
57+
matrix = pd.DataFrame({"pop/area_size": [1.0, 1.0, 1.0]})
58+
y = pd.DataFrame({"pop/area_size": [3000.0, 3000.0]})
59+
# Simple country mask: both areas include all households.
60+
r = np.ones((AREA_COUNT, 3))
61+
return matrix, y, r
62+
63+
64+
def _fake_national_matrix(dataset):
65+
matrix = pd.DataFrame({"pop/national": [1.0, 1.0, 1.0]})
66+
y = pd.DataFrame({"pop/national": [6000.0]})
67+
return matrix, y
68+
69+
70+
@pytest.fixture
71+
def patched_storage(tmp_path: Path, monkeypatch):
72+
"""Redirect the hard-coded STORAGE_FOLDER write in calibrate.py."""
73+
monkeypatch.setattr(calibrate_module, "STORAGE_FOLDER", tmp_path)
74+
return tmp_path
75+
76+
77+
# ---------------------------------------------------------------------------
78+
# Tests
79+
# ---------------------------------------------------------------------------
80+
81+
82+
def test_default_kwargs_reproduce_pre_step5_behaviour(patched_storage):
83+
"""No prior + zero penalty ⇒ the smoothness branch must be inert."""
84+
# NB: calibrate_local_areas only flushes the weight file when the
85+
# final epoch index is a multiple of 10 (the function saves every 10
86+
# epochs). Use 11 epochs so the final epoch = 10 triggers a save.
87+
np.random.seed(0)
88+
import torch
89+
90+
torch.manual_seed(0)
91+
calibrate_local_areas(
92+
dataset=_tiny_dataset(),
93+
matrix_fn=_fake_local_matrix,
94+
national_matrix_fn=_fake_national_matrix,
95+
area_count=AREA_COUNT,
96+
weight_file="test_weights.h5",
97+
epochs=11,
98+
)
99+
assert (patched_storage / "test_weights.h5").exists()
100+
101+
102+
def test_shape_mismatch_in_prior_raises(patched_storage):
103+
bogus_prior = np.ones((AREA_COUNT, 99)) # wrong household count
104+
with pytest.raises(ValueError, match="prior_weights shape"):
105+
calibrate_local_areas(
106+
dataset=_tiny_dataset(),
107+
matrix_fn=_fake_local_matrix,
108+
national_matrix_fn=_fake_national_matrix,
109+
area_count=AREA_COUNT,
110+
weight_file="test_weights.h5",
111+
epochs=1,
112+
prior_weights=bogus_prior,
113+
smoothness_penalty=1.0,
114+
)
115+
116+
117+
def test_none_prior_with_penalty_is_noop(patched_storage):
118+
"""A penalty coefficient without a prior must not crash."""
119+
calibrate_local_areas(
120+
dataset=_tiny_dataset(),
121+
matrix_fn=_fake_local_matrix,
122+
national_matrix_fn=_fake_national_matrix,
123+
area_count=AREA_COUNT,
124+
weight_file="test_weights.h5",
125+
epochs=1,
126+
prior_weights=None,
127+
smoothness_penalty=10.0,
128+
)
129+
130+
131+
def test_zero_penalty_with_prior_is_noop(patched_storage):
132+
"""A prior without a penalty coefficient must not crash either."""
133+
prior = np.ones((AREA_COUNT, 3)) * 500.0
134+
calibrate_local_areas(
135+
dataset=_tiny_dataset(),
136+
matrix_fn=_fake_local_matrix,
137+
national_matrix_fn=_fake_national_matrix,
138+
area_count=AREA_COUNT,
139+
weight_file="test_weights.h5",
140+
epochs=1,
141+
prior_weights=prior,
142+
smoothness_penalty=0.0,
143+
)
144+
145+
146+
def test_large_penalty_keeps_weights_near_prior(patched_storage):
147+
"""With a huge penalty, the optimised weights should stay near the prior."""
148+
import h5py
149+
150+
# Prior that is deliberately far from what the fit-loss alone would
151+
# drive us to (fit alone wants ~1000 per household per area to match
152+
# the area target; this prior has 10x larger values).
153+
prior = np.ones((AREA_COUNT, 3)) * 10_000.0
154+
155+
np.random.seed(0)
156+
import torch
157+
158+
torch.manual_seed(0)
159+
calibrate_local_areas(
160+
dataset=_tiny_dataset(),
161+
matrix_fn=_fake_local_matrix,
162+
national_matrix_fn=_fake_national_matrix,
163+
area_count=AREA_COUNT,
164+
weight_file="with_smoothness.h5",
165+
# 21 epochs ⇒ final index 20 is a multiple of 10 → save triggers.
166+
epochs=21,
167+
prior_weights=prior,
168+
smoothness_penalty=1e6,
169+
)
170+
171+
with h5py.File(patched_storage / "with_smoothness.h5", "r") as f:
172+
final_with = np.array(f["2025"])
173+
174+
# And the same run without the smoothness penalty.
175+
np.random.seed(0)
176+
torch.manual_seed(0)
177+
calibrate_local_areas(
178+
dataset=_tiny_dataset(),
179+
matrix_fn=_fake_local_matrix,
180+
national_matrix_fn=_fake_national_matrix,
181+
area_count=AREA_COUNT,
182+
weight_file="without_smoothness.h5",
183+
epochs=21,
184+
)
185+
186+
with h5py.File(patched_storage / "without_smoothness.h5", "r") as f:
187+
final_without = np.array(f["2025"])
188+
189+
# With the huge penalty, weights should be closer (in log-space) to
190+
# the prior than the no-smoothness run.
191+
log_dev_with = np.mean((np.log(final_with + 1e-8) - np.log(prior)) ** 2)
192+
log_dev_without = np.mean((np.log(final_without + 1e-8) - np.log(prior)) ** 2)
193+
assert log_dev_with < log_dev_without, (
194+
f"Smoothness failed to pull weights towards prior: "
195+
f"with={log_dev_with:.4f} vs without={log_dev_without:.4f}"
196+
)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Tests for the cross-year smoothness penalty (step 5 of #345)."""
2+
3+
import pytest
4+
import torch
5+
6+
from policyengine_uk_data.utils.calibrate import (
7+
compute_log_weight_smoothness_penalty,
8+
)
9+
10+
11+
def test_zero_when_log_weights_match_log_prior():
12+
"""If current weights already equal the prior, the penalty is zero."""
13+
prior = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
14+
log_weights = torch.log(prior)
15+
penalty = compute_log_weight_smoothness_penalty(log_weights, prior)
16+
assert penalty.item() == pytest.approx(0.0)
17+
18+
19+
def test_penalty_scales_with_squared_log_deviation():
20+
"""A log-ratio of ln(2) on every entry → penalty = (ln 2)**2."""
21+
prior = torch.ones(3, 4)
22+
# log_weights = log(2 * prior) = log(2)
23+
log_weights = torch.full((3, 4), float(torch.log(torch.tensor(2.0))))
24+
penalty = compute_log_weight_smoothness_penalty(log_weights, prior)
25+
assert penalty.item() == pytest.approx(
26+
float(torch.log(torch.tensor(2.0))) ** 2, rel=1e-6
27+
)
28+
29+
30+
def test_zero_prior_entries_are_excluded_from_mean():
31+
"""Households outside an area's country (prior == 0) must not inflate the penalty."""
32+
prior = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
33+
log_weights = torch.zeros_like(prior) # log(1) on the valid entries
34+
penalty = compute_log_weight_smoothness_penalty(log_weights, prior)
35+
# Only two entries are valid and both match the prior → penalty is zero.
36+
assert penalty.item() == pytest.approx(0.0)
37+
38+
39+
def test_zero_prior_entries_do_not_pull_gradient():
40+
"""Gradient w.r.t. a masked-out entry must be exactly zero."""
41+
prior = torch.tensor([[0.0, 2.0]])
42+
log_weights = torch.tensor([[100.0, 0.0]], requires_grad=True)
43+
penalty = compute_log_weight_smoothness_penalty(log_weights, prior)
44+
penalty.backward()
45+
# First entry is masked out → grad should be zero regardless of value.
46+
assert log_weights.grad[0, 0].item() == pytest.approx(0.0)
47+
# Second entry pulled towards log(2).
48+
assert log_weights.grad[0, 1].item() != 0.0
49+
50+
51+
def test_all_zero_prior_returns_zero_without_nan():
52+
"""No valid entries → zero, not NaN."""
53+
prior = torch.zeros(2, 2)
54+
log_weights = torch.randn(2, 2)
55+
penalty = compute_log_weight_smoothness_penalty(log_weights, prior)
56+
assert penalty.item() == 0.0
57+
assert not torch.isnan(penalty)
58+
59+
60+
def test_shape_mismatch_raises_valueerror():
61+
prior = torch.ones(3, 4)
62+
log_weights = torch.zeros(3, 5)
63+
with pytest.raises(ValueError, match="shape"):
64+
compute_log_weight_smoothness_penalty(log_weights, prior)
65+
66+
67+
def test_symmetric_log_deviation():
68+
"""Doubling the prior and halving it produce the same penalty magnitude."""
69+
prior = torch.ones(2, 2)
70+
log_weights_double = torch.full((2, 2), float(torch.log(torch.tensor(2.0))))
71+
log_weights_half = torch.full((2, 2), -float(torch.log(torch.tensor(2.0))))
72+
a = compute_log_weight_smoothness_penalty(log_weights_double, prior)
73+
b = compute_log_weight_smoothness_penalty(log_weights_half, prior)
74+
assert a.item() == pytest.approx(b.item())
75+
76+
77+
def test_penalty_is_differentiable():
78+
"""The result must carry a grad so Adam can actually use it."""
79+
prior = torch.ones(2, 3)
80+
log_weights = torch.randn(2, 3, requires_grad=True)
81+
penalty = compute_log_weight_smoothness_penalty(log_weights, prior)
82+
assert penalty.requires_grad
83+
penalty.backward()
84+
assert log_weights.grad is not None
85+
# Some entry must see a non-zero gradient for a non-trivial prior.
86+
assert torch.any(log_weights.grad != 0)
87+
88+
89+
def test_device_and_dtype_round_trip():
90+
"""The output dtype matches the log_weights dtype (not the prior's)."""
91+
prior = torch.ones(2, 2, dtype=torch.float32)
92+
log_weights = torch.zeros(2, 2, dtype=torch.float64)
93+
penalty = compute_log_weight_smoothness_penalty(log_weights, prior)
94+
assert penalty.dtype == torch.float64
95+
96+
97+
def test_heterogeneous_mask_and_values():
98+
"""Explicit hand-computed example to lock in the arithmetic."""
99+
# prior = [[1, 0], [4, e]] ⇒ valid entries are (0,0), (1,0), (1,1).
100+
e = float(torch.e)
101+
prior = torch.tensor([[1.0, 0.0], [4.0, e]])
102+
# log_weights = [[0, any], [0, 0]] ⇒ deviations on valid entries
103+
# are: (0 - log 1)=0, (0 - log 4)=-2 log 2, (0 - log e)=-1.
104+
log_weights = torch.tensor([[0.0, 999.0], [0.0, 0.0]])
105+
penalty = compute_log_weight_smoothness_penalty(log_weights, prior)
106+
expected = (0.0**2 + (2 * torch.log(torch.tensor(2.0))).item() ** 2 + 1.0**2) / 3
107+
assert penalty.item() == pytest.approx(expected, rel=1e-5)

0 commit comments

Comments
 (0)