Skip to content

Commit ebb379f

Browse files
committed
Shared superclass for RS and oRS to remove duplicated functions. Correct RS and oRS to handle errors of empty documents.
1 parent bbd95a5 commit ebb379f

4 files changed

Lines changed: 332 additions & 394 deletions

File tree

octis/models/RSM.py

Lines changed: 20 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from octis.models.model import AbstractModel
2+
from octis.models.RS_class import Replicated_Softmax
23
import numpy as np
34
from tqdm import tqdm
45
import gensim.corpora as corpora
@@ -83,6 +84,19 @@ def __init__(
8384
'rmsprop' for RMSProp optimizer,
8485
'adam' for Adam optimizer,
8586
'adagrad' for Adagrad optimizer
87+
88+
89+
Example usage
90+
--------------------
91+
92+
from octis.dataset.dataset import Dataset
93+
from octis.models.RSM import RSM
94+
95+
dataset_20ng = Dataset()
96+
dataset_20ng.fetch_dataset("20NewsGroup")
97+
98+
rsm = RSM(num_topics=20, epochs=500, btsz=20, lr=0.0001, cd_type='mfcd', train_optimizer='rmsprop')
99+
output_rsm = rsm.train(dataset_20ng)
86100
"""
87101
super().__init__()
88102
self.hyperparameters = dict()
@@ -236,21 +250,11 @@ def build_dtm(self, tokenized_corpus, id2word=None):
236250

237251
############################################################## RSM original class
238252

239-
class RSM_model(object):
253+
class RSM_model(Replicated_Softmax):
240254
def __init__(self):
241-
self.W = None
255+
super().__init__()
242256

243-
def softmax_vec(self, array):
244-
exparr = np.exp(array)
245-
return exparr / exparr.sum()
246257

247-
def softmax(self, array):
248-
maxs = np.max(array, axis=1, keepdims=True)
249-
lse = maxs + np.log(np.sum(np.exp(array - maxs), axis=1, keepdims=True))
250-
return np.exp(array - lse)
251-
252-
def sigmoid(self, x):
253-
return 1 / (1 + np.exp(-x))
254258

255259
############################## energy and probability
256260

@@ -263,31 +267,6 @@ def neg_energy(self, v, h):
263267
en = t1 + t2 + t3
264268
return en
265269

266-
def neg_free_energy(self, v): # it's equivalent to the log pdf
267-
w_vh, w_v, w_h = self.W
268-
T = self.hidden
269-
D = v.sum(axis=1)
270-
fren = np.dot(v, w_v)
271-
for j in range(T):
272-
w_j = w_vh[:, j]
273-
a_j = w_h[j]
274-
fren += np.log(1 + np.exp(D * a_j + np.dot(v, w_j)))
275-
return fren
276-
277-
def neg_free_energy_single_doc(self, v): # it's equivalent to the log pdf
278-
w_vh, w_v, w_h = self.W
279-
T = self.hidden
280-
D = v.sum()
281-
fren = np.dot(v, w_v)
282-
for j in range(T):
283-
w_j = w_vh[:, j]
284-
a_j = w_h[j]
285-
fren += np.log(1 + np.exp(D * a_j + np.dot(v, w_j)))
286-
return fren
287-
288-
def marginal_pdf(self, v):
289-
return np.exp(self.neg_free_energy(v))
290-
291270
def visible2hidden_vec(self, v):
292271
w_vh, w_v, w_h = self.W
293272
D = v.sum()
@@ -310,79 +289,8 @@ def hidden2visible(self, h):
310289
energy = np.tile(w_v, (h.shape[0], 1)).T + np.dot(w_vh, h.T)
311290
return self.softmax(energy.T)
312291

313-
def topic_words(self, topk, id2word=None):
314-
w_vh, w_v, w_h = self.W
315-
T = self.hidden
316-
if id2word is None:
317-
id2word = self.id2word
318-
words = np.array([k for k in id2word.token2id.keys()])
319-
320-
toplist = []
321-
for t in range(T):
322-
topw = w_vh[:, t]
323-
bestwords = words[np.argsort(topw)[::-1]][0:topk]
324-
toplist.append(bestwords)
325-
326-
return toplist
327-
328-
def _get_topic_word_matrix(self):
329-
"""
330-
Return the topic representation of the words
331-
"""
332-
w_vh, w_v, w_h = self.W
333-
topic_word_matrix = w_vh.T
334-
normalized = []
335-
for words_w in topic_word_matrix:
336-
minimum = min(words_w)
337-
words = words_w - minimum
338-
normalized.append([float(i) / sum(words) for i in words])
339-
topic_word_matrix = np.array(normalized)
340-
return topic_word_matrix
341-
342-
def _get_topic_word_matrix0(self):
343-
"""
344-
Return the topic representation of the words
345-
"""
346-
w_vh, w_v, w_h = self.W
347-
topic_word_matrix = np.empty(w_vh.T.shape)
348-
for t in range(w_vh.T.shape[0]):
349-
topic_word_matrix[t, :] = self.softmax_vec(w_vh.T[t, :] - w_v)
350-
return topic_word_matrix
351-
352-
def _get_topic_doc(self, dtm):
353-
return self.visible2hidden(dtm).T
354-
355-
def _get_topics(self, topk):
356-
w_vh, w_v, w_h = self.W
357-
T = self.hidden
358-
words = np.array([k for k in self.id2word.token2id.keys()])
359-
360-
toplist = []
361-
for t in range(T):
362-
topw = w_vh[:, t]
363-
bestwords = words[np.argsort(topw)[::-1]][0:topk]
364-
toplist.append(bestwords)
365-
366-
return toplist
367-
368-
# topics_output = []
369-
# for topic in result["topic-word-matrix"]:
370-
# top_k = np.argsort(topic)[-top_words:]
371-
# top_k_words = list(reversed([self.id2word[i] for i in top_k]))
372-
# topics_output.append(top_k_words)
373-
374292
##################################### leapfrog trainsition operators
375293

376-
def multinomial_sample(self, probs, N):
377-
return np.random.multinomial(N, probs, size=1)[0]
378-
379-
def unif_reject_sample(self, probs):
380-
h_unif = np.random.rand(*probs.shape)
381-
h_sample = np.array(h_unif < probs, dtype=int)
382-
return h_sample
383-
384-
def deterministic_sample(self, probs):
385-
return (probs > 0.5).astype(int)
386294

387295
def gibbs_transition(self, v):
388296
D = v.sum(axis=1)
@@ -432,21 +340,6 @@ def MH_transition_vec(self, state, logpdf):
432340

433341
################################## gradient descent optimization
434342

435-
def interaction_penalty(self, vel_vh, w_vh):
436-
if self.penalty:
437-
if self.penL1: # L1 penalty
438-
if self.local_penalty:
439-
penal = self.decay * np.sign(w_vh)
440-
else:
441-
penal = self.decay * np.sum(np.abs(w_vh)) * np.sign(w_vh)
442-
else: # L2 penalty
443-
if self.local_penalty:
444-
penal = self.decay * w_vh
445-
else:
446-
penal = self.decay * np.sum(w_vh)
447-
448-
vel_vh = vel_vh - penal
449-
return vel_vh
450343

451344
def gradient_simple(self, v1, v2, h1, h2):
452345
w_vh, w_v, w_h = self.W
@@ -782,66 +675,6 @@ def train_epoch(self):
782675

783676
self.t += 1
784677

785-
def set_structure_from_dtm(
786-
self,
787-
winit=None,
788-
dtm=None,
789-
val_dtm=None,
790-
softstart=0.001,
791-
num_topics=5,
792-
epochs=5,
793-
monitor_ppl=False,
794-
monitor_time=False,
795-
monitor_loglik=False,
796-
logdtm=False,
797-
):
798-
doval = val_dtm is not None
799-
800-
if logdtm:
801-
self.dtm = np.log(1 + dtm)
802-
if doval:
803-
self.val_dtm = np.log(1 + val_dtm)
804-
else:
805-
self.dtm = dtm
806-
if doval:
807-
self.val_dtm = np.log(1 + val_dtm)
808-
809-
self.hidden = num_topics
810-
N, dictsize = dtm.shape
811-
self.visible = dictsize
812-
813-
self.obs_ids = np.arange(N)
814-
815-
if winit is not None:
816-
###self.W = winit WRONG: You are referencing the same arrays across runs
817-
# defensive copy to avoid sharing mutable numpy arrays across runs
818-
try:
819-
self.W = tuple(np.array(arr, copy=True) for arr in winit)
820-
except Exception:
821-
# fallback: keep original if not iterable
822-
self.W = winit
823-
824-
if self.W is None:
825-
w_vh = softstart * np.random.randn(dictsize, num_topics)
826-
w_v = softstart * np.random.randn(dictsize)
827-
w_h = softstart * np.random.randn(num_topics)
828-
self.W = w_vh, w_v, w_h
829-
else:
830-
print("train already available weights")
831-
w_vh, w_v, w_h = self.W
832-
833-
if monitor_time:
834-
self.train_time = np.empty(epochs)
835-
836-
if monitor_ppl:
837-
self.train_ppl = np.empty(epochs)
838-
if doval:
839-
self.val_ppl = np.empty(epochs)
840-
841-
if monitor_loglik:
842-
self.train_loglik = np.empty(epochs)
843-
if doval:
844-
self.val_loglik = np.empty(epochs)
845678

846679
def set_train_hyper(
847680
self,
@@ -998,7 +831,10 @@ def log_ppl_approx(self, dtm):
998831
"""
999832
mfh = self.visible2hidden(dtm)
1000833
vprob = self.hidden2visible(mfh)
1001-
lpub = np.exp(-np.nansum(np.log(vprob) * dtm) / np.sum(dtm))
834+
vprob = np.clip(vprob, 1e-12, None)
835+
sum_dtm = np.sum(dtm)
836+
assert sum_dtm > 0, 'the sum of the dtm s entries has to be positive'
837+
lpub = -np.nansum(np.log(vprob) * dtm) / sum_dtm
1002838
return lpub
1003839

1004840
def ppl_approx(self, testmatrix):

0 commit comments

Comments
 (0)