11# -*- coding: utf-8 -*-
22
3- import importlib
43import logging
54
65__all__ = ["get_progress_bar" ]
76
87logger = logging .getLogger (__name__ )
98
109try :
11- import tqdm
12- import tqdm . auto
10+ from rich . console import Console
11+ from rich . progress import BarColumn , Progress , TaskProgressColumn , TextColumn
1312except ImportError :
14- tqdm = None
13+ Progress = None
1514
1615
1716class _NoOpPBar (object ):
@@ -30,31 +29,68 @@ def update(self, count):
3029 pass
3130
3231
32+ class _RichPBar (object ):
33+ """A wrapper that provides emcee's progress-bar interface over rich."""
34+
35+ def __init__ (self , total , ** kwargs ):
36+ self .total = total
37+ self .description = kwargs .pop ("desc" , "Sampling" )
38+ leave = kwargs .pop ("leave" , True )
39+ self .progress = None
40+ self .task_id = None
41+
42+ # leave=False means clearing the bar when complete.
43+ self .transient = not leave
44+
45+ # Preserve legacy behavior by writing to stderr by default.
46+ self .console = kwargs .pop ("console" , Console (stderr = True ))
47+
48+ if kwargs :
49+ logger .warning (
50+ "Ignoring unsupported progress bar kwargs for rich backend: %s" ,
51+ ", " .join (sorted (kwargs .keys ())),
52+ )
53+
54+ def __enter__ (self , * args , ** kwargs ):
55+ self .progress = Progress (
56+ TextColumn ("{task.description}" ),
57+ BarColumn (),
58+ TaskProgressColumn (),
59+ console = self .console ,
60+ transient = self .transient ,
61+ )
62+ self .progress .__enter__ ()
63+ self .task_id = self .progress .add_task (self .description , total = self .total )
64+ return self
65+
66+ def __exit__ (self , * args , ** kwargs ):
67+ self .progress .__exit__ (* args , ** kwargs )
68+
69+ def update (self , count ):
70+ self .progress .update (self .task_id , advance = count )
71+
72+
3373def get_progress_bar (display , total , ** kwargs ):
3474 """Get a progress bar interface with given properties
3575
36- If the tqdm library is not installed, this will always return a "progress
76+ If the rich library is not installed, this will always return a "progress
3777 bar" that does nothing.
3878
3979 Args:
40- display (bool or str): Should the bar actually show the progress? Or a
41- string to indicate which tqdm bar (subomdule) to use.
80+ display (bool or str): Should the bar actually show the progress?
4281 total (int): The total size of the progress bar.
43- kwargs (dict): Optional keyword arguments to be passed to the tqdm call.
82+ kwargs (dict): Optional keyword arguments to be passed to the progress
83+ bar implementation.
4484
4585 """
4686 if display :
47- if tqdm is None :
87+ if Progress is None :
4888 logger .warning (
49- "You must install the tqdm library to use progress "
89+ "You must install the rich library to use progress "
5090 "indicators with emcee"
5191 )
5292 return _NoOpPBar ()
5393 else :
54- if display is True :
55- return tqdm .auto .tqdm (total = total , ** kwargs )
56- else :
57- tqdm_submodule = importlib .import_module (f"tqdm.{ display } " )
58- return tqdm_submodule .tqdm (total = total , ** kwargs )
94+ return _RichPBar (total = total , ** kwargs )
5995
6096 return _NoOpPBar ()
0 commit comments