Skip to content

Commit 7d3850b

Browse files
author
TOPAPEC
committed
Formatting
1 parent e24fec3 commit 7d3850b

17 files changed

Lines changed: 252 additions & 187 deletions

rectools/fast_transformers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Fast Transformers: flat sequential recommenders without ItemNet hierarchy."""
22

3-
from .gpu_data import build_sequences, align_embeddings, GPUBatchDataset, make_dataloader
3+
from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, make_dataloader
44
from .lightning_wrap import FlatSASRecLightning
55
from .model import FlatSASRecConfig, FlatSASRecModel
66
from .net import FlatSASRec, SASRecBlock
77
from .ranking import rank_topk
8-
from .unisrec_net import UniSRec, FeedForward
98
from .unisrec_lightning import UniSRecLightning
109
from .unisrec_model import UniSRecModel
10+
from .unisrec_net import FeedForward, UniSRec
1111

1212
__all__ = [
1313
"build_sequences",

rectools/fast_transformers/gpu_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import typing as tp
44

55
import torch
6-
from torch.utils.data import Dataset as TorchDataset, DataLoader
6+
from torch.utils.data import DataLoader
7+
from torch.utils.data import Dataset as TorchDataset
78

89

910
def build_sequences(
@@ -52,7 +53,9 @@ def build_sequences(
5253
if total_elements > 0:
5354
user_indices = torch.repeat_interleave(torch.arange(n_users, device=device), effective_lens)
5455
cumsum = effective_lens.cumsum(0)
55-
offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave(cumsum - effective_lens, effective_lens)
56+
offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave(
57+
cumsum - effective_lens, effective_lens
58+
)
5659

5760
x_src = torch.repeat_interleave(ends - capped_lens, effective_lens) + offsets
5861
y_src = x_src + 1

rectools/fast_transformers/lightning_wrap.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import typing as tp
44

5-
import torch
65
import pytorch_lightning as pl
6+
import torch
77
from torch import nn
88

99
from .net import FlatSASRec
@@ -47,7 +47,9 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to
4747
if self.loss_name == "softmax":
4848
# logits: (B, L, n_items) — full catalog
4949
# targets need to be 0-indexed item ids (subtract 1 since item ids start from 1)
50-
targets = y - 1 # shift to 0-based for CrossEntropyLoss; padding (0) becomes -1 -> ignore_index=0 won't work
50+
targets = (
51+
y - 1
52+
) # shift to 0-based for CrossEntropyLoss; padding (0) becomes -1 -> ignore_index=0 won't work
5153
# Actually, we set ignore_index=0 but padding maps to -1.
5254
# Let's use a different approach: set padding targets to 0 and use ignore_index=0
5355
targets = y.clone()

rectools/fast_transformers/model.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,15 @@
22

33
import typing as tp
44

5-
import numpy as np
65
import pandas as pd
7-
import torch
86
import pytorch_lightning as pl
7+
import torch
98
from scipy import sparse
109

11-
from rectools import Columns
1210
from rectools.dataset import Dataset
13-
from rectools.dataset.identifiers import IdMap
1411
from rectools.models.base import InternalRecoTriplet, ModelBase, ModelConfig
15-
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator
1612
from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler
13+
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator
1714
from rectools.types import InternalIdsArray
1815
from rectools.utils.config import BaseConfig
1916

@@ -157,10 +154,6 @@ def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None:
157154
dp.process_dataset_train(dataset)
158155
self._data_preparator = dp
159156

160-
n_items = dp.item_id_map.size # includes extra tokens (padding)
161-
# item ids in the preparator go from 0 (padding) to n_items-1
162-
# FlatSASRec expects n_items = max real item count (embedding table = n_items+1 with padding at 0)
163-
# The preparator's item_id_map.size includes the padding token, so real items = size - 1
164157
n_real_items = dp.item_id_map.size - dp.n_item_extra_tokens
165158

166159
net = FlatSASRec(
@@ -242,7 +235,6 @@ def _recommend_u2i(
242235
sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],
243236
) -> InternalRecoTriplet:
244237
assert self._data_preparator is not None
245-
device = next(self._net.parameters()).device # type: ignore
246238

247239
user_embs = self._get_user_embeddings(dataset) # (n_users, D)
248240
item_embs = self._get_item_embeddings() # (n_items, D)
@@ -278,7 +270,9 @@ def _recommend_u2i(
278270
whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])]
279271

280272
u_ids, i_ids, scores = rank_topk(
281-
user_embs, item_embs, k,
273+
user_embs,
274+
item_embs,
275+
k,
282276
filter_csr=filter_csr,
283277
whitelist=whitelist,
284278
batch_size=self.recommend_batch_size,
@@ -298,7 +292,6 @@ def _recommend_i2i(
298292
sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray],
299293
) -> InternalRecoTriplet:
300294
assert self._data_preparator is not None and self._net is not None
301-
device = next(self._net.parameters()).device
302295

303296
item_embs = self._get_item_embeddings() # (n_items, D)
304297
n_extra = self._data_preparator.n_item_extra_tokens
@@ -313,7 +306,9 @@ def _recommend_i2i(
313306
whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])]
314307

315308
t_ids, i_ids, scores = rank_topk(
316-
target_embs, item_embs, k,
309+
target_embs,
310+
item_embs,
311+
k,
317312
whitelist=whitelist,
318313
batch_size=self.recommend_batch_size,
319314
)

rectools/fast_transformers/net.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,7 @@ def encode_last(self, x: torch.Tensor) -> torch.Tensor:
127127
Tensor (B, D)
128128
"""
129129
h = self.encode(x) # (B, L, D)
130-
# Find last non-padding position per row
131-
non_pad = (x != self.PADDING_IDX) # (B, L)
132-
# lengths: number of non-pad tokens
133-
lengths = non_pad.sum(dim=1) # (B,)
134-
# Clamp to at least 1 to avoid index -1 for fully-padded rows
135-
last_idx = (lengths - 1).clamp(min=0)
136-
# We use left-padding, so last non-pad is at position (L - 1) if any token exists
137-
# Actually with left padding, non-pad tokens are at the end, so the last position is L-1
138-
# But let's compute correctly: the last non-pad index
139-
# With left-padding: first non-pad is at L - length, last non-pad is at L - 1
140-
B = x.shape[0]
141-
last_pos = x.shape[1] - 1 # last position is always the last for left-padded sequences
142-
return h[:, last_pos, :] # (B, D)
130+
return h[:, -1, :] # left-padded: last position is always rightmost
143131

144132
def all_item_embeddings(self) -> torch.Tensor:
145133
"""

rectools/fast_transformers/unisrec_lightning.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import math
44
import typing as tp
55

6+
import pytorch_lightning as pl
67
import torch
78
import torch.nn.functional as F
8-
import pytorch_lightning as pl
99
from torch.optim.lr_scheduler import LambdaLR
1010

1111
from .unisrec_net import UniSRec
@@ -63,23 +63,29 @@ def _get_all_embs(self) -> torch.Tensor:
6363
return self.net.project_all()
6464

6565
def _get_pos_neg_logits(
66-
self, hidden: torch.Tensor, labels: torch.Tensor, negatives: torch.Tensor,
66+
self,
67+
hidden: torch.Tensor,
68+
labels: torch.Tensor,
69+
negatives: torch.Tensor,
6770
) -> torch.Tensor:
6871
"""Compute (B, L, 1+N) logits where index 0 = positive."""
6972
emb_pos = self._get_item_embs(labels)
7073
logits_pos = (hidden * emb_pos).sum(dim=-1)
7174

7275
emb_neg = self._get_item_embs(negatives)
7376
logits_neg = torch.matmul(
74-
hidden.unsqueeze(2), emb_neg.transpose(2, 3),
77+
hidden.unsqueeze(2),
78+
emb_neg.transpose(2, 3),
7579
).squeeze(2)
7680

7781
return torch.cat([logits_pos.unsqueeze(-1), logits_neg], dim=-1)
7882

7983
# ── losses ──
8084

8185
def _calc_loss(
82-
self, hidden: torch.Tensor, batch: tp.Dict[str, torch.Tensor],
86+
self,
87+
hidden: torch.Tensor,
88+
batch: tp.Dict[str, torch.Tensor],
8389
) -> torch.Tensor:
8490
labels = batch["y"]
8591
has_neg = "negatives" in batch
@@ -114,7 +120,9 @@ def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torc
114120
targets = labels.clone()
115121
targets[targets == 0] = -100
116122
return F.cross_entropy(
117-
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100,
123+
logits.view(-1, logits.size(-1)),
124+
targets.view(-1),
125+
ignore_index=-100,
118126
)
119127

120128
def _sampled_softmax_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
@@ -123,7 +131,9 @@ def _sampled_softmax_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> tor
123131
logits[:, :, [0, 1]] = logits[:, :, [1, 0]]
124132
targets = mask.long() # 1 where non-padding, 0 where padding
125133
return F.cross_entropy(
126-
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0,
134+
logits.view(-1, logits.size(-1)),
135+
targets.view(-1),
136+
ignore_index=0,
127137
)
128138

129139
def _bce_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:

rectools/fast_transformers/unisrec_model.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import typing as tp
44
from pathlib import Path
55

6-
import torch
76
import pytorch_lightning as pl
7+
import torch
88
from pytorch_lightning.callbacks import EarlyStopping
99

10+
from .gpu_data import align_embeddings, build_sequences, make_dataloader
11+
from .unisrec_lightning import SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS, UniSRecLightning
1012
from .unisrec_net import UniSRec
11-
from .unisrec_lightning import UniSRecLightning, SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS
12-
from .gpu_data import build_sequences, align_embeddings, make_dataloader
1313

1414

1515
class UniSRecModel:
@@ -143,7 +143,12 @@ def _make_trainer(self, max_epochs: int, val_dl: tp.Any = None) -> pl.Trainer:
143143
)
144144

145145
def _make_lightning(
146-
self, net: UniSRec, param_groups: tp.List[tp.Dict], use_id: bool, max_epochs: int, train_dl: tp.Any,
146+
self,
147+
net: UniSRec,
148+
param_groups: tp.List[tp.Dict],
149+
use_id: bool,
150+
max_epochs: int,
151+
train_dl: tp.Any,
147152
) -> UniSRecLightning:
148153
total_steps = len(train_dl) * max_epochs if self.scheduler else None
149154
return UniSRecLightning(
@@ -172,16 +177,22 @@ def _phase2_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]:
172177
{"params": [net.whitening_bias], "lr": self.phase2_lr * 10.0, "weight_decay": 0.0},
173178
]
174179
if net.head is not None:
175-
groups.append({
176-
"params": list(net.head.parameters()),
177-
"lr": self.phase2_lr * self.lr_head,
178-
"weight_decay": self.weight_decay,
179-
})
180+
groups.append(
181+
{
182+
"params": list(net.head.parameters()),
183+
"lr": self.phase2_lr * self.lr_head,
184+
"weight_decay": self.weight_decay,
185+
}
186+
)
180187
else:
181188
groups = [
182189
{"params": list(net.bn_input.parameters()), "lr": self.phase2_lr, "weight_decay": 0.0},
183190
{"params": list(net.bn_score.parameters()), "lr": self.phase2_lr, "weight_decay": 0.0},
184-
{"params": list(net.head.parameters()), "lr": self.phase2_lr * self.lr_head, "weight_decay": self.weight_decay},
191+
{
192+
"params": list(net.head.parameters()),
193+
"lr": self.phase2_lr * self.lr_head,
194+
"weight_decay": self.weight_decay,
195+
},
185196
]
186197
return groups
187198

@@ -198,21 +209,27 @@ def _phase3_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]:
198209
]
199210
head: tp.List[tp.Dict[str, tp.Any]] = []
200211
if net.head is not None:
201-
head = [{"params": list(net.head.parameters()), "lr": self.phase3_lr * self.lr_head, "weight_decay": self.weight_decay}]
212+
head = [
213+
{
214+
"params": list(net.head.parameters()),
215+
"lr": self.phase3_lr * self.lr_head,
216+
"weight_decay": self.weight_decay,
217+
}
218+
]
202219
transformer = [
203220
{"params": list(net.pos_emb.parameters()), "lr": self.phase3_lr * self.lr_transformer, "weight_decay": 0.0},
204221
{
205222
"params": (
206-
[p for l in net.attention_layers for p in l.parameters()]
207-
+ [p for l in net.forward_layers for p in l.parameters()]
223+
[p for layer in net.attention_layers for p in layer.parameters()]
224+
+ [p for layer in net.forward_layers for p in layer.parameters()]
208225
),
209226
"lr": self.phase3_lr * self.lr_transformer,
210227
"weight_decay": self.weight_decay,
211228
},
212229
{
213230
"params": (
214-
[p for l in net.attention_layernorms for p in l.parameters()]
215-
+ [p for l in net.forward_layernorms for p in l.parameters()]
231+
[p for layer in net.attention_layernorms for p in layer.parameters()]
232+
+ [p for layer in net.forward_layernorms for p in layer.parameters()]
216233
+ list(net.last_layernorm.parameters())
217234
),
218235
"lr": self.phase3_lr,
@@ -246,7 +263,9 @@ def fit(
246263
self
247264
"""
248265
x, y, unique_items, unique_users = build_sequences(
249-
user_ids, item_ids, timestamps,
266+
user_ids,
267+
item_ids,
268+
timestamps,
250269
max_len=self.session_max_len,
251270
min_interactions=self.train_min_user_interactions,
252271
)
@@ -303,12 +322,15 @@ def _run_phase(param_groups: tp.List[tp.Dict], use_id: bool, max_epochs: int) ->
303322

304323
def save_checkpoint(self, path: tp.Union[str, Path]) -> None:
305324
assert self._net is not None
306-
torch.save({
307-
"net": self._net.state_dict(),
308-
"unique_items": self._unique_items,
309-
"unique_users": self._unique_users,
310-
"n_items": len(self._unique_items),
311-
}, path)
325+
torch.save(
326+
{
327+
"net": self._net.state_dict(),
328+
"unique_items": self._unique_items,
329+
"unique_users": self._unique_users,
330+
"n_items": len(self._unique_items),
331+
},
332+
path,
333+
)
312334

313335
def load_checkpoint(self, path: tp.Union[str, Path], device: str = "cuda") -> None:
314336
ckpt = torch.load(path, map_location=device, weights_only=False)

rectools/fast_transformers/unisrec_net.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,17 @@ def make_ffn(n_factors: int, ffn_type: str, expansion: int, dropout: float) -> n
5151
hidden = n_factors * expansion
5252
if ffn_type == "linear_gelu":
5353
return nn.Sequential(
54-
nn.Linear(n_factors, hidden), nn.GELU(), nn.Dropout(dropout),
55-
nn.Linear(hidden, n_factors), nn.Dropout(dropout),
54+
nn.Linear(n_factors, hidden),
55+
nn.GELU(),
56+
nn.Dropout(dropout),
57+
nn.Linear(hidden, n_factors),
58+
nn.Dropout(dropout),
5659
)
5760
if ffn_type == "linear_relu":
5861
return nn.Sequential(
59-
nn.Linear(n_factors, hidden), nn.ReLU(), nn.Dropout(dropout),
62+
nn.Linear(n_factors, hidden),
63+
nn.ReLU(),
64+
nn.Dropout(dropout),
6065
nn.Linear(hidden, n_factors),
6166
)
6267
raise ValueError(f"Unknown ffn_type: {ffn_type}. Choose from: conv1d, linear_gelu, linear_relu")
@@ -238,8 +243,10 @@ def project_all(self) -> torch.Tensor:
238243
@property
239244
def transformer_params(self) -> tp.List[nn.Parameter]:
240245
modules = (
241-
list(self.attention_layernorms) + list(self.attention_layers)
242-
+ list(self.forward_layernorms) + list(self.forward_layers)
246+
list(self.attention_layernorms)
247+
+ list(self.attention_layers)
248+
+ list(self.forward_layernorms)
249+
+ list(self.forward_layers)
243250
+ [self.last_layernorm, self.pos_emb]
244251
)
245252
return [p for m in modules for p in m.parameters()]
@@ -272,9 +279,9 @@ def _encode(self, seqs: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
272279
seqs = seqs + self.pos_emb(positions)
273280
seqs = self.emb_dropout(seqs)
274281

275-
pad_mask = (input_ids == self.PADDING_IDX) # (B, L)
276-
pad_mask_3d = pad_mask.unsqueeze(-1) # (B, L, 1)
277-
seqs = seqs.masked_fill(pad_mask_3d, 0.0) # zero out padding
282+
pad_mask = input_ids == self.PADDING_IDX # (B, L)
283+
pad_mask_3d = pad_mask.unsqueeze(-1) # (B, L, 1)
284+
seqs = seqs.masked_fill(pad_mask_3d, 0.0) # zero out padding
278285

279286
attn_mask = self._causal_mask(L, seqs.device)
280287
key_padding_mask = pad_mask
@@ -284,7 +291,9 @@ def _encode(self, seqs: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
284291
# Zero padding in Q/K/V so NaN can never appear in dot-products
285292
normed = normed.masked_fill(pad_mask_3d, 0.0)
286293
mha_out, _ = self.attention_layers[i](
287-
normed, normed, normed,
294+
normed,
295+
normed,
296+
normed,
288297
attn_mask=attn_mask,
289298
key_padding_mask=key_padding_mask,
290299
need_weights=False,

0 commit comments

Comments
 (0)