Skip to content

Commit c8638c2

Browse files
committed
Add calibration setup heartbeats for CI
1 parent e743602 commit c8638c2

5 files changed

Lines changed: 228 additions & 60 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add periodic CI heartbeats around long calibration setup stages so dataset release builds do not get canceled while constituency target matrices are being prepared.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import time
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from policyengine_uk_data.utils.calibrate import calibrate_local_areas
7+
8+
9+
class _DummyValues:
10+
def __init__(self, values):
11+
self.values = np.array(values, dtype=float)
12+
13+
14+
class _DummyHousehold:
15+
def __init__(self, weights):
16+
self.household_weight = _DummyValues(weights)
17+
18+
19+
class _DummyDataset:
20+
def __init__(self, weights):
21+
self.household = _DummyHousehold(weights)
22+
23+
def copy(self):
24+
return _DummyDataset(self.household.household_weight.values.copy())
25+
26+
27+
def test_calibrate_local_areas_logs_setup_stage_heartbeats_in_ci(
28+
monkeypatch, capsys, tmp_path
29+
):
30+
monkeypatch.setenv("CI", "true")
31+
monkeypatch.setenv("POLICYENGINE_PROGRESS_HEARTBEAT_SECONDS", "0.01")
32+
monkeypatch.setattr(
33+
"policyengine_uk_data.utils.calibrate.STORAGE_FOLDER",
34+
tmp_path,
35+
)
36+
37+
dataset = _DummyDataset([10.0, 20.0, 30.0])
38+
39+
def matrix_fn(_dataset):
40+
time.sleep(0.03)
41+
matrix = pd.DataFrame({"metric": [1.0, 0.0, 1.0]})
42+
targets = pd.DataFrame({"metric": [2.0]})
43+
mask = np.ones((1, 3))
44+
return matrix, targets, mask
45+
46+
def national_matrix_fn(_dataset):
47+
time.sleep(0.03)
48+
matrix = pd.DataFrame({"national_metric": [1.0, 1.0, 1.0]})
49+
targets = pd.Series({"national_metric": 3.0})
50+
return matrix, targets
51+
52+
calibrate_local_areas(
53+
dataset=dataset,
54+
matrix_fn=matrix_fn,
55+
national_matrix_fn=national_matrix_fn,
56+
area_count=1,
57+
weight_file="weights.h5",
58+
epochs=1,
59+
verbose=True,
60+
area_name="Constituency",
61+
)
62+
63+
output = capsys.readouterr().out
64+
assert "[calibration] starting: Constituency: build local target matrix" in output
65+
assert "[calibration] heartbeat: Constituency: build local target matrix" in output
66+
assert "[calibration] completed: Constituency: build local target matrix" in output
67+
assert (
68+
"[calibration] starting: Constituency: build national target matrix" in output
69+
)
70+
assert (
71+
"[calibration] heartbeat: Constituency: build national target matrix" in output
72+
)
73+
assert (
74+
"[calibration] completed: Constituency: build national target matrix" in output
75+
)
76+
assert "[calibration] epoch 1/1: calculating loss" in output

policyengine_uk_data/tests/test_progress.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import time
2+
3+
import pytest
4+
15
from policyengine_uk_data.utils.progress import ProcessingProgress
26

37

@@ -36,3 +40,33 @@ def test_track_calibration_logs_heartbeats_in_ci(monkeypatch, capsys):
3640
assert "[calibration] epoch 1/12: calculating loss" in output
3741
assert "[calibration] epoch 10/12: loss=1.000000" in output
3842
assert "[calibration] epoch 12/12: loss=1.200000" in output
43+
44+
45+
def test_track_stage_logs_periodic_heartbeats_in_ci(monkeypatch, capsys):
46+
monkeypatch.setenv("CI", "true")
47+
monkeypatch.setenv("POLICYENGINE_PROGRESS_HEARTBEAT_SECONDS", "0.01")
48+
49+
progress = ProcessingProgress()
50+
51+
with progress.track_stage("Constituency: build local target matrix"):
52+
time.sleep(0.03)
53+
54+
output = capsys.readouterr().out
55+
assert "[calibration] starting: Constituency: build local target matrix" in output
56+
assert "[calibration] heartbeat: Constituency: build local target matrix" in output
57+
assert "[calibration] completed: Constituency: build local target matrix" in output
58+
59+
60+
def test_track_stage_logs_failures_in_ci(monkeypatch, capsys):
61+
monkeypatch.setenv("CI", "true")
62+
monkeypatch.setenv("POLICYENGINE_PROGRESS_HEARTBEAT_SECONDS", "0.01")
63+
64+
progress = ProcessingProgress()
65+
66+
with pytest.raises(RuntimeError, match="boom"):
67+
with progress.track_stage("Constituency: build local target matrix"):
68+
time.sleep(0.02)
69+
raise RuntimeError("boom")
70+
71+
output = capsys.readouterr().out
72+
assert "[calibration] failed: Constituency: build local target matrix" in output

policyengine_uk_data/utils/calibrate.py

Lines changed: 76 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from contextlib import nullcontext
2+
13
import torch
2-
from policyengine_uk import Microsimulation
34
import pandas as pd
45
import numpy as np
56
import h5py
@@ -39,59 +40,74 @@ def calibrate_local_areas(
3940
verbose: Whether to print progress
4041
area_name: Name of the area type for logging
4142
"""
42-
dataset = dataset.copy()
43-
matrix, y, r = matrix_fn(dataset)
43+
progress_tracker = ProcessingProgress() if verbose else None
44+
45+
def track_stage(stage_name: str):
46+
if progress_tracker is None:
47+
return nullcontext()
48+
return progress_tracker.track_stage(stage_name)
49+
50+
with track_stage(f"{area_name}: copy dataset"):
51+
dataset = dataset.copy()
52+
53+
with track_stage(f"{area_name}: build local target matrix"):
54+
matrix, y, r = matrix_fn(dataset)
4455
m_c, y_c = matrix.copy(), y.copy()
45-
m_national, y_national = national_matrix_fn(dataset)
56+
57+
with track_stage(f"{area_name}: build national target matrix"):
58+
m_national, y_national = national_matrix_fn(dataset)
4659
m_n, y_n = m_national.copy(), y_national.copy()
4760

48-
# Weights - area_count x num_households
49-
# Use country-aware initialization: divide each household's weight by the
50-
# number of areas in its country, not the total area count. This ensures
51-
# households start at approximately correct weight for their country's targets.
52-
# The country_mask r[i,j]=1 iff household j is in same country as area i.
53-
areas_per_household = r.sum(
54-
axis=0
55-
) # number of areas each household can contribute to
56-
areas_per_household = np.maximum(areas_per_household, 1) # avoid division by zero
57-
original_weights = np.log(
58-
dataset.household.household_weight.values / areas_per_household
59-
+ np.random.random(len(dataset.household.household_weight.values)) * 0.01
60-
)
61-
weights = torch.tensor(
62-
np.ones((area_count, len(original_weights))) * original_weights,
63-
dtype=torch.float32,
64-
requires_grad=True,
65-
)
66-
67-
# Set up validation targets if specified
68-
validation_targets_local = (
69-
matrix.columns.isin(excluded_training_targets)
70-
if hasattr(matrix, "columns")
71-
else None
72-
)
73-
validation_targets_national = (
74-
m_national.columns.isin(excluded_training_targets)
75-
if hasattr(m_national, "columns")
76-
else None
77-
)
78-
dropout_targets = len(excluded_training_targets) > 0
79-
80-
# Convert to tensors
81-
metrics = torch.tensor(
82-
matrix.values if hasattr(matrix, "values") else matrix,
83-
dtype=torch.float32,
84-
)
85-
y = torch.tensor(y.values if hasattr(y, "values") else y, dtype=torch.float32)
86-
matrix_national = torch.tensor(
87-
m_national.values if hasattr(m_national, "values") else m_national,
88-
dtype=torch.float32,
89-
)
90-
y_national = torch.tensor(
91-
y_national.values if hasattr(y_national, "values") else y_national,
92-
dtype=torch.float32,
93-
)
94-
r = torch.tensor(r, dtype=torch.float32)
61+
with track_stage(f"{area_name}: prepare tensors and optimizer"):
62+
# Weights - area_count x num_households
63+
# Use country-aware initialization: divide each household's weight by the
64+
# number of areas in its country, not the total area count. This ensures
65+
# households start at approximately correct weight for their country's targets.
66+
# The country_mask r[i,j]=1 iff household j is in same country as area i.
67+
areas_per_household = r.sum(
68+
axis=0
69+
) # number of areas each household can contribute to
70+
areas_per_household = np.maximum(
71+
areas_per_household, 1
72+
) # avoid division by zero
73+
original_weights = np.log(
74+
dataset.household.household_weight.values / areas_per_household
75+
+ np.random.random(len(dataset.household.household_weight.values)) * 0.01
76+
)
77+
weights = torch.tensor(
78+
np.ones((area_count, len(original_weights))) * original_weights,
79+
dtype=torch.float32,
80+
requires_grad=True,
81+
)
82+
83+
# Set up validation targets if specified
84+
validation_targets_local = (
85+
matrix.columns.isin(excluded_training_targets)
86+
if hasattr(matrix, "columns")
87+
else None
88+
)
89+
validation_targets_national = (
90+
m_national.columns.isin(excluded_training_targets)
91+
if hasattr(m_national, "columns")
92+
else None
93+
)
94+
dropout_targets = len(excluded_training_targets) > 0
95+
96+
# Convert to tensors
97+
metrics = torch.tensor(
98+
matrix.values if hasattr(matrix, "values") else matrix,
99+
dtype=torch.float32,
100+
)
101+
y = torch.tensor(y.values if hasattr(y, "values") else y, dtype=torch.float32)
102+
matrix_national = torch.tensor(
103+
m_national.values if hasattr(m_national, "values") else m_national,
104+
dtype=torch.float32,
105+
)
106+
y_national = torch.tensor(
107+
y_national.values if hasattr(y_national, "values") else y_national,
108+
dtype=torch.float32,
109+
)
110+
r = torch.tensor(r, dtype=torch.float32)
95111

96112
def sre(x, y):
97113
one_way = ((1 + x) / (1 + y) - 1) ** 2
@@ -160,8 +176,6 @@ def dropout_weights(weights, p):
160176
final_weights = (torch.exp(weights) * r).detach().numpy()
161177
performance = pd.DataFrame()
162178

163-
progress_tracker = ProcessingProgress() if verbose else None
164-
165179
if verbose and progress_tracker:
166180
with progress_tracker.track_calibration(
167181
epochs, nested_progress
@@ -171,8 +185,8 @@ def dropout_weights(weights, p):
171185

172186
optimizer.zero_grad()
173187
weights_ = torch.exp(dropout_weights(weights, 0.05)) * r
174-
l = loss(weights_)
175-
l.backward()
188+
loss_value = loss(weights_)
189+
loss_value.backward()
176190
optimizer.step()
177191

178192
local_close = pct_close(weights_, local=True, national=False)
@@ -187,7 +201,9 @@ def dropout_weights(weights, p):
187201
)
188202
else:
189203
update_calibration(
190-
epoch + 1, loss_value=l.item(), calculating_loss=False
204+
epoch + 1,
205+
loss_value=loss_value.item(),
206+
calculating_loss=False,
191207
)
192208

193209
if epoch % 10 == 0:
@@ -225,8 +241,8 @@ def dropout_weights(weights, p):
225241
for epoch in range(epochs):
226242
optimizer.zero_grad()
227243
weights_ = torch.exp(dropout_weights(weights, 0.05)) * r
228-
l = loss(weights_)
229-
l.backward()
244+
loss_value = loss(weights_)
245+
loss_value.backward()
230246
optimizer.step()
231247

232248
local_close = pct_close(weights_, local=True, national=False)
@@ -236,12 +252,12 @@ def dropout_weights(weights, p):
236252
if dropout_targets:
237253
validation_loss = loss(weights_, validation=True)
238254
print(
239-
f"Training loss: {l.item():,.3f}, Validation loss: {validation_loss.item():,.3f}, Epoch: {epoch}, "
255+
f"Training loss: {loss_value.item():,.3f}, Validation loss: {validation_loss.item():,.3f}, Epoch: {epoch}, "
240256
f"{area_name}<10%: {local_close:.1%}, National<10%: {national_close:.1%}"
241257
)
242258
else:
243259
print(
244-
f"Loss: {l.item()}, Epoch: {epoch}, {area_name}<10%: {local_close:.1%}, National<10%: {national_close:.1%}"
260+
f"Loss: {loss_value.item()}, Epoch: {epoch}, {area_name}<10%: {local_close:.1%}, National<10%: {national_close:.1%}"
245261
)
246262

247263
if epoch % 10 == 0:

policyengine_uk_data/utils/progress.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from contextlib import contextmanager
99
import os
10+
import threading
11+
import time
1012
from typing import Any, Dict, List, Optional, Union
1113

1214
from rich.console import Console
@@ -190,10 +192,49 @@ def __init__(self, console: Optional[Console] = None):
190192
self._plain_output = (
191193
os.environ.get("GITHUB_ACTIONS") == "true" or os.environ.get("CI") == "true"
192194
)
195+
self._heartbeat_seconds = float(
196+
os.environ.get("POLICYENGINE_PROGRESS_HEARTBEAT_SECONDS", "60")
197+
)
193198

194199
def _emit(self, message: str):
195200
print(message, flush=True)
196201

202+
@contextmanager
203+
def track_stage(self, stage_name: str, category: str = "calibration"):
204+
"""Track a long-running stage with periodic CI heartbeats."""
205+
if not self._plain_output:
206+
yield
207+
return
208+
209+
self._emit(f"[{category}] starting: {stage_name}")
210+
started_at = time.monotonic()
211+
stop_event = threading.Event()
212+
213+
def emit_heartbeats():
214+
while not stop_event.wait(self._heartbeat_seconds):
215+
elapsed = int(time.monotonic() - started_at)
216+
self._emit(f"[{category}] heartbeat: {stage_name} ({elapsed}s elapsed)")
217+
218+
heartbeat_thread = threading.Thread(
219+
target=emit_heartbeats,
220+
name=f"{category}-heartbeat",
221+
daemon=True,
222+
)
223+
heartbeat_thread.start()
224+
225+
try:
226+
yield
227+
except Exception:
228+
elapsed = int(time.monotonic() - started_at)
229+
self._emit(f"[{category}] failed: {stage_name} ({elapsed}s elapsed)")
230+
raise
231+
else:
232+
elapsed = int(time.monotonic() - started_at)
233+
self._emit(f"[{category}] completed: {stage_name} ({elapsed}s elapsed)")
234+
finally:
235+
stop_event.set()
236+
heartbeat_thread.join(timeout=1)
237+
197238
@contextmanager
198239
def track_dataset_creation(self, datasets: List[str]):
199240
"""Track dataset creation progress with stable display.

0 commit comments

Comments
 (0)