Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ requires-python = ">=3.10,<3.15"
keywords = [
"pytorch",
"tabular",
"pytorch-lightning",
"lightning",
"neural network",
]

Expand Down Expand Up @@ -43,7 +43,7 @@ dependencies = [
"numpy<=3.0.0",
"pandas>=1.1.5,<3.0.0",
"scikit-learn>=1.3.0,<2.0",
"pytorch-lightning>=2.0.0,<2.7.0",
"lightning>=2.0.0,<2.7.0",
"scipy>=1.8,<2.0",
"omegaconf>=2.3.0",
"torchmetrics>=0.10.0,<1.9.0",
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class TrainerConfig:
seed (int): Seed for random number generators. Defaults to 42

trainer_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch Lightning Trainer. See
https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html

"""

Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_tabular/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torchmetrics
Expand Down Expand Up @@ -217,7 +217,7 @@ def load_from_checkpoint(
):
from skbase.utils.dependencies import _check_soft_dependencies

if not _check_soft_dependencies("pytorch_lightning<2.6", severity="none"):
if not _check_soft_dependencies("lightning<2.6", severity="none"):
if "weights_only" not in kwargs:
kwargs["weights_only"] = False
else:
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_tabular/ssl_models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Dict, Optional, Union
from pathlib import Path

import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -162,7 +162,7 @@ def load_from_checkpoint(
):
from skbase.utils.dependencies import _check_soft_dependencies

if not _check_soft_dependencies("pytorch_lightning<2.6", severity="none"):
if not _check_soft_dependencies("lightning<2.6", severity="none"):
if "weights_only" not in kwargs:
kwargs["weights_only"] = False
else:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_tabular/ssl_models/common/ssl_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pytorch_lightning as pl
import lightning.pytorch as pl
from torch import Tensor

from pytorch_tabular.models.common import PositionWiseFeedForward
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import joblib
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
from omegaconf import DictConfig
from pandas import DataFrame, DatetimeTZDtype, to_datetime
Expand Down
14 changes: 7 additions & 7 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@
import joblib
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
import torchmetrics
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pandas import DataFrame
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import (
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.callbacks.gradient_accumulation_scheduler import (
GradientAccumulationScheduler,
)
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities.model_summary import summarize
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from lightning.pytorch.tuner.tuning import Tuner
from lightning.pytorch.utilities.model_summary import summarize
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from sklearn.base import TransformerMixin
from sklearn.model_selection import BaseCrossValidator, KFold, StratifiedKFold
from torch import nn
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_tabular/utils/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch

try: # for 1.8
from pytorch_lightning.utilities.cloud_io import get_filesystem
from lightning.pytorch.utilities.cloud_io import get_filesystem
except ImportError: # for 1.9
from pytorch_lightning.core.saving import get_filesystem
from lightning.pytorch.core.saving import get_filesystem

import pytorch_tabular as root_module

Expand Down