Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ dependencies = [
"omegaconf>=2.3.0",
"torchmetrics>=0.10.0,<1.9.0",
"einops>=0.6.0,<0.9.0",
"rich>=11.0.0",
"scikit-base",
]

Expand All @@ -73,6 +72,7 @@ extra = [
"kaleido>=0.2.0,<0.3.0",
"captum>=0.5.0,<0.8.0",
"pytorch-tabnet<4.2",
"rich>=11.0.0",
]

notebooks = [
Expand Down
5 changes: 2 additions & 3 deletions src/pytorch_tabular/categorical_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pickle

import numpy as np
from rich.progress import track
from pytorch_tabular.utils.progress import get_progress_tracker
from sklearn.base import BaseEstimator, TransformerMixin

from pytorch_tabular.utils import get_logger
Expand Down Expand Up @@ -234,10 +234,9 @@ def transform(self, X: DataFrame, y=None) -> DataFrame:
assert all(c in X.columns for c in self.cols)

X_encoded = X.copy(deep=True)
for col, mapping in track(
for col, mapping in get_progress_tracker("none")(
self._mapping.items(),
description="Encoding the data...",
total=len(self._mapping.values()),
):
for dim in range(mapping[self.NAN_CATEGORY].shape[0]):
X_encoded.loc[:, f"{col}_embed_dim_{dim}"] = (
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class TrainerConfig:
track_grad_norm (int): Track and Log Gradient Norms in the logger. -1 by default means no tracking.
1 for the L1 norm, 2 for L2 norm, etc.

progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`.
progress_bar (str): Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `simple`.

precision (str): Precision of the model. Defaults to `32`. See
https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
Expand Down Expand Up @@ -541,7 +541,7 @@ class TrainerConfig:
)
progress_bar: str = field(
default="simple",
metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."},
metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `simple`."},
)
precision: str = field(
default="32",
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_tabular/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict

import pandas as pd
from rich.progress import track
from pytorch_tabular.utils.progress import get_progress_tracker
from sklearn.base import BaseEstimator, TransformerMixin

from pytorch_tabular.models import NODEModel, TabNetModel
Expand Down Expand Up @@ -65,7 +65,7 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
self.tabular_model.model.eval()
inference_dataloader = self.tabular_model.datamodule.prepare_inference_dataloader(X_encoded)
logits_predictions = defaultdict(list)
for batch in track(inference_dataloader, description="Generating Features..."):
for batch in get_progress_tracker("none")(inference_dataloader, description="Generating Features..."):
for k, v in batch.items():
if isinstance(v, list) and (len(v) == 0):
# Skipping empty list
Expand Down
31 changes: 14 additions & 17 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from omegaconf.dictconfig import DictConfig
from pandas import DataFrame
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_tabular.utils.progress import get_progress_bar_callback
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import (
GradientAccumulationScheduler,
)
Expand Down Expand Up @@ -319,8 +319,14 @@ def _prepare_callbacks(self, callbacks=None) -> List:
self.config.enable_checkpointing = True
else:
self.config.enable_checkpointing = False
if self.config.progress_bar == "rich" and self.config.trainer_kwargs.get("enable_progress_bar", True):
callbacks.append(RichProgressBar())
progress_callback = get_progress_bar_callback(
self.config.progress_bar,
self.config.trainer_kwargs.get("enable_progress_bar", True)
)
if progress_callback is not None:
callbacks.append(progress_callback)
elif self.config.progress_bar == "none":
self.config.trainer_kwargs["enable_progress_bar"] = False
if self.verbose:
logger.debug(f"Callbacks used: {callbacks}")
return callbacks
Expand Down Expand Up @@ -1230,13 +1236,13 @@ def _generate_predictions(
quantiles,
n_samples,
ret_logits,
progress_bar,
progress_tracker,
is_probabilistic,
):
point_predictions = []
quantile_predictions = []
logits_predictions = defaultdict(list)
for batch in progress_bar(inference_dataloader):
for batch in progress_tracker(inference_dataloader):
for k, v in batch.items():
if isinstance(v, list) and (len(v) == 0):
continue # Skipping empty list
Expand Down Expand Up @@ -1373,23 +1379,14 @@ def _predict(
inference_dataloader = self.datamodule.prepare_inference_dataloader(test)
is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic

if progress_bar == "rich":
from rich.progress import track

progress_bar = partial(track, description="Generating Predictions...")
elif progress_bar == "tqdm":
from tqdm.auto import tqdm

progress_bar = partial(tqdm, description="Generating Predictions...")
else:
progress_bar = lambda it: it # E731
progress_tracker = get_progress_tracker(progress_bar or "none", description="Generating Predictions...")
point_predictions, quantile_predictions, logits_predictions = self._generate_predictions(
model,
inference_dataloader,
quantiles,
n_samples,
ret_logits,
progress_bar,
progress_tracker,
is_probabilistic,
)
pred_df = self._format_predicitons(
Expand Down Expand Up @@ -1501,7 +1498,7 @@ def add_noise(module, input, output):
ret_logits,
include_input_features=False,
device=device,
progress_bar=progress_bar or "None",
progress_bar=progress_bar or "none",
)
pred_idx = pred_df.index
if self.config.task == "classification":
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_tabular/tabular_model_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import pandas as pd
from rich.progress import Progress, track
from pytorch_tabular.utils.progress import get_progress_context, get_progress_tracker
from skbase.utils.dependencies import _check_soft_dependencies

from pytorch_tabular import TabularModel, models
Expand Down Expand Up @@ -321,8 +321,8 @@ def _init_tabular_model(m):
best_model = None
is_lower_better = rank_metric[1] == "lower_is_better"
best_score = 1e9 if is_lower_better else -1e9
it = track(model_list, description="Sweeping Models") if progress_bar else model_list
ctx = Progress() if progress_bar else nullcontext()
it = get_progress_tracker("simple" if progress_bar else "none")(model_list, description="Sweeping Models")
ctx = get_progress_context("simple" if progress_bar else "none")
with ctx as progress:
if progress_bar:
task_p = progress.add_task("Sweeping Models", total=len(model_list))
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_tabular/tabular_model_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas as pd
from omegaconf.dictconfig import DictConfig
from pandas import DataFrame
from rich.progress import Progress
from pytorch_tabular.utils.progress import get_progress_context
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler

from pytorch_tabular.config import (
Expand Down Expand Up @@ -255,7 +255,7 @@ def tune(

verbose_tabular_model = self.tabular_model_init_kwargs.pop("verbose", False)

with Progress() as progress:
with get_progress_context("simple" if progress_bar else "none") as progress:
model_config_iterator = range(len(self.model_config))
if progress_bar:
model_config_iterator = progress.track(
Expand Down
100 changes: 100 additions & 0 deletions src/pytorch_tabular/utils/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Progress bar utilities for PyTorch Tabular."""

from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Iterator, Optional


class DummyProgress:
"""A dummy progress class that mimics rich.Progress but does nothing."""

def add_task(self, *args, **kwargs):
return None

def update(self, *args, **kwargs):
pass

def __enter__(self):
return self

def __exit__(self, *args):
pass

def track(self, iterable, *args, **kwargs):
return iterable


def get_progress_tracker(backend: str = "none", description: Optional[str] = None) -> Callable[[Iterator], Iterator]:
"""Get a progress tracker function based on the backend.

Args:
backend: The progress bar backend. Can be 'rich', 'tqdm', or 'none'.
description: Description for the progress bar.

Returns:
A function that takes an iterable and returns an iterator with progress tracking.
"""
if backend == "rich":
try:
from rich.progress import track
return partial(track, description=description) if description else track
except ImportError:
# Fallback to none if rich is not available
return lambda it, **kwargs: it
elif backend == "tqdm":
try:
from tqdm.auto import tqdm
return partial(tqdm, desc=description) if description else tqdm
except ImportError:
return lambda it, **kwargs: it
else: # none
return lambda it, **kwargs: it


def get_progress_context(backend: str = "none"):
"""Get a progress context manager based on the backend.

Args:
backend: The progress bar backend. Can be 'rich', 'tqdm', or 'none'.

Returns:
A context manager for progress tracking that has a track method.
"""
if backend == "rich":
try:
from rich.progress import Progress
return Progress()
except ImportError:
return DummyProgress()
elif backend == "tqdm":
# tqdm doesn't have a context manager like rich's Progress
# For now, return DummyProgress
return DummyProgress()
else:
return DummyProgress()


def get_progress_bar_callback(backend: str = "simple", enable_progress_bar: bool = True):
"""Get the appropriate PyTorch Lightning progress bar callback based on backend.

Args:
backend: The progress bar backend. Can be 'rich', 'simple', or 'none'.
enable_progress_bar: Whether progress bar is enabled in trainer kwargs.

Returns:
A PyTorch Lightning callback or None.
"""
if backend == "rich" and enable_progress_bar:
try:
from pytorch_lightning.callbacks import RichProgressBar
return RichProgressBar()
except ImportError:
return None
elif backend == "simple" and enable_progress_bar:
try:
from pytorch_lightning.callbacks import TQDMProgressBar
return TQDMProgressBar()
except ImportError:
return None
else: # none
return None