Skip to content

Commit faf3acb

Browse files
committed
Support pytorch_lightning with torch 2.6
1 parent 697fe39 commit faf3acb

9 files changed

Lines changed: 75 additions & 146 deletions

File tree

poetry.lock

Lines changed: 33 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ torch = [
8888
{version = ">=1.6.0, <3.0.0", python = "<3.13", optional = true},
8989
{version = ">=2.6.0, <3.0.0", python = ">=3.13", optional = true},
9090
]
91-
pytorch-lightning = {version = ">=1.6.0, <3.0.0", python = "<3.13", optional = true}
91+
pytorch-lightning = [
92+
{version = ">=1.6.0, <3.0.0", python = "<3.13", optional = true},
93+
{version = ">=2.5.1, <3.0.0", python = ">=3.13", optional = true},
94+
]
9295

9396
ipywidgets = {version = ">=7.7,<8.2", optional = true}
9497
plotly = {version="^5.22.0", optional = true}

tests/models/nn/test_dssm.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,46 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import sys
1615
import typing as tp
1716

1817
import numpy as np
1918
import pandas as pd
2019
import pytest
21-
22-
try:
23-
from lightning_fabric import seed_everything
24-
except ImportError:
25-
pass
26-
27-
try:
28-
import pytorch_lightning # noqa # pylint: disable=unused-import
29-
30-
filter_warnings_decorator = pytest.mark.filterwarnings(
31-
"ignore::pytorch_lightning.utilities.warnings.PossibleUserWarning"
32-
)
33-
except ImportError:
34-
35-
def filter_warnings_decorator(func): # type: ignore
36-
return func
37-
20+
import pytorch_lightning # noqa # pylint: disable=unused-import
21+
from lightning_fabric import seed_everything
3822

3923
from rectools.columns import Columns
4024
from rectools.dataset import Dataset
4125
from rectools.exceptions import NotFittedError
42-
43-
try:
44-
from rectools.models import DSSMModel
45-
from rectools.models.nn.dssm import DSSM
46-
except ModuleNotFoundError:
47-
pass
26+
from rectools.models import DSSMModel
27+
from rectools.models.nn.dssm import DSSM
4828
from rectools.models.vector import ImplicitRanker
4929
from tests.models.utils import assert_dumps_loads_do_not_change_model, assert_second_fit_refits_model
5030

5131
from ..data import INTERACTIONS
5232

53-
pytestmark = pytest.mark.skipif(
54-
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
33+
filter_warnings_decorator = pytest.mark.filterwarnings(
34+
"ignore::pytorch_lightning.utilities.warnings.PossibleUserWarning"
5535
)
5636

5737

tests/models/nn/test_item_net.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@
1919
import pandas as pd
2020
import pytest
2121
import torch
22-
23-
try:
24-
from pytorch_lightning import seed_everything
25-
except ImportError:
26-
pass
22+
from pytorch_lightning import seed_everything
2723

2824
from rectools.columns import Columns
2925
from rectools.dataset import Dataset
@@ -38,10 +34,6 @@
3834

3935
from ..data import DATASET, INTERACTIONS
4036

41-
pytestmark = pytest.mark.skipif(
42-
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
43-
)
44-
4537

4638
class TestIdEmbeddingsItemNet:
4739
def setup_method(self) -> None:

tests/models/nn/transformers/test_base.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,25 @@
1313
# limitations under the License.
1414

1515
import os
16-
import sys
1716
import typing as tp
1817
from tempfile import NamedTemporaryFile
1918

2019
import pandas as pd
2120
import pytest
2221
import torch
2322
from pytest import FixtureRequest
24-
25-
try:
26-
from pytorch_lightning import Trainer, seed_everything
27-
from pytorch_lightning.loggers import CSVLogger
28-
29-
except ImportError:
30-
Trainer = object # type: ignore
23+
from pytorch_lightning import Trainer, seed_everything
24+
from pytorch_lightning.loggers import CSVLogger
3125

3226
from rectools import Columns
3327
from rectools.dataset import Dataset
34-
35-
try:
36-
from rectools.models import BERT4RecModel, SASRecModel, load_model
37-
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet
38-
from rectools.models.nn.transformers.base import TransformerModelBase
39-
except ImportError:
40-
TransformerModelBase = object # type: ignore
28+
from rectools.models import BERT4RecModel, SASRecModel, load_model
29+
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet
30+
from rectools.models.nn.transformers.base import TransformerModelBase
4131
from tests.models.data import INTERACTIONS
4232
from tests.models.utils import assert_save_load_do_not_change_model
4333

44-
try:
45-
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
46-
except NameError:
47-
pass
48-
49-
pytestmark = pytest.mark.skipif(
50-
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
51-
)
34+
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
5235

5336

5437
class TestTransformerModelBase:

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,50 +12,31 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import sys
1615
import typing as tp
1716
from functools import partial
1817

1918
import numpy as np
2019
import pandas as pd
2120
import pytest
2221
import torch
23-
24-
try:
25-
from pytorch_lightning import Trainer, seed_everything
26-
except ImportError:
27-
Trainer = object # type: ignore
22+
from pytorch_lightning import Trainer, seed_everything
2823

2924
from rectools import ExternalIds
3025
from rectools.columns import Columns
3126
from rectools.dataset import Dataset
32-
33-
try:
34-
from rectools.models import BERT4RecModel
35-
from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
36-
from rectools.models.nn.transformers.base import (
37-
LearnableInversePositionalEncoding,
38-
PreLNTransformerLayers,
39-
TrainerCallable,
40-
TransformerLightningModule,
41-
)
42-
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
43-
except ImportError:
44-
TrainerCallable = object # type: ignore
45-
BERT4RecDataPreparator = object # type: ignore
46-
from tests.models.data import DATASET
47-
from tests.models.utils import (
48-
assert_default_config_and_default_model_params_are_the_same,
49-
assert_second_fit_refits_model,
27+
from rectools.models import BERT4RecModel
28+
from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
29+
from rectools.models.nn.transformers.base import (
30+
LearnableInversePositionalEncoding,
31+
PreLNTransformerLayers,
32+
TrainerCallable,
33+
TransformerLightningModule,
5034
)
35+
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
5136

52-
try:
53-
from .utils import custom_trainer, leave_one_out_mask
54-
except NameError:
55-
pass
37+
from .utils import custom_trainer, leave_one_out_mask
5638

5739

58-
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13")
5940
class TestBERT4RecModel:
6041
def setup_method(self) -> None:
6142
self._seed_everything()
@@ -627,7 +608,6 @@ def _collate_fn_train(
627608
)
628609

629610

630-
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13")
631611
class TestBERT4RecDataPreparator:
632612

633613
def setup_method(self) -> None:
@@ -807,7 +787,6 @@ def test_get_dataloader_val(
807787
assert torch.equal(value, val_batch[key])
808788

809789

810-
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13")
811790
class TestBERT4RecModelConfiguration:
812791
def setup_method(self) -> None:
813792
self._seed_everything()

tests/models/nn/transformers/test_sasrec.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,52 +14,35 @@
1414

1515
# pylint: disable=too-many-lines
1616

17-
import sys
1817
import typing as tp
1918
from functools import partial
2019

2120
import numpy as np
2221
import pandas as pd
2322
import pytest
2423
import torch
25-
26-
try:
27-
from pytorch_lightning import Trainer, seed_everything
28-
except ImportError:
29-
Trainer = object # type: ignore
24+
from pytorch_lightning import Trainer, seed_everything
3025

3126
from rectools import ExternalIds
3227
from rectools.columns import Columns
3328
from rectools.dataset import Dataset, IdMap, Interactions
34-
35-
try:
36-
from rectools.models import SASRecModel
37-
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
38-
from rectools.models.nn.transformers.base import (
39-
LearnableInversePositionalEncoding,
40-
TrainerCallable,
41-
TransformerLightningModule,
42-
TransformerTorchBackbone,
43-
)
44-
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
45-
except ImportError:
46-
TrainerCallable = object # type: ignore
47-
SASRecDataPreparator = object # type: ignore
29+
from rectools.models import SASRecModel
30+
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
31+
from rectools.models.nn.transformers.base import (
32+
LearnableInversePositionalEncoding,
33+
TrainerCallable,
34+
TransformerLightningModule,
35+
TransformerTorchBackbone,
36+
)
37+
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
4838
from tests.models.data import DATASET
4939
from tests.models.utils import (
5040
assert_default_config_and_default_model_params_are_the_same,
5141
assert_second_fit_refits_model,
5242
)
5343
from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal
5444

55-
try:
56-
from .utils import custom_trainer, leave_one_out_mask
57-
except NameError:
58-
pass
59-
60-
pytestmark = pytest.mark.skipif(
61-
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
62-
)
45+
from .utils import custom_trainer, leave_one_out_mask
6346

6447

6548
class TestSASRecModel:

tests/models/nn/transformers/utils.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import sys
16-
1715
import pandas as pd
18-
import pytest
19-
20-
try:
21-
from pytorch_lightning import Trainer
22-
from pytorch_lightning.callbacks import ModelCheckpoint
23-
except ImportError:
24-
pass
16+
from pytorch_lightning import Trainer
17+
from pytorch_lightning.callbacks import ModelCheckpoint
2518

2619
from rectools import Columns
2720

28-
pytestmark = pytest.mark.skipif(
29-
sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13"
30-
)
31-
3221

3322
def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:
3423
rank = (

0 commit comments

Comments
 (0)