Skip to content

Commit 4dd0882

Browse files
committed
add yambda timestamped
1 parent 0b28254 commit 4dd0882

40 files changed

Lines changed: 8854 additions & 425 deletions

notebooks/AmazonBeautyDatasetStatistics.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@
405405
],
406406
"metadata": {
407407
"kernelspec": {
408-
"display_name": ".venv",
408+
"display_name": "Python 3 (ipykernel)",
409409
"language": "python",
410410
"name": "python3"
411411
},
@@ -419,7 +419,7 @@
419419
"name": "python",
420420
"nbconvert_exporter": "python",
421421
"pygments_lexer": "ipython3",
422-
"version": "3.12.6"
422+
"version": "3.10.12"
423423
}
424424
},
425425
"nbformat": 4,
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
from loguru import logger
2+
import os
3+
4+
import torch
5+
6+
import pickle
7+
8+
import irec.callbacks as cb
9+
from irec.data.dataloader import DataLoader
10+
from irec.data.transforms import Collate, ToTorch, ToDevice
11+
from irec.runners import TrainingRunner
12+
13+
from irec.utils import fix_random_seed
14+
15+
from callbacks import InitCodebooks, FixDeadCentroids
16+
from data import EmbeddingDataset, ProcessEmbeddings
17+
from models import PlumRQVAE
18+
from transforms import AddWeightedCooccurrenceEmbeddings
19+
from cooc_data import CoocMappingDataset
20+
21+
SEED_VALUE = 42
22+
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
23+
24+
NUM_EPOCHS = 500
25+
BATCH_SIZE = 1024
26+
27+
INPUT_DIM = 4096
28+
HIDDEN_DIM = 32
29+
CODEBOOK_SIZE = 256
30+
NUM_CODEBOOKS = 3
31+
BETA = 0.25
32+
LR = 1e-4
33+
WINDOW_SIZE = 2
34+
35+
EXPERIMENT_NAME = f'4-1_yambda_quantile_ws_{WINDOW_SIZE}'
36+
INTER_TRAIN_PATH = "/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json"
37+
EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl"
38+
IREC_PATH = '../../'
39+
40+
print(INTER_TRAIN_PATH)
41+
def main():
42+
fix_random_seed(SEED_VALUE)
43+
44+
data = CoocMappingDataset.create_from_split_part(
45+
train_inter_json_path=INTER_TRAIN_PATH,
46+
window_size=WINDOW_SIZE
47+
)
48+
49+
dataset = EmbeddingDataset(
50+
data_path=EMBEDDINGS_PATH
51+
)
52+
53+
item_id_to_embedding = {}
54+
all_item_ids = []
55+
for idx in range(len(dataset)):
56+
sample = dataset[idx]
57+
item_id = int(sample['item_id'])
58+
item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
59+
all_item_ids.append(item_id)
60+
61+
add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
62+
data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
63+
64+
train_dataloader = DataLoader(
65+
dataset,
66+
batch_size=BATCH_SIZE,
67+
shuffle=True,
68+
drop_last=True,
69+
).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
70+
ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
71+
).map(add_cooc_transform).repeat(NUM_EPOCHS)
72+
73+
valid_dataloader = DataLoader(
74+
dataset,
75+
batch_size=BATCH_SIZE,
76+
shuffle=False,
77+
drop_last=False,
78+
).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
79+
80+
LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
81+
82+
model = PlumRQVAE(
83+
input_dim=INPUT_DIM,
84+
num_codebooks=NUM_CODEBOOKS,
85+
codebook_size=CODEBOOK_SIZE,
86+
embedding_dim=HIDDEN_DIM,
87+
beta=BETA,
88+
quant_loss_weight=1.0,
89+
contrastive_loss_weight=1.0,
90+
temperature=1.0
91+
).to(DEVICE)
92+
93+
total_params = sum(p.numel() for p in model.parameters())
94+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
95+
96+
logger.debug(f'Overall parameters: {total_params:,}')
97+
logger.debug(f'Trainable parameters: {trainable_params:,}')
98+
99+
optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
100+
101+
callbacks = [
102+
InitCodebooks(valid_dataloader),
103+
104+
cb.BatchMetrics(metrics=lambda model_outputs, batch: {
105+
'loss': model_outputs['loss'],
106+
'recon_loss': model_outputs['recon_loss'],
107+
'rqvae_loss': model_outputs['rqvae_loss'],
108+
'con_loss': model_outputs['con_loss']
109+
}, name='train'),
110+
111+
FixDeadCentroids(valid_dataloader),
112+
113+
cb.MetricAccumulator(
114+
accumulators={
115+
'train/loss': cb.MeanAccumulator(),
116+
'train/recon_loss': cb.MeanAccumulator(),
117+
'train/rqvae_loss': cb.MeanAccumulator(),
118+
'train/con_loss': cb.MeanAccumulator(),
119+
'num_dead/0': cb.MeanAccumulator(),
120+
'num_dead/1': cb.MeanAccumulator(),
121+
'num_dead/2': cb.MeanAccumulator(),
122+
},
123+
reset_every_num_steps=LOG_EVERY_NUM_STEPS
124+
),
125+
126+
cb.Validation(
127+
dataset=valid_dataloader,
128+
callbacks=[
129+
cb.BatchMetrics(metrics=lambda model_outputs, batch: {
130+
'loss': model_outputs['loss'],
131+
'recon_loss': model_outputs['recon_loss'],
132+
'rqvae_loss': model_outputs['rqvae_loss'],
133+
'con_loss': model_outputs['con_loss']
134+
}, name='valid'),
135+
cb.MetricAccumulator(
136+
accumulators={
137+
'valid/loss': cb.MeanAccumulator(),
138+
'valid/recon_loss': cb.MeanAccumulator(),
139+
'valid/rqvae_loss': cb.MeanAccumulator(),
140+
'valid/con_loss': cb.MeanAccumulator()
141+
}
142+
),
143+
],
144+
).every_num_steps(LOG_EVERY_NUM_STEPS),
145+
146+
cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
147+
cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
148+
149+
cb.EarlyStopping(
150+
metric='valid/recon_loss',
151+
patience=40,
152+
minimize=True,
153+
model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
154+
).every_num_steps(LOG_EVERY_NUM_STEPS),
155+
]
156+
157+
logger.debug('Everything is ready for training process!')
158+
159+
runner = TrainingRunner(
160+
model=model,
161+
optimizer=optimizer,
162+
dataset=train_dataloader,
163+
callbacks=callbacks,
164+
)
165+
runner.run()
166+
167+
168+
if __name__ == '__main__':
169+
main()
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
from loguru import logger
2+
import os
3+
4+
import torch
5+
6+
import pickle
7+
8+
import irec.callbacks as cb
9+
from irec.data.dataloader import DataLoader
10+
from irec.data.transforms import Collate, ToTorch, ToDevice
11+
from irec.runners import TrainingRunner
12+
13+
from irec.utils import fix_random_seed
14+
15+
from callbacks import InitCodebooks, FixDeadCentroids
16+
from data import EmbeddingDataset, ProcessEmbeddings
17+
from models import PlumRQVAE
18+
from transforms import AddWeightedCooccurrenceEmbeddings
19+
from cooc_data import CoocMappingDataset
20+
21+
SEED_VALUE = 42
22+
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
23+
24+
NUM_EPOCHS = 500
25+
BATCH_SIZE = 1024
26+
27+
INPUT_DIM = 4096
28+
HIDDEN_DIM = 32
29+
CODEBOOK_SIZE = 256
30+
NUM_CODEBOOKS = 3
31+
BETA = 0.25
32+
LR = 1e-4
33+
WINDOW_SIZE = 2
34+
35+
EXPERIMENT_NAME = f'4-2_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
36+
INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json"
37+
EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl"
38+
IREC_PATH = '../../'
39+
40+
print(INTER_TRAIN_PATH)
41+
def main():
42+
fix_random_seed(SEED_VALUE)
43+
44+
data = CoocMappingDataset.create_from_split_part(
45+
train_inter_json_path=INTER_TRAIN_PATH,
46+
window_size=WINDOW_SIZE
47+
)
48+
49+
dataset = EmbeddingDataset(
50+
data_path=EMBEDDINGS_PATH
51+
)
52+
53+
item_id_to_embedding = {}
54+
all_item_ids = []
55+
for idx in range(len(dataset)):
56+
sample = dataset[idx]
57+
item_id = int(sample['item_id'])
58+
item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
59+
all_item_ids.append(item_id)
60+
61+
add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
62+
data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
63+
64+
train_dataloader = DataLoader(
65+
dataset,
66+
batch_size=BATCH_SIZE,
67+
shuffle=True,
68+
drop_last=True,
69+
).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
70+
ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
71+
).map(add_cooc_transform).repeat(NUM_EPOCHS)
72+
73+
valid_dataloader = DataLoader(
74+
dataset,
75+
batch_size=BATCH_SIZE,
76+
shuffle=False,
77+
drop_last=False,
78+
).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
79+
80+
LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
81+
82+
model = PlumRQVAE(
83+
input_dim=INPUT_DIM,
84+
num_codebooks=NUM_CODEBOOKS,
85+
codebook_size=CODEBOOK_SIZE,
86+
embedding_dim=HIDDEN_DIM,
87+
beta=BETA,
88+
quant_loss_weight=1.0,
89+
contrastive_loss_weight=1.0,
90+
temperature=1.0
91+
).to(DEVICE)
92+
93+
total_params = sum(p.numel() for p in model.parameters())
94+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
95+
96+
logger.debug(f'Overall parameters: {total_params:,}')
97+
logger.debug(f'Trainable parameters: {trainable_params:,}')
98+
99+
optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
100+
101+
callbacks = [
102+
InitCodebooks(valid_dataloader),
103+
104+
cb.BatchMetrics(metrics=lambda model_outputs, batch: {
105+
'loss': model_outputs['loss'],
106+
'recon_loss': model_outputs['recon_loss'],
107+
'rqvae_loss': model_outputs['rqvae_loss'],
108+
'con_loss': model_outputs['con_loss']
109+
}, name='train'),
110+
111+
FixDeadCentroids(valid_dataloader),
112+
113+
cb.MetricAccumulator(
114+
accumulators={
115+
'train/loss': cb.MeanAccumulator(),
116+
'train/recon_loss': cb.MeanAccumulator(),
117+
'train/rqvae_loss': cb.MeanAccumulator(),
118+
'train/con_loss': cb.MeanAccumulator(),
119+
'num_dead/0': cb.MeanAccumulator(),
120+
'num_dead/1': cb.MeanAccumulator(),
121+
'num_dead/2': cb.MeanAccumulator(),
122+
},
123+
reset_every_num_steps=LOG_EVERY_NUM_STEPS
124+
),
125+
126+
cb.Validation(
127+
dataset=valid_dataloader,
128+
callbacks=[
129+
cb.BatchMetrics(metrics=lambda model_outputs, batch: {
130+
'loss': model_outputs['loss'],
131+
'recon_loss': model_outputs['recon_loss'],
132+
'rqvae_loss': model_outputs['rqvae_loss'],
133+
'con_loss': model_outputs['con_loss']
134+
}, name='valid'),
135+
cb.MetricAccumulator(
136+
accumulators={
137+
'valid/loss': cb.MeanAccumulator(),
138+
'valid/recon_loss': cb.MeanAccumulator(),
139+
'valid/rqvae_loss': cb.MeanAccumulator(),
140+
'valid/con_loss': cb.MeanAccumulator()
141+
}
142+
),
143+
],
144+
).every_num_steps(LOG_EVERY_NUM_STEPS),
145+
146+
cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
147+
cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
148+
149+
cb.EarlyStopping(
150+
metric='valid/recon_loss',
151+
patience=40,
152+
minimize=True,
153+
model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
154+
).every_num_steps(LOG_EVERY_NUM_STEPS),
155+
]
156+
157+
logger.debug('Everything is ready for training process!')
158+
159+
runner = TrainingRunner(
160+
model=model,
161+
optimizer=optimizer,
162+
dataset=train_dataloader,
163+
callbacks=callbacks,
164+
)
165+
runner.run()
166+
167+
168+
if __name__ == '__main__':
169+
main()

0 commit comments

Comments
 (0)