Skip to content

Commit 6809160

Browse files
TOPAPECRecTools Dev
authored andcommitted
Simplify UniSRec: remove 3-phase training, hash IDs, ranking.py
- Remove ranking.py (duplicates TorchRanker) - Remove hash ID mapping from build_sequences/align_embeddings - Simplify UniSRecModel to single joint training phase (adaptor + transformer) - Rename gpu_data.py -> sequence_data.py, GPUBatchDataset -> SequenceBatchDataset - Vectorize map_item_ids with torch.searchsorted - Fix device default (None -> auto-detect from input tensor) - Fix double torch.unique call - Add empty dataset validation in fit() - Add **kwargs to make_dataloader - Add dataloader_num_workers passthrough - Move benchmark script to benchmark/ folder - Add KION training demo with Qwen3-Embedding-0.6B results - Update tests for simplified API - Clean up CHANGELOG and .gitignore
1 parent f2fdfe5 commit 6809160

16 files changed

Lines changed: 548 additions & 909 deletions

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,5 @@ benchmark_results/
9898
catboost_info/
9999

100100
# Dev artifacts
101-
training_folder/
102101
*.pt
103102
data/*

CHANGELOG.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
### Added
1212
- `rectools.fast_transformers` module — standalone transformer-based sequential recommenders that work directly with torch tensors, bypassing the `Dataset`/pandas pipeline. GPU-native sequence building via `build_sequences()` gives ~30x preprocessing speedup over `SASRecDataPreparator` on ML-20M
1313
- `FlatSASRec` network and `FlatSASRecModel` — flat SASRec implementation without the ItemNet hierarchy. Pre-norm transformer encoder with id-embeddings, causal masking, softmax and BCE losses. Integrates with RecTools `ModelBase` for compatibility with the standard `fit`/`recommend` API
14-
- `UniSRec` network and `UniSRecModel` — sequential recommender with pretrained text embeddings (e.g. Qwen) and a learnable PCA/BN adaptor. Three-phase training: (1) SASRec warm-up on ID embeddings, (2) adaptor-only with frozen transformer, (3) full fine-tune on pretrained embeddings. Configurable losses (softmax, BCE, gBCE, sampled_softmax), optimizers (Adam, AdamW), cosine warmup scheduler, early stopping, checkpoint save/load. `UniSRecModel.fit()` accepts raw `(user_ids, item_ids, timestamps)` tensors
15-
- `rank_topk()` utility for batched top-k scoring with CSR-based viewed-item filtering and item whitelist support
14+
- `UniSRec` network and `UniSRecModel` — sequential recommender with pretrained text embeddings (e.g. Qwen) and a learnable PCA/BN adaptor. Joint training of adaptor + transformer on pretrained embeddings. Configurable losses (softmax, BCE, gBCE, sampled_softmax), optimizers (Adam, AdamW), cosine warmup scheduler, early stopping, checkpoint save/load. `UniSRecModel.fit()` accepts raw `(user_ids, item_ids, timestamps)` tensors
1615
- `align_embeddings()` for mapping pretrained embedding matrices to internal item ID order
17-
- `GPUBatchDataset` and `make_dataloader()` — lightweight torch Dataset/DataLoader wrappers for sequence training data
16+
- `SequenceBatchDataset` and `make_dataloader()` — lightweight torch Dataset/DataLoader wrappers for sequence training data
1817
- Configurable FFN blocks in `UniSRec`: `conv1d` (original paper), `linear_gelu`, `linear_relu` with adjustable expansion factor
19-
- Tests for all `fast_transformers` submodules (143 tests)
2018

2119

2220
## [0.18.0] - 21.02.2026
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from rectools import Columns
1919
from rectools.dataset import Dataset
2020
from rectools.fast_transformers import UniSRecModel
21-
from rectools.fast_transformers.gpu_data import build_sequences
21+
from rectools.fast_transformers.sequence_data import build_sequences
2222
from rectools.models import SASRecModel
2323

2424
DATA_DIR = Path("data/ml-20m")
@@ -406,10 +406,8 @@ def sasrec_val_mask(interactions_df, **kwargs):
406406
adaptor_dropout=0.2,
407407
adaptor_type="pca",
408408
use_adaptor_ffn=True,
409-
phase1_epochs=EPOCHS,
410-
phase2_epochs=0,
411-
phase3_epochs=0,
412-
phase1_lr=LR,
409+
epochs=EPOCHS,
410+
lr=LR,
413411
optimizer="adam",
414412
grad_clip=1.0,
415413
weight_decay=0.0,
File renamed without changes.

rectools/fast_transformers/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
"""Fast Transformers: flat sequential recommenders without ItemNet hierarchy."""
22

3-
from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, hash_item_ids, make_dataloader
43
from .net import FlatSASRec, SASRecBlock
5-
from .ranking import rank_topk
4+
from .sequence_data import (
5+
GPUBatchDataset,
6+
SequenceBatchDataset,
7+
align_embeddings,
8+
build_sequences,
9+
make_dataloader,
10+
)
611
from .unisrec_lightning import UniSRecLightning
712
from .unisrec_model import UniSRecModel
813
from .unisrec_net import FeedForward, UniSRec
914

1015
__all__ = [
1116
"build_sequences",
1217
"align_embeddings",
13-
"hash_item_ids",
18+
"SequenceBatchDataset",
1419
"GPUBatchDataset",
1520
"make_dataloader",
1621
"FlatSASRec",
1722
"SASRecBlock",
18-
"rank_topk",
1923
"UniSRec",
2024
"FeedForward",
2125
"UniSRecLightning",
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# UniSRec Training Demo: KION Dataset
2+
3+
This guide demonstrates training a UniSRec sequential recommender on the KION movie dataset using real text embeddings from movie descriptions.
4+
5+
## Overview
6+
7+
UniSRec jointly trains a PCA-based adaptor and a SASRec transformer encoder on frozen pretrained text embeddings. This allows the model to leverage semantic item representations without requiring collaborative item IDs.
8+
9+
## Prerequisites
10+
11+
```bash
12+
pip install torch pytorch-lightning sentence-transformers
13+
```
14+
15+
## 1. Prepare Data
16+
17+
### Download the KION dataset
18+
19+
```bash
20+
git clone https://github.com/irsafilo/KION_DATASET kion_data
21+
```
22+
23+
### Load and filter interactions
24+
25+
```python
26+
import pandas as pd
27+
import torch
28+
29+
# Load interactions
30+
interactions = pd.read_csv("kion_data/interactions.csv")
31+
interactions = interactions.rename(columns={"last_watch_dt": "timestamp"})
32+
interactions["timestamp"] = pd.to_datetime(interactions["timestamp"]).astype(int) // 10**9
33+
34+
# Filter: min 5 interactions per item, min 2 per user
35+
item_counts = interactions.groupby("item_id").size()
36+
interactions = interactions[interactions["item_id"].isin(item_counts[item_counts >= 5].index)]
37+
user_counts = interactions.groupby("user_id").size()
38+
interactions = interactions[interactions["user_id"].isin(user_counts[user_counts >= 2].index)]
39+
40+
print(f"Interactions: {len(interactions):,}")
41+
print(f"Users: {interactions['user_id'].nunique():,}")
42+
print(f"Items: {interactions['item_id'].nunique():,}")
43+
# Interactions: 643,786
44+
# Users: 201,851
45+
# Items: 6,228
46+
```
47+
48+
### Leave-last-out split
49+
50+
```python
51+
interactions = interactions.sort_values(["user_id", "timestamp"])
52+
test = interactions.groupby("user_id").tail(1)
53+
train_val = interactions.drop(test.index)
54+
55+
print(f"Train+Val: {len(train_val):,}, Test: {len(test):,}")
56+
# Train+Val: 441,935, Test: 201,851
57+
```
58+
59+
## 2. Generate Text Embeddings
60+
61+
Use English movie descriptions from the dataset with Qwen3-Embedding-0.6B:
62+
63+
```bash
64+
pip install transformers
65+
```
66+
67+
```python
68+
from transformers import AutoTokenizer, AutoModel
69+
70+
# Load item metadata (English descriptions)
71+
items = pd.read_csv("kion_data/data_en/items_en.csv")
72+
items = items.set_index("item_id")
73+
74+
# Build description text
75+
texts = {}
76+
for item_id, row in items.iterrows():
77+
parts = [str(row.get("title", ""))]
78+
if pd.notna(row.get("description")):
79+
parts.append(str(row["description"]))
80+
if pd.notna(row.get("genres")):
81+
parts.append(f"Genres: {row['genres']}")
82+
texts[item_id] = " ".join(parts)
83+
84+
# Encode with Qwen3-Embedding-0.6B
85+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
86+
encoder = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B", dtype=torch.float16)
87+
encoder.cuda().eval()
88+
89+
max_item_id = items.index.max()
90+
embeddings = torch.zeros(max_item_id + 1, 1024)
91+
92+
item_ids_list = list(texts.keys())
93+
text_list = list(texts.values())
94+
95+
with torch.no_grad():
96+
for start in range(0, len(text_list), 32):
97+
batch_texts = text_list[start:start + 32]
98+
batch_ids = item_ids_list[start:start + 32]
99+
encoded = tokenizer(batch_texts, return_tensors="pt", padding=True,
100+
truncation=True, max_length=512).to("cuda")
101+
outputs = encoder(**encoded)
102+
mask = encoded["attention_mask"].unsqueeze(-1).half()
103+
pooled = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1)
104+
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
105+
for i, item_id in enumerate(batch_ids):
106+
embeddings[item_id] = pooled[i].cpu().float()
107+
108+
torch.save(embeddings, "item_embeddings.pt")
109+
print(f"Embeddings: {embeddings.shape}")
110+
# Embeddings: torch.Size([16519, 1024])
111+
```
112+
113+
## 3. Train UniSRec
114+
115+
```python
116+
from rectools.fast_transformers import UniSRecModel
117+
118+
embeddings = torch.load("item_embeddings.pt", weights_only=True)
119+
120+
user_ids = torch.tensor(train_val["user_id"].values, dtype=torch.long)
121+
item_ids = torch.tensor(train_val["item_id"].values, dtype=torch.long)
122+
timestamps = torch.tensor(train_val["timestamp"].values, dtype=torch.long)
123+
124+
model = UniSRecModel(
125+
pretrained_item_embeddings=embeddings,
126+
# Architecture
127+
n_factors=256,
128+
projection_hidden=512,
129+
n_blocks=2,
130+
n_heads=2,
131+
session_max_len=50,
132+
dropout=0.1,
133+
adaptor_dropout=0.2,
134+
adaptor_type="pca",
135+
use_adaptor_ffn=True,
136+
ffn_type="conv1d",
137+
ffn_expansion=1,
138+
# Training
139+
epochs=10,
140+
lr=1e-4,
141+
lr_head=0.3,
142+
lr_wp=0.1,
143+
lr_transformer=3.0,
144+
optimizer="adamw",
145+
scheduler="cosine_warmup",
146+
warmup_ratio=0.05,
147+
min_lr_ratio=0.1,
148+
grad_clip=1.0,
149+
weight_decay=0.01,
150+
loss="softmax",
151+
batch_size=128,
152+
train_min_user_interactions=2,
153+
verbose=1,
154+
)
155+
156+
model.fit(user_ids, item_ids, timestamps)
157+
# Training: ~194s on RTX 3090 (10 epochs)
158+
```
159+
160+
### Save / load checkpoint
161+
162+
```python
163+
model.save_checkpoint("unisrec_kion.pt")
164+
165+
# Later:
166+
model2 = UniSRecModel(pretrained_item_embeddings=embeddings, n_factors=256, ...)
167+
model2.load_checkpoint("unisrec_kion.pt", device="cuda")
168+
```
169+
170+
## 4. Evaluate
171+
172+
Leave-last-out evaluation with HR@K and NDCG@K:
173+
174+
```python
175+
import numpy as np
176+
177+
net = model.net
178+
net.eval().cuda()
179+
device = torch.device("cuda")
180+
181+
# Get projected item embeddings
182+
item_embs = net.project_all()
183+
unique_items = model.item_id_mapping
184+
ext_to_int = {int(unique_items[i].item()): i + 1 for i in range(len(unique_items))}
185+
186+
# Build user histories
187+
train_grouped = train_val.sort_values("timestamp").groupby("user_id")["item_id"].agg(list).to_dict()
188+
test_grouped = test.groupby("user_id")["item_id"].first().to_dict()
189+
test_users = list(test_grouped.keys())
190+
191+
hits10, ndcg10, total = 0, 0.0, 0
192+
maxlen = model.session_max_len
193+
194+
with torch.no_grad():
195+
for start in range(0, len(test_users), 256):
196+
batch_users = test_users[start:start + 256]
197+
seqs, targets = [], []
198+
for uid in batch_users:
199+
history = train_grouped.get(uid, [])
200+
mapped = [ext_to_int[iid] for iid in history if iid in ext_to_int]
201+
if not mapped:
202+
continue
203+
seq = mapped[-maxlen:]
204+
seqs.append([0] * (maxlen - len(seq)) + seq)
205+
targets.append(ext_to_int.get(test_grouped[uid]))
206+
if not seqs:
207+
continue
208+
x = torch.tensor(seqs, dtype=torch.long, device=device)
209+
h = net.encode_last(x, use_id=False)
210+
scores = h @ item_embs.T
211+
scores[:, 0] = float("-inf")
212+
for i, target_int in enumerate(targets):
213+
if target_int is None:
214+
continue
215+
_, topk = scores[i].topk(10)
216+
topk = topk.cpu().tolist()
217+
if target_int in topk:
218+
rank = topk.index(target_int)
219+
hits10 += 1
220+
ndcg10 += 1.0 / np.log2(rank + 2)
221+
total += 1
222+
223+
print(f"HR@10 = {hits10/total:.4f}")
224+
print(f"NDCG@10 = {ndcg10/total:.4f}")
225+
```
226+
227+
## 5. Results
228+
229+
Trained on NVIDIA RTX 3090, 10 epochs, same architecture (256d, 2 blocks, 2 heads, max_len=50):
230+
231+
| Model | Embedder | HR@5 | NDCG@5 | HR@10 | NDCG@10 | Train Time |
232+
|-------|----------|------|--------|-------|---------|------------|
233+
| **UniSRec** | all-MiniLM-L6-v2 (384d) | 0.1421 | 0.0988 | 0.1896 | 0.1145 | ~194s |
234+
| **UniSRec** | Qwen3-Embedding-0.6B (1024d) | 0.1529 | 0.1012 | 0.2018 | 0.1171 | ~178s |
235+
| **SASRec** (RecTools) | ID embeddings | 0.1606 | 0.1081 | 0.2175 | 0.1265 | ~166s |
236+
237+
Qwen3-Embedding-0.6B closes most of the gap to SASRec (HR@10 delta: 1.6pp vs 2.8pp with MiniLM). SASRec with learned ID embeddings is stronger when sufficient interaction data is available. UniSRec's advantage is in cold-start and transfer scenarios where text embeddings provide semantic signal for items with no interaction history.
238+
239+
## Key Parameters
240+
241+
| Parameter | Description | Default |
242+
|-----------|-------------|---------|
243+
| `n_factors` | Hidden dimension of the transformer | 256 |
244+
| `adaptor_type` | Adaptor type: `"pca"` or `"bn"` | `"pca"` |
245+
| `session_max_len` | Maximum sequence length | 200 |
246+
| `epochs` | Number of training epochs | 10 |
247+
| `lr` | Base learning rate (adaptor layernorms) | 1e-4 |
248+
| `lr_wp` | Multiplier for PCA whitening projection | 0.1 |
249+
| `lr_transformer` | Multiplier for transformer layers | 3.0 |
250+
| `lr_head` | Multiplier for head layer | 0.3 |
251+
| `loss` | Loss function: `"softmax"`, `"BCE"`, `"gBCE"`, `"sampled_softmax"` | `"softmax"` |
252+
| `patience` | Early stopping patience (None = disabled) | None |
253+
| `scheduler` | LR scheduler: `None` or `"cosine_warmup"` | None |
254+
255+
## ONNX Export
256+
257+
```python
258+
model.export_to_onnx(
259+
encoder_path="unisrec_encoder.onnx",
260+
items_path="unisrec_items.onnx",
261+
)
262+
```

0 commit comments

Comments
 (0)