|
| 1 | +# UniSRec Training Demo: KION Dataset |
| 2 | + |
| 3 | +This guide demonstrates training a UniSRec sequential recommender on the KION movie dataset using real text embeddings from movie descriptions. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +UniSRec jointly trains a PCA-based adaptor and a SASRec transformer encoder on frozen pretrained text embeddings. This allows the model to leverage semantic item representations without requiring collaborative item IDs. |
| 8 | + |
| 9 | +## Prerequisites |
| 10 | + |
| 11 | +```bash |
| 12 | +pip install torch pytorch-lightning sentence-transformers |
| 13 | +``` |
| 14 | + |
| 15 | +## 1. Prepare Data |
| 16 | + |
| 17 | +### Download the KION dataset |
| 18 | + |
| 19 | +```bash |
| 20 | +git clone https://github.com/irsafilo/KION_DATASET kion_data |
| 21 | +``` |
| 22 | + |
| 23 | +### Load and filter interactions |
| 24 | + |
| 25 | +```python |
| 26 | +import pandas as pd |
| 27 | +import torch |
| 28 | + |
| 29 | +# Load interactions |
| 30 | +interactions = pd.read_csv("kion_data/interactions.csv") |
| 31 | +interactions = interactions.rename(columns={"last_watch_dt": "timestamp"}) |
| 32 | +interactions["timestamp"] = pd.to_datetime(interactions["timestamp"]).astype(int) // 10**9 |
| 33 | + |
| 34 | +# Filter: min 5 interactions per item, min 2 per user |
| 35 | +item_counts = interactions.groupby("item_id").size() |
| 36 | +interactions = interactions[interactions["item_id"].isin(item_counts[item_counts >= 5].index)] |
| 37 | +user_counts = interactions.groupby("user_id").size() |
| 38 | +interactions = interactions[interactions["user_id"].isin(user_counts[user_counts >= 2].index)] |
| 39 | + |
| 40 | +print(f"Interactions: {len(interactions):,}") |
| 41 | +print(f"Users: {interactions['user_id'].nunique():,}") |
| 42 | +print(f"Items: {interactions['item_id'].nunique():,}") |
| 43 | +# Interactions: 643,786 |
| 44 | +# Users: 201,851 |
| 45 | +# Items: 6,228 |
| 46 | +``` |
| 47 | + |
| 48 | +### Leave-last-out split |
| 49 | + |
| 50 | +```python |
| 51 | +interactions = interactions.sort_values(["user_id", "timestamp"]) |
| 52 | +test = interactions.groupby("user_id").tail(1) |
| 53 | +train_val = interactions.drop(test.index) |
| 54 | + |
| 55 | +print(f"Train+Val: {len(train_val):,}, Test: {len(test):,}") |
| 56 | +# Train+Val: 441,935, Test: 201,851 |
| 57 | +``` |
| 58 | + |
| 59 | +## 2. Generate Text Embeddings |
| 60 | + |
| 61 | +Use English movie descriptions from the dataset with Qwen3-Embedding-0.6B: |
| 62 | + |
| 63 | +```bash |
| 64 | +pip install transformers |
| 65 | +``` |
| 66 | + |
| 67 | +```python |
| 68 | +from transformers import AutoTokenizer, AutoModel |
| 69 | + |
| 70 | +# Load item metadata (English descriptions) |
| 71 | +items = pd.read_csv("kion_data/data_en/items_en.csv") |
| 72 | +items = items.set_index("item_id") |
| 73 | + |
| 74 | +# Build description text |
| 75 | +texts = {} |
| 76 | +for item_id, row in items.iterrows(): |
| 77 | + parts = [str(row.get("title", ""))] |
| 78 | + if pd.notna(row.get("description")): |
| 79 | + parts.append(str(row["description"])) |
| 80 | + if pd.notna(row.get("genres")): |
| 81 | + parts.append(f"Genres: {row['genres']}") |
| 82 | + texts[item_id] = " ".join(parts) |
| 83 | + |
| 84 | +# Encode with Qwen3-Embedding-0.6B |
| 85 | +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") |
| 86 | +encoder = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B", dtype=torch.float16) |
| 87 | +encoder.cuda().eval() |
| 88 | + |
| 89 | +max_item_id = items.index.max() |
| 90 | +embeddings = torch.zeros(max_item_id + 1, 1024) |
| 91 | + |
| 92 | +item_ids_list = list(texts.keys()) |
| 93 | +text_list = list(texts.values()) |
| 94 | + |
| 95 | +with torch.no_grad(): |
| 96 | + for start in range(0, len(text_list), 32): |
| 97 | + batch_texts = text_list[start:start + 32] |
| 98 | + batch_ids = item_ids_list[start:start + 32] |
| 99 | + encoded = tokenizer(batch_texts, return_tensors="pt", padding=True, |
| 100 | + truncation=True, max_length=512).to("cuda") |
| 101 | + outputs = encoder(**encoded) |
| 102 | + mask = encoded["attention_mask"].unsqueeze(-1).half() |
| 103 | + pooled = (outputs.last_hidden_state * mask).sum(1) / mask.sum(1) |
| 104 | + pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1) |
| 105 | + for i, item_id in enumerate(batch_ids): |
| 106 | + embeddings[item_id] = pooled[i].cpu().float() |
| 107 | + |
| 108 | +torch.save(embeddings, "item_embeddings.pt") |
| 109 | +print(f"Embeddings: {embeddings.shape}") |
| 110 | +# Embeddings: torch.Size([16519, 1024]) |
| 111 | +``` |
| 112 | + |
| 113 | +## 3. Train UniSRec |
| 114 | + |
| 115 | +```python |
| 116 | +from rectools.fast_transformers import UniSRecModel |
| 117 | + |
| 118 | +embeddings = torch.load("item_embeddings.pt", weights_only=True) |
| 119 | + |
| 120 | +user_ids = torch.tensor(train_val["user_id"].values, dtype=torch.long) |
| 121 | +item_ids = torch.tensor(train_val["item_id"].values, dtype=torch.long) |
| 122 | +timestamps = torch.tensor(train_val["timestamp"].values, dtype=torch.long) |
| 123 | + |
| 124 | +model = UniSRecModel( |
| 125 | + pretrained_item_embeddings=embeddings, |
| 126 | + # Architecture |
| 127 | + n_factors=256, |
| 128 | + projection_hidden=512, |
| 129 | + n_blocks=2, |
| 130 | + n_heads=2, |
| 131 | + session_max_len=50, |
| 132 | + dropout=0.1, |
| 133 | + adaptor_dropout=0.2, |
| 134 | + adaptor_type="pca", |
| 135 | + use_adaptor_ffn=True, |
| 136 | + ffn_type="conv1d", |
| 137 | + ffn_expansion=1, |
| 138 | + # Training |
| 139 | + epochs=10, |
| 140 | + lr=1e-4, |
| 141 | + lr_head=0.3, |
| 142 | + lr_wp=0.1, |
| 143 | + lr_transformer=3.0, |
| 144 | + optimizer="adamw", |
| 145 | + scheduler="cosine_warmup", |
| 146 | + warmup_ratio=0.05, |
| 147 | + min_lr_ratio=0.1, |
| 148 | + grad_clip=1.0, |
| 149 | + weight_decay=0.01, |
| 150 | + loss="softmax", |
| 151 | + batch_size=128, |
| 152 | + train_min_user_interactions=2, |
| 153 | + verbose=1, |
| 154 | +) |
| 155 | + |
| 156 | +model.fit(user_ids, item_ids, timestamps) |
| 157 | +# Training: ~194s on RTX 3090 (10 epochs) |
| 158 | +``` |
| 159 | + |
| 160 | +### Save / load checkpoint |
| 161 | + |
| 162 | +```python |
| 163 | +model.save_checkpoint("unisrec_kion.pt") |
| 164 | + |
| 165 | +# Later: |
| 166 | +model2 = UniSRecModel(pretrained_item_embeddings=embeddings, n_factors=256, ...) |
| 167 | +model2.load_checkpoint("unisrec_kion.pt", device="cuda") |
| 168 | +``` |
| 169 | + |
| 170 | +## 4. Evaluate |
| 171 | + |
| 172 | +Leave-last-out evaluation with HR@K and NDCG@K: |
| 173 | + |
| 174 | +```python |
| 175 | +import numpy as np |
| 176 | + |
| 177 | +net = model.net |
| 178 | +net.eval().cuda() |
| 179 | +device = torch.device("cuda") |
| 180 | + |
| 181 | +# Get projected item embeddings |
| 182 | +item_embs = net.project_all() |
| 183 | +unique_items = model.item_id_mapping |
| 184 | +ext_to_int = {int(unique_items[i].item()): i + 1 for i in range(len(unique_items))} |
| 185 | + |
| 186 | +# Build user histories |
| 187 | +train_grouped = train_val.sort_values("timestamp").groupby("user_id")["item_id"].agg(list).to_dict() |
| 188 | +test_grouped = test.groupby("user_id")["item_id"].first().to_dict() |
| 189 | +test_users = list(test_grouped.keys()) |
| 190 | + |
| 191 | +hits10, ndcg10, total = 0, 0.0, 0 |
| 192 | +maxlen = model.session_max_len |
| 193 | + |
| 194 | +with torch.no_grad(): |
| 195 | + for start in range(0, len(test_users), 256): |
| 196 | + batch_users = test_users[start:start + 256] |
| 197 | + seqs, targets = [], [] |
| 198 | + for uid in batch_users: |
| 199 | + history = train_grouped.get(uid, []) |
| 200 | + mapped = [ext_to_int[iid] for iid in history if iid in ext_to_int] |
| 201 | + if not mapped: |
| 202 | + continue |
| 203 | + seq = mapped[-maxlen:] |
| 204 | + seqs.append([0] * (maxlen - len(seq)) + seq) |
| 205 | + targets.append(ext_to_int.get(test_grouped[uid])) |
| 206 | + if not seqs: |
| 207 | + continue |
| 208 | + x = torch.tensor(seqs, dtype=torch.long, device=device) |
| 209 | + h = net.encode_last(x, use_id=False) |
| 210 | + scores = h @ item_embs.T |
| 211 | + scores[:, 0] = float("-inf") |
| 212 | + for i, target_int in enumerate(targets): |
| 213 | + if target_int is None: |
| 214 | + continue |
| 215 | + _, topk = scores[i].topk(10) |
| 216 | + topk = topk.cpu().tolist() |
| 217 | + if target_int in topk: |
| 218 | + rank = topk.index(target_int) |
| 219 | + hits10 += 1 |
| 220 | + ndcg10 += 1.0 / np.log2(rank + 2) |
| 221 | + total += 1 |
| 222 | + |
| 223 | +print(f"HR@10 = {hits10/total:.4f}") |
| 224 | +print(f"NDCG@10 = {ndcg10/total:.4f}") |
| 225 | +``` |
| 226 | + |
| 227 | +## 5. Results |
| 228 | + |
| 229 | +Trained on NVIDIA RTX 3090, 10 epochs, same architecture (256d, 2 blocks, 2 heads, max_len=50): |
| 230 | + |
| 231 | +| Model | Embedder | HR@5 | NDCG@5 | HR@10 | NDCG@10 | Train Time | |
| 232 | +|-------|----------|------|--------|-------|---------|------------| |
| 233 | +| **UniSRec** | all-MiniLM-L6-v2 (384d) | 0.1421 | 0.0988 | 0.1896 | 0.1145 | ~194s | |
| 234 | +| **UniSRec** | Qwen3-Embedding-0.6B (1024d) | 0.1529 | 0.1012 | 0.2018 | 0.1171 | ~178s | |
| 235 | +| **SASRec** (RecTools) | ID embeddings | 0.1606 | 0.1081 | 0.2175 | 0.1265 | ~166s | |
| 236 | + |
| 237 | +Qwen3-Embedding-0.6B closes most of the gap to SASRec (HR@10 delta: 1.6pp vs 2.8pp with MiniLM). SASRec with learned ID embeddings is stronger when sufficient interaction data is available. UniSRec's advantage is in cold-start and transfer scenarios where text embeddings provide semantic signal for items with no interaction history. |
| 238 | + |
| 239 | +## Key Parameters |
| 240 | + |
| 241 | +| Parameter | Description | Default | |
| 242 | +|-----------|-------------|---------| |
| 243 | +| `n_factors` | Hidden dimension of the transformer | 256 | |
| 244 | +| `adaptor_type` | Adaptor type: `"pca"` or `"bn"` | `"pca"` | |
| 245 | +| `session_max_len` | Maximum sequence length | 200 | |
| 246 | +| `epochs` | Number of training epochs | 10 | |
| 247 | +| `lr` | Base learning rate (adaptor layernorms) | 1e-4 | |
| 248 | +| `lr_wp` | Multiplier for PCA whitening projection | 0.1 | |
| 249 | +| `lr_transformer` | Multiplier for transformer layers | 3.0 | |
| 250 | +| `lr_head` | Multiplier for head layer | 0.3 | |
| 251 | +| `loss` | Loss function: `"softmax"`, `"BCE"`, `"gBCE"`, `"sampled_softmax"` | `"softmax"` | |
| 252 | +| `patience` | Early stopping patience (None = disabled) | None | |
| 253 | +| `scheduler` | LR scheduler: `None` or `"cosine_warmup"` | None | |
| 254 | + |
| 255 | +## ONNX Export |
| 256 | + |
| 257 | +```python |
| 258 | +model.export_to_onnx( |
| 259 | + encoder_path="unisrec_encoder.onnx", |
| 260 | + items_path="unisrec_items.onnx", |
| 261 | +) |
| 262 | +``` |
0 commit comments