1313# limitations under the License.
1414
1515import typing as tp
16- from typing import TypedDict
17- from typing_extensions import Unpack
1816from functools import partial
17+ from typing import TypedDict
1918
2019import numpy as np
2120import pandas as pd
2221import pytest
2322import torch
2423from pytorch_lightning import Trainer , seed_everything
24+ from typing_extensions import Unpack
2525
2626from rectools import ExternalIds
2727from rectools .columns import Columns
3535 TransformerLightningModule ,
3636)
3737from rectools .models .nn .transformers .bert4rec import MASKING_VALUE , BERT4RecDataPreparator , ValMaskCallable
38+ from rectools .models .nn .transformers .data_preparator import InitKwargs
3839from rectools .models .nn .transformers .negative_sampler import CatalogUniformSampler , TransformerNegativeSamplerBase
3940from rectools .models .nn .transformers .similarity import DistanceSimilarityModule
4041from rectools .models .nn .transformers .torch_backbone import TransformerTorchBackbone
4647
4748from .utils import custom_trainer , leave_one_out_mask
4849
49- InitKwargs = tp .Dict [str , tp .Any ]
5050
5151class KwargsSpec (TypedDict ):
5252 max_epochs : int
5353 accelerator : str
5454
55+
5556class TestBERT4RecModel :
5657 def setup_method (self ) -> None :
5758 self ._seed_everything ()
@@ -125,8 +126,8 @@ def get_trainer() -> Trainer:
125126 def get_custom_trainer_func (self ) -> TrainerCallable :
126127 def get_trainer_func (** kwargs : Unpack [KwargsSpec ]) -> Trainer :
127128 # internal logic for kwargs
128- max_epochs = kwargs ["max_epochs" ]
129- accelerator = kwargs ["accelerator" ]
129+ max_epochs = kwargs ["max_epochs" ]
130+ accelerator = kwargs ["accelerator" ]
130131 return Trainer (
131132 max_epochs = max_epochs ,
132133 min_epochs = 2 ,
@@ -139,7 +140,7 @@ def get_trainer_func(**kwargs: Unpack[KwargsSpec]) -> Trainer:
139140 return get_trainer_func
140141
141142 @pytest .fixture
142- def get_custom_val_mask_func (self )-> ValMaskCallable :
143+ def get_custom_val_mask_func (self ) -> ValMaskCallable :
143144 def get_val_mask_func (interactions : pd .DataFrame , ** kwargs : InitKwargs ) -> np .ndarray :
144145 val_users = kwargs ["val_users" ]
145146 rank = (
@@ -150,9 +151,8 @@ def get_val_mask_func(interactions: pd.DataFrame, **kwargs: InitKwargs) -> np.nd
150151 )
151152 val_mask = (interactions [Columns .User ].isin (val_users )) & (rank <= 1 )
152153 return val_mask .values
153- return get_val_mask_func
154-
155154
155+ return get_val_mask_func
156156
157157 @pytest .mark .parametrize (
158158 "accelerator,n_devices,recommend_torch_device" ,
@@ -593,17 +593,17 @@ def test_recommend_for_cold_user_with_hot_item(
593593 (
594594 (
595595 {
596- "max_epochs" :2 ,
597- "accelerator" :"cpu" ,
596+ "max_epochs" : 2 ,
597+ "accelerator" : "cpu" ,
598598 },
599- {"val_users" : [30 ,40 ]}
599+ {"val_users" : [30 , 40 ]},
600600 ),
601601 (
602602 {
603603 "max_epochs" : 3 ,
604- "accelerator" :"gpu" ,
604+ "accelerator" : "gpu" ,
605605 },
606- {"val_users" : [20 ,30 ]}
606+ {"val_users" : [20 , 30 ]},
607607 ),
608608 ),
609609 )
@@ -773,7 +773,6 @@ def data_preparator(self) -> BERT4RecDataPreparator:
773773 mask_prob = 0.5 ,
774774 )
775775
776-
777776 @pytest .fixture
778777 def data_preparator_val_mask (self ) -> BERT4RecDataPreparator :
779778 def get_val_mask (interactions : pd .DataFrame , val_users : ExternalIds ) -> np .ndarray :
@@ -877,7 +876,6 @@ def test_get_dataloader_recommend(
877876 ),
878877 ),
879878 )
880-
881879 def test_get_dataloader_val (
882880 self , dataset : Dataset , data_preparator_val_mask : BERT4RecDataPreparator , val_batch : tp .List
883881 ) -> None :
@@ -928,24 +926,24 @@ def get_custom_val_mask_func(interactions: pd.DataFrame, **kwargs: tp.Dict[str,
928926 val_mask = (interactions [Columns .User ].isin (val_users )) & (rank <= 1 )
929927 return val_mask .values
930928
931-
932929 get_custom_val_mask_func_kwargs = {"val_users" : val_users }
933930 data_preparator_val_mask = BERT4RecDataPreparator (
934- session_max_len = 4 ,
935- n_negatives = 2 ,
936- train_min_user_interactions = 2 ,
937- mask_prob = 0.5 ,
938- batch_size = 4 ,
939- dataloader_num_workers = 0 ,
940- get_val_mask_func = get_custom_val_mask_func ,
941- get_val_mask_func_kwargs = get_custom_val_mask_func_kwargs ,
942- )
931+ session_max_len = 4 ,
932+ n_negatives = 2 ,
933+ train_min_user_interactions = 2 ,
934+ mask_prob = 0.5 ,
935+ batch_size = 4 ,
936+ dataloader_num_workers = 0 ,
937+ get_val_mask_func = get_custom_val_mask_func ,
938+ get_val_mask_func_kwargs = get_custom_val_mask_func_kwargs ,
939+ )
943940 data_preparator_val_mask .process_dataset_train (dataset )
944941 dataloader = data_preparator_val_mask .get_dataloader_val ()
945942 actual = next (iter (dataloader )) # type: ignore
946943 for key , value in actual .items ():
947944 assert torch .equal (value , val_batch [key ])
948945
946+
949947class TestBERT4RecModelConfiguration :
950948 def setup_method (self ) -> None :
951949 self ._seed_everything ()
0 commit comments