|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import sys |
| 16 | +import types |
15 | 17 | import typing as tp |
16 | 18 | from functools import partial |
17 | 19 |
|
18 | 20 | import numpy as np |
19 | 21 | import pandas as pd |
20 | 22 | import pytest |
21 | | -import torch |
22 | | -from pytorch_lightning import Trainer, seed_everything |
| 23 | + |
| 24 | +try: |
| 25 | + import torch |
| 26 | + from pytorch_lightning import Trainer, seed_everything |
| 27 | +except ImportError: |
| 28 | + torch = types.ModuleType("torch") |
| 29 | + torch.Tensor = object # type: ignore |
| 30 | + torch.float = object # type: ignore |
| 31 | + Trainer = object # type: ignore |
| 32 | + |
| 33 | + def tensor(*args: tp.Any, **kwargs: tp.Any) -> tp.Any: |
| 34 | + return object() |
| 35 | + |
| 36 | + torch.tensor = tensor |
23 | 37 |
|
24 | 38 | from rectools import ExternalIds |
25 | 39 | from rectools.columns import Columns |
26 | 40 | from rectools.dataset import Dataset |
27 | | -from rectools.models import BERT4RecModel |
28 | | -from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
29 | | -from rectools.models.nn.transformers.base import ( |
30 | | - LearnableInversePositionalEncoding, |
31 | | - PreLNTransformerLayers, |
32 | | - TrainerCallable, |
33 | | - TransformerLightningModule, |
34 | | -) |
35 | | -from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable |
| 41 | + |
| 42 | +try: |
| 43 | + from rectools.models import BERT4RecModel |
| 44 | + from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor |
| 45 | + from rectools.models.nn.transformers.base import ( |
| 46 | + LearnableInversePositionalEncoding, |
| 47 | + PreLNTransformerLayers, |
| 48 | + TrainerCallable, |
| 49 | + TransformerLightningModule, |
| 50 | + ) |
| 51 | + from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable |
| 52 | +except ImportError: |
| 53 | + TrainerCallable = object # type: ignore |
| 54 | + BERT4RecDataPreparator = object # type: ignore |
36 | 55 | from tests.models.data import DATASET |
37 | 56 | from tests.models.utils import ( |
38 | 57 | assert_default_config_and_default_model_params_are_the_same, |
39 | 58 | assert_second_fit_refits_model, |
40 | 59 | ) |
41 | 60 |
|
42 | | -from .utils import custom_trainer, leave_one_out_mask |
| 61 | +try: |
| 62 | + from .utils import custom_trainer, leave_one_out_mask |
| 63 | +except NameError: |
| 64 | + pass |
43 | 65 |
|
44 | 66 |
|
| 67 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
45 | 68 | class TestBERT4RecModel: |
46 | 69 | def setup_method(self) -> None: |
47 | 70 | self._seed_everything() |
@@ -119,27 +142,33 @@ def get_trainer() -> Trainer: |
119 | 142 | "cpu", |
120 | 143 | 1, |
121 | 144 | "cuda", |
122 | | - marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"), |
| 145 | + marks=pytest.mark.skipif( |
| 146 | + sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available" |
| 147 | + ), |
123 | 148 | ), |
124 | 149 | ("cpu", 2, "cpu"), |
125 | 150 | pytest.param( |
126 | 151 | "gpu", |
127 | 152 | 1, |
128 | 153 | "cpu", |
129 | | - marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"), |
| 154 | + marks=pytest.mark.skipif( |
| 155 | + sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available" |
| 156 | + ), |
130 | 157 | ), |
131 | 158 | pytest.param( |
132 | 159 | "gpu", |
133 | 160 | 1, |
134 | 161 | "cuda", |
135 | | - marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"), |
| 162 | + marks=pytest.mark.skipif( |
| 163 | + sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available" |
| 164 | + ), |
136 | 165 | ), |
137 | 166 | pytest.param( |
138 | 167 | "gpu", |
139 | 168 | 2, |
140 | 169 | "cpu", |
141 | 170 | marks=pytest.mark.skipif( |
142 | | - torch.cuda.is_available() is False or torch.cuda.device_count() < 2, |
| 171 | + sys.version_info >= (3, 13) or torch.cuda.is_available() is False or torch.cuda.device_count() < 2, |
143 | 172 | reason="GPU is not available or there is only one gpu device", |
144 | 173 | ), |
145 | 174 | ), |
@@ -613,6 +642,7 @@ def _collate_fn_train( |
613 | 642 | ) |
614 | 643 |
|
615 | 644 |
|
| 645 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
616 | 646 | class TestBERT4RecDataPreparator: |
617 | 647 |
|
618 | 648 | def setup_method(self) -> None: |
@@ -792,6 +822,7 @@ def test_get_dataloader_val( |
792 | 822 | assert torch.equal(value, val_batch[key]) |
793 | 823 |
|
794 | 824 |
|
| 825 | +@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13") |
795 | 826 | class TestBERT4RecModelConfiguration: |
796 | 827 | def setup_method(self) -> None: |
797 | 828 | self._seed_everything() |
|
0 commit comments