Skip to content

Commit f2fdfe5

Browse files
TOPAPECRecTools Dev
authored andcommitted
feat: add ONNX export, hash ID mapping, and map_item_ids
- Add hash-based ID mapping (splitmix64) as alternative to dense torch.unique mapping in build_sequences and align_embeddings. - Add UniSRecModel.export_to_onnx() for native ONNX export of encoder and item embeddings (project_all). - Add UniSRecModel.map_item_ids() for external→internal ID conversion at inference time (works for both dense and hash modes). - Remove FlatSASRecModel/FlatSASRecLightning (RecTools-coupled wrappers that duplicated UniSRecModel functionality). - Add tests: hash mapping (including string-derived IDs), ONNX export roundtrip, map_item_ids for both modes.
1 parent 7d3850b commit f2fdfe5

11 files changed

Lines changed: 605 additions & 706 deletions

File tree

rectools/fast_transformers/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Fast Transformers: flat sequential recommenders without ItemNet hierarchy."""
22

3-
from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, make_dataloader
4-
from .lightning_wrap import FlatSASRecLightning
5-
from .model import FlatSASRecConfig, FlatSASRecModel
3+
from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, hash_item_ids, make_dataloader
64
from .net import FlatSASRec, SASRecBlock
75
from .ranking import rank_topk
86
from .unisrec_lightning import UniSRecLightning
@@ -12,13 +10,11 @@
1210
__all__ = [
1311
"build_sequences",
1412
"align_embeddings",
13+
"hash_item_ids",
1514
"GPUBatchDataset",
1615
"make_dataloader",
1716
"FlatSASRec",
1817
"SASRecBlock",
19-
"FlatSASRecLightning",
20-
"FlatSASRecModel",
21-
"FlatSASRecConfig",
2218
"rank_topk",
2319
"UniSRec",
2420
"FeedForward",

rectools/fast_transformers/gpu_data.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,49 @@
77
from torch.utils.data import Dataset as TorchDataset
88

99

10+
def _splitmix64(x: torch.Tensor) -> torch.Tensor:
11+
"""Vectorized splitmix64 bit-mixer: element-wise int64 hash over a torch tensor.
12+
13+
Standard library hashes (``hash()``, ``hashlib``) operate on scalar Python objects
14+
and cannot be vectorized across GPU tensors. Splitmix64 is pure int64 arithmetic,
15+
so it maps naturally to ``torch.Tensor`` ops and runs on any device.
16+
17+
Reference: https://xorshift.di.unimi.it/splitmix64.c (Vigna, 2015).
18+
"""
19+
x = x.long()
20+
x = (x ^ (x >> 30)) * (-4658895280553007687) # 0xbf58476d1ce4e5b9 as signed int64
21+
x = (x ^ (x >> 27)) * (-7723592293110705685) # 0x94d049bb133111eb as signed int64
22+
return x ^ (x >> 31)
23+
24+
25+
def hash_item_ids(item_ids: torch.Tensor, dict_size: int) -> torch.Tensor:
26+
"""Map arbitrary integer item IDs to [1, dict_size] via splitmix64 hash."""
27+
return _splitmix64(item_ids) % dict_size + 1
28+
29+
1030
def build_sequences(
1131
user_ids: torch.Tensor,
1232
item_ids: torch.Tensor,
1333
timestamps: torch.Tensor,
1434
max_len: int,
1535
min_interactions: int = 2,
1636
device: str = "cuda",
37+
id_mapping: str = "dense",
1738
) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1839
user_ids = user_ids.to(device)
1940
item_ids = item_ids.to(device)
2041
timestamps = timestamps.to(device)
2142

22-
unique_items, item_inv = torch.unique(item_ids, return_inverse=True)
23-
internal_items = item_inv + 1
43+
unique_items = torch.unique(item_ids)
44+
n_unique = len(unique_items)
45+
46+
if id_mapping == "dense":
47+
_, item_inv = torch.unique(item_ids, return_inverse=True)
48+
internal_items = item_inv + 1
49+
elif id_mapping == "hash":
50+
internal_items = hash_item_ids(item_ids, n_unique)
51+
else:
52+
raise ValueError(f"Unknown id_mapping: {id_mapping}. Use 'dense' or 'hash'")
2453

2554
unique_users, user_inv = torch.unique(user_ids, return_inverse=True)
2655

@@ -74,16 +103,23 @@ def align_embeddings(
74103
pretrained: torch.Tensor,
75104
unique_items: torch.Tensor,
76105
n_items: int,
106+
id_mapping: str = "dense",
77107
) -> torch.Tensor:
78108
idx = unique_items.long().cpu()
79109
valid = (idx >= 0) & (idx < pretrained.shape[0])
80110

81111
if pretrained.ndim == 2:
82112
aligned = torch.zeros(n_items + 1, pretrained.shape[1])
83-
aligned[1:][valid] = pretrained[idx[valid]]
84113
else:
85114
aligned = torch.zeros(n_items + 1, pretrained.shape[1], pretrained.shape[2])
115+
116+
if id_mapping == "dense":
86117
aligned[1:][valid] = pretrained[idx[valid]]
118+
elif id_mapping == "hash":
119+
positions = hash_item_ids(idx, n_items)
120+
aligned[positions[valid]] = pretrained[idx[valid]]
121+
else:
122+
raise ValueError(f"Unknown id_mapping: {id_mapping}. Use 'dense' or 'hash'")
87123

88124
return aligned
89125

rectools/fast_transformers/lightning_wrap.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

0 commit comments

Comments
 (0)