11import numpy as np
2+ from scipy .special import expit
23
34
45class Replicated_Softmax :
@@ -32,13 +33,16 @@ def softmax(self, x):
3233 return np .exp (x - lse )
3334
3435 def softmax_vec (self , array ):
35- """simple softmax activation for an array vector (single document)"""
36- exparr = np .exp (array )
36+ """Numerically stable softmax for a single document / 1D vector (uses the same LSE trick as softmax)."""
37+ shifted = array - np .max (array ) # subtract max for stability
38+ exparr = np .exp (shifted )
3739 return exparr / exparr .sum ()
3840
41+
3942 def sigmoid (self , x ):
40- """basic sigmoid activation"""
41- return 1 / (1 + np .exp (- x ))
43+ """Numerically stable sigmoid activation"""
44+ return expit (x )
45+
4246
4347 def multinomial_sample (self , probs , N ):
4448 """
@@ -74,13 +78,13 @@ def interaction_penalty(self, vel_vh, w_vh):
7478 function to adjust the gradient of the
7579 topic-word interaction weights during a training iteration
7680 of a RS model by a penalty factor.
77- The model shoud have the attributes:
81+ The model should have the attributes:
7882 - penalty : bool : if the penalization should be applied
7983 - penL1: bool : if the penalty is of type L1 or L2
8084 - local_penalty : bool : if the penalty should be local or global
8185 - decay : float : the penalty factor to use
8286 This function also requires two numpy arrays as arguments:
83- - the interaction weigths matrix w_vh, that connects topics to words
87+ - the interaction weights matrix w_vh, that connects topics to words
8488 - the respective gradients vel_vh (also a matrix)
8589 """
8690 if self .penalty :
@@ -100,42 +104,28 @@ def interaction_penalty(self, vel_vh, w_vh):
100104
101105 ############### likelihood utils
102106
103- def neg_free_energy (self , v ):
104- """
105- given an array v similar to the dtm, computes the
106- log pdf under the replicated softmax
107- """
108- w_vh , w_v , w_h = self .W
109- T = self .hidden
110- D = v .sum (axis = 1 )
111- fren = np .dot (v , w_v )
112- for j in range (T ):
113- w_j = w_vh [:, j ]
114- a_j = w_h [j ]
115- fren += np .log (1 + np .exp (D * a_j + np .dot (v , w_j )))
116- return fren
117107
118- def neg_free_energy_single_doc (self , v ):
108+ def neg_free_energy (self , v ):
119109 """
120- given a one dimensional Bow vector v representing a single document,
121- computes the log pdf under the replicated softmax
110+ Given a BoW vector or document-term matrix v, computes the
111+ log pdf under the replicated softmax.
112+ Accepts both a 1D array (single document) and 2D array (batch).
122113 """
123114 w_vh , w_v , w_h = self .W
124115 T = self .hidden
125- D = v .sum ()
116+ D = v .sum (axis = - 1 ) # works for both 1D and 2D
126117 fren = np .dot (v , w_v )
127118 for j in range (T ):
128119 w_j = w_vh [:, j ]
129120 a_j = w_h [j ]
130- fren += np .log (1 + np .exp (D * a_j + np .dot (v , w_j )))
121+ #fren += np.log(1 + np.exp(D * a_j + np.dot(v, w_j)))
122+ arg = D * a_j + np .dot (v , w_j )
123+ fren += np .logaddexp (0 , arg ) # = log(1 + exp(arg)), numerically stable
131124 return fren
132125
133126 def marginal_pdf (self , v ):
134127 return np .exp (self .neg_free_energy (v ))
135128
136- def marginal_pdf_single_doc (self , v ):
137- return np .exp (self .neg_free_energy_single_doc (v ))
138-
139129 ############ octis output functions
140130
141131 def topic_words (self , topk , id2word = None ):
@@ -170,10 +160,10 @@ def _get_topics(self, topk):
170160 def _get_topic_word_matrix (self ):
171161 """
172162 Returns the topic representation of the words.
173- Uses min-max normalization by topic of the interaction weigths
163+ Uses min-max normalization by topic of the interaction weights
174164 matrix w_vh. The ranking of the words using this matrix
175165 is equivalent to the ranking obtained from the unnormalized
176- matrix of weigths w_vh.
166+ matrix of weights w_vh.
177167 """
178168 w_vh , w_v , w_h = self .W
179169 topic_word_matrix = w_vh .T
@@ -208,7 +198,7 @@ def set_structure_from_dtm(
208198 monitor_loglik = False ,
209199 logdtm = False ,
210200 ):
211- """function to initialize the weigths matrices
201+ """function to initialize the weights matrices
212202 given the dtm and the number of topics"""
213203 doval = val_dtm is not None
214204
@@ -222,7 +212,19 @@ def set_structure_from_dtm(
222212 self .val_dtm = val_dtm
223213
224214 D = self .dtm .sum (axis = 1 )
225- assert not np .any (D == 0 ), "all the documents should have positive length"
215+ if np .any (D == 0 ):
216+ raise ValueError (
217+ "All training documents must have positive length; "
218+ f"found { (D == 0 ).sum ()} empty document(s)."
219+ )
220+
221+ if doval :
222+ D_val = self .val_dtm .sum (axis = 1 )
223+ if np .any (D_val == 0 ):
224+ raise ValueError (
225+ "All validation documents must have positive length; "
226+ f"found { (D_val == 0 ).sum ()} empty document(s)."
227+ )
226228
227229 self .hidden = num_topics
228230 self .F = num_topics
0 commit comments