diff --git a/pyproject.toml b/pyproject.toml index 35395dba..7254f804 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,8 +54,8 @@ dependencies = [ "matplotlib>3.1", "ipywidgets", "einops>=0.6.0,<0.8.0", - "rich>=11.0.0", "fsspec>=2022.5.0,<2024.4.0; python_version == '3.8'", + "rich", ] diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 999c2c4a..ad7bda85 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -540,7 +540,7 @@ class TrainerConfig: }, ) progress_bar: str = field( - default="rich", + default="simple", metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."}, ) precision: str = field( diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 0b34adf4..5c402d1a 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -12,7 +12,7 @@ from collections import defaultdict from functools import partial from pathlib import Path -from pprint import pformat +from pprint import pformat, pprint from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import joblib @@ -32,8 +32,6 @@ from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities.model_summary import summarize from pytorch_lightning.utilities.rank_zero import rank_zero_only -from rich import print as rich_print -from rich.pretty import pprint from sklearn.base import TransformerMixin from sklearn.model_selection import BaseCrossValidator, KFold, StratifiedKFold from torch import nn @@ -1675,14 +1673,14 @@ def summary(self, model=None, max_depth: int = -1) -> None: elif self.has_model: print(summarize(self.model, max_depth=max_depth)) else: - rich_print(f"[bold green]{self.__class__.__name__}[/bold green]") - rich_print("-" * 100) - rich_print("[bold yellow]Config[/bold yellow]") - rich_print("-" * 100) + print(self.__class__.__name__) + print("-" * 100) + print("Config") + print("-" * 100) pprint(self.config.__dict__["_content"]) - rich_print( - ":triangular_flag:[bold red]Full Model Summary once model has " - "been initialized or passed in as an argument[/bold red]" + print( + "⚠ Full Model Summary once model has been initialized " + "or passed in as an argument" ) def ret_summary(self, model=None, max_depth: int = -1) -> str: diff --git a/src/pytorch_tabular/utils/logger.py b/src/pytorch_tabular/utils/logger.py index 1e690caf..41e1506b 100644 --- a/src/pytorch_tabular/utils/logger.py +++ b/src/pytorch_tabular/utils/logger.py @@ -1,18 +1,22 @@ import logging import os - -from rich.logging import RichHandler +import sys def get_logger(name): logger = logging.getLogger(name) # ch = logging.StreamHandler() logger.setLevel(level=os.environ.get("PT_LOGLEVEL", "INFO")) - formatter = logging.Formatter("%(asctime)s - {%(name)s:%(lineno)d} - %(levelname)s - %(message)s") - if not logger.hasHandlers(): - ch = RichHandler(show_level=False, show_time=False, show_path=False, rich_tracebacks=True) - ch.setLevel(logging.DEBUG) - ch.setFormatter(formatter) - logger.addHandler(ch) + + if not logger.handlers: + handler = logging.StreamHandler(sys.stderr) + handler.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - {%(name)s:%(lineno)d} - %(levelname)s - %(message)s" + formatter = logging.Formatter(fmt) + handler.setFormatter(formatter) + + logger.addHandler(handler) logger.propagate = False + return logger