Skip to content

Commit b08a196

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/py313
2 parents a4e5418 + 6ffd25b commit b08a196

20 files changed

Lines changed: 1088 additions & 439 deletions

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Python 3.13 support ([#227](https://github.com/MobileTeleSystems/RecTools/pull/227))
1313

14+
## [0.13.0] - 10.04.2025
15+
16+
### Added
17+
- `TransformerNegativeSamplerBase` and `CatalogUniformSampler` classes, `negative_sampler_type` and `negative_sampler_kwargs` parameters to transformer-based models ([#275](https://github.com/MobileTeleSystems/RecTools/pull/275))
18+
- `SimilarityModuleBase`, `DistanceSimilarityModule`, similarity module to `TransformerTorchBackbone` parameters to transformer-based models `similarity_module_type`, `similarity_module_kwargs` ([#272](https://github.com/MobileTeleSystems/RecTools/pull/272))
19+
- `out_dim` property to `IdEmbeddingsItemNet`, `CatFeaturesItemNet` and `SumOfEmbeddingsConstructor` ([#276](https://github.com/MobileTeleSystems/RecTools/pull/276))
20+
- `TransformerBackboneBase`, `backbone_type` and `backbone_kwargs` parameters to transformer-based models ([#277](https://github.com/MobileTeleSystems/RecTools/pull/277))
21+
- `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274))
22+
23+
1424
## [0.12.0] - 24.02.2025
1525

1626
### Added

examples/tutorials/transformers_advanced_training_guide.ipynb

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -412,15 +412,15 @@
412412
"name": "stdout",
413413
"output_type": "stream",
414414
"text": [
415-
"epoch,step,train_loss,val_loss\r",
415+
"epoch,step,train_loss,val_loss\r\n",
416416
"\r\n",
417-
"0,1,,22.365339279174805\r",
417+
"0,1,,22.365339279174805\r\n",
418418
"\r\n",
419-
"0,1,22.38391876220703,\r",
419+
"0,1,22.38391876220703,\r\n",
420420
"\r\n",
421-
"1,3,,22.189851760864258\r",
421+
"1,3,,22.189851760864258\r\n",
422422
"\r\n",
423-
"1,3,22.898216247558594,\r",
423+
"1,3,22.898216247558594,\r\n",
424424
"\r\n"
425425
]
426426
}
@@ -526,23 +526,23 @@
526526
"name": "stdout",
527527
"output_type": "stream",
528528
"text": [
529-
"epoch,step,train_loss,val_loss\r",
529+
"epoch,step,train_loss,val_loss\r\n",
530530
"\r\n",
531-
"0,1,,22.343637466430664\r",
531+
"0,1,,22.343637466430664\r\n",
532532
"\r\n",
533-
"0,1,22.36273765563965,\r",
533+
"0,1,22.36273765563965,\r\n",
534534
"\r\n",
535-
"1,3,,22.159835815429688\r",
535+
"1,3,,22.159835815429688\r\n",
536536
"\r\n",
537-
"1,3,22.33755874633789,\r",
537+
"1,3,22.33755874633789,\r\n",
538538
"\r\n",
539-
"2,5,,21.94308853149414\r",
539+
"2,5,,21.94308853149414\r\n",
540540
"\r\n",
541-
"2,5,22.244243621826172,\r",
541+
"2,5,22.244243621826172,\r\n",
542542
"\r\n",
543-
"3,7,,21.702259063720703\r",
543+
"3,7,,21.702259063720703\r\n",
544544
"\r\n",
545-
"3,7,22.196012496948242,\r",
545+
"3,7,22.196012496948242,\r\n",
546546
"\r\n"
547547
]
548548
}
@@ -898,7 +898,7 @@
898898
" ) -> None:\n",
899899
" logits = outputs[\"logits\"]\n",
900900
" if logits is None:\n",
901-
" logits = pl_module.torch_model.encode_sessions(batch[\"x\"], pl_module.item_embs)[:, -1, :]\n",
901+
" logits = pl_module.torch_model.encode_sessions(batch, pl_module.item_embs)[:, -1, :]\n",
902902
" _, sorted_batch_recos = logits.topk(k=self.top_k)\n",
903903
"\n",
904904
" batch_recos = sorted_batch_recos.tolist()\n",
@@ -2039,9 +2039,9 @@
20392039
],
20402040
"metadata": {
20412041
"kernelspec": {
2042-
"display_name": "rectools",
2042+
"display_name": ".venv",
20432043
"language": "python",
2044-
"name": "rectools"
2044+
"name": "python3"
20452045
},
20462046
"language_info": {
20472047
"codemirror_mode": {
@@ -2053,7 +2053,7 @@
20532053
"name": "python",
20542054
"nbconvert_exporter": "python",
20552055
"pygments_lexer": "ipython3",
2056-
"version": "3.9.12"
2056+
"version": "3.10.13"
20572057
}
20582058
},
20592059
"nbformat": 4,

examples/tutorials/transformers_tutorial.ipynb

Lines changed: 129 additions & 191 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "RecTools"
3-
version = "0.12.0"
3+
version = "0.13.0"
44
description = "An easy-to-use Python library for building recommendation systems"
55
license = "Apache-2.0"
66
authors = [

rectools/models/nn/item_net.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def get_all_embeddings(self) -> torch.Tensor:
4646
"""Return item embeddings."""
4747
raise NotImplementedError()
4848

49+
@property
50+
def out_dim(self) -> int:
51+
"""Return item embedding output dimension."""
52+
raise NotImplementedError()
53+
4954
@property
5055
def device(self) -> torch.device:
5156
"""Return ItemNet device."""
@@ -222,6 +227,11 @@ def from_dataset_schema(
222227
)
223228
return None
224229

230+
@property
231+
def out_dim(self) -> int:
232+
"""Return categorical item embedding output dimension."""
233+
return self.embedding_bag.embedding_dim
234+
225235

226236
class IdEmbeddingsItemNet(ItemNetBase):
227237
"""
@@ -317,6 +327,11 @@ def from_dataset_schema(
317327
n_items = dataset_schema.items.n_hot
318328
return cls(n_factors, n_items, dropout_rate)
319329

330+
@property
331+
def out_dim(self) -> int:
332+
"""Return item embedding output dimension."""
333+
return self.ids_emb.embedding_dim
334+
320335

321336
class ItemNetConstructorBase(ItemNetBase):
322337
"""
@@ -467,3 +482,8 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
467482
item_emb = self.item_net_blocks[idx_block](items)
468483
item_embs.append(item_emb)
469484
return torch.sum(torch.stack(item_embs, dim=0), dim=0)
485+
486+
@property
487+
def out_dim(self) -> int:
488+
"""Return item net constructor output dimension."""
489+
return self.item_net_blocks[0].out_dim

rectools/models/nn/transformers/base.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@
4040
)
4141
from .data_preparator import TransformerDataPreparatorBase
4242
from .lightning import TransformerLightningModule, TransformerLightningModuleBase
43+
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
4344
from .net_blocks import (
4445
LearnableInversePositionalEncoding,
4546
PositionalEncodingBase,
4647
PreLNTransformerLayers,
4748
TransformerLayersBase,
4849
)
49-
from .torch_backbone import TransformerTorchBackbone
50+
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
51+
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone
5052

5153
InitKwargs = tp.Dict[str, tp.Any]
5254

@@ -97,6 +99,26 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
9799
),
98100
]
99101

102+
SimilarityModuleType = tpe.Annotated[
103+
tp.Type[SimilarityModuleBase],
104+
BeforeValidator(_get_class_obj),
105+
PlainSerializer(
106+
func=get_class_or_function_full_path,
107+
return_type=str,
108+
when_used="json",
109+
),
110+
]
111+
112+
TransformerBackboneType = tpe.Annotated[
113+
tp.Type[TransformerBackboneBase],
114+
BeforeValidator(_get_class_obj),
115+
PlainSerializer(
116+
func=get_class_or_function_full_path,
117+
return_type=str,
118+
when_used="json",
119+
),
120+
]
121+
100122
TransformerDataPreparatorType = tpe.Annotated[
101123
tp.Type[TransformerDataPreparatorBase],
102124
BeforeValidator(_get_class_obj),
@@ -107,6 +129,16 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
107129
),
108130
]
109131

132+
TransformerNegativeSamplerType = tpe.Annotated[
133+
tp.Type[TransformerNegativeSamplerBase],
134+
BeforeValidator(_get_class_obj),
135+
PlainSerializer(
136+
func=get_class_or_function_full_path,
137+
return_type=str,
138+
when_used="json",
139+
),
140+
]
141+
110142

111143
ItemNetConstructorType = tpe.Annotated[
112144
tp.Type[ItemNetConstructorBase],
@@ -183,13 +215,19 @@ class TransformerModelConfig(ModelConfig):
183215
pos_encoding_type: PositionalEncodingType = LearnableInversePositionalEncoding
184216
transformer_layers_type: TransformerLayersType = PreLNTransformerLayers
185217
lightning_module_type: TransformerLightningModuleType = TransformerLightningModule
218+
negative_sampler_type: TransformerNegativeSamplerType = CatalogUniformSampler
219+
similarity_module_type: SimilarityModuleType = DistanceSimilarityModule
220+
backbone_type: TransformerBackboneType = TransformerTorchBackbone
186221
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
187222
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
188223
data_preparator_kwargs: tp.Optional[InitKwargs] = None
189224
transformer_layers_kwargs: tp.Optional[InitKwargs] = None
190225
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
191226
pos_encoding_kwargs: tp.Optional[InitKwargs] = None
192227
lightning_module_kwargs: tp.Optional[InitKwargs] = None
228+
negative_sampler_kwargs: tp.Optional[InitKwargs] = None
229+
similarity_module_kwargs: tp.Optional[InitKwargs] = None
230+
backbone_kwargs: tp.Optional[InitKwargs] = None
193231

194232

195233
TransformerModelConfig_T = tp.TypeVar("TransformerModelConfig_T", bound=TransformerModelConfig)
@@ -237,13 +275,19 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
237275
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
238276
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
239277
lightning_module_type: tp.Type[TransformerLightningModuleBase] = TransformerLightningModule,
278+
negative_sampler_type: tp.Type[TransformerNegativeSamplerBase] = CatalogUniformSampler,
279+
similarity_module_type: tp.Type[SimilarityModuleBase] = DistanceSimilarityModule,
280+
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
240281
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
241282
get_trainer_func: tp.Optional[TrainerCallable] = None,
242283
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
243284
transformer_layers_kwargs: tp.Optional[InitKwargs] = None,
244285
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
245286
pos_encoding_kwargs: tp.Optional[InitKwargs] = None,
246287
lightning_module_kwargs: tp.Optional[InitKwargs] = None,
288+
negative_sampler_kwargs: tp.Optional[InitKwargs] = None,
289+
similarity_module_kwargs: tp.Optional[InitKwargs] = None,
290+
backbone_kwargs: tp.Optional[InitKwargs] = None,
247291
**kwargs: tp.Any,
248292
) -> None:
249293
super().__init__(verbose=verbose)
@@ -268,17 +312,23 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
268312
self.recommend_batch_size = recommend_batch_size
269313
self.recommend_torch_device = recommend_torch_device
270314
self.train_min_user_interactions = train_min_user_interactions
315+
self.similarity_module_type = similarity_module_type
271316
self.item_net_block_types = item_net_block_types
272317
self.item_net_constructor_type = item_net_constructor_type
273318
self.pos_encoding_type = pos_encoding_type
274319
self.lightning_module_type = lightning_module_type
320+
self.negative_sampler_type = negative_sampler_type
321+
self.backbone_type = backbone_type
275322
self.get_val_mask_func = get_val_mask_func
276323
self.get_trainer_func = get_trainer_func
277324
self.data_preparator_kwargs = data_preparator_kwargs
278325
self.transformer_layers_kwargs = transformer_layers_kwargs
279326
self.item_net_constructor_kwargs = item_net_constructor_kwargs
280327
self.pos_encoding_kwargs = pos_encoding_kwargs
281328
self.lightning_module_kwargs = lightning_module_kwargs
329+
self.negative_sampler_kwargs = negative_sampler_kwargs
330+
self.similarity_module_kwargs = similarity_module_kwargs
331+
self.backbone_kwargs = backbone_kwargs
282332

283333
self._init_data_preparator()
284334
self._init_trainer()
@@ -295,12 +345,14 @@ def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
295345
return kwargs
296346

297347
def _init_data_preparator(self) -> None:
348+
requires_negatives = self.lightning_module_type.requires_negatives(self.loss)
298349
self.data_preparator = self.data_preparator_type(
299350
session_max_len=self.session_max_len,
300351
batch_size=self.batch_size,
301352
dataloader_num_workers=self.dataloader_num_workers,
302353
train_min_user_interactions=self.train_min_user_interactions,
303-
n_negatives=self.n_negatives if self.loss != "softmax" else None,
354+
negative_sampler=self._init_negative_sampler() if requires_negatives else None,
355+
n_negatives=self.n_negatives if requires_negatives else None,
304356
get_val_mask_func=self.get_val_mask_func,
305357
shuffle_train=True,
306358
**self._get_kwargs(self.data_preparator_kwargs),
@@ -321,6 +373,12 @@ def _init_trainer(self) -> None:
321373
else:
322374
self._trainer = self.get_trainer_func()
323375

376+
def _init_negative_sampler(self) -> TransformerNegativeSamplerBase:
377+
return self.negative_sampler_type(
378+
n_negatives=self.n_negatives,
379+
**self._get_kwargs(self.negative_sampler_kwargs),
380+
)
381+
324382
def _construct_item_net(self, dataset: Dataset) -> ItemNetBase:
325383
return self.item_net_constructor_type.from_dataset(
326384
dataset,
@@ -356,22 +414,28 @@ def _init_transformer_layers(self) -> TransformerLayersBase:
356414
**self._get_kwargs(self.transformer_layers_kwargs),
357415
)
358416

359-
def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone:
417+
def _init_similarity_module(self) -> SimilarityModuleBase:
418+
return self.similarity_module_type(**self._get_kwargs(self.similarity_module_kwargs))
419+
420+
def _init_torch_model(self, item_model: ItemNetBase) -> TransformerBackboneBase:
360421
pos_encoding_layer = self._init_pos_encoding_layer()
361422
transformer_layers = self._init_transformer_layers()
362-
return TransformerTorchBackbone(
423+
similarity_module = self._init_similarity_module()
424+
return self.backbone_type(
363425
n_heads=self.n_heads,
364426
dropout_rate=self.dropout_rate,
365427
item_model=item_model,
366428
pos_encoding_layer=pos_encoding_layer,
367429
transformer_layers=transformer_layers,
430+
similarity_module=similarity_module,
368431
use_causal_attn=self.use_causal_attn,
369432
use_key_padding_mask=self.use_key_padding_mask,
433+
**self._get_kwargs(self.backbone_kwargs),
370434
)
371435

372436
def _init_lightning_model(
373437
self,
374-
torch_model: TransformerTorchBackbone,
438+
torch_model: TransformerBackboneBase,
375439
dataset_schema: DatasetSchemaDict,
376440
item_external_ids: ExternalIds,
377441
model_config: tp.Dict[str, tp.Any],
@@ -467,7 +531,7 @@ def _recommend_i2i(
467531
)
468532

469533
@property
470-
def torch_model(self) -> TransformerTorchBackbone:
534+
def torch_model(self) -> TransformerBackboneBase:
471535
"""Pytorch model."""
472536
return self.lightning_model.torch_model
473537

0 commit comments

Comments
 (0)