4040)
4141from .data_preparator import TransformerDataPreparatorBase
4242from .lightning import TransformerLightningModule , TransformerLightningModuleBase
43+ from .negative_sampler import CatalogUniformSampler , TransformerNegativeSamplerBase
4344from .net_blocks import (
4445 LearnableInversePositionalEncoding ,
4546 PositionalEncodingBase ,
4647 PreLNTransformerLayers ,
4748 TransformerLayersBase ,
4849)
49- from .torch_backbone import TransformerTorchBackbone
50+ from .similarity import DistanceSimilarityModule , SimilarityModuleBase
51+ from .torch_backbone import TransformerBackboneBase , TransformerTorchBackbone
5052
5153InitKwargs = tp .Dict [str , tp .Any ]
5254
@@ -97,6 +99,26 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
9799 ),
98100]
99101
102+ SimilarityModuleType = tpe .Annotated [
103+ tp .Type [SimilarityModuleBase ],
104+ BeforeValidator (_get_class_obj ),
105+ PlainSerializer (
106+ func = get_class_or_function_full_path ,
107+ return_type = str ,
108+ when_used = "json" ,
109+ ),
110+ ]
111+
112+ TransformerBackboneType = tpe .Annotated [
113+ tp .Type [TransformerBackboneBase ],
114+ BeforeValidator (_get_class_obj ),
115+ PlainSerializer (
116+ func = get_class_or_function_full_path ,
117+ return_type = str ,
118+ when_used = "json" ,
119+ ),
120+ ]
121+
100122TransformerDataPreparatorType = tpe .Annotated [
101123 tp .Type [TransformerDataPreparatorBase ],
102124 BeforeValidator (_get_class_obj ),
@@ -107,6 +129,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
107129 ),
108130]
109131
132+ TransformerNegativeSamplerType = tpe .Annotated [
133+ tp .Type [TransformerNegativeSamplerBase ],
134+ BeforeValidator (_get_class_obj ),
135+ PlainSerializer (
136+ func = get_class_or_function_full_path ,
137+ return_type = str ,
138+ when_used = "json" ,
139+ ),
140+ ]
141+
110142
111143ItemNetConstructorType = tpe .Annotated [
112144 tp .Type [ItemNetConstructorBase ],
@@ -183,13 +215,19 @@ class TransformerModelConfig(ModelConfig):
183215 pos_encoding_type : PositionalEncodingType = LearnableInversePositionalEncoding
184216 transformer_layers_type : TransformerLayersType = PreLNTransformerLayers
185217 lightning_module_type : TransformerLightningModuleType = TransformerLightningModule
218+ negative_sampler_type : TransformerNegativeSamplerType = CatalogUniformSampler
219+ similarity_module_type : SimilarityModuleType = DistanceSimilarityModule
220+ backbone_type : TransformerBackboneType = TransformerTorchBackbone
186221 get_val_mask_func : tp .Optional [ValMaskCallableSerialized ] = None
187222 get_trainer_func : tp .Optional [TrainerCallableSerialized ] = None
188223 data_preparator_kwargs : tp .Optional [InitKwargs ] = None
189224 transformer_layers_kwargs : tp .Optional [InitKwargs ] = None
190225 item_net_constructor_kwargs : tp .Optional [InitKwargs ] = None
191226 pos_encoding_kwargs : tp .Optional [InitKwargs ] = None
192227 lightning_module_kwargs : tp .Optional [InitKwargs ] = None
228+ negative_sampler_kwargs : tp .Optional [InitKwargs ] = None
229+ similarity_module_kwargs : tp .Optional [InitKwargs ] = None
230+ backbone_kwargs : tp .Optional [InitKwargs ] = None
193231
194232
195233TransformerModelConfig_T = tp .TypeVar ("TransformerModelConfig_T" , bound = TransformerModelConfig )
@@ -237,13 +275,19 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
237275 item_net_constructor_type : tp .Type [ItemNetConstructorBase ] = SumOfEmbeddingsConstructor ,
238276 pos_encoding_type : tp .Type [PositionalEncodingBase ] = LearnableInversePositionalEncoding ,
239277 lightning_module_type : tp .Type [TransformerLightningModuleBase ] = TransformerLightningModule ,
278+ negative_sampler_type : tp .Type [TransformerNegativeSamplerBase ] = CatalogUniformSampler ,
279+ similarity_module_type : tp .Type [SimilarityModuleBase ] = DistanceSimilarityModule ,
280+ backbone_type : tp .Type [TransformerBackboneBase ] = TransformerTorchBackbone ,
240281 get_val_mask_func : tp .Optional [ValMaskCallable ] = None ,
241282 get_trainer_func : tp .Optional [TrainerCallable ] = None ,
242283 data_preparator_kwargs : tp .Optional [InitKwargs ] = None ,
243284 transformer_layers_kwargs : tp .Optional [InitKwargs ] = None ,
244285 item_net_constructor_kwargs : tp .Optional [InitKwargs ] = None ,
245286 pos_encoding_kwargs : tp .Optional [InitKwargs ] = None ,
246287 lightning_module_kwargs : tp .Optional [InitKwargs ] = None ,
288+ negative_sampler_kwargs : tp .Optional [InitKwargs ] = None ,
289+ similarity_module_kwargs : tp .Optional [InitKwargs ] = None ,
290+ backbone_kwargs : tp .Optional [InitKwargs ] = None ,
247291 ** kwargs : tp .Any ,
248292 ) -> None :
249293 super ().__init__ (verbose = verbose )
@@ -268,17 +312,23 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
268312 self .recommend_batch_size = recommend_batch_size
269313 self .recommend_torch_device = recommend_torch_device
270314 self .train_min_user_interactions = train_min_user_interactions
315+ self .similarity_module_type = similarity_module_type
271316 self .item_net_block_types = item_net_block_types
272317 self .item_net_constructor_type = item_net_constructor_type
273318 self .pos_encoding_type = pos_encoding_type
274319 self .lightning_module_type = lightning_module_type
320+ self .negative_sampler_type = negative_sampler_type
321+ self .backbone_type = backbone_type
275322 self .get_val_mask_func = get_val_mask_func
276323 self .get_trainer_func = get_trainer_func
277324 self .data_preparator_kwargs = data_preparator_kwargs
278325 self .transformer_layers_kwargs = transformer_layers_kwargs
279326 self .item_net_constructor_kwargs = item_net_constructor_kwargs
280327 self .pos_encoding_kwargs = pos_encoding_kwargs
281328 self .lightning_module_kwargs = lightning_module_kwargs
329+ self .negative_sampler_kwargs = negative_sampler_kwargs
330+ self .similarity_module_kwargs = similarity_module_kwargs
331+ self .backbone_kwargs = backbone_kwargs
282332
283333 self ._init_data_preparator ()
284334 self ._init_trainer ()
@@ -295,12 +345,14 @@ def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
295345 return kwargs
296346
297347 def _init_data_preparator (self ) -> None :
348+ requires_negatives = self .lightning_module_type .requires_negatives (self .loss )
298349 self .data_preparator = self .data_preparator_type (
299350 session_max_len = self .session_max_len ,
300351 batch_size = self .batch_size ,
301352 dataloader_num_workers = self .dataloader_num_workers ,
302353 train_min_user_interactions = self .train_min_user_interactions ,
303- n_negatives = self .n_negatives if self .loss != "softmax" else None ,
354+ negative_sampler = self ._init_negative_sampler () if requires_negatives else None ,
355+ n_negatives = self .n_negatives if requires_negatives else None ,
304356 get_val_mask_func = self .get_val_mask_func ,
305357 shuffle_train = True ,
306358 ** self ._get_kwargs (self .data_preparator_kwargs ),
@@ -321,6 +373,12 @@ def _init_trainer(self) -> None:
321373 else :
322374 self ._trainer = self .get_trainer_func ()
323375
376+ def _init_negative_sampler (self ) -> TransformerNegativeSamplerBase :
377+ return self .negative_sampler_type (
378+ n_negatives = self .n_negatives ,
379+ ** self ._get_kwargs (self .negative_sampler_kwargs ),
380+ )
381+
324382 def _construct_item_net (self , dataset : Dataset ) -> ItemNetBase :
325383 return self .item_net_constructor_type .from_dataset (
326384 dataset ,
@@ -356,22 +414,28 @@ def _init_transformer_layers(self) -> TransformerLayersBase:
356414 ** self ._get_kwargs (self .transformer_layers_kwargs ),
357415 )
358416
359- def _init_torch_model (self , item_model : ItemNetBase ) -> TransformerTorchBackbone :
417+ def _init_similarity_module (self ) -> SimilarityModuleBase :
418+ return self .similarity_module_type (** self ._get_kwargs (self .similarity_module_kwargs ))
419+
420+ def _init_torch_model (self , item_model : ItemNetBase ) -> TransformerBackboneBase :
360421 pos_encoding_layer = self ._init_pos_encoding_layer ()
361422 transformer_layers = self ._init_transformer_layers ()
362- return TransformerTorchBackbone (
423+ similarity_module = self ._init_similarity_module ()
424+ return self .backbone_type (
363425 n_heads = self .n_heads ,
364426 dropout_rate = self .dropout_rate ,
365427 item_model = item_model ,
366428 pos_encoding_layer = pos_encoding_layer ,
367429 transformer_layers = transformer_layers ,
430+ similarity_module = similarity_module ,
368431 use_causal_attn = self .use_causal_attn ,
369432 use_key_padding_mask = self .use_key_padding_mask ,
433+ ** self ._get_kwargs (self .backbone_kwargs ),
370434 )
371435
372436 def _init_lightning_model (
373437 self ,
374- torch_model : TransformerTorchBackbone ,
438+ torch_model : TransformerBackboneBase ,
375439 dataset_schema : DatasetSchemaDict ,
376440 item_external_ids : ExternalIds ,
377441 model_config : tp .Dict [str , tp .Any ],
@@ -467,7 +531,7 @@ def _recommend_i2i(
467531 )
468532
469533 @property
470- def torch_model (self ) -> TransformerTorchBackbone :
534+ def torch_model (self ) -> TransformerBackboneBase :
471535 """Pytorch model."""
472536 return self .lightning_model .torch_model
473537
0 commit comments