diff --git a/pyproject.toml b/pyproject.toml index 3dd8a037..618469ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ dependencies = [ "omegaconf>=2.3.0", "torchmetrics>=0.10.0,<1.9.0", "einops>=0.6.0,<0.9.0", - "rich>=11.0.0", "scikit-base", ] @@ -73,6 +72,7 @@ extra = [ "kaleido>=0.2.0,<0.3.0", "captum>=0.5.0,<0.8.0", "pytorch-tabnet<4.2", + "rich>=11.0.0", ] notebooks = [ diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index b3d7a1ee..00a41735 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -12,7 +12,7 @@ import pickle import numpy as np -from rich.progress import track +from pytorch_tabular.utils.progress import get_progress_tracker from sklearn.base import BaseEstimator, TransformerMixin from pytorch_tabular.utils import get_logger @@ -234,10 +234,9 @@ def transform(self, X: DataFrame, y=None) -> DataFrame: assert all(c in X.columns for c in self.cols) X_encoded = X.copy(deep=True) - for col, mapping in track( + for col, mapping in get_progress_tracker("none")( self._mapping.items(), description="Encoding the data...", - total=len(self._mapping.values()), ): for dim in range(mapping[self.NAN_CATEGORY].shape[0]): X_encoded.loc[:, f"{col}_embed_dim_{dim}"] = ( diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index ad7bda85..6ddc775f 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -356,7 +356,7 @@ class TrainerConfig: track_grad_norm (int): Track and Log Gradient Norms in the logger. -1 by default means no tracking. 1 for the L1 norm, 2 for L2 norm, etc. - progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`. + progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `simple`. precision (str): Precision of the model. Defaults to `32`. See https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision @@ -541,7 +541,7 @@ class TrainerConfig: ) progress_bar: str = field( default="simple", - metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."}, + metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `simple`."}, ) precision: str = field( default="32", diff --git a/src/pytorch_tabular/feature_extractor.py b/src/pytorch_tabular/feature_extractor.py index 424a03e4..af51bdba 100644 --- a/src/pytorch_tabular/feature_extractor.py +++ b/src/pytorch_tabular/feature_extractor.py @@ -4,7 +4,7 @@ from collections import defaultdict import pandas as pd -from rich.progress import track +from pytorch_tabular.utils.progress import get_progress_tracker from sklearn.base import BaseEstimator, TransformerMixin from pytorch_tabular.models import NODEModel, TabNetModel @@ -65,7 +65,7 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: self.tabular_model.model.eval() inference_dataloader = self.tabular_model.datamodule.prepare_inference_dataloader(X_encoded) logits_predictions = defaultdict(list) - for batch in track(inference_dataloader, description="Generating Features..."): + for batch in get_progress_tracker("none")(inference_dataloader, description="Generating Features..."): for k, v in batch.items(): if isinstance(v, list) and (len(v) == 0): # Skipping empty list diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 5c402d1a..6938e924 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -25,7 +25,7 @@ from omegaconf.dictconfig import DictConfig from pandas import DataFrame from pytorch_lightning import seed_everything -from pytorch_lightning.callbacks import RichProgressBar +from pytorch_tabular.utils.progress import get_progress_bar_callback from pytorch_lightning.callbacks.gradient_accumulation_scheduler import ( GradientAccumulationScheduler, ) @@ -319,8 +319,14 @@ def _prepare_callbacks(self, callbacks=None) -> List: self.config.enable_checkpointing = True else: self.config.enable_checkpointing = False - if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get("enable_progress_bar", True): - callbacks.append(RichProgressBar()) + progress_callback = get_progress_bar_callback( + self.config.progress_bar, + self.config.trainer_kwargs.get("enable_progress_bar", True) + ) + if progress_callback is not None: + callbacks.append(progress_callback) + elif self.config.progress_bar == "none": + self.config.trainer_kwargs["enable_progress_bar"] = False if self.verbose: logger.debug(f"Callbacks used: {callbacks}") return callbacks @@ -1230,13 +1236,13 @@ def _generate_predictions( quantiles, n_samples, ret_logits, - progress_bar, + progress_tracker, is_probabilistic, ): point_predictions = [] quantile_predictions = [] logits_predictions = defaultdict(list) - for batch in progress_bar(inference_dataloader): + for batch in progress_tracker(inference_dataloader): for k, v in batch.items(): if isinstance(v, list) and (len(v) == 0): continue # Skipping empty list @@ -1373,23 +1379,14 @@ def _predict( inference_dataloader = self.datamodule.prepare_inference_dataloader(test) is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic - if progress_bar == "rich": - from rich.progress import track - - progress_bar = partial(track, description="Generating Predictions...") - elif progress_bar == "tqdm": - from tqdm.auto import tqdm - - progress_bar = partial(tqdm, description="Generating Predictions...") - else: - progress_bar = lambda it: it # E731 + progress_tracker = get_progress_tracker(progress_bar or "none", description="Generating Predictions...") point_predictions, quantile_predictions, logits_predictions = self._generate_predictions( model, inference_dataloader, quantiles, n_samples, ret_logits, - progress_bar, + progress_tracker, is_probabilistic, ) pred_df = self._format_predicitons( @@ -1501,7 +1498,7 @@ def add_noise(module, input, output): ret_logits, include_input_features=False, device=device, - progress_bar=progress_bar or "None", + progress_bar=progress_bar or "none", ) pred_idx = pred_df.index if self.config.task == "classification": diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index fc97140e..5b654675 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from rich.progress import Progress, track +from pytorch_tabular.utils.progress import get_progress_context, get_progress_tracker from skbase.utils.dependencies import _check_soft_dependencies from pytorch_tabular import TabularModel, models @@ -321,8 +321,8 @@ def _init_tabular_model(m): best_model = None is_lower_better = rank_metric[1] == "lower_is_better" best_score = 1e9 if is_lower_better else -1e9 - it = track(model_list, description="Sweeping Models") if progress_bar else model_list - ctx = Progress() if progress_bar else nullcontext() + it = get_progress_tracker("simple" if progress_bar else "none")(model_list, description="Sweeping Models") + ctx = get_progress_context("simple" if progress_bar else "none") with ctx as progress: if progress_bar: task_p = progress.add_task("Sweeping Models", total=len(model_list)) diff --git a/src/pytorch_tabular/tabular_model_tuner.py b/src/pytorch_tabular/tabular_model_tuner.py index d199d1fb..35984739 100644 --- a/src/pytorch_tabular/tabular_model_tuner.py +++ b/src/pytorch_tabular/tabular_model_tuner.py @@ -13,7 +13,7 @@ import pandas as pd from omegaconf.dictconfig import DictConfig from pandas import DataFrame -from rich.progress import Progress +from pytorch_tabular.utils.progress import get_progress_context from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler from pytorch_tabular.config import ( @@ -255,7 +255,7 @@ def tune( verbose_tabular_model = self.tabular_model_init_kwargs.pop("verbose", False) - with Progress() as progress: + with get_progress_context("simple" if progress_bar else "none") as progress: model_config_iterator = range(len(self.model_config)) if progress_bar: model_config_iterator = progress.track( diff --git a/src/pytorch_tabular/utils/progress.py b/src/pytorch_tabular/utils/progress.py new file mode 100644 index 00000000..f098deb4 --- /dev/null +++ b/src/pytorch_tabular/utils/progress.py @@ -0,0 +1,100 @@ +"""Progress bar utilities for PyTorch Tabular.""" + +from contextlib import nullcontext +from functools import partial +from typing import Any, Callable, Iterator, Optional + + +class DummyProgress: + """A dummy progress class that mimics rich.Progress but does nothing.""" + + def add_task(self, *args, **kwargs): + return None + + def update(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def track(self, iterable, *args, **kwargs): + return iterable + + +def get_progress_tracker(backend: str = "none", description: Optional[str] = None) -> Callable[[Iterator], Iterator]: + """Get a progress tracker function based on the backend. + + Args: + backend: The progress bar backend. Can be 'rich', 'tqdm', or 'none'. + description: Description for the progress bar. + + Returns: + A function that takes an iterable and returns an iterator with progress tracking. + """ + if backend == "rich": + try: + from rich.progress import track + return partial(track, description=description) if description else track + except ImportError: + # Fallback to none if rich is not available + return lambda it, **kwargs: it + elif backend == "tqdm": + try: + from tqdm.auto import tqdm + return partial(tqdm, desc=description) if description else tqdm + except ImportError: + return lambda it, **kwargs: it + else: # none + return lambda it, **kwargs: it + + +def get_progress_context(backend: str = "none"): + """Get a progress context manager based on the backend. + + Args: + backend: The progress bar backend. Can be 'rich', 'tqdm', or 'none'. + + Returns: + A context manager for progress tracking that has a track method. + """ + if backend == "rich": + try: + from rich.progress import Progress + return Progress() + except ImportError: + return DummyProgress() + elif backend == "tqdm": + # tqdm doesn't have a context manager like rich's Progress + # For now, return DummyProgress + return DummyProgress() + else: + return DummyProgress() + + +def get_progress_bar_callback(backend: str = "simple", enable_progress_bar: bool = True): + """Get the appropriate PyTorch Lightning progress bar callback based on backend. + + Args: + backend: The progress bar backend. Can be 'rich', 'simple', or 'none'. + enable_progress_bar: Whether progress bar is enabled in trainer kwargs. + + Returns: + A PyTorch Lightning callback or None. + """ + if backend == "rich" and enable_progress_bar: + try: + from pytorch_lightning.callbacks import RichProgressBar + return RichProgressBar() + except ImportError: + return None + elif backend == "simple" and enable_progress_bar: + try: + from pytorch_lightning.callbacks import TQDMProgressBar + return TQDMProgressBar() + except ImportError: + return None + else: # none + return None \ No newline at end of file