Skip to content

Commit 58817f6

Browse files
committed
[ENH] replacement of tqdm progress bar with a modern rich progress bar
1 parent 8ab6c0f commit 58817f6

6 files changed

Lines changed: 66 additions & 37 deletions

File tree

binder/environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- h5py
99
- matplotlib
1010
- corner
11-
- tqdm
11+
- rich
1212
- mpi4py
1313
- schwimmbad
1414
- pip

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ exclude_lines = [
5555
"logging.warning",
5656
"deprecation_warning",
5757
"deprecated",
58-
"if tqdm is None"
58+
"if Progress is None"
5959
]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"wheel",
3030
]
3131
EXTRA_REQUIRE = {
32-
"extras": ["h5py", "scipy", "tqdm", "ipywidgets"],
32+
"extras": ["h5py", "scipy", "rich", "ipywidgets"],
3333
"tests": ["pytest", "pytest-cov", "coverage[toml]"],
3434
}
3535

src/emcee/ensemble.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,10 @@ def sample(
290290
a file or if you don't need to analyze the samples after the
291291
fact (for burn-in for example) set ``store`` to ``False``.
292292
progress (Optional[bool or str]): If ``True``, a progress bar will
293-
be shown as the sampler progresses. If a string, will select a
294-
specific ``tqdm`` progress bar - most notable is
295-
``'notebook'``, which shows a progress bar suitable for
296-
Jupyter notebooks. If ``False``, no progress bar will be
293+
be shown as the sampler progresses. If ``False``, no progress bar will be
297294
shown.
298295
progress_kwargs (Optional[dict]): A ``dict`` of keyword arguments
299-
to be passed to the tqdm call.
296+
to be passed to the progress bar implementation.
300297
skip_initial_state_check (Optional[bool]): If ``True``, a check
301298
that the initial_state can fully explore the space will be
302299
skipped. (default: ``False``)

src/emcee/pbar.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# -*- coding: utf-8 -*-
22

3-
import importlib
43
import logging
54

65
__all__ = ["get_progress_bar"]
76

87
logger = logging.getLogger(__name__)
98

109
try:
11-
import tqdm
12-
import tqdm.auto
10+
from rich.console import Console
11+
from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn
1312
except ImportError:
14-
tqdm = None
13+
Progress = None
1514

1615

1716
class _NoOpPBar(object):
@@ -30,31 +29,68 @@ def update(self, count):
3029
pass
3130

3231

32+
class _RichPBar(object):
33+
"""A wrapper that provides emcee's progress-bar interface over rich."""
34+
35+
def __init__(self, total, **kwargs):
36+
self.total = total
37+
self.description = kwargs.pop("desc", "Sampling")
38+
leave = kwargs.pop("leave", True)
39+
self.progress = None
40+
self.task_id = None
41+
42+
# leave=False means clearing the bar when complete.
43+
self.transient = not leave
44+
45+
# Preserve legacy behavior by writing to stderr by default.
46+
self.console = kwargs.pop("console", Console(stderr=True))
47+
48+
if kwargs:
49+
logger.warning(
50+
"Ignoring unsupported progress bar kwargs for rich backend: %s",
51+
", ".join(sorted(kwargs.keys())),
52+
)
53+
54+
def __enter__(self, *args, **kwargs):
55+
self.progress = Progress(
56+
TextColumn("{task.description}"),
57+
BarColumn(),
58+
TaskProgressColumn(),
59+
console=self.console,
60+
transient=self.transient,
61+
)
62+
self.progress.__enter__()
63+
self.task_id = self.progress.add_task(self.description, total=self.total)
64+
return self
65+
66+
def __exit__(self, *args, **kwargs):
67+
self.progress.__exit__(*args, **kwargs)
68+
69+
def update(self, count):
70+
self.progress.update(self.task_id, advance=count)
71+
72+
3373
def get_progress_bar(display, total, **kwargs):
3474
"""Get a progress bar interface with given properties
3575
36-
If the tqdm library is not installed, this will always return a "progress
76+
If the rich library is not installed, this will always return a "progress
3777
bar" that does nothing.
3878
3979
Args:
40-
display (bool or str): Should the bar actually show the progress? Or a
41-
string to indicate which tqdm bar (subomdule) to use.
80+
display (bool or str): Should the bar actually show the progress?
4281
total (int): The total size of the progress bar.
43-
kwargs (dict): Optional keyword arguments to be passed to the tqdm call.
82+
kwargs (dict): Optional keyword arguments to be passed to the progress
83+
bar implementation.
4484
4585
"""
4686
if display:
47-
if tqdm is None:
87+
if Progress is None:
4888
logger.warning(
49-
"You must install the tqdm library to use progress "
89+
"You must install the rich library to use progress "
5090
"indicators with emcee"
5191
)
5292
return _NoOpPBar()
5393
else:
54-
if display is True:
55-
return tqdm.auto.tqdm(total=total, **kwargs)
56-
else:
57-
tqdm_submodule = importlib.import_module(f"tqdm.{display}")
58-
return tqdm_submodule.tqdm(total=total, **kwargs)
94+
return _RichPBar(total=total, **kwargs)
5995

6096
return _NoOpPBar()

src/emcee/tests/unit/test_pbar.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
import pytest
22

3-
from emcee.pbar import _NoOpPBar, get_progress_bar
3+
from emcee.pbar import _NoOpPBar, _RichPBar, get_progress_bar
44

55
try:
6-
import tqdm
6+
import rich
77
except ImportError:
8-
tqdm = None
8+
rich = None
99

1010

1111
def test_display_false():
1212
assert isinstance(get_progress_bar(False, 100), _NoOpPBar)
1313

1414

15-
@pytest.mark.skipif(tqdm is None, reason="tqdm not available")
16-
def test_tqdm_modes():
17-
assert isinstance(get_progress_bar(True, 1000), tqdm.asyncio.tqdm_asyncio)
18-
assert isinstance(get_progress_bar("std", 1000), tqdm.std.tqdm)
19-
assert isinstance(
20-
get_progress_bar("notebook", 1000), tqdm.notebook.tqdm_notebook
21-
)
22-
assert isinstance(
23-
get_progress_bar("auto", 1000), tqdm.asyncio.tqdm_asyncio
24-
)
25-
assert isinstance(get_progress_bar("autonotebook", 1000), tqdm.std.tqdm)
15+
@pytest.mark.skipif(rich is None, reason="rich not available")
16+
def test_rich_modes():
17+
assert isinstance(get_progress_bar(True, 1000), _RichPBar)
18+
assert isinstance(get_progress_bar("std", 1000), _RichPBar)
19+
assert isinstance(get_progress_bar("notebook", 1000), _RichPBar)
20+
assert isinstance(get_progress_bar("auto", 1000), _RichPBar)
21+
assert isinstance(get_progress_bar("autonotebook", 1000), _RichPBar)

0 commit comments

Comments
 (0)