Skip to content

Commit ec9d6c2

Browse files
committed
add yambda timestamped
1 parent 0b28254 commit ec9d6c2

33 files changed

Lines changed: 7760 additions & 425 deletions

notebooks/AmazonBeautyDatasetStatistics.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@
405405
],
406406
"metadata": {
407407
"kernelspec": {
408-
"display_name": ".venv",
408+
"display_name": "Python 3 (ipykernel)",
409409
"language": "python",
410410
"name": "python3"
411411
},
@@ -419,7 +419,7 @@
419419
"name": "python",
420420
"nbconvert_exporter": "python",
421421
"pygments_lexer": "ipython3",
422-
"version": "3.12.6"
422+
"version": "3.10.12"
423423
}
424424
},
425425
"nbformat": 4,

scripts/plum-yambda/callbacks.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
3+
import irec.callbacks as cb
4+
from irec.runners import TrainingRunner, TrainingRunnerContext
5+
6+
class InitCodebooks(cb.TrainingCallback):
7+
def __init__(self, dataloader):
8+
super().__init__()
9+
self._dataloader = dataloader
10+
11+
@torch.no_grad()
12+
def before_run(self, runner: TrainingRunner):
13+
for i in range(len(runner.model.codebooks)):
14+
X = next(iter(self._dataloader))['embedding']
15+
idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])]
16+
remainder = runner.model.encoder(X[idx])
17+
18+
for j in range(i):
19+
codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j])
20+
codebook_vectors = runner.model.codebooks[j][codebook_indices]
21+
remainder = remainder - codebook_vectors
22+
23+
runner.model.codebooks[i].data = remainder.detach()
24+
25+
26+
class FixDeadCentroids(cb.TrainingCallback):
27+
def __init__(self, dataloader):
28+
super().__init__()
29+
self._dataloader = dataloader
30+
31+
def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext):
32+
for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)):
33+
context.metrics[f'num_dead/{i}'] = num_fixed
34+
35+
@torch.no_grad()
36+
def fix_dead_codebooks(self, runner: TrainingRunner):
37+
num_fixed = []
38+
for codebook_idx, codebook in enumerate(runner.model.codebooks):
39+
centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device)
40+
random_batch = next(iter(self._dataloader))['embedding']
41+
42+
for batch in self._dataloader:
43+
remainder = runner.model.encoder(batch['embedding'])
44+
for l in range(codebook_idx):
45+
ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
46+
remainder = remainder - runner.model.codebooks[l][ind]
47+
48+
indices = runner.model.get_codebook_indices(remainder, codebook)
49+
centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))
50+
51+
dead_mask = (centroid_counts == 0)
52+
num_dead = int(dead_mask.sum().item())
53+
num_fixed.append(num_dead)
54+
if num_dead == 0:
55+
continue
56+
57+
remainder = runner.model.encoder(random_batch)
58+
for l in range(codebook_idx):
59+
ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
60+
remainder = remainder - runner.model.codebooks[l][ind]
61+
remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
62+
codebook[dead_mask] = remainder.detach()
63+
64+
return num_fixed

scripts/plum-yambda/cooc_data.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import json
2+
import pickle
3+
from collections import defaultdict, Counter
4+
5+
import numpy as np
6+
from loguru import logger
7+
8+
9+
import pickle
10+
from collections import defaultdict, Counter
11+
12+
class CoocMappingDataset:
13+
def __init__(
14+
self,
15+
train_sampler,
16+
num_items,
17+
cooccur_counter_mapping=None
18+
):
19+
self._train_sampler = train_sampler
20+
self._num_items = num_items
21+
self._cooccur_counter_mapping = cooccur_counter_mapping
22+
23+
@classmethod
24+
def create(cls, inter_json_path, window_size):
25+
max_item_id = 0
26+
train_dataset, validation_dataset, test_dataset = [], [], []
27+
28+
with open(inter_json_path, 'r') as f:
29+
user_interactions = json.load(f)
30+
31+
for user_id_str, item_ids in user_interactions.items():
32+
user_id = int(user_id_str)
33+
if item_ids:
34+
max_item_id = max(max_item_id, max(item_ids))
35+
assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items'
36+
train_dataset.append({
37+
'user.ids': [user_id],
38+
'item.ids': item_ids[:-2],
39+
})
40+
41+
cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
42+
logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')
43+
44+
train_sampler = train_dataset
45+
46+
return cls(
47+
train_sampler=train_sampler,
48+
num_items=max_item_id + 1,
49+
cooccur_counter_mapping=cooccur_counter_mapping
50+
)
51+
52+
@classmethod
53+
def create_from_split_part(
54+
cls,
55+
train_inter_json_path,
56+
window_size
57+
):
58+
59+
max_item_id = 0
60+
train_dataset = []
61+
62+
with open(train_inter_json_path, 'r') as f:
63+
train_interactions = json.load(f)
64+
65+
# Обрабатываем TRAIN
66+
for user_id_str, item_ids in train_interactions.items():
67+
user_id = int(user_id_str)
68+
if item_ids:
69+
max_item_id = max(max_item_id, max(item_ids))
70+
71+
train_dataset.append({
72+
'user.ids': [user_id],
73+
'item.ids': item_ids,
74+
})
75+
76+
logger.debug(f'Train: {len(train_dataset)} users')
77+
logger.debug(f'Max item ID: {max_item_id}')
78+
79+
cooccur_counter_mapping = cls.build_cooccur_counter_mapping(
80+
train_dataset,
81+
window_size=window_size
82+
)
83+
84+
logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items')
85+
86+
return cls(
87+
train_sampler=train_dataset,
88+
num_items=max_item_id + 1,
89+
cooccur_counter_mapping=cooccur_counter_mapping
90+
)
91+
92+
93+
@staticmethod
94+
def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно
95+
cooccur_counts = defaultdict(Counter)
96+
for session in train_dataset:
97+
items = session['item.ids']
98+
for i in range(len(items)):
99+
item_i = items[i]
100+
for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)):
101+
if i != j:
102+
cooccur_counts[item_i][items[j]] += 1
103+
return cooccur_counts
104+
105+
106+
@property
107+
def cooccur_counter_mapping(self):
108+
return self._cooccur_counter_mapping

scripts/plum-yambda/data.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import numpy as np
2+
import pickle
3+
4+
from irec.data.base import BaseDataset
5+
from irec.data.transforms import Transform
6+
7+
8+
import polars as pl
9+
import numpy as np
10+
import torch
11+
12+
class EmbeddingDatasetParquet(BaseDataset):
13+
def __init__(self, data_path):
14+
self.df = pl.read_parquet(data_path)
15+
self.item_ids = np.array(self.df['item_id'], dtype=np.int64)
16+
self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32)
17+
print(f"embedding dim: {self.embeddings[0].shape}")
18+
19+
def __getitem__(self, idx):
20+
index = self.item_ids[idx]
21+
tensor_emb = self.embeddings[idx]
22+
return {
23+
'item_id': index,
24+
'embedding': tensor_emb,
25+
'embedding_dim': len(tensor_emb)
26+
}
27+
28+
def __len__(self):
29+
return len(self.embeddings)
30+
31+
32+
class EmbeddingDataset(BaseDataset):
33+
def __init__(self, data_path):
34+
self.data_path = data_path
35+
with open(data_path, 'rb') as f:
36+
self.data = pickle.load(f)
37+
38+
self.item_ids = np.array(self.data['item_id'], dtype=np.int64)
39+
self.embeddings = np.array(self.data['embedding'], dtype=np.float32)
40+
41+
def __getitem__(self, idx):
42+
index = self.item_ids[idx]
43+
tensor_emb = self.embeddings[idx]
44+
return {
45+
'item_id': index,
46+
'embedding': tensor_emb,
47+
'embedding_dim': len(tensor_emb)
48+
}
49+
50+
def __len__(self):
51+
return len(self.embeddings)
52+
53+
54+
class ProcessEmbeddings(Transform):
55+
def __init__(self, embedding_dim, keys):
56+
self.embedding_dim = embedding_dim
57+
self.keys = keys
58+
59+
def __call__(self, batch):
60+
for key in self.keys:
61+
batch[key] = batch[key].reshape(-1, self.embedding_dim)
62+
return batch

scripts/plum-yambda/models.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class PlumRQVAE(nn.Module):
6+
def __init__(
7+
self,
8+
input_dim,
9+
num_codebooks,
10+
codebook_size,
11+
embedding_dim,
12+
beta=0.25,
13+
quant_loss_weight=1.0,
14+
contrastive_loss_weight=1.0,
15+
temperature=0.0,
16+
):
17+
super().__init__()
18+
self.register_buffer('beta', torch.tensor(beta))
19+
self.temperature = temperature
20+
21+
self.input_dim = input_dim
22+
self.num_codebooks = num_codebooks
23+
self.codebook_size = codebook_size
24+
self.embedding_dim = embedding_dim
25+
self.quant_loss_weight = quant_loss_weight
26+
27+
self.contrastive_loss_weight = contrastive_loss_weight
28+
29+
self.encoder = self.make_encoding_tower(input_dim, embedding_dim)
30+
self.decoder = self.make_encoding_tower(embedding_dim, input_dim)
31+
32+
self.codebooks = torch.nn.ParameterList()
33+
for _ in range(num_codebooks):
34+
cb = torch.FloatTensor(codebook_size, embedding_dim)
35+
#nn.init.normal_(cb)
36+
self.codebooks.append(cb)
37+
38+
@staticmethod
39+
def make_encoding_tower(d1, d2, bias=False):
40+
return torch.nn.Sequential(
41+
nn.Linear(d1, d1),
42+
nn.ReLU(),
43+
nn.Linear(d1, d2),
44+
nn.ReLU(),
45+
nn.Linear(d2, d2, bias=bias)
46+
)
47+
48+
@staticmethod
49+
def get_codebook_indices(remainder, codebook):
50+
dist = torch.cdist(remainder, codebook)
51+
return dist.argmin(dim=-1)
52+
53+
def _quantize_representation(self, latent_vector):
54+
latent_restored = 0
55+
remainder = latent_vector
56+
57+
for codebook in self.codebooks:
58+
codebook_indices = self.get_codebook_indices(remainder, codebook)
59+
quantized = codebook[codebook_indices]
60+
codebook_vectors = remainder + (quantized - remainder).detach()
61+
latent_restored += codebook_vectors
62+
remainder = remainder - codebook_vectors
63+
64+
return latent_restored
65+
66+
def contrastive_loss(self, p_i, p_i_star):
67+
N_b = p_i.size(0)
68+
69+
p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза
70+
p_i_star = F.normalize(p_i_star, p=2, dim=-1)
71+
72+
similarities = torch.matmul(p_i, p_i_star.T) / self.temperature
73+
74+
labels = torch.arange(N_b, dtype=torch.long, device=p_i.device)
75+
76+
loss = F.cross_entropy(similarities, labels)
77+
78+
return loss #только по последней размерности
79+
80+
def forward(self, inputs):
81+
latent_vector = self.encoder(inputs['embedding'])
82+
# print(f"latent vector shape: {latent_vector.shape}")
83+
# print(f"inputs embedding shape: {inputs['embedding']}")
84+
item_ids = inputs['item_id']
85+
86+
latent_restored = 0
87+
rqvae_loss = 0
88+
clusters = []
89+
remainder = latent_vector
90+
91+
for codebook in self.codebooks:
92+
codebook_indices = self.get_codebook_indices(remainder, codebook)
93+
clusters.append(codebook_indices)
94+
95+
quantized = codebook[codebook_indices]
96+
codebook_vectors = remainder + (quantized - remainder).detach()
97+
98+
rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach())
99+
rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach())
100+
101+
latent_restored += codebook_vectors
102+
remainder = remainder - codebook_vectors
103+
104+
embeddings_restored = self.decoder(latent_restored)
105+
recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding'])
106+
107+
if 'cooccurrence_embedding' in inputs:
108+
# print(f"cooccurrence_embedding shape: {inputs['cooccurrence_embedding'].shape} device {inputs['cooccurrence_embedding'].device}" )
109+
# print(f"latent_restored shape {latent_restored.shape} device {latent_restored.device}")
110+
cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device))
111+
cooccurrence_restored = self._quantize_representation(cooccurrence_latent)
112+
con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored)
113+
else:
114+
con_loss = torch.as_tensor(0.0, device=latent_vector.device)
115+
116+
loss = (
117+
recon_loss
118+
+ self.quant_loss_weight * rqvae_loss
119+
+ self.contrastive_loss_weight * con_loss
120+
).mean()
121+
122+
clusters_counts = []
123+
for cluster in clusters:
124+
clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size))
125+
126+
return loss, {
127+
'loss': loss.item(),
128+
'recon_loss': recon_loss.mean().item(),
129+
'rqvae_loss': rqvae_loss.mean().item(),
130+
'con_loss': con_loss.item(),
131+
132+
'clusters_counts': clusters_counts,
133+
'clusters': torch.stack(clusters).T,
134+
'embedding_hat': embeddings_restored,
135+
}

0 commit comments

Comments
 (0)