Skip to content

Commit 0c0a45f

Browse files
vahid-ahmadiclaude
authored andcommitted
Fix asymmetric loss function that biased optimiser toward overshoot
The min-of-two-ratios SRE loss penalised undershoot more than overshoot of the same magnitude (e.g. 6% overshoot cost 89% of 6% undershoot). Across ~11k targets this systematically inflated weights, causing the ~6% population overshoot. Replace with squared log-ratio which is perfectly symmetric: log(a/b)² = log(b/a)². Also remove redundant Scotland children/babies targets that overlapped with regional age bands. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0d8972e commit 0c0a45f

2 files changed

Lines changed: 9 additions & 104 deletions

File tree

policyengine_uk_data/targets/sources/ons_demographics.py

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -205,33 +205,6 @@ def _parse_regional_from_csv() -> list[Target]:
205205
return targets
206206

207207

208-
# Scotland-specific (from NRS/census — not in ONS projections)
209-
_SCOTLAND_CHILDREN_UNDER_16 = {
210-
y: v * 1e3
211-
for y, v in {
212-
2022: 904,
213-
2023: 900,
214-
2024: 896,
215-
2025: 892,
216-
2026: 888,
217-
2027: 884,
218-
2028: 880,
219-
}.items()
220-
}
221-
222-
_SCOTLAND_BABIES_UNDER_1 = {
223-
y: v * 1e3
224-
for y, v in {
225-
2022: 46,
226-
2023: 46,
227-
2024: 46,
228-
2025: 46,
229-
2026: 46,
230-
2027: 46,
231-
2028: 46,
232-
}.items()
233-
}
234-
235208
_SCOTLAND_HOUSEHOLDS_3PLUS_CHILDREN = {
236209
y: v * 1e3
237210
for y, v in {
@@ -263,38 +236,7 @@ def get_targets() -> list[Target]:
263236
# Regional age bands from demographics.csv
264237
targets.extend(_parse_regional_from_csv())
265238

266-
# Scotland-specific (NRS/census — small number of static values)
267-
targets.append(
268-
Target(
269-
name="ons/scotland_children_under_16",
270-
variable="age",
271-
source="nrs",
272-
unit=Unit.COUNT,
273-
values=_SCOTLAND_CHILDREN_UNDER_16,
274-
is_count=True,
275-
geographic_level=GeographicLevel.COUNTRY,
276-
geo_code="S",
277-
geo_name="Scotland",
278-
reference_url=_REF_NRS,
279-
)
280-
)
281-
targets.append(
282-
Target(
283-
name="ons/scotland_babies_under_1",
284-
variable="age",
285-
source="nrs",
286-
unit=Unit.COUNT,
287-
values=_SCOTLAND_BABIES_UNDER_1,
288-
is_count=True,
289-
geographic_level=GeographicLevel.COUNTRY,
290-
geo_code="S",
291-
geo_name="Scotland",
292-
reference_url=(
293-
"https://www.nrscotland.gov.uk/publications/"
294-
"vital-events-reference-tables-2024/"
295-
),
296-
)
297-
)
239+
# Scotland households (census-derived, no overlap with age bands)
298240
targets.append(
299241
Target(
300242
name="ons/scotland_households_3plus_children",

policyengine_uk_data/utils/calibrate.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16-
# Population gets this multiplier in the national loss so the optimiser
17-
# keeps it on target rather than letting it drift ~6% high.
18-
POPULATION_LOSS_WEIGHT = 10.0
19-
2016
def load_weights(
2117
weight_file: Union[str, Path],
2218
dataset_key: str = "2025",
@@ -89,27 +85,6 @@ def load_weights(
8985
return arr
9086

9187

92-
def _build_national_target_weights(
93-
national_matrix,
94-
population_weight: float = POPULATION_LOSS_WEIGHT,
95-
) -> np.ndarray:
96-
"""Build per-target weight vector for the national loss.
97-
98-
Every target gets weight 1.0 except ``ons/uk_population`` which gets
99-
``population_weight``. This ensures the optimiser treats population
100-
accuracy as a hard constraint rather than 1-of-N soft targets.
101-
"""
102-
pop_col_name = "ons/uk_population"
103-
if hasattr(national_matrix, "columns"):
104-
n = len(national_matrix.columns)
105-
w = np.ones(n, dtype=np.float32)
106-
cols = list(national_matrix.columns)
107-
if pop_col_name in cols:
108-
w[cols.index(pop_col_name)] = population_weight
109-
return w
110-
# Fallback: no column names available — equal weights
111-
return np.ones(national_matrix.shape[1], dtype=np.float32)
112-
11388

11489
def calibrate_local_areas(
11590
dataset: UKSingleYearDataset,
@@ -211,19 +186,13 @@ def track_stage(stage_name: str):
211186
)
212187
r = torch.tensor(r, dtype=torch.float32)
213188

214-
# Per-target weights for the national loss (population gets boosted)
215-
national_target_weights = torch.tensor(
216-
_build_national_target_weights(m_national),
217-
dtype=torch.float32,
218-
)
219-
220189
def sre(x, y):
221-
one_way = ((1 + x) / (1 + y) - 1) ** 2
222-
other_way = ((1 + y) / (1 + x) - 1) ** 2
223-
return torch.min(one_way, other_way)
224-
225-
def weighted_mean(values, weights):
226-
return (values * weights).sum() / weights.sum()
190+
"""Squared log-ratio loss — symmetric so overshoot and undershoot
191+
of the same magnitude incur identical cost. The previous
192+
min-of-two-ratios formulation penalised undershoot more than
193+
overshoot, which systematically biased the optimiser toward
194+
inflating weights (root cause of the ~6 % population overshoot)."""
195+
return torch.log((1 + x) / (1 + y)) ** 2
227196

228197
def loss(w, validation: bool = False):
229198
pred_local = (w.unsqueeze(-1) * metrics.unsqueeze(0)).sum(dim=1)
@@ -244,15 +213,9 @@ def loss(w, validation: bool = False):
244213
else:
245214
mask = ~validation_targets_national
246215
pred_national = pred_national[mask]
247-
mse_national = weighted_mean(
248-
sre(pred_national, y_national[mask]),
249-
national_target_weights[mask],
250-
)
216+
mse_national = torch.mean(sre(pred_national, y_national[mask]))
251217
else:
252-
mse_national = weighted_mean(
253-
sre(pred_national, y_national),
254-
national_target_weights,
255-
)
218+
mse_national = torch.mean(sre(pred_national, y_national))
256219

257220
return mse_local + mse_national
258221

0 commit comments

Comments
 (0)