Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion binder/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
- h5py
- matplotlib
- corner
- tqdm
- rich
- mpi4py
- schwimmbad
- pip
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ exclude_lines = [
"logging.warning",
"deprecation_warning",
"deprecated",
"if tqdm is None"
"if Progress is None"
]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"wheel",
]
EXTRA_REQUIRE = {
"extras": ["h5py", "scipy", "tqdm", "ipywidgets"],
"extras": ["h5py", "scipy", "rich", "ipywidgets"],
"tests": ["pytest", "pytest-cov", "coverage[toml]"],
}

Expand Down
7 changes: 2 additions & 5 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,10 @@ def sample(
a file or if you don't need to analyze the samples after the
fact (for burn-in for example) set ``store`` to ``False``.
progress (Optional[bool or str]): If ``True``, a progress bar will
be shown as the sampler progresses. If a string, will select a
specific ``tqdm`` progress bar - most notable is
``'notebook'``, which shows a progress bar suitable for
Jupyter notebooks. If ``False``, no progress bar will be
be shown as the sampler progresses. If ``False``, no progress bar will be
shown.
progress_kwargs (Optional[dict]): A ``dict`` of keyword arguments
to be passed to the tqdm call.
to be passed to the progress bar implementation.
skip_initial_state_check (Optional[bool]): If ``True``, a check
that the initial_state can fully explore the space will be
skipped. (default: ``False``)
Expand Down
116 changes: 101 additions & 15 deletions src/emcee/pbar.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,59 @@
# -*- coding: utf-8 -*-

import importlib
import logging
from datetime import timedelta

__all__ = ["get_progress_bar"]

logger = logging.getLogger(__name__)

try:
import tqdm
import tqdm.auto
from rich.console import Console
from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn
from rich.text import Text

_RICH_AVAILABLE = True
except ImportError:
tqdm = None
Progress = None
_RICH_AVAILABLE = False


def _format_timer(seconds):
if seconds is None:
return "--:--"

total_seconds = max(0, int(seconds))
td = timedelta(seconds=total_seconds)
total_seconds = int(td.total_seconds())
hours, remainder = divmod(total_seconds, 3600)
minutes, secs = divmod(remainder, 60)

if hours > 0:
return f"{hours:d}:{minutes:02d}:{secs:02d}"
return f"{minutes:02d}:{secs:02d}"


if _RICH_AVAILABLE:

class _ElapsedTimeColumn(ProgressColumn):
def render(self, task):
return Text(
f"elapsed {_format_timer(task.elapsed)}",
style="yellow",
)

class _RemainingTimeColumn(ProgressColumn):
def render(self, task):
if task.finished:
return Text("", style="blue")
return Text(
f"left {_format_timer(task.time_remaining)}",
style="blue",
)

else:
_ElapsedTimeColumn = object
_RemainingTimeColumn = object


class _NoOpPBar(object):
Expand All @@ -30,31 +72,75 @@ def update(self, count):
pass


class _RichPBar(object):
"""A wrapper that provides emcee's progress-bar interface over rich."""

def __init__(self, total, **kwargs):
self.total = total
self.description = kwargs.pop("desc", "Sampling")
leave = kwargs.pop("leave", True)
self.progress = None
self.task_id = None

# leave=False means clearing the bar when complete.
self.transient = not leave

# Preserve legacy behavior by writing to stderr by default.
self.console = kwargs.pop("console", Console(stderr=True))

if kwargs:
logger.warning(
"Ignoring unsupported progress bar kwargs for rich backend: %s",
", ".join(sorted(kwargs.keys())),
)

def __enter__(self, *args, **kwargs):
assert Progress is not None
self.progress = Progress(
TextColumn("{task.description}"),
BarColumn(),
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
_ElapsedTimeColumn(),
_RemainingTimeColumn(),
console=self.console,
transient=self.transient,
)
self.progress.__enter__()
self.task_id = self.progress.add_task(
self.description, total=self.total
)
return self

def __exit__(self, *args, **kwargs):
assert self.progress is not None
self.progress.__exit__(*args, **kwargs)

def update(self, count):
assert self.progress is not None
self.progress.update(self.task_id, advance=count)


def get_progress_bar(display, total, **kwargs):
"""Get a progress bar interface with given properties

If the tqdm library is not installed, this will always return a "progress
If the rich library is not installed, this will always return a "progress
bar" that does nothing.

Args:
display (bool or str): Should the bar actually show the progress? Or a
string to indicate which tqdm bar (subomdule) to use.
display (bool or str): Should the bar actually show the progress?
total (int): The total size of the progress bar.
kwargs (dict): Optional keyword arguments to be passed to the tqdm call.
kwargs (dict): Optional keyword arguments to be passed to the progress
bar implementation.

"""
if display:
if tqdm is None:
if Progress is None:
logger.warning(
"You must install the tqdm library to use progress "
"You must install the rich library to use progress "
"indicators with emcee"
)
return _NoOpPBar()
else:
if display is True:
return tqdm.auto.tqdm(total=total, **kwargs)
else:
tqdm_submodule = importlib.import_module(f"tqdm.{display}")
return tqdm_submodule.tqdm(total=total, **kwargs)
return _RichPBar(total=total, **kwargs)

return _NoOpPBar()
45 changes: 31 additions & 14 deletions src/emcee/tests/unit/test_pbar.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,42 @@
import pytest

from emcee.pbar import _NoOpPBar, get_progress_bar
from emcee.pbar import (
_ElapsedTimeColumn,
_NoOpPBar,
_RemainingTimeColumn,
_RichPBar,
get_progress_bar,
)

try:
import tqdm
import rich
except ImportError:
tqdm = None
rich = None


def test_display_false():
assert isinstance(get_progress_bar(False, 100), _NoOpPBar)


@pytest.mark.skipif(tqdm is None, reason="tqdm not available")
def test_tqdm_modes():
assert isinstance(get_progress_bar(True, 1000), tqdm.asyncio.tqdm_asyncio)
assert isinstance(get_progress_bar("std", 1000), tqdm.std.tqdm)
assert isinstance(
get_progress_bar("notebook", 1000), tqdm.notebook.tqdm_notebook
)
assert isinstance(
get_progress_bar("auto", 1000), tqdm.asyncio.tqdm_asyncio
)
assert isinstance(get_progress_bar("autonotebook", 1000), tqdm.std.tqdm)
@pytest.mark.skipif(rich is None, reason="rich not available")
def test_rich_modes():
assert isinstance(get_progress_bar(True, 1000), _RichPBar)
assert isinstance(get_progress_bar("std", 1000), _RichPBar)
assert isinstance(get_progress_bar("notebook", 1000), _RichPBar)
assert isinstance(get_progress_bar("auto", 1000), _RichPBar)
assert isinstance(get_progress_bar("autonotebook", 1000), _RichPBar)


@pytest.mark.skipif(rich is None, reason="rich not available")
def test_rich_progress_includes_elapsed_and_remaining():
pbar = get_progress_bar(True, 1000)

with pbar:
assert any(
isinstance(column, _ElapsedTimeColumn)
for column in pbar.progress.columns
)
assert any(
isinstance(column, _RemainingTimeColumn)
for column in pbar.progress.columns
)