diff --git a/binder/environment.yml b/binder/environment.yml index 81495815..1c8e4d5f 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -8,7 +8,7 @@ dependencies: - h5py - matplotlib - corner - - tqdm + - rich - mpi4py - schwimmbad - pip diff --git a/pyproject.toml b/pyproject.toml index 6e314f05..bf9f26a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,5 +55,5 @@ exclude_lines = [ "logging.warning", "deprecation_warning", "deprecated", - "if tqdm is None" + "if Progress is None" ] diff --git a/setup.py b/setup.py index e2b1a91f..dc9a3dc7 100755 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "wheel", ] EXTRA_REQUIRE = { - "extras": ["h5py", "scipy", "tqdm", "ipywidgets"], + "extras": ["h5py", "scipy", "rich", "ipywidgets"], "tests": ["pytest", "pytest-cov", "coverage[toml]"], } diff --git a/src/emcee/ensemble.py b/src/emcee/ensemble.py index c71f6ee5..0ceb0211 100644 --- a/src/emcee/ensemble.py +++ b/src/emcee/ensemble.py @@ -290,13 +290,10 @@ def sample( a file or if you don't need to analyze the samples after the fact (for burn-in for example) set ``store`` to ``False``. progress (Optional[bool or str]): If ``True``, a progress bar will - be shown as the sampler progresses. If a string, will select a - specific ``tqdm`` progress bar - most notable is - ``'notebook'``, which shows a progress bar suitable for - Jupyter notebooks. If ``False``, no progress bar will be + be shown as the sampler progresses. If ``False``, no progress bar will be shown. progress_kwargs (Optional[dict]): A ``dict`` of keyword arguments - to be passed to the tqdm call. + to be passed to the progress bar implementation. skip_initial_state_check (Optional[bool]): If ``True``, a check that the initial_state can fully explore the space will be skipped. (default: ``False``) diff --git a/src/emcee/pbar.py b/src/emcee/pbar.py index a0636349..eaae2ff2 100644 --- a/src/emcee/pbar.py +++ b/src/emcee/pbar.py @@ -1,17 +1,59 @@ # -*- coding: utf-8 -*- -import importlib import logging +from datetime import timedelta __all__ = ["get_progress_bar"] logger = logging.getLogger(__name__) try: - import tqdm - import tqdm.auto + from rich.console import Console + from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn + from rich.text import Text + + _RICH_AVAILABLE = True except ImportError: - tqdm = None + Progress = None + _RICH_AVAILABLE = False + + +def _format_timer(seconds): + if seconds is None: + return "--:--" + + total_seconds = max(0, int(seconds)) + td = timedelta(seconds=total_seconds) + total_seconds = int(td.total_seconds()) + hours, remainder = divmod(total_seconds, 3600) + minutes, secs = divmod(remainder, 60) + + if hours > 0: + return f"{hours:d}:{minutes:02d}:{secs:02d}" + return f"{minutes:02d}:{secs:02d}" + + +if _RICH_AVAILABLE: + + class _ElapsedTimeColumn(ProgressColumn): + def render(self, task): + return Text( + f"elapsed {_format_timer(task.elapsed)}", + style="yellow", + ) + + class _RemainingTimeColumn(ProgressColumn): + def render(self, task): + if task.finished: + return Text("", style="blue") + return Text( + f"left {_format_timer(task.time_remaining)}", + style="blue", + ) + +else: + _ElapsedTimeColumn = object + _RemainingTimeColumn = object class _NoOpPBar(object): @@ -30,31 +72,75 @@ def update(self, count): pass +class _RichPBar(object): + """A wrapper that provides emcee's progress-bar interface over rich.""" + + def __init__(self, total, **kwargs): + self.total = total + self.description = kwargs.pop("desc", "Sampling") + leave = kwargs.pop("leave", True) + self.progress = None + self.task_id = None + + # leave=False means clearing the bar when complete. + self.transient = not leave + + # Preserve legacy behavior by writing to stderr by default. + self.console = kwargs.pop("console", Console(stderr=True)) + + if kwargs: + logger.warning( + "Ignoring unsupported progress bar kwargs for rich backend: %s", + ", ".join(sorted(kwargs.keys())), + ) + + def __enter__(self, *args, **kwargs): + assert Progress is not None + self.progress = Progress( + TextColumn("{task.description}"), + BarColumn(), + TextColumn("{task.completed:.0f}/{task.total:.0f}"), + _ElapsedTimeColumn(), + _RemainingTimeColumn(), + console=self.console, + transient=self.transient, + ) + self.progress.__enter__() + self.task_id = self.progress.add_task( + self.description, total=self.total + ) + return self + + def __exit__(self, *args, **kwargs): + assert self.progress is not None + self.progress.__exit__(*args, **kwargs) + + def update(self, count): + assert self.progress is not None + self.progress.update(self.task_id, advance=count) + + def get_progress_bar(display, total, **kwargs): """Get a progress bar interface with given properties - If the tqdm library is not installed, this will always return a "progress + If the rich library is not installed, this will always return a "progress bar" that does nothing. Args: - display (bool or str): Should the bar actually show the progress? Or a - string to indicate which tqdm bar (subomdule) to use. + display (bool or str): Should the bar actually show the progress? total (int): The total size of the progress bar. - kwargs (dict): Optional keyword arguments to be passed to the tqdm call. + kwargs (dict): Optional keyword arguments to be passed to the progress + bar implementation. """ if display: - if tqdm is None: + if Progress is None: logger.warning( - "You must install the tqdm library to use progress " + "You must install the rich library to use progress " "indicators with emcee" ) return _NoOpPBar() else: - if display is True: - return tqdm.auto.tqdm(total=total, **kwargs) - else: - tqdm_submodule = importlib.import_module(f"tqdm.{display}") - return tqdm_submodule.tqdm(total=total, **kwargs) + return _RichPBar(total=total, **kwargs) return _NoOpPBar() diff --git a/src/emcee/tests/unit/test_pbar.py b/src/emcee/tests/unit/test_pbar.py index 2ffa9bf0..115f3931 100644 --- a/src/emcee/tests/unit/test_pbar.py +++ b/src/emcee/tests/unit/test_pbar.py @@ -1,25 +1,42 @@ import pytest -from emcee.pbar import _NoOpPBar, get_progress_bar +from emcee.pbar import ( + _ElapsedTimeColumn, + _NoOpPBar, + _RemainingTimeColumn, + _RichPBar, + get_progress_bar, +) try: - import tqdm + import rich except ImportError: - tqdm = None + rich = None def test_display_false(): assert isinstance(get_progress_bar(False, 100), _NoOpPBar) -@pytest.mark.skipif(tqdm is None, reason="tqdm not available") -def test_tqdm_modes(): - assert isinstance(get_progress_bar(True, 1000), tqdm.asyncio.tqdm_asyncio) - assert isinstance(get_progress_bar("std", 1000), tqdm.std.tqdm) - assert isinstance( - get_progress_bar("notebook", 1000), tqdm.notebook.tqdm_notebook - ) - assert isinstance( - get_progress_bar("auto", 1000), tqdm.asyncio.tqdm_asyncio - ) - assert isinstance(get_progress_bar("autonotebook", 1000), tqdm.std.tqdm) +@pytest.mark.skipif(rich is None, reason="rich not available") +def test_rich_modes(): + assert isinstance(get_progress_bar(True, 1000), _RichPBar) + assert isinstance(get_progress_bar("std", 1000), _RichPBar) + assert isinstance(get_progress_bar("notebook", 1000), _RichPBar) + assert isinstance(get_progress_bar("auto", 1000), _RichPBar) + assert isinstance(get_progress_bar("autonotebook", 1000), _RichPBar) + + +@pytest.mark.skipif(rich is None, reason="rich not available") +def test_rich_progress_includes_elapsed_and_remaining(): + pbar = get_progress_bar(True, 1000) + + with pbar: + assert any( + isinstance(column, _ElapsedTimeColumn) + for column in pbar.progress.columns + ) + assert any( + isinstance(column, _RemainingTimeColumn) + for column in pbar.progress.columns + )