11from octis .models .model import AbstractModel
2+ from octis .models .RS_class import Replicated_Softmax
23import numpy as np
34from tqdm import tqdm
45import gensim .corpora as corpora
@@ -83,6 +84,19 @@ def __init__(
8384 'rmsprop' for RMSProp optimizer,
8485 'adam' for Adam optimizer,
8586 'adagrad' for Adagrad optimizer
87+
88+
89+ Example usage
90+ --------------------
91+
92+ from octis.dataset.dataset import Dataset
93+ from octis.models.RSM import RSM
94+
95+ dataset_20ng = Dataset()
96+ dataset_20ng.fetch_dataset("20NewsGroup")
97+
98+ rsm = RSM(num_topics=20, epochs=500, btsz=20, lr=0.0001, cd_type='mfcd', train_optimizer='rmsprop')
99+ output_rsm = rsm.train(dataset_20ng)
86100 """
87101 super ().__init__ ()
88102 self .hyperparameters = dict ()
@@ -236,21 +250,11 @@ def build_dtm(self, tokenized_corpus, id2word=None):
236250
237251 ############################################################## RSM original class
238252
239- class RSM_model (object ):
253+ class RSM_model (Replicated_Softmax ):
240254 def __init__ (self ):
241- self . W = None
255+ super (). __init__ ()
242256
243- def softmax_vec (self , array ):
244- exparr = np .exp (array )
245- return exparr / exparr .sum ()
246257
247- def softmax (self , array ):
248- maxs = np .max (array , axis = 1 , keepdims = True )
249- lse = maxs + np .log (np .sum (np .exp (array - maxs ), axis = 1 , keepdims = True ))
250- return np .exp (array - lse )
251-
252- def sigmoid (self , x ):
253- return 1 / (1 + np .exp (- x ))
254258
255259 ############################## energy and probability
256260
@@ -263,31 +267,6 @@ def neg_energy(self, v, h):
263267 en = t1 + t2 + t3
264268 return en
265269
266- def neg_free_energy (self , v ): # it's equivalent to the log pdf
267- w_vh , w_v , w_h = self .W
268- T = self .hidden
269- D = v .sum (axis = 1 )
270- fren = np .dot (v , w_v )
271- for j in range (T ):
272- w_j = w_vh [:, j ]
273- a_j = w_h [j ]
274- fren += np .log (1 + np .exp (D * a_j + np .dot (v , w_j )))
275- return fren
276-
277- def neg_free_energy_single_doc (self , v ): # it's equivalent to the log pdf
278- w_vh , w_v , w_h = self .W
279- T = self .hidden
280- D = v .sum ()
281- fren = np .dot (v , w_v )
282- for j in range (T ):
283- w_j = w_vh [:, j ]
284- a_j = w_h [j ]
285- fren += np .log (1 + np .exp (D * a_j + np .dot (v , w_j )))
286- return fren
287-
288- def marginal_pdf (self , v ):
289- return np .exp (self .neg_free_energy (v ))
290-
291270 def visible2hidden_vec (self , v ):
292271 w_vh , w_v , w_h = self .W
293272 D = v .sum ()
@@ -310,79 +289,8 @@ def hidden2visible(self, h):
310289 energy = np .tile (w_v , (h .shape [0 ], 1 )).T + np .dot (w_vh , h .T )
311290 return self .softmax (energy .T )
312291
313- def topic_words (self , topk , id2word = None ):
314- w_vh , w_v , w_h = self .W
315- T = self .hidden
316- if id2word is None :
317- id2word = self .id2word
318- words = np .array ([k for k in id2word .token2id .keys ()])
319-
320- toplist = []
321- for t in range (T ):
322- topw = w_vh [:, t ]
323- bestwords = words [np .argsort (topw )[::- 1 ]][0 :topk ]
324- toplist .append (bestwords )
325-
326- return toplist
327-
328- def _get_topic_word_matrix (self ):
329- """
330- Return the topic representation of the words
331- """
332- w_vh , w_v , w_h = self .W
333- topic_word_matrix = w_vh .T
334- normalized = []
335- for words_w in topic_word_matrix :
336- minimum = min (words_w )
337- words = words_w - minimum
338- normalized .append ([float (i ) / sum (words ) for i in words ])
339- topic_word_matrix = np .array (normalized )
340- return topic_word_matrix
341-
342- def _get_topic_word_matrix0 (self ):
343- """
344- Return the topic representation of the words
345- """
346- w_vh , w_v , w_h = self .W
347- topic_word_matrix = np .empty (w_vh .T .shape )
348- for t in range (w_vh .T .shape [0 ]):
349- topic_word_matrix [t , :] = self .softmax_vec (w_vh .T [t , :] - w_v )
350- return topic_word_matrix
351-
352- def _get_topic_doc (self , dtm ):
353- return self .visible2hidden (dtm ).T
354-
355- def _get_topics (self , topk ):
356- w_vh , w_v , w_h = self .W
357- T = self .hidden
358- words = np .array ([k for k in self .id2word .token2id .keys ()])
359-
360- toplist = []
361- for t in range (T ):
362- topw = w_vh [:, t ]
363- bestwords = words [np .argsort (topw )[::- 1 ]][0 :topk ]
364- toplist .append (bestwords )
365-
366- return toplist
367-
368- # topics_output = []
369- # for topic in result["topic-word-matrix"]:
370- # top_k = np.argsort(topic)[-top_words:]
371- # top_k_words = list(reversed([self.id2word[i] for i in top_k]))
372- # topics_output.append(top_k_words)
373-
374292 ##################################### leapfrog trainsition operators
375293
376- def multinomial_sample (self , probs , N ):
377- return np .random .multinomial (N , probs , size = 1 )[0 ]
378-
379- def unif_reject_sample (self , probs ):
380- h_unif = np .random .rand (* probs .shape )
381- h_sample = np .array (h_unif < probs , dtype = int )
382- return h_sample
383-
384- def deterministic_sample (self , probs ):
385- return (probs > 0.5 ).astype (int )
386294
387295 def gibbs_transition (self , v ):
388296 D = v .sum (axis = 1 )
@@ -432,21 +340,6 @@ def MH_transition_vec(self, state, logpdf):
432340
433341 ################################## gradient descent optimization
434342
435- def interaction_penalty (self , vel_vh , w_vh ):
436- if self .penalty :
437- if self .penL1 : # L1 penalty
438- if self .local_penalty :
439- penal = self .decay * np .sign (w_vh )
440- else :
441- penal = self .decay * np .sum (np .abs (w_vh )) * np .sign (w_vh )
442- else : # L2 penalty
443- if self .local_penalty :
444- penal = self .decay * w_vh
445- else :
446- penal = self .decay * np .sum (w_vh )
447-
448- vel_vh = vel_vh - penal
449- return vel_vh
450343
451344 def gradient_simple (self , v1 , v2 , h1 , h2 ):
452345 w_vh , w_v , w_h = self .W
@@ -782,66 +675,6 @@ def train_epoch(self):
782675
783676 self .t += 1
784677
785- def set_structure_from_dtm (
786- self ,
787- winit = None ,
788- dtm = None ,
789- val_dtm = None ,
790- softstart = 0.001 ,
791- num_topics = 5 ,
792- epochs = 5 ,
793- monitor_ppl = False ,
794- monitor_time = False ,
795- monitor_loglik = False ,
796- logdtm = False ,
797- ):
798- doval = val_dtm is not None
799-
800- if logdtm :
801- self .dtm = np .log (1 + dtm )
802- if doval :
803- self .val_dtm = np .log (1 + val_dtm )
804- else :
805- self .dtm = dtm
806- if doval :
807- self .val_dtm = np .log (1 + val_dtm )
808-
809- self .hidden = num_topics
810- N , dictsize = dtm .shape
811- self .visible = dictsize
812-
813- self .obs_ids = np .arange (N )
814-
815- if winit is not None :
816- ###self.W = winit WRONG: You are referencing the same arrays across runs
817- # defensive copy to avoid sharing mutable numpy arrays across runs
818- try :
819- self .W = tuple (np .array (arr , copy = True ) for arr in winit )
820- except Exception :
821- # fallback: keep original if not iterable
822- self .W = winit
823-
824- if self .W is None :
825- w_vh = softstart * np .random .randn (dictsize , num_topics )
826- w_v = softstart * np .random .randn (dictsize )
827- w_h = softstart * np .random .randn (num_topics )
828- self .W = w_vh , w_v , w_h
829- else :
830- print ("train already available weights" )
831- w_vh , w_v , w_h = self .W
832-
833- if monitor_time :
834- self .train_time = np .empty (epochs )
835-
836- if monitor_ppl :
837- self .train_ppl = np .empty (epochs )
838- if doval :
839- self .val_ppl = np .empty (epochs )
840-
841- if monitor_loglik :
842- self .train_loglik = np .empty (epochs )
843- if doval :
844- self .val_loglik = np .empty (epochs )
845678
846679 def set_train_hyper (
847680 self ,
@@ -998,7 +831,10 @@ def log_ppl_approx(self, dtm):
998831 """
999832 mfh = self .visible2hidden (dtm )
1000833 vprob = self .hidden2visible (mfh )
1001- lpub = np .exp (- np .nansum (np .log (vprob ) * dtm ) / np .sum (dtm ))
834+ vprob = np .clip (vprob , 1e-12 , None )
835+ sum_dtm = np .sum (dtm )
836+ assert sum_dtm > 0 , 'the sum of the dtm s entries has to be positive'
837+ lpub = - np .nansum (np .log (vprob ) * dtm ) / sum_dtm
1002838 return lpub
1003839
1004840 def ppl_approx (self , testmatrix ):
0 commit comments