1+ import torch
2+ import torch .nn as nn
3+ import torch .nn .functional as F
4+
5+ class PlumRQVAE (nn .Module ):
6+ def __init__ (
7+ self ,
8+ input_dim ,
9+ num_codebooks ,
10+ codebook_size ,
11+ embedding_dim ,
12+ beta = 0.25 ,
13+ quant_loss_weight = 1.0 ,
14+ contrastive_loss_weight = 1.0 ,
15+ temperature = 0.0 ,
16+ ):
17+ super ().__init__ ()
18+ self .register_buffer ('beta' , torch .tensor (beta ))
19+ self .temperature = temperature
20+
21+ self .input_dim = input_dim
22+ self .num_codebooks = num_codebooks
23+ self .codebook_size = codebook_size
24+ self .embedding_dim = embedding_dim
25+ self .quant_loss_weight = quant_loss_weight
26+
27+ self .contrastive_loss_weight = contrastive_loss_weight
28+
29+ self .encoder = self .make_encoding_tower (input_dim , embedding_dim )
30+ self .decoder = self .make_encoding_tower (embedding_dim , input_dim )
31+
32+ self .codebooks = torch .nn .ParameterList ()
33+ for _ in range (num_codebooks ):
34+ cb = torch .FloatTensor (codebook_size , embedding_dim )
35+ #nn.init.normal_(cb)
36+ self .codebooks .append (cb )
37+
38+ @staticmethod
39+ def make_encoding_tower (d1 , d2 , bias = False ):
40+ return torch .nn .Sequential (
41+ nn .Linear (d1 , d1 ),
42+ nn .ReLU (),
43+ nn .Linear (d1 , d2 ),
44+ nn .ReLU (),
45+ nn .Linear (d2 , d2 , bias = bias )
46+ )
47+
48+ @staticmethod
49+ def get_codebook_indices (remainder , codebook ):
50+ dist = torch .cdist (remainder , codebook )
51+ return dist .argmin (dim = - 1 )
52+
53+ def _quantize_representation (self , latent_vector ):
54+ latent_restored = 0
55+ remainder = latent_vector
56+
57+ for codebook in self .codebooks :
58+ codebook_indices = self .get_codebook_indices (remainder , codebook )
59+ quantized = codebook [codebook_indices ]
60+ codebook_vectors = remainder + (quantized - remainder ).detach ()
61+ latent_restored += codebook_vectors
62+ remainder = remainder - codebook_vectors
63+
64+ return latent_restored
65+
66+ def contrastive_loss (self , p_i , p_i_star ):
67+ N_b = p_i .size (0 )
68+
69+ p_i = F .normalize (p_i , p = 2 , dim = - 1 ) #TODO посмотреть без нормалайза
70+ p_i_star = F .normalize (p_i_star , p = 2 , dim = - 1 )
71+
72+ similarities = torch .matmul (p_i , p_i_star .T ) / self .temperature
73+
74+ labels = torch .arange (N_b , dtype = torch .long , device = p_i .device )
75+
76+ loss = F .cross_entropy (similarities , labels )
77+
78+ return loss #только по последней размерности
79+
80+ def forward (self , inputs ):
81+ latent_vector = self .encoder (inputs ['embedding' ])
82+ # print(f"latent vector shape: {latent_vector.shape}")
83+ # print(f"inputs embedding shape: {inputs['embedding']}")
84+ item_ids = inputs ['item_id' ]
85+
86+ latent_restored = 0
87+ rqvae_loss = 0
88+ clusters = []
89+ remainder = latent_vector
90+
91+ for codebook in self .codebooks :
92+ codebook_indices = self .get_codebook_indices (remainder , codebook )
93+ clusters .append (codebook_indices )
94+
95+ quantized = codebook [codebook_indices ]
96+ codebook_vectors = remainder + (quantized - remainder ).detach ()
97+
98+ rqvae_loss += self .beta * torch .nn .functional .mse_loss (remainder , quantized .detach ())
99+ rqvae_loss += torch .nn .functional .mse_loss (quantized , remainder .detach ())
100+
101+ latent_restored += codebook_vectors
102+ remainder = remainder - codebook_vectors
103+
104+ embeddings_restored = self .decoder (latent_restored )
105+ recon_loss = torch .nn .functional .mse_loss (embeddings_restored , inputs ['embedding' ])
106+
107+ if 'cooccurrence_embedding' in inputs :
108+ # print(f"cooccurrence_embedding shape: {inputs['cooccurrence_embedding'].shape} device {inputs['cooccurrence_embedding'].device}" )
109+ # print(f"latent_restored shape {latent_restored.shape} device {latent_restored.device}")
110+ cooccurrence_latent = self .encoder (inputs ['cooccurrence_embedding' ].to (latent_restored .device ))
111+ cooccurrence_restored = self ._quantize_representation (cooccurrence_latent )
112+ con_loss = self .contrastive_loss (latent_restored , cooccurrence_restored )
113+ else :
114+ con_loss = torch .as_tensor (0.0 , device = latent_vector .device )
115+
116+ loss = (
117+ recon_loss
118+ + self .quant_loss_weight * rqvae_loss
119+ + self .contrastive_loss_weight * con_loss
120+ ).mean ()
121+
122+ clusters_counts = []
123+ for cluster in clusters :
124+ clusters_counts .append (torch .bincount (cluster , minlength = self .codebook_size ))
125+
126+ return loss , {
127+ 'loss' : loss .item (),
128+ 'recon_loss' : recon_loss .mean ().item (),
129+ 'rqvae_loss' : rqvae_loss .mean ().item (),
130+ 'con_loss' : con_loss .item (),
131+
132+ 'clusters_counts' : clusters_counts ,
133+ 'clusters' : torch .stack (clusters ).T ,
134+ 'embedding_hat' : embeddings_restored ,
135+ }
0 commit comments