Skip to content

Commit 45ed8ae

Browse files
TOPAPECRecTools Dev
authored andcommitted
Clean up UniSRec: remove dead code, add GPU metrics
- Remove item_emb, use_id, freeze/unfreeze, phase references from net/lightning - Remove GPUBatchDataset alias and make_dataloader wrapper - Reorganize into preprocessing/ and unisrec/ subpackages - Add GPU-friendly HR@K, NDCG@K, MRR@K metrics (tested against RecTools) - Update benchmark, demo, and all tests (102 passed + 28 metric tests)
1 parent 6809160 commit 45ed8ae

15 files changed

Lines changed: 613 additions & 291 deletions

File tree

benchmark/compare_sasrec_unisrec.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from rectools import Columns
1919
from rectools.dataset import Dataset
2020
from rectools.fast_transformers import UniSRecModel
21-
from rectools.fast_transformers.sequence_data import build_sequences
21+
from rectools.fast_transformers.preprocessing import build_sequences
2222
from rectools.models import SASRecModel
2323

2424
DATA_DIR = Path("data/ml-20m")
@@ -78,13 +78,13 @@ def to_tensors(df):
7878

7979

8080
@torch.no_grad()
81-
def evaluate_unisrec(model, train_df, test_df, k=10, batch_size=256, use_id=False):
81+
def evaluate_unisrec(model, train_df, test_df, k=10, batch_size=256):
8282
net = model.net
8383
net.cuda().eval()
8484
device = torch.device("cuda")
8585
maxlen = net.session_max_len
8686

87-
item_embs = net.item_emb.weight if use_id else net.project_all()
87+
item_embs = net.project_all()
8888
unique_items = model.item_id_mapping
8989
ext_to_int = {int(unique_items[i].item()): i + 1 for i in range(len(unique_items))}
9090

@@ -107,7 +107,7 @@ def evaluate_unisrec(model, train_df, test_df, k=10, batch_size=256, use_id=Fals
107107
if not seqs:
108108
continue
109109
x = torch.tensor(seqs, dtype=torch.long, device=device)
110-
h = net.encode_last(x, use_id=use_id)
110+
h = net.encode_last(x)
111111
scores = h @ item_embs.T
112112
scores[:, 0] = float("-inf")
113113
for i, target_int in enumerate(targets):
@@ -430,7 +430,7 @@ def sasrec_val_mask(interactions_df, **kwargs):
430430
# Eval
431431
print(" Evaluating...")
432432
t0 = time.time()
433-
unisrec_metrics = evaluate_unisrec(unisrec_id, train_with_val, test_ratings, use_id=True)
433+
unisrec_metrics = evaluate_unisrec(unisrec_id, train_with_val, test_ratings)
434434
timings["unisrec_eval"] = time.time() - t0
435435
print(f" Eval: {timings['unisrec_eval']:.1f}s")
436436
hr = unisrec_metrics["HR@10"]
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
"""Fast Transformers: flat sequential recommenders without ItemNet hierarchy."""
22

3+
from .metrics import compute_metrics, hitrate_at_k, mrr_at_k, ndcg_at_k
34
from .net import FlatSASRec, SASRecBlock
4-
from .sequence_data import (
5-
GPUBatchDataset,
5+
from .preprocessing import (
66
SequenceBatchDataset,
77
align_embeddings,
88
build_sequences,
9-
make_dataloader,
109
)
11-
from .unisrec_lightning import UniSRecLightning
12-
from .unisrec_model import UniSRecModel
13-
from .unisrec_net import FeedForward, UniSRec
10+
from .unisrec import UniSRec, UniSRecLightning, UniSRecModel
11+
from .unisrec.net import FeedForward
1412

1513
__all__ = [
1614
"build_sequences",
1715
"align_embeddings",
1816
"SequenceBatchDataset",
19-
"GPUBatchDataset",
20-
"make_dataloader",
2117
"FlatSASRec",
2218
"SASRecBlock",
2319
"UniSRec",
2420
"FeedForward",
2521
"UniSRecLightning",
2622
"UniSRecModel",
23+
"hitrate_at_k",
24+
"ndcg_at_k",
25+
"mrr_at_k",
26+
"compute_metrics",
2727
]
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""GPU-friendly ranking metrics for leave-one-out evaluation.
2+
3+
All functions operate on PyTorch tensors and stay on the original device
4+
(CPU or CUDA), avoiding numpy/pandas roundtrips. Results are numerically
5+
identical to the corresponding RecTools metrics with default settings:
6+
7+
- :class:`rectools.metrics.HitRate` (k=K)
8+
- :class:`rectools.metrics.NDCG` (k=K, log_base=2, divide_by_achievable=False)
9+
- :class:`rectools.metrics.MRR` (k=K)
10+
11+
These functions assume **leave-one-out** evaluation: each user has exactly
12+
one ground-truth target item.
13+
"""
14+
15+
import typing as tp
16+
17+
import torch
18+
19+
20+
@torch.no_grad()
21+
def hitrate_at_k(
22+
topk_ids: torch.Tensor,
23+
targets: torch.Tensor,
24+
) -> torch.Tensor:
25+
"""Hit Rate @ K (leave-one-out).
26+
27+
Parameters
28+
----------
29+
topk_ids : LongTensor (B, K)
30+
Top-K predicted item IDs per user.
31+
targets : LongTensor (B,)
32+
Ground-truth item ID per user.
33+
34+
Returns
35+
-------
36+
Tensor (scalar)
37+
Mean hit rate across users.
38+
"""
39+
hits = (topk_ids == targets.unsqueeze(1)).any(dim=1)
40+
return hits.float().mean()
41+
42+
43+
@torch.no_grad()
44+
def ndcg_at_k(
45+
topk_ids: torch.Tensor,
46+
targets: torch.Tensor,
47+
log_base: int = 2,
48+
) -> torch.Tensor:
49+
"""NDCG @ K (leave-one-out, divide_by_achievable=False).
50+
51+
Matches :class:`rectools.metrics.NDCG` with default parameters.
52+
IDCG is computed as the maximum possible DCG when all K positions are
53+
relevant (constant across users), which is the RecTools default.
54+
55+
Parameters
56+
----------
57+
topk_ids : LongTensor (B, K)
58+
Top-K predicted item IDs per user.
59+
targets : LongTensor (B,)
60+
Ground-truth item ID per user.
61+
log_base : int, default 2
62+
Logarithm base for the discount factor.
63+
64+
Returns
65+
-------
66+
Tensor (scalar)
67+
Mean NDCG across users.
68+
"""
69+
k = topk_ids.shape[1]
70+
hits = (topk_ids == targets.unsqueeze(1)).float() # (B, K)
71+
ranks = torch.arange(1, k + 1, device=topk_ids.device, dtype=torch.float)
72+
discounts = 1.0 / torch.log(ranks + 1) * (1.0 / _log(log_base))
73+
dcg = (hits * discounts.unsqueeze(0)).sum(dim=1) # (B,)
74+
idcg = discounts.sum()
75+
return (dcg / idcg).mean()
76+
77+
78+
@torch.no_grad()
79+
def mrr_at_k(
80+
topk_ids: torch.Tensor,
81+
targets: torch.Tensor,
82+
) -> torch.Tensor:
83+
"""MRR @ K (leave-one-out).
84+
85+
Parameters
86+
----------
87+
topk_ids : LongTensor (B, K)
88+
Top-K predicted item IDs per user.
89+
targets : LongTensor (B,)
90+
Ground-truth item ID per user.
91+
92+
Returns
93+
-------
94+
Tensor (scalar)
95+
Mean reciprocal rank across users.
96+
"""
97+
hits = (topk_ids == targets.unsqueeze(1)) # (B, K)
98+
# For each user find the rank of the first hit (1-based), 0 if no hit
99+
has_hit = hits.any(dim=1)
100+
# argmax returns the first True index
101+
first_hit_rank = hits.float().argmax(dim=1) + 1 # (B,)
102+
rr = torch.zeros_like(first_hit_rank, dtype=torch.float)
103+
rr[has_hit] = 1.0 / first_hit_rank[has_hit].float()
104+
return rr.mean()
105+
106+
107+
@torch.no_grad()
108+
def compute_metrics(
109+
topk_ids: torch.Tensor,
110+
targets: torch.Tensor,
111+
ks: tp.Optional[tp.List[int]] = None,
112+
log_base: int = 2,
113+
) -> tp.Dict[str, float]:
114+
"""Compute HR, NDCG, MRR at multiple K values.
115+
116+
Parameters
117+
----------
118+
topk_ids : LongTensor (B, K_max)
119+
Top-K_max predicted item IDs per user.
120+
targets : LongTensor (B,)
121+
Ground-truth item ID per user.
122+
ks : list of int, optional
123+
K values to evaluate. Defaults to ``[K_max]``.
124+
log_base : int, default 2
125+
Logarithm base for NDCG discount.
126+
127+
Returns
128+
-------
129+
dict
130+
Keys like ``"HR@10"``, ``"NDCG@10"``, ``"MRR@10"``.
131+
"""
132+
k_max = topk_ids.shape[1]
133+
if ks is None:
134+
ks = [k_max]
135+
results: tp.Dict[str, float] = {}
136+
for k in ks:
137+
if k > k_max:
138+
raise ValueError(f"k={k} exceeds topk_ids width {k_max}")
139+
top = topk_ids[:, :k]
140+
results[f"HR@{k}"] = hitrate_at_k(top, targets).item()
141+
results[f"NDCG@{k}"] = ndcg_at_k(top, targets, log_base=log_base).item()
142+
results[f"MRR@{k}"] = mrr_at_k(top, targets).item()
143+
return results
144+
145+
146+
def _log(base: int) -> float:
147+
"""Natural log of base (cached constant)."""
148+
import math
149+
return math.log(base)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Vectorized sequence preprocessing for transformer recommenders."""
2+
3+
from .sequence_data import (
4+
SequenceBatchDataset,
5+
align_embeddings,
6+
build_sequences,
7+
)
8+
9+
__all__ = [
10+
"build_sequences",
11+
"align_embeddings",
12+
"SequenceBatchDataset",
13+
]

rectools/fast_transformers/sequence_data.py renamed to rectools/fast_transformers/preprocessing/sequence_data.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import typing as tp
88

99
import torch
10-
from torch.utils.data import DataLoader
1110
from torch.utils.data import Dataset as TorchDataset
1211

1312

@@ -171,41 +170,3 @@ def __getitem__(self, idx: int) -> tp.Dict[str, torch.Tensor]:
171170
if self.transform:
172171
batch = self.transform(batch)
173172
return batch
174-
175-
176-
# Keep old name as alias for backwards compatibility
177-
GPUBatchDataset = SequenceBatchDataset
178-
179-
180-
def make_dataloader(
181-
x: torch.Tensor,
182-
y: torch.Tensor,
183-
batch_size: int,
184-
shuffle: bool = True,
185-
transform: tp.Optional[tp.Callable] = None,
186-
num_workers: int = 0,
187-
**kwargs: tp.Any,
188-
) -> DataLoader:
189-
"""Create a DataLoader from prebuilt sequence tensors.
190-
191-
Parameters
192-
----------
193-
x, y : Tensor
194-
Input and target sequences from :func:`build_sequences`.
195-
batch_size : int
196-
Batch size.
197-
shuffle : bool, default True
198-
Whether to shuffle.
199-
transform : callable, optional
200-
Per-sample transform (e.g. negative sampling).
201-
num_workers : int, default 0
202-
Number of DataLoader workers.
203-
**kwargs
204-
Additional keyword arguments passed to :class:`~torch.utils.data.DataLoader`.
205-
206-
Returns
207-
-------
208-
DataLoader
209-
"""
210-
ds = SequenceBatchDataset(x, y, transform=transform)
211-
return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, **kwargs)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""UniSRec: sequential recommender with pretrained text embeddings."""
2+
3+
from .lightning import UniSRecLightning
4+
from .model import UniSRecModel
5+
from .net import FeedForward, UniSRec
6+
7+
__all__ = [
8+
"UniSRec",
9+
"FeedForward",
10+
"UniSRecLightning",
11+
"UniSRecModel",
12+
]

rectools/fast_transformers/demo_kion_unisrec.md renamed to rectools/fast_transformers/unisrec/demo_kion.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ with torch.no_grad():
206206
if not seqs:
207207
continue
208208
x = torch.tensor(seqs, dtype=torch.long, device=device)
209-
h = net.encode_last(x, use_id=False)
209+
h = net.encode_last(x)
210210
scores = h @ item_embs.T
211211
scores[:, 0] = float("-inf")
212212
for i, target_int in enumerate(targets):

rectools/fast_transformers/unisrec_lightning.py renamed to rectools/fast_transformers/unisrec/lightning.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn.functional as F
99
from torch.optim.lr_scheduler import LambdaLR
1010

11-
from .unisrec_net import UniSRec
11+
from .net import UniSRec
1212

1313
SUPPORTED_LOSSES = ("softmax", "BCE", "gBCE", "sampled_softmax")
1414
SUPPORTED_OPTIMIZERS = ("adam", "adamw")
@@ -17,17 +17,16 @@
1717

1818
class UniSRecLightning(pl.LightningModule):
1919
"""
20-
Thin Lightning wrapper reused across all training phases.
20+
Thin Lightning wrapper for joint UniSRec training.
2121
22-
Each phase creates a fresh ``UniSRecLightning`` with appropriate
23-
``param_groups`` and ``use_id`` flag, sharing the same ``net`` instance.
22+
Wraps a :class:`UniSRec` network with configurable loss, optimizer,
23+
and learning-rate scheduler.
2424
"""
2525

2626
def __init__(
2727
self,
2828
net: UniSRec,
2929
param_groups: tp.List[tp.Dict[str, tp.Any]],
30-
use_id: bool = False,
3130
loss: str = "softmax",
3231
n_negatives: tp.Optional[int] = None,
3332
gbce_t: float = 0.2,
@@ -40,7 +39,6 @@ def __init__(
4039
super().__init__()
4140
self.net = net
4241
self._param_groups = param_groups
43-
self.use_id = use_id
4442
self.loss_name = loss
4543
self.n_negatives = n_negatives
4644
self.gbce_t = gbce_t
@@ -53,13 +51,9 @@ def __init__(
5351
# ── helpers ──
5452

5553
def _get_item_embs(self, item_ids: torch.Tensor) -> torch.Tensor:
56-
if self.use_id:
57-
return self.net.item_emb(item_ids)
5854
return self.net._adapt_score(self.net._sample_frozen(item_ids))
5955

6056
def _get_all_embs(self) -> torch.Tensor:
61-
if self.use_id:
62-
return self.net.item_emb.weight
6357
return self.net.project_all()
6458

6559
def _get_pos_neg_logits(
@@ -90,11 +84,7 @@ def _calc_loss(
9084
labels = batch["y"]
9185
has_neg = "negatives" in batch
9286

93-
if self.loss_name == "softmax" and not has_neg:
94-
return self._full_softmax_loss(hidden, labels)
95-
96-
if self.loss_name == "softmax" and has_neg:
97-
# full softmax even if negatives are available
87+
if self.loss_name == "softmax":
9888
return self._full_softmax_loss(hidden, labels)
9989

10090
if not has_neg:
@@ -165,13 +155,13 @@ def _gbce_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
165155
# ── training / validation ──
166156

167157
def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
168-
hidden = self.net(batch["x"], use_id=self.use_id)
158+
hidden = self.net(batch["x"])
169159
loss = self._calc_loss(hidden, batch)
170160
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
171161
return loss
172162

173163
def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
174-
hidden = self.net(batch["x"], use_id=self.use_id)
164+
hidden = self.net(batch["x"])
175165
# Validation batch has y of shape (B, 1) -- take last hidden position only
176166
hidden = hidden[:, -1:, :]
177167
loss = self._calc_loss(hidden, batch)

0 commit comments

Comments
 (0)