From fa371e11d4d16ae38d272acf25672714b178b04c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 24 Jan 2026 23:34:35 +0100 Subject: [PATCH 1/4] simple progr bar --- pyproject.toml | 1 - src/pytorch_tabular/config/config.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 35395dba..c7028f03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ dependencies = [ "matplotlib>3.1", "ipywidgets", "einops>=0.6.0,<0.8.0", - "rich>=11.0.0", "fsspec>=2022.5.0,<2024.4.0; python_version == '3.8'", ] diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 999c2c4a..ad7bda85 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -540,7 +540,7 @@ class TrainerConfig: }, ) progress_bar: str = field( - default="rich", + default="simple", metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."}, ) precision: str = field( From 7e3dfe3c1e7ef5eda950ee7304373a462b62c5ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 24 Jan 2026 23:40:20 +0100 Subject: [PATCH 2/4] Update logger.py --- src/pytorch_tabular/utils/logger.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/pytorch_tabular/utils/logger.py b/src/pytorch_tabular/utils/logger.py index 1e690caf..41e1506b 100644 --- a/src/pytorch_tabular/utils/logger.py +++ b/src/pytorch_tabular/utils/logger.py @@ -1,18 +1,22 @@ import logging import os - -from rich.logging import RichHandler +import sys def get_logger(name): logger = logging.getLogger(name) # ch = logging.StreamHandler() logger.setLevel(level=os.environ.get("PT_LOGLEVEL", "INFO")) - formatter = logging.Formatter("%(asctime)s - {%(name)s:%(lineno)d} - %(levelname)s - %(message)s") - if not logger.hasHandlers(): - ch = RichHandler(show_level=False, show_time=False, show_path=False, rich_tracebacks=True) - ch.setLevel(logging.DEBUG) - ch.setFormatter(formatter) - logger.addHandler(ch) + + if not logger.handlers: + handler = logging.StreamHandler(sys.stderr) + handler.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - {%(name)s:%(lineno)d} - %(levelname)s - %(message)s" + formatter = logging.Formatter(fmt) + handler.setFormatter(formatter) + + logger.addHandler(handler) logger.propagate = False + return logger From 01beea5f6f0e8cedc0b9a93cd43b19d383859899 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 24 Jan 2026 23:43:48 +0100 Subject: [PATCH 3/4] Update tabular_model.py --- src/pytorch_tabular/tabular_model.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 0b34adf4..5c402d1a 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -12,7 +12,7 @@ from collections import defaultdict from functools import partial from pathlib import Path -from pprint import pformat +from pprint import pformat, pprint from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import joblib @@ -32,8 +32,6 @@ from pytorch_lightning.tuner.tuning import Tuner from pytorch_lightning.utilities.model_summary import summarize from pytorch_lightning.utilities.rank_zero import rank_zero_only -from rich import print as rich_print -from rich.pretty import pprint from sklearn.base import TransformerMixin from sklearn.model_selection import BaseCrossValidator, KFold, StratifiedKFold from torch import nn @@ -1675,14 +1673,14 @@ def summary(self, model=None, max_depth: int = -1) -> None: elif self.has_model: print(summarize(self.model, max_depth=max_depth)) else: - rich_print(f"[bold green]{self.__class__.__name__}[/bold green]") - rich_print("-" * 100) - rich_print("[bold yellow]Config[/bold yellow]") - rich_print("-" * 100) + print(self.__class__.__name__) + print("-" * 100) + print("Config") + print("-" * 100) pprint(self.config.__dict__["_content"]) - rich_print( - ":triangular_flag:[bold red]Full Model Summary once model has " - "been initialized or passed in as an argument[/bold red]" + print( + "⚠ Full Model Summary once model has been initialized " + "or passed in as an argument" ) def ret_summary(self, model=None, max_depth: int = -1) -> str: From 80db43315da5637bd2352fa6f1ccdc1dbc9e3149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 24 Jan 2026 23:46:23 +0100 Subject: [PATCH 4/4] Update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c7028f03..7254f804 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "ipywidgets", "einops>=0.6.0,<0.8.0", "fsspec>=2022.5.0,<2024.4.0; python_version == '3.8'", + "rich", ]