|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import sys |
16 | 15 | import typing as tp |
17 | 16 | from functools import partial |
18 | 17 |
|
19 | 18 | import numpy as np |
20 | 19 | import pandas as pd |
21 | 20 | import pytest |
22 | 21 | import torch |
23 | | - |
24 | | -try: |
25 | | - from pytorch_lightning import Trainer, seed_everything |
26 | | -except ImportError: |
27 | | - Trainer = object # type: ignore |
| 22 | +from pytorch_lightning import Trainer, seed_everything |
28 | 23 |
|
29 | 24 | from rectools import ExternalIds |
30 | 25 | from rectools.columns import Columns |
31 | 26 | from rectools.dataset import Dataset |
32 | | - |
33 | | -try: |
34 | | - from rectools.models import BERT4RecModel |
35 | | - from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
36 | | - from rectools.models.nn.transformers.base import ( |
37 | | - LearnableInversePositionalEncoding, |
38 | | - PreLNTransformerLayers, |
39 | | - TrainerCallable, |
40 | | - TransformerLightningModule, |
41 | | - ) |
42 | | - from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable |
43 | | -except ImportError: |
44 | | - TrainerCallable = object # type: ignore |
45 | | - BERT4RecDataPreparator = object # type: ignore |
| 27 | +from rectools.models import BERT4RecModel |
| 28 | +from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
| 29 | +from rectools.models.nn.transformers.base import ( |
| 30 | + LearnableInversePositionalEncoding, |
| 31 | + PreLNTransformerLayers, |
| 32 | + TrainerCallable, |
| 33 | + TransformerLightningModule, |
| 34 | +) |
| 35 | +from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable |
46 | 36 | from tests.models.data import DATASET |
47 | 37 | from tests.models.utils import ( |
48 | 38 | assert_default_config_and_default_model_params_are_the_same, |
49 | 39 | assert_second_fit_refits_model, |
50 | 40 | ) |
51 | 41 |
|
52 | | -try: |
53 | | - from .utils import custom_trainer, leave_one_out_mask |
54 | | -except NameError: |
55 | | - pass |
| 42 | +from .utils import custom_trainer, leave_one_out_mask |
56 | 43 |
|
57 | 44 |
|
58 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
59 | 45 | class TestBERT4RecModel: |
60 | 46 | def setup_method(self) -> None: |
61 | 47 | self._seed_everything() |
@@ -627,7 +613,6 @@ def _collate_fn_train( |
627 | 613 | ) |
628 | 614 |
|
629 | 615 |
|
630 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
631 | 616 | class TestBERT4RecDataPreparator: |
632 | 617 |
|
633 | 618 | def setup_method(self) -> None: |
@@ -807,7 +792,6 @@ def test_get_dataloader_val( |
807 | 792 | assert torch.equal(value, val_batch[key]) |
808 | 793 |
|
809 | 794 |
|
810 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
811 | 795 | class TestBERT4RecModelConfiguration: |
812 | 796 | def setup_method(self) -> None: |
813 | 797 | self._seed_everything() |
|
0 commit comments