Skip to content

Commit 3c3bb6a

Browse files
committed
corrections for rs class
1 parent 608463d commit 3c3bb6a

2 files changed

Lines changed: 122 additions & 109 deletions

File tree

octis/models/RS_class.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from scipy.special import expit
23

34

45
class Replicated_Softmax:
@@ -32,13 +33,16 @@ def softmax(self, x):
3233
return np.exp(x - lse)
3334

3435
def softmax_vec(self, array):
35-
"""simple softmax activation for an array vector (single document)"""
36-
exparr = np.exp(array)
36+
"""Numerically stable softmax for a single document / 1D vector (uses the same LSE trick as softmax)."""
37+
shifted = array - np.max(array) # subtract max for stability
38+
exparr = np.exp(shifted)
3739
return exparr / exparr.sum()
3840

41+
3942
def sigmoid(self, x):
40-
"""basic sigmoid activation"""
41-
return 1 / (1 + np.exp(-x))
43+
"""Numerically stable sigmoid activation"""
44+
return expit(x)
45+
4246

4347
def multinomial_sample(self, probs, N):
4448
"""
@@ -74,13 +78,13 @@ def interaction_penalty(self, vel_vh, w_vh):
7478
function to adjust the gradient of the
7579
topic-word interaction weights during a training iteration
7680
of a RS model by a penalty factor.
77-
The model shoud have the attributes:
81+
The model should have the attributes:
7882
- penalty : bool : if the penalization should be applied
7983
- penL1: bool : if the penalty is of type L1 or L2
8084
- local_penalty : bool : if the penalty should be local or global
8185
- decay : float : the penalty factor to use
8286
This function also requires two numpy arrays as arguments:
83-
- the interaction weigths matrix w_vh, that connects topics to words
87+
- the interaction weights matrix w_vh, that connects topics to words
8488
- the respective gradients vel_vh (also a matrix)
8589
"""
8690
if self.penalty:
@@ -100,42 +104,28 @@ def interaction_penalty(self, vel_vh, w_vh):
100104

101105
############### likelihood utils
102106

103-
def neg_free_energy(self, v):
104-
"""
105-
given an array v similar to the dtm, computes the
106-
log pdf under the replicated softmax
107-
"""
108-
w_vh, w_v, w_h = self.W
109-
T = self.hidden
110-
D = v.sum(axis=1)
111-
fren = np.dot(v, w_v)
112-
for j in range(T):
113-
w_j = w_vh[:, j]
114-
a_j = w_h[j]
115-
fren += np.log(1 + np.exp(D * a_j + np.dot(v, w_j)))
116-
return fren
117107

118-
def neg_free_energy_single_doc(self, v):
108+
def neg_free_energy(self, v):
119109
"""
120-
given a one dimensional Bow vector v representing a single document,
121-
computes the log pdf under the replicated softmax
110+
Given a BoW vector or document-term matrix v, computes the
111+
log pdf under the replicated softmax.
112+
Accepts both a 1D array (single document) and 2D array (batch).
122113
"""
123114
w_vh, w_v, w_h = self.W
124115
T = self.hidden
125-
D = v.sum()
116+
D = v.sum(axis=-1) # works for both 1D and 2D
126117
fren = np.dot(v, w_v)
127118
for j in range(T):
128119
w_j = w_vh[:, j]
129120
a_j = w_h[j]
130-
fren += np.log(1 + np.exp(D * a_j + np.dot(v, w_j)))
121+
#fren += np.log(1 + np.exp(D * a_j + np.dot(v, w_j)))
122+
arg = D * a_j + np.dot(v, w_j)
123+
fren += np.logaddexp(0, arg) # = log(1 + exp(arg)), numerically stable
131124
return fren
132125

133126
def marginal_pdf(self, v):
134127
return np.exp(self.neg_free_energy(v))
135128

136-
def marginal_pdf_single_doc(self, v):
137-
return np.exp(self.neg_free_energy_single_doc(v))
138-
139129
############ octis output functions
140130

141131
def topic_words(self, topk, id2word=None):
@@ -170,10 +160,10 @@ def _get_topics(self, topk):
170160
def _get_topic_word_matrix(self):
171161
"""
172162
Returns the topic representation of the words.
173-
Uses min-max normalization by topic of the interaction weigths
163+
Uses min-max normalization by topic of the interaction weights
174164
matrix w_vh. The ranking of the words using this matrix
175165
is equivalent to the ranking obtained from the unnormalized
176-
matrix of weigths w_vh.
166+
matrix of weights w_vh.
177167
"""
178168
w_vh, w_v, w_h = self.W
179169
topic_word_matrix = w_vh.T
@@ -208,7 +198,7 @@ def set_structure_from_dtm(
208198
monitor_loglik=False,
209199
logdtm=False,
210200
):
211-
"""function to initialize the weigths matrices
201+
"""function to initialize the weights matrices
212202
given the dtm and the number of topics"""
213203
doval = val_dtm is not None
214204

@@ -222,7 +212,19 @@ def set_structure_from_dtm(
222212
self.val_dtm = val_dtm
223213

224214
D = self.dtm.sum(axis=1)
225-
assert not np.any(D == 0), "all the documents should have positive length"
215+
if np.any(D == 0):
216+
raise ValueError(
217+
"All training documents must have positive length; "
218+
f"found {(D == 0).sum()} empty document(s)."
219+
)
220+
221+
if doval:
222+
D_val = self.val_dtm.sum(axis=1)
223+
if np.any(D_val == 0):
224+
raise ValueError(
225+
"All validation documents must have positive length; "
226+
f"found {(D_val == 0).sum()} empty document(s)."
227+
)
226228

227229
self.hidden = num_topics
228230
self.F = num_topics

0 commit comments

Comments
 (0)