Skip to content

Commit d2d21da

Browse files
authored
&phoeenniixx [BUG] change progress bar default to simple and decouple rich core dependency to avoid default failure (#601)
changes progress bar default to `simple` and remove `rich` core dependency in the model progress bar. Fixes #600, acting on the suspicion of @phoeenniixx that the `rich` based progress bar is the culprit. No deprecation is needed for the change in default, since progress bars are purely visual, and the `rich` based one is replaced by on-board python stdout.
1 parent b32cca9 commit d2d21da

4 files changed

Lines changed: 22 additions & 20 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ dependencies = [
5454
"matplotlib>3.1",
5555
"ipywidgets",
5656
"einops>=0.6.0,<0.8.0",
57-
"rich>=11.0.0",
5857
"fsspec>=2022.5.0,<2024.4.0; python_version == '3.8'",
58+
"rich",
5959
]
6060

6161

src/pytorch_tabular/config/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ class TrainerConfig:
540540
},
541541
)
542542
progress_bar: str = field(
543-
default="rich",
543+
default="simple",
544544
metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."},
545545
)
546546
precision: str = field(

src/pytorch_tabular/tabular_model.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections import defaultdict
1313
from functools import partial
1414
from pathlib import Path
15-
from pprint import pformat
15+
from pprint import pformat, pprint
1616
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
1717

1818
import joblib
@@ -32,8 +32,6 @@
3232
from pytorch_lightning.tuner.tuning import Tuner
3333
from pytorch_lightning.utilities.model_summary import summarize
3434
from pytorch_lightning.utilities.rank_zero import rank_zero_only
35-
from rich import print as rich_print
36-
from rich.pretty import pprint
3735
from sklearn.base import TransformerMixin
3836
from sklearn.model_selection import BaseCrossValidator, KFold, StratifiedKFold
3937
from torch import nn
@@ -1675,14 +1673,14 @@ def summary(self, model=None, max_depth: int = -1) -> None:
16751673
elif self.has_model:
16761674
print(summarize(self.model, max_depth=max_depth))
16771675
else:
1678-
rich_print(f"[bold green]{self.__class__.__name__}[/bold green]")
1679-
rich_print("-" * 100)
1680-
rich_print("[bold yellow]Config[/bold yellow]")
1681-
rich_print("-" * 100)
1676+
print(self.__class__.__name__)
1677+
print("-" * 100)
1678+
print("Config")
1679+
print("-" * 100)
16821680
pprint(self.config.__dict__["_content"])
1683-
rich_print(
1684-
":triangular_flag:[bold red]Full Model Summary once model has "
1685-
"been initialized or passed in as an argument[/bold red]"
1681+
print(
1682+
"Full Model Summary once model has been initialized "
1683+
"or passed in as an argument"
16861684
)
16871685

16881686
def ret_summary(self, model=None, max_depth: int = -1) -> str:
Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import logging
22
import os
3-
4-
from rich.logging import RichHandler
3+
import sys
54

65

76
def get_logger(name):
87
logger = logging.getLogger(name)
98
# ch = logging.StreamHandler()
109
logger.setLevel(level=os.environ.get("PT_LOGLEVEL", "INFO"))
11-
formatter = logging.Formatter("%(asctime)s - {%(name)s:%(lineno)d} - %(levelname)s - %(message)s")
12-
if not logger.hasHandlers():
13-
ch = RichHandler(show_level=False, show_time=False, show_path=False, rich_tracebacks=True)
14-
ch.setLevel(logging.DEBUG)
15-
ch.setFormatter(formatter)
16-
logger.addHandler(ch)
10+
11+
if not logger.handlers:
12+
handler = logging.StreamHandler(sys.stderr)
13+
handler.setLevel(logging.DEBUG)
14+
15+
fmt = "%(asctime)s - {%(name)s:%(lineno)d} - %(levelname)s - %(message)s"
16+
formatter = logging.Formatter(fmt)
17+
handler.setFormatter(formatter)
18+
19+
logger.addHandler(handler)
1720
logger.propagate = False
21+
1822
return logger

0 commit comments

Comments
 (0)