Skip to content

Commit d535148

Browse files
committed
Fix non terminating speed/elapsed in marimo progress bar
1 parent 8d99218 commit d535148

3 files changed

Lines changed: 76 additions & 4 deletions

File tree

pymc/progress_bar/marimo_progress.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
self._mo_replace: Callable[[object], None] | None = None
7272
self._task_state: list[dict[str, Any]] = []
7373
self._start_times: list[float | None] = []
74+
self._end_times: list[float | None] = []
7475
self._last_render_time: float = 0.0
7576
self._min_render_interval: float = 0.1
7677

@@ -107,6 +108,7 @@ def _initialize_tasks(self) -> None:
107108
# ``cores < chains``, later chains aren't timed from the first chain's
108109
# start.
109110
self._start_times = [None] * self.n_bars
111+
self._end_times = [None] * self.n_bars
110112

111113
def update(
112114
self,
@@ -148,6 +150,8 @@ def update(
148150

149151
if is_last:
150152
self._task_state[task_id]["completed"] = self._task_state[task_id]["total"]
153+
if self._end_times[task_id] is None:
154+
self._end_times[task_id] = perf_counter()
151155

152156
now = perf_counter()
153157
if is_last or (now - self._last_render_time) >= self._min_render_interval:
@@ -215,7 +219,12 @@ def _render_task_row(self, task_id: int, state: dict[str, Any], stat_keys: list[
215219

216220
pct = (completed / total * 100) if total else 0
217221
start_time = self._start_times[task_id]
218-
elapsed = 0.0 if start_time is None else perf_counter() - start_time
222+
if start_time is None:
223+
elapsed = 0.0
224+
else:
225+
end_time = self._end_times[task_id]
226+
now = end_time if end_time is not None else perf_counter()
227+
elapsed = now - start_time
219228

220229
action = self.step_name.lower()
221230
# Wait for a small window of elapsed time before computing speed so the

pymc/progress_bar/progress.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def __init__(
372372
)
373373
# Used to compute delta draws between calls
374374
self._previous_finished = [0] * chains
375+
self._chain_completed = [False] * chains
375376

376377
progress_columns = [
377378
TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)),
@@ -424,11 +425,14 @@ def update(self, chain_progresses) -> None:
424425
for chain_idx, cp in enumerate(chain_progresses):
425426
# With ``cores < chains`` queued chains haven't started yet;
426427
# skip them so their bar doesn't show progress or elapsed time.
427-
if not cp.started:
428+
if not cp.started or self._chain_completed[chain_idx]:
428429
continue
429430
cp_finished_draws = cp.finished_draws
430431
delta = cp_finished_draws - self._previous_finished[chain_idx]
431432
self._previous_finished[chain_idx] = cp_finished_draws
433+
is_last = cp_finished_draws >= cp.total_draws
434+
if is_last:
435+
self._chain_completed[chain_idx] = True
432436
# Use nutpie's per-chain runtime as the source of truth for
433437
# elapsed/speed, so reads aren't skewed by the wait time
434438
# before this chain started.
@@ -445,7 +449,7 @@ def update(self, chain_progresses) -> None:
445449
advance=delta,
446450
failing=bool(cp.divergent_draws),
447451
stats=stats,
448-
is_last=cp_finished_draws >= cp.total_draws,
452+
is_last=is_last,
449453
total=cp.total_draws,
450454
)
451455

tests/progress_bar/test_marimo.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from time import sleep
16+
from types import SimpleNamespace
1517
from unittest.mock import patch
1618

1719
import pytest
1820

1921
import pymc as pm
2022

21-
from pymc.progress_bar import MCMCProgressBarManager
23+
from pymc.progress_bar import MCMCProgressBarManager, NutpieProgressBarManager
2224
from pymc.progress_bar.marimo_progress import MarimoProgressBackend
2325

2426

@@ -75,6 +77,7 @@ def test_render_html_structure(self, step_method):
7577
{"completed": 75, "total": 150, "failing": True, "stats": {"divergences": 1}},
7678
]
7779
backend._start_times = [0, 0]
80+
backend._end_times = [None, None]
7881

7982
html = backend._render_html()
8083

@@ -103,6 +106,7 @@ def test_render_html_with_stats(self, step_method):
103106
},
104107
]
105108
backend._start_times = [0]
109+
backend._end_times = [None]
106110

107111
html = backend._render_html()
108112

@@ -121,6 +125,25 @@ def test_is_last_sets_completed_to_total(self):
121125
assert backend._task_state[0]["completed"] == 150
122126
assert backend._task_state[1]["completed"] == 0
123127

128+
def test_elapsed_freezes_after_completion(self):
129+
"""Completed chains must not show drifting speed/elapsed on re-renders."""
130+
backend = MarimoProgressBackend(
131+
step_name="Draw", n_bars=2, total=10, combined=False, full_stats=False
132+
)
133+
backend._initialize_tasks()
134+
135+
# Complete chain 0
136+
for i in range(10):
137+
backend.update(task_id=0, advance=1, failing=False, stats={}, is_last=i == 9)
138+
139+
html_at_finish = backend._render_task_row(0, backend._task_state[0], [])
140+
141+
# Let wall-clock advance so unfrozen elapsed would visibly drift
142+
sleep(0.3)
143+
144+
html_after = backend._render_task_row(0, backend._task_state[0], [])
145+
assert html_at_finish == html_after
146+
124147
def test_marimo_smc_progress(self):
125148
backend = MarimoProgressBackend(
126149
step_name="Stage", n_bars=1, total=1.0, combined=False, full_stats=False
@@ -139,3 +162,39 @@ def test_marimo_smc_progress(self):
139162
old = beta
140163

141164
assert backend._task_state[0]["completed"] == 1.0
165+
166+
def test_nutpie_elapsed_freezes_after_completion(self):
167+
"""Completed nutpie chains must not show drifting elapsed in marimo."""
168+
with patch("pymc.progress_bar.progress.in_marimo_notebook", return_value=True):
169+
manager = NutpieProgressBarManager(chains=2, draws=100, progressbar=True)
170+
assert isinstance(manager._backend, MarimoProgressBackend)
171+
total = 1100
172+
173+
backend = manager._backend
174+
backend._initialize_tasks()
175+
176+
def cp(finished, runtime_ms, started=True):
177+
return SimpleNamespace(
178+
finished_draws=finished,
179+
total_draws=total,
180+
runtime_ms=runtime_ms,
181+
started=started,
182+
divergent_draws=[],
183+
step_size=0.5,
184+
latest_num_steps=7,
185+
)
186+
187+
# Chain 0 finishes, chain 1 halfway
188+
manager.update([cp(total, runtime_ms=5000), cp(500, runtime_ms=2500)])
189+
190+
html_at_finish = backend._render_task_row(0, backend._task_state[0], [])
191+
192+
# Let wall-clock advance so unfrozen elapsed would visibly drift
193+
sleep(0.3)
194+
195+
# More callbacks arrive while chain 1 is still running
196+
manager.update([cp(total, runtime_ms=5000), cp(800, runtime_ms=4000)])
197+
198+
# Chain 0's row must be identical — elapsed and speed frozen
199+
html_after = backend._render_task_row(0, backend._task_state[0], [])
200+
assert html_at_finish == html_after

0 commit comments

Comments
 (0)