Skip to content

Commit 03cebe9

Browse files
committed
Skip more torch
1 parent 8b97751 commit 03cebe9

9 files changed

Lines changed: 203 additions & 65 deletions

File tree

tests/models/nn/test_item_net.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,42 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617

1718
import numpy as np
1819
import pandas as pd
1920
import pytest
20-
import torch
21-
from pytorch_lightning import seed_everything
21+
22+
try:
23+
import torch
24+
from pytorch_lightning import seed_everything
25+
except ImportError:
26+
pass
2227

2328
from rectools.columns import Columns
2429
from rectools.dataset import Dataset
2530
from rectools.dataset.dataset import DatasetSchema, EntitySchema
26-
from rectools.models.nn.item_net import (
27-
CatFeaturesItemNet,
28-
IdEmbeddingsItemNet,
29-
ItemNetBase,
30-
ItemNetConstructorBase,
31-
SumOfEmbeddingsConstructor,
32-
)
31+
32+
try:
33+
from rectools.models.nn.item_net import (
34+
CatFeaturesItemNet,
35+
IdEmbeddingsItemNet,
36+
ItemNetBase,
37+
ItemNetConstructorBase,
38+
SumOfEmbeddingsConstructor,
39+
)
40+
except ImportError:
41+
CatFeaturesItemNet = object # type: ignore
42+
IdEmbeddingsItemNet = object # type: ignore
43+
ItemNetBase = object # type: ignore
44+
ItemNetConstructorBase = object # type: ignore
45+
SumOfEmbeddingsConstructor = object # type: ignore
3346

3447
from ..data import DATASET, INTERACTIONS
3548

49+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
50+
3651

3752
class TestIdEmbeddingsItemNet:
3853
def setup_method(self) -> None:

tests/models/nn/transformers/test_base.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,40 @@
1313
# limitations under the License.
1414

1515
import os
16+
import sys
1617
import typing as tp
1718
from tempfile import NamedTemporaryFile
1819

1920
import pandas as pd
2021
import pytest
21-
import torch
2222
from pytest import FixtureRequest
23-
from pytorch_lightning import Trainer, seed_everything
24-
from pytorch_lightning.loggers import CSVLogger
23+
24+
try:
25+
import torch
26+
from pytorch_lightning import Trainer, seed_everything
27+
from pytorch_lightning.loggers import CSVLogger
28+
29+
except ImportError:
30+
Trainer = object # type: ignore
2531

2632
from rectools import Columns
2733
from rectools.dataset import Dataset
28-
from rectools.models import BERT4RecModel, SASRecModel, load_model
29-
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet
30-
from rectools.models.nn.transformers.base import TransformerModelBase
34+
35+
try:
36+
from rectools.models import BERT4RecModel, SASRecModel, load_model
37+
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet
38+
from rectools.models.nn.transformers.base import TransformerModelBase
39+
except ImportError:
40+
TransformerModelBase = object # type: ignore
3141
from tests.models.data import INTERACTIONS
3242
from tests.models.utils import assert_save_load_do_not_change_model
3343

34-
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
44+
try:
45+
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
46+
except NameError:
47+
pass
48+
49+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
3550

3651

3752
class TestTransformerModelBase:

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,59 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
16+
import types
1517
import typing as tp
1618
from functools import partial
1719

1820
import numpy as np
1921
import pandas as pd
2022
import pytest
21-
import torch
22-
from pytorch_lightning import Trainer, seed_everything
23+
24+
try:
25+
import torch
26+
from pytorch_lightning import Trainer, seed_everything
27+
except ImportError:
28+
torch = types.ModuleType("torch")
29+
torch.Tensor = object # type: ignore
30+
torch.float = object # type: ignore
31+
Trainer = object # type: ignore
32+
33+
def tensor(*args: tp.Any, **kwargs: tp.Any) -> tp.Any:
34+
return object()
35+
36+
torch.tensor = tensor
2337

2438
from rectools import ExternalIds
2539
from rectools.columns import Columns
2640
from rectools.dataset import Dataset
27-
from rectools.models import BERT4RecModel
28-
from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
29-
from rectools.models.nn.transformers.base import (
30-
LearnableInversePositionalEncoding,
31-
PreLNTransformerLayers,
32-
TrainerCallable,
33-
TransformerLightningModule,
34-
)
35-
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
41+
42+
try:
43+
from rectools.models import BERT4RecModel
44+
from rectools.models.nn.item_net import IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
45+
from rectools.models.nn.transformers.base import (
46+
LearnableInversePositionalEncoding,
47+
PreLNTransformerLayers,
48+
TrainerCallable,
49+
TransformerLightningModule,
50+
)
51+
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
52+
except ImportError:
53+
TrainerCallable = object # type: ignore
54+
BERT4RecDataPreparator = object # type: ignore
3655
from tests.models.data import DATASET
3756
from tests.models.utils import (
3857
assert_default_config_and_default_model_params_are_the_same,
3958
assert_second_fit_refits_model,
4059
)
4160

42-
from .utils import custom_trainer, leave_one_out_mask
61+
try:
62+
from .utils import custom_trainer, leave_one_out_mask
63+
except NameError:
64+
pass
4365

4466

67+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
4568
class TestBERT4RecModel:
4669
def setup_method(self) -> None:
4770
self._seed_everything()
@@ -119,27 +142,33 @@ def get_trainer() -> Trainer:
119142
"cpu",
120143
1,
121144
"cuda",
122-
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
145+
marks=pytest.mark.skipif(
146+
sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available"
147+
),
123148
),
124149
("cpu", 2, "cpu"),
125150
pytest.param(
126151
"gpu",
127152
1,
128153
"cpu",
129-
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
154+
marks=pytest.mark.skipif(
155+
sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available"
156+
),
130157
),
131158
pytest.param(
132159
"gpu",
133160
1,
134161
"cuda",
135-
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
162+
marks=pytest.mark.skipif(
163+
sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available"
164+
),
136165
),
137166
pytest.param(
138167
"gpu",
139168
2,
140169
"cpu",
141170
marks=pytest.mark.skipif(
142-
torch.cuda.is_available() is False or torch.cuda.device_count() < 2,
171+
sys.version_info >= (3, 13) or torch.cuda.is_available() is False or torch.cuda.device_count() < 2,
143172
reason="GPU is not available or there is only one gpu device",
144173
),
145174
),
@@ -613,6 +642,7 @@ def _collate_fn_train(
613642
)
614643

615644

645+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
616646
class TestBERT4RecDataPreparator:
617647

618648
def setup_method(self) -> None:
@@ -792,6 +822,7 @@ def test_get_dataloader_val(
792822
assert torch.equal(value, val_batch[key])
793823

794824

825+
@pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
795826
class TestBERT4RecModelConfiguration:
796827
def setup_method(self) -> None:
797828
self._seed_everything()

tests/models/nn/transformers/test_data_preparator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing as tp
1617

1718
import numpy as np
@@ -21,9 +22,15 @@
2122
from rectools.columns import Columns
2223
from rectools.dataset import Dataset, IdMap, Interactions
2324
from rectools.dataset.features import DenseFeatures
24-
from rectools.models.nn.transformers.data_preparator import SequenceDataset, TransformerDataPreparatorBase
25+
26+
try:
27+
from rectools.models.nn.transformers.data_preparator import SequenceDataset, TransformerDataPreparatorBase
28+
except ImportError:
29+
TransformerDataPreparatorBase = object # type: ignore
2530
from tests.testing_utils import assert_feature_set_equal, assert_id_map_equal, assert_interactions_set_equal
2631

32+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
33+
2734

2835
class TestSequenceDataset:
2936

tests/models/nn/transformers/test_sasrec.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,54 @@
1414

1515
# pylint: disable=too-many-lines
1616

17+
import sys
18+
import types
1719
import typing as tp
1820
from functools import partial
1921

2022
import numpy as np
2123
import pandas as pd
2224
import pytest
23-
import torch
24-
from pytorch_lightning import Trainer, seed_everything
25+
26+
try:
27+
import torch
28+
from pytorch_lightning import Trainer, seed_everything
29+
except ImportError:
30+
torch = types.ModuleType("torch")
31+
torch.tensor = lambda x: None # type: ignore
32+
torch.Tensor = object # type: ignore
33+
Trainer = object # type: ignore
2534

2635
from rectools import ExternalIds
2736
from rectools.columns import Columns
2837
from rectools.dataset import Dataset, IdMap, Interactions
29-
from rectools.models import SASRecModel
30-
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
31-
from rectools.models.nn.transformers.base import (
32-
LearnableInversePositionalEncoding,
33-
TrainerCallable,
34-
TransformerLightningModule,
35-
TransformerTorchBackbone,
36-
)
37-
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
38+
39+
try:
40+
from rectools.models import SASRecModel
41+
from rectools.models.nn.item_net import CatFeaturesItemNet, IdEmbeddingsItemNet, SumOfEmbeddingsConstructor
42+
from rectools.models.nn.transformers.base import (
43+
LearnableInversePositionalEncoding,
44+
TrainerCallable,
45+
TransformerLightningModule,
46+
TransformerTorchBackbone,
47+
)
48+
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
49+
except ImportError:
50+
TrainerCallable = object # type: ignore
51+
SASRecDataPreparator = object # type: ignore
3852
from tests.models.data import DATASET
3953
from tests.models.utils import (
4054
assert_default_config_and_default_model_params_are_the_same,
4155
assert_second_fit_refits_model,
4256
)
4357
from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal
4458

45-
from .utils import custom_trainer, leave_one_out_mask
59+
try:
60+
from .utils import custom_trainer, leave_one_out_mask
61+
except NameError:
62+
pass
63+
64+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
4665

4766

4867
class TestSASRecModel:
@@ -164,27 +183,33 @@ def get_trainer() -> Trainer:
164183
"cpu",
165184
1,
166185
"cuda",
167-
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
186+
marks=pytest.mark.skipif(
187+
sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available"
188+
),
168189
),
169190
("cpu", 2, "cpu"),
170191
pytest.param(
171192
"gpu",
172193
1,
173194
"cpu",
174-
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
195+
marks=pytest.mark.skipif(
196+
sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available"
197+
),
175198
),
176199
pytest.param(
177200
"gpu",
178201
1,
179202
"cuda",
180-
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
203+
marks=pytest.mark.skipif(
204+
sys.version_info >= (3, 13) or torch.cuda.is_available() is False, reason="GPU is not available"
205+
),
181206
),
182207
pytest.param(
183208
"gpu",
184209
[0, 1],
185210
"cpu",
186211
marks=pytest.mark.skipif(
187-
torch.cuda.is_available() is False or torch.cuda.device_count() < 2,
212+
sys.version_info >= (3, 13) or torch.cuda.is_available() is False or torch.cuda.device_count() < 2,
188213
reason="GPU is not available or there is only one gpu device",
189214
),
190215
),

tests/models/nn/transformers/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
16+
1517
import pandas as pd
16-
from pytorch_lightning import Trainer
17-
from pytorch_lightning.callbacks import ModelCheckpoint
18+
import pytest
19+
20+
try:
21+
from pytorch_lightning import Trainer
22+
from pytorch_lightning.callbacks import ModelCheckpoint
23+
except ImportError:
24+
pass
1825

1926
from rectools import Columns
2027

28+
pytestmark = pytest.mark.skipif(sys.version_info >= (3, 13), reason="`torch` is not compatible with Python >= 3.13")
29+
2130

2231
def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:
2332
rank = (

0 commit comments

Comments
 (0)