diff --git a/pyproject.toml b/pyproject.toml index 3dd8a037..9db2baf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ requires-python = ">=3.10,<3.15" keywords = [ "pytorch", "tabular", - "pytorch-lightning", + "lightning", "neural network", ] @@ -43,7 +43,7 @@ dependencies = [ "numpy<=3.0.0", "pandas>=1.1.5,<3.0.0", "scikit-learn>=1.3.0,<2.0", - "pytorch-lightning>=2.0.0,<2.7.0", + "lightning>=2.0.0,<2.7.0", "scipy>=1.8,<2.0", "omegaconf>=2.3.0", "torchmetrics>=0.10.0,<1.9.0", diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index ad7bda85..b23e66e2 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -364,7 +364,7 @@ class TrainerConfig: seed (int): Seed for random number generators. Defaults to 42 trainer_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch Lightning Trainer. See - https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer + https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html """ diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index 380f4cab..6eed59b9 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -11,7 +11,7 @@ from pathlib import Path import numpy as np -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn import torchmetrics @@ -217,7 +217,7 @@ def load_from_checkpoint( ): from skbase.utils.dependencies import _check_soft_dependencies - if not _check_soft_dependencies("pytorch_lightning<2.6", severity="none"): + if not _check_soft_dependencies("lightning<2.6", severity="none"): if "weights_only" not in kwargs: kwargs["weights_only"] = False else: diff --git a/src/pytorch_tabular/ssl_models/base_model.py b/src/pytorch_tabular/ssl_models/base_model.py index 4c4b0257..42683e99 100644 --- a/src/pytorch_tabular/ssl_models/base_model.py +++ b/src/pytorch_tabular/ssl_models/base_model.py @@ -8,7 +8,7 @@ from typing import Dict, Optional, Union from pathlib import Path -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn from omegaconf import DictConfig, OmegaConf @@ -162,7 +162,7 @@ def load_from_checkpoint( ): from skbase.utils.dependencies import _check_soft_dependencies - if not _check_soft_dependencies("pytorch_lightning<2.6", severity="none"): + if not _check_soft_dependencies("lightning<2.6", severity="none"): if "weights_only" not in kwargs: kwargs["weights_only"] = False else: diff --git a/src/pytorch_tabular/ssl_models/common/ssl_utils.py b/src/pytorch_tabular/ssl_models/common/ssl_utils.py index b6f03488..9ffcadd7 100644 --- a/src/pytorch_tabular/ssl_models/common/ssl_utils.py +++ b/src/pytorch_tabular/ssl_models/common/ssl_utils.py @@ -1,4 +1,4 @@ -import pytorch_lightning as pl +import lightning.pytorch as pl from torch import Tensor from pytorch_tabular.models.common import PositionWiseFeedForward diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 2baebecb..cf91e6a1 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -12,7 +12,7 @@ import joblib import numpy as np import pandas as pd -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import DictConfig from pandas import DataFrame, DatetimeTZDtype, to_datetime diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 5c402d1a..6575797c 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -18,20 +18,20 @@ import joblib import numpy as np import pandas as pd -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torchmetrics from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from pandas import DataFrame -from pytorch_lightning import seed_everything -from pytorch_lightning.callbacks import RichProgressBar -from pytorch_lightning.callbacks.gradient_accumulation_scheduler import ( +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks import RichProgressBar +from lightning.pytorch.callbacks.gradient_accumulation_scheduler import ( GradientAccumulationScheduler, ) -from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities.model_summary import summarize -from pytorch_lightning.utilities.rank_zero import rank_zero_only +from lightning.pytorch.tuner.tuning import Tuner +from lightning.pytorch.utilities.model_summary import summarize +from lightning.pytorch.utilities.rank_zero import rank_zero_only from sklearn.base import TransformerMixin from sklearn.model_selection import BaseCrossValidator, KFold, StratifiedKFold from torch import nn diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index e08503ed..340169e1 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -7,9 +7,9 @@ import torch try: # for 1.8 - from pytorch_lightning.utilities.cloud_io import get_filesystem + from lightning.pytorch.utilities.cloud_io import get_filesystem except ImportError: # for 1.9 - from pytorch_lightning.core.saving import get_filesystem + from lightning.pytorch.core.saving import get_filesystem import pytorch_tabular as root_module