Skip to content

Commit ce48e86

Browse files
committed
Use PyMC progress bar with nutpie
1 parent ed33d37 commit ce48e86

4 files changed

Lines changed: 175 additions & 12 deletions

File tree

pymc/progress_bar/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from pymc.progress_bar.progress import (
1616
MCMCProgressBarManager,
17+
NutpieProgressBarManager,
1718
ProgressBarManager,
1819
ProgressBarOptions,
1920
SMCProgressBarManager,

pymc/progress_bar/progress.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,71 @@ def update(self, chain_idx: int, is_last: bool, draw: int, tuning: bool, stats)
329329
)
330330

331331

332+
class NutpieProgressBarManager(ProgressBarManager):
333+
"""Progress bar manager for nutpie NUTS sampling.
334+
335+
Bridges ``nutpie.sample``'s ``progress_callback`` (a callable that receives a
336+
list of ``nutpie.ChainProgress`` objects) to PyMC's progress bar backends,
337+
so nutpie draws through the same UI as the pymc sampler.
338+
"""
339+
340+
step_name: str = "Draw"
341+
342+
def __init__(
343+
self,
344+
chains: int,
345+
draws: int,
346+
tune: int,
347+
progressbar: bool | ProgressBarOptions = True,
348+
progressbar_theme: Theme | str | None = None,
349+
):
350+
from pymc.step_methods.hmc import NUTS
351+
352+
super().__init__(
353+
n_bars=chains,
354+
progressbar=progressbar,
355+
progressbar_theme=progressbar_theme,
356+
)
357+
358+
progress_columns, progress_stats = NUTS._progressbar_config(chains)
359+
progress_stats["draw"] = [0] * chains
360+
361+
self.total_draws = draws + tune
362+
self._previous_finished = [0] * chains
363+
364+
self._backend = self._create_backend(
365+
total=self.total_draws * chains if self.combined_progress else self.total_draws,
366+
progress_columns=progress_columns,
367+
progress_stats=progress_stats,
368+
)
369+
370+
def update(self, chain_progresses) -> None:
371+
"""Consume a list of ``nutpie.ChainProgress`` objects and advance each bar."""
372+
if not self._show_progress:
373+
return
374+
375+
for chain_idx, cp in enumerate(chain_progresses):
376+
delta = cp.finished_draws - self._previous_finished[chain_idx]
377+
if delta <= 0:
378+
continue
379+
self._previous_finished[chain_idx] = cp.finished_draws
380+
is_last = cp.finished_draws >= cp.total_draws
381+
stats = {
382+
"divergences": cp.divergences,
383+
"step_size": cp.step_size,
384+
"tree_size": cp.latest_num_steps,
385+
"draw": cp.finished_draws,
386+
}
387+
task_id = 0 if self.combined_progress else chain_idx
388+
self._backend.update(
389+
task_id=task_id,
390+
advance=delta,
391+
failing=cp.divergences > 0,
392+
stats=stats,
393+
is_last=is_last,
394+
)
395+
396+
332397
class SMCProgressBarManager(ProgressBarManager):
333398
"""Progress bar manager for SMC sampling.
334399

pymc/sampling/mcmc.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import contextlib
1616
import importlib.util
17+
import inspect
1718
import logging
1819
import multiprocessing
1920
import pickle
@@ -54,6 +55,7 @@
5455
from pymc.model import Model, modelcontext
5556
from pymc.progress_bar import (
5657
MCMCProgressBarManager,
58+
NutpieProgressBarManager,
5759
ProgressBarOptions,
5860
default_progress_theme,
5961
)
@@ -333,7 +335,8 @@ def _sample_external_nuts(
333335
initvals: StartDict | Sequence[StartDict | None] | None,
334336
model: Model,
335337
var_names: Sequence[str] | None,
336-
progressbar: bool,
338+
progressbar: bool | ProgressBarOptions,
339+
progressbar_theme: Theme | None,
337340
quiet: bool,
338341
idata_kwargs: dict | None,
339342
compute_convergence_checks: bool,
@@ -387,16 +390,38 @@ def _sample_external_nuts(
387390
var_names=var_names,
388391
**compile_kwargs,
389392
)
390-
t_start = time.time()
391-
idata = nutpie.sample(
392-
compiled_model,
393-
draws=draws,
394-
tune=tune,
395-
chains=chains,
396-
seed=_get_seeds_per_chain(random_seed, 1)[0],
397-
progress_bar=progressbar,
398-
**nuts_kwargs,
393+
394+
nutpie_supports_progress_callback = (
395+
"progress_callback" in inspect.signature(nutpie.sample).parameters
399396
)
397+
if nutpie_supports_progress_callback:
398+
pb_manager = NutpieProgressBarManager(
399+
chains=chains,
400+
draws=draws,
401+
tune=tune,
402+
progressbar=progressbar,
403+
progressbar_theme=progressbar_theme,
404+
)
405+
nutpie_progress_kwargs = {
406+
"progress_bar": False,
407+
"progress_callback": pb_manager.update,
408+
}
409+
progress_cm: contextlib.AbstractContextManager = pb_manager
410+
else:
411+
nutpie_progress_kwargs = {"progress_bar": bool(progressbar)}
412+
progress_cm = contextlib.nullcontext()
413+
414+
t_start = time.time()
415+
with progress_cm:
416+
idata = nutpie.sample(
417+
compiled_model,
418+
draws=draws,
419+
tune=tune,
420+
chains=chains,
421+
seed=_get_seeds_per_chain(random_seed, 1)[0],
422+
**nutpie_progress_kwargs,
423+
**nuts_kwargs,
424+
)
400425
t_sample = time.time() - t_start
401426
patch_nutpie_idata(
402427
idata,
@@ -438,7 +463,7 @@ def _sample_external_nuts(
438463
initvals=initvals,
439464
model=model,
440465
var_names=var_names,
441-
progressbar=progressbar,
466+
progressbar=bool(progressbar),
442467
quiet=quiet,
443468
nuts_sampler=sampler,
444469
nuts_kwargs=jax_nuts_kwargs,
@@ -877,7 +902,8 @@ def sample(
877902
initvals=initvals,
878903
model=model,
879904
var_names=var_names,
880-
progressbar=progress_bool,
905+
progressbar=progressbar if nuts_sampler == "nutpie" else progress_bool,
906+
progressbar_theme=progressbar_theme,
881907
quiet=quiet,
882908
idata_kwargs=idata_kwargs,
883909
compute_convergence_checks=compute_convergence_checks,

tests/sampling/test_mcmc_external.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import unittest.mock as mock
16+
17+
from types import SimpleNamespace
18+
1519
import numpy as np
1620
import numpy.testing as npt
1721
import pytest
1822
import xarray as xr
1923

2024
from pymc import Data, Deterministic, HalfNormal, Model, Normal, sample
25+
from pymc.progress_bar import NutpieProgressBarManager
2126

2227

2328
# temporarily skip nutpie
@@ -147,3 +152,69 @@ def test_sample_var_names(nuts_sampler):
147152
assert var in idata_2.posterior
148153

149154
xr.testing.assert_allclose(idata_1.posterior[var], idata_2.posterior[var])
155+
156+
157+
def test_nutpie_progress_bar_manager_update():
158+
pb = NutpieProgressBarManager(chains=2, draws=10, tune=10, progressbar=False)
159+
pb._backend = mock.Mock()
160+
pb._show_progress = True # force the update path even without a real backend
161+
162+
cp0 = SimpleNamespace(
163+
finished_draws=5,
164+
total_draws=20,
165+
divergences=0,
166+
step_size=0.5,
167+
latest_num_steps=3,
168+
)
169+
cp1 = SimpleNamespace(
170+
finished_draws=4,
171+
total_draws=20,
172+
divergences=1,
173+
step_size=0.4,
174+
latest_num_steps=7,
175+
)
176+
pb.update([cp0, cp1])
177+
assert pb._backend.update.call_count == 2
178+
first_call = pb._backend.update.call_args_list[0].kwargs
179+
assert first_call["task_id"] == 0
180+
assert first_call["advance"] == 5
181+
assert first_call["stats"]["divergences"] == 0
182+
second_call = pb._backend.update.call_args_list[1].kwargs
183+
assert second_call["task_id"] == 1
184+
assert second_call["advance"] == 4
185+
assert second_call["failing"] is True
186+
187+
# A second update only advances by the delta since the previous call.
188+
cp0.finished_draws = 20
189+
cp1.finished_draws = 20
190+
pb._backend.update.reset_mock()
191+
pb.update([cp0, cp1])
192+
deltas = [c.kwargs["advance"] for c in pb._backend.update.call_args_list]
193+
is_last_flags = [c.kwargs["is_last"] for c in pb._backend.update.call_args_list]
194+
assert deltas == [15, 16]
195+
assert is_last_flags == [True, True]
196+
197+
198+
def test_nutpie_end_to_end():
199+
# Released nutpie 0.16.8 references `arviz.InferenceData` which arviz 1.0 removed,
200+
# so `import nutpie` raises AttributeError on the current CI matrix. Skip until a
201+
# nutpie release compatible with arviz 1.0 ships.
202+
try:
203+
import nutpie # noqa: F401
204+
except (ImportError, AttributeError):
205+
pytest.skip("nutpie unavailable or incompatible with the installed arviz")
206+
with Model() as m:
207+
HalfNormal("sigma")
208+
Normal("mu")
209+
Normal("y", mu=0, sigma=1, observed=[1.0, 2.0, 3.0])
210+
idata = sample(
211+
nuts_sampler="nutpie",
212+
tune=20,
213+
draws=20,
214+
chains=2,
215+
progressbar=False,
216+
random_seed=1411,
217+
)
218+
assert {"posterior", "sample_stats", "observed_data"} <= set(idata.children)
219+
assert set(idata.posterior.data_vars) == {"mu", "sigma"}
220+
assert idata.posterior.sizes == {"chain": 2, "draw": 20}

0 commit comments

Comments
 (0)