Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pull_request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
runs-on: ubuntu-latest
env:
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
PYTHONUNBUFFERED: "1"
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
id-token: "write"
env:
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
PYTHONUNBUFFERED: "1"
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions changelog.d/ci-progress-heartbeats.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make long-running dataset builds emit plain CI heartbeat logs so release workflows are less likely to die silently during calibration.
38 changes: 38 additions & 0 deletions policyengine_uk_data/tests/test_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from policyengine_uk_data.utils.progress import ProcessingProgress


def test_track_dataset_creation_logs_in_ci(monkeypatch, capsys):
monkeypatch.setenv("GITHUB_ACTIONS", "true")

progress = ProcessingProgress()

with progress.track_dataset_creation(["Build base", "Save final"]) as (
update_dataset,
nested_progress,
):
assert nested_progress is None
update_dataset("Build base", "processing")
update_dataset("Build base", "completed")
update_dataset("Save final", "processing")
update_dataset("Save final", "completed")

output = capsys.readouterr().out
assert "[dataset] starting: Build base" in output
assert "[dataset] completed (1/2): Build base" in output
assert "[dataset] completed (2/2): Save final" in output


def test_track_calibration_logs_heartbeats_in_ci(monkeypatch, capsys):
monkeypatch.setenv("CI", "true")

progress = ProcessingProgress()

with progress.track_calibration(12) as update_calibration:
for iteration in range(1, 13):
update_calibration(iteration, calculating_loss=True)
update_calibration(iteration, loss_value=iteration / 10)

output = capsys.readouterr().out
assert "[calibration] epoch 1/12: calculating loss" in output
assert "[calibration] epoch 10/12: loss=1.000000" in output
assert "[calibration] epoch 12/12: loss=1.200000" in output
46 changes: 45 additions & 1 deletion policyengine_uk_data/utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"""

from contextlib import contextmanager
import os
from typing import Any, Dict, List, Optional, Union
import time

from rich.console import Console
from rich.progress import (
Expand Down Expand Up @@ -187,6 +187,12 @@ def __init__(self, console: Optional[Console] = None):
"""
self.console = console or Console()
self.progress_manager: Optional[RichProgress] = None
self._plain_output = (
os.environ.get("GITHUB_ACTIONS") == "true" or os.environ.get("CI") == "true"
)

def _emit(self, message: str):
print(message, flush=True)

@contextmanager
def track_dataset_creation(self, datasets: List[str]):
Expand All @@ -198,6 +204,23 @@ def track_dataset_creation(self, datasets: List[str]):
Yields:
Tuple of (update_dataset function, progress manager for nested tasks).
"""
if self._plain_output:
completed_count = 0

def update_dataset(dataset_name: str, status: str = "processing"):
nonlocal completed_count

if status == "processing":
self._emit(f"[dataset] starting: {dataset_name}")
elif status == "completed":
completed_count += 1
self._emit(
f"[dataset] completed ({completed_count}/{len(datasets)}): {dataset_name}"
)

yield update_dataset, None
return

with create_progress(self.console) as progress:
# Main dataset creation progress
main_task = progress.add_task(
Expand Down Expand Up @@ -265,6 +288,27 @@ def track_calibration(self, iterations: int, nested_progress=None):
Yields:
Function to update calibration progress.
"""
if self._plain_output:

def update_calibration(
iteration: int,
loss_value: Optional[float] = None,
calculating_loss: bool = False,
):
if calculating_loss:
self._emit(
f"[calibration] epoch {iteration}/{iterations}: calculating loss"
)
elif loss_value is not None and (
iteration == 1 or iteration == iterations or iteration % 10 == 0
):
self._emit(
f"[calibration] epoch {iteration}/{iterations}: loss={loss_value:.6f}"
)

yield update_calibration
return

if nested_progress:
# Add calibration as a nested task in existing progress
calibration_task = nested_progress.add_task(
Expand Down
Loading