|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import sys |
16 | 15 | import typing as tp |
17 | 16 | from functools import partial |
18 | 17 |
|
19 | 18 | import numpy as np |
20 | 19 | import pandas as pd |
21 | 20 | import pytest |
22 | 21 | import torch |
23 | | - |
24 | | -try: |
25 | | - from pytorch_lightning import Trainer, seed_everything |
26 | | -except ImportError: |
27 | | - Trainer = object # type: ignore |
| 22 | +from pytorch_lightning import Trainer, seed_everything |
28 | 23 |
|
29 | 24 | from rectools import ExternalIds |
30 | 25 | from rectools.columns import Columns |
31 | 26 | from rectools.dataset import Dataset |
32 | | - |
33 | | -try: |
34 | | - from rectools.models import BERT4RecModel |
35 | | - from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
36 | | - from rectools.models.nn.transformers.base import ( |
37 | | - LearnableInversePositionalEncoding, |
38 | | - PreLNTransformerLayers, |
39 | | - TrainerCallable, |
40 | | - TransformerLightningModule, |
41 | | - ) |
42 | | - from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable |
43 | | -except ImportError: |
44 | | - TrainerCallable = object # type: ignore |
45 | | - BERT4RecDataPreparator = object # type: ignore |
46 | | -from tests.models.data import DATASET |
47 | | -from tests.models.utils import ( |
48 | | - assert_default_config_and_default_model_params_are_the_same, |
49 | | - assert_second_fit_refits_model, |
| 27 | +from rectools.models import BERT4RecModel |
| 28 | +from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
| 29 | +from rectools.models.nn.transformers.base import ( |
| 30 | + LearnableInversePositionalEncoding, |
| 31 | + PreLNTransformerLayers, |
| 32 | + TrainerCallable, |
| 33 | + TransformerLightningModule, |
50 | 34 | ) |
| 35 | +from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable |
51 | 36 |
|
52 | | -try: |
53 | | - from .utils import custom_trainer, leave_one_out_mask |
54 | | -except NameError: |
55 | | - pass |
| 37 | +from .utils import custom_trainer, leave_one_out_mask |
56 | 38 |
|
57 | 39 |
|
58 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
59 | 40 | class TestBERT4RecModel: |
60 | 41 | def setup_method(self) -> None: |
61 | 42 | self._seed_everything() |
@@ -627,7 +608,6 @@ def _collate_fn_train( |
627 | 608 | ) |
628 | 609 |
|
629 | 610 |
|
630 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
631 | 611 | class TestBERT4RecDataPreparator: |
632 | 612 |
|
633 | 613 | def setup_method(self) -> None: |
@@ -807,7 +787,6 @@ def test_get_dataloader_val( |
807 | 787 | assert torch.equal(value, val_batch[key]) |
808 | 788 |
|
809 | 789 |
|
810 | | -@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`pytorch_lightning` is not compatible with Python >= 3.13") |
811 | 790 | class TestBERT4RecModelConfiguration: |
812 | 791 | def setup_method(self) -> None: |
813 | 792 | self._seed_everything() |
|
0 commit comments