Skip to content

Commit d32468b

Browse files
committed
sync deploy to local git
1 parent f737e02 commit d32468b

4 files changed

Lines changed: 29 additions & 35 deletions

File tree

rectools/models/nn/transformers/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
ItemNetConstructorBase,
3939
SumOfEmbeddingsConstructor,
4040
)
41-
from .data_preparator import TransformerDataPreparatorBase
41+
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
4242
from .lightning import TransformerLightningModule, TransformerLightningModuleBase
4343
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4444
from .net_blocks import (
@@ -50,8 +50,6 @@
5050
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
5151
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone
5252

53-
InitKwargs = tp.Dict[str, tp.Any]
54-
5553
# #### -------------- Transformer Model Config -------------- #### #
5654

5755

rectools/models/nn/transformers/bert4rec.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
SumOfEmbeddingsConstructor,
2828
)
2929
from .base import (
30-
InitKwargs,
3130
TrainerCallable,
3231
TransformerDataPreparatorType,
3332
TransformerLightningModule,
@@ -37,7 +36,7 @@
3736
ValMaskCallable,
3837
)
3938
from .constants import MASKING_VALUE, PADDING_VALUE
40-
from .data_preparator import TransformerDataPreparatorBase
39+
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
4140
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4241
from .net_blocks import (
4342
LearnableInversePositionalEncoding,
@@ -445,7 +444,7 @@ def _init_data_preparator(self) -> None:
445444
train_min_user_interactions=self.train_min_user_interactions,
446445
mask_prob=self.mask_prob,
447446
get_val_mask_func=self.get_val_mask_func,
448-
get_val_mask_func_kwargs= self.get_val_mask_func_kwargs,
447+
get_val_mask_func_kwargs=self.get_val_mask_func_kwargs,
449448
shuffle_train=True,
450449
**self._get_kwargs(self.data_preparator_kwargs),
451450
)

rectools/models/nn/transformers/sasrec.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
SumOfEmbeddingsConstructor,
2828
)
2929
from .base import (
30-
InitKwargs,
3130
TrainerCallable,
3231
TransformerDataPreparatorType,
3332
TransformerLayersType,
@@ -37,7 +36,7 @@
3736
TransformerModelConfig,
3837
ValMaskCallable,
3938
)
40-
from .data_preparator import TransformerDataPreparatorBase
39+
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
4140
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4241
from .net_blocks import (
4342
LearnableInversePositionalEncoding,
@@ -493,8 +492,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
493492
backbone_type=backbone_type,
494493
get_val_mask_func=get_val_mask_func,
495494
get_trainer_func=get_trainer_func,
496-
get_val_mask_func_kwargs = get_val_mask_func_kwargs,
497-
get_trainer_func_kwargs = get_trainer_func_kwargs,
495+
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
496+
get_trainer_func_kwargs=get_trainer_func_kwargs,
498497
data_preparator_kwargs=data_preparator_kwargs,
499498
transformer_layers_kwargs=transformer_layers_kwargs,
500499
item_net_constructor_kwargs=item_net_constructor_kwargs,

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
# limitations under the License.
1414

1515
import typing as tp
16-
from typing import TypedDict
17-
from typing_extensions import Unpack
1816
from functools import partial
17+
from typing import TypedDict
1918

2019
import numpy as np
2120
import pandas as pd
2221
import pytest
2322
import torch
2423
from pytorch_lightning import Trainer, seed_everything
24+
from typing_extensions import Unpack
2525

2626
from rectools import ExternalIds
2727
from rectools.columns import Columns
@@ -35,6 +35,7 @@
3535
TransformerLightningModule,
3636
)
3737
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
38+
from rectools.models.nn.transformers.data_preparator import InitKwargs
3839
from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
3940
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
4041
from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone
@@ -46,12 +47,12 @@
4647

4748
from .utils import custom_trainer, leave_one_out_mask
4849

49-
InitKwargs = tp.Dict[str, tp.Any]
5050

5151
class KwargsSpec(TypedDict):
5252
max_epochs: int
5353
accelerator: str
5454

55+
5556
class TestBERT4RecModel:
5657
def setup_method(self) -> None:
5758
self._seed_everything()
@@ -125,8 +126,8 @@ def get_trainer() -> Trainer:
125126
def get_custom_trainer_func(self) -> TrainerCallable:
126127
def get_trainer_func(**kwargs: Unpack[KwargsSpec]) -> Trainer:
127128
# internal logic for kwargs
128-
max_epochs=kwargs["max_epochs"]
129-
accelerator=kwargs["accelerator"]
129+
max_epochs = kwargs["max_epochs"]
130+
accelerator = kwargs["accelerator"]
130131
return Trainer(
131132
max_epochs=max_epochs,
132133
min_epochs=2,
@@ -139,7 +140,7 @@ def get_trainer_func(**kwargs: Unpack[KwargsSpec]) -> Trainer:
139140
return get_trainer_func
140141

141142
@pytest.fixture
142-
def get_custom_val_mask_func(self)->ValMaskCallable:
143+
def get_custom_val_mask_func(self) -> ValMaskCallable:
143144
def get_val_mask_func(interactions: pd.DataFrame, **kwargs: InitKwargs) -> np.ndarray:
144145
val_users = kwargs["val_users"]
145146
rank = (
@@ -150,9 +151,8 @@ def get_val_mask_func(interactions: pd.DataFrame, **kwargs: InitKwargs) -> np.nd
150151
)
151152
val_mask = (interactions[Columns.User].isin(val_users)) & (rank <= 1)
152153
return val_mask.values
153-
return get_val_mask_func
154-
155154

155+
return get_val_mask_func
156156

157157
@pytest.mark.parametrize(
158158
"accelerator,n_devices,recommend_torch_device",
@@ -593,17 +593,17 @@ def test_recommend_for_cold_user_with_hot_item(
593593
(
594594
(
595595
{
596-
"max_epochs":2,
597-
"accelerator":"cpu",
596+
"max_epochs": 2,
597+
"accelerator": "cpu",
598598
},
599-
{"val_users": [30,40]}
599+
{"val_users": [30, 40]},
600600
),
601601
(
602602
{
603603
"max_epochs": 3,
604-
"accelerator":"gpu",
604+
"accelerator": "gpu",
605605
},
606-
{"val_users": [20,30]}
606+
{"val_users": [20, 30]},
607607
),
608608
),
609609
)
@@ -773,7 +773,6 @@ def data_preparator(self) -> BERT4RecDataPreparator:
773773
mask_prob=0.5,
774774
)
775775

776-
777776
@pytest.fixture
778777
def data_preparator_val_mask(self) -> BERT4RecDataPreparator:
779778
def get_val_mask(interactions: pd.DataFrame, val_users: ExternalIds) -> np.ndarray:
@@ -877,7 +876,6 @@ def test_get_dataloader_recommend(
877876
),
878877
),
879878
)
880-
881879
def test_get_dataloader_val(
882880
self, dataset: Dataset, data_preparator_val_mask: BERT4RecDataPreparator, val_batch: tp.List
883881
) -> None:
@@ -928,24 +926,24 @@ def get_custom_val_mask_func(interactions: pd.DataFrame, **kwargs: tp.Dict[str,
928926
val_mask = (interactions[Columns.User].isin(val_users)) & (rank <= 1)
929927
return val_mask.values
930928

931-
932929
get_custom_val_mask_func_kwargs = {"val_users": val_users}
933930
data_preparator_val_mask = BERT4RecDataPreparator(
934-
session_max_len=4,
935-
n_negatives=2,
936-
train_min_user_interactions=2,
937-
mask_prob=0.5,
938-
batch_size=4,
939-
dataloader_num_workers=0,
940-
get_val_mask_func=get_custom_val_mask_func,
941-
get_val_mask_func_kwargs=get_custom_val_mask_func_kwargs,
942-
)
931+
session_max_len=4,
932+
n_negatives=2,
933+
train_min_user_interactions=2,
934+
mask_prob=0.5,
935+
batch_size=4,
936+
dataloader_num_workers=0,
937+
get_val_mask_func=get_custom_val_mask_func,
938+
get_val_mask_func_kwargs=get_custom_val_mask_func_kwargs,
939+
)
943940
data_preparator_val_mask.process_dataset_train(dataset)
944941
dataloader = data_preparator_val_mask.get_dataloader_val()
945942
actual = next(iter(dataloader)) # type: ignore
946943
for key, value in actual.items():
947944
assert torch.equal(value, val_batch[key])
948945

946+
949947
class TestBERT4RecModelConfiguration:
950948
def setup_method(self) -> None:
951949
self._seed_everything()

0 commit comments

Comments
 (0)