Skip to content

Commit 608463d

Browse files
committed
formatting last changes with ruff
1 parent ebb379f commit 608463d

3 files changed

Lines changed: 36 additions & 68 deletions

File tree

octis/models/RSM.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
9292
from octis.dataset.dataset import Dataset
9393
from octis.models.RSM import RSM
94-
94+
9595
dataset_20ng = Dataset()
9696
dataset_20ng.fetch_dataset("20NewsGroup")
9797
@@ -254,8 +254,6 @@ class RSM_model(Replicated_Softmax):
254254
def __init__(self):
255255
super().__init__()
256256

257-
258-
259257
############################## energy and probability
260258

261259
def neg_energy(self, v, h):
@@ -291,7 +289,6 @@ def hidden2visible(self, h):
291289

292290
##################################### leapfrog trainsition operators
293291

294-
295292
def gibbs_transition(self, v):
296293
D = v.sum(axis=1)
297294
hidden_probs = self.visible2hidden(v)
@@ -340,7 +337,6 @@ def MH_transition_vec(self, state, logpdf):
340337

341338
################################## gradient descent optimization
342339

343-
344340
def gradient_simple(self, v1, v2, h1, h2):
345341
w_vh, w_v, w_h = self.W
346342
lr = self.lr
@@ -675,7 +671,6 @@ def train_epoch(self):
675671

676672
self.t += 1
677673

678-
679674
def set_train_hyper(
680675
self,
681676
epochs=3,
@@ -833,7 +828,7 @@ def log_ppl_approx(self, dtm):
833828
vprob = self.hidden2visible(mfh)
834829
vprob = np.clip(vprob, 1e-12, None)
835830
sum_dtm = np.sum(dtm)
836-
assert sum_dtm > 0, 'the sum of the dtm s entries has to be positive'
831+
assert sum_dtm > 0, "the sum of the dtm s entries has to be positive"
837832
lpub = -np.nansum(np.log(vprob) * dtm) / sum_dtm
838833
return lpub
839834

octis/models/RS_class.py

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import numpy as np
22

33

4-
5-
class Replicated_Softmax():
4+
class Replicated_Softmax:
65
def __init__(self):
76
self.W = None
8-
97

108
###### to implement in the specific class
119

@@ -21,62 +19,59 @@ def set_train_hyper(self):
2119
def visible2hidden(self):
2220
raise NotImplementedError
2321

24-
2522
####### activations and sampling
2623

2724
def softmax(self, x):
28-
'''
25+
"""
2926
Softmax activation by row of the matrix x.
30-
The denominator log(sum(exp(x[i]))) leads to many inf, so the
27+
The denominator log(sum(exp(x[i]))) leads to many inf, so the
3128
LogSumExp approximation is used instead.
32-
'''
29+
"""
3330
maxs = np.max(x, axis=1, keepdims=True)
3431
lse = maxs + np.log(np.sum(np.exp(x - maxs), axis=1, keepdims=True))
3532
return np.exp(x - lse)
36-
33+
3734
def softmax_vec(self, array):
38-
'''simple softmax activation for an array vector (single document)'''
35+
"""simple softmax activation for an array vector (single document)"""
3936
exparr = np.exp(array)
4037
return exparr / exparr.sum()
4138

4239
def sigmoid(self, x):
43-
'''basic sigmoid activation'''
40+
"""basic sigmoid activation"""
4441
return 1 / (1 + np.exp(-x))
4542

46-
4743
def multinomial_sample(self, probs, N):
48-
'''
44+
"""
4945
wrapper of np.random.multinomial
5046
probs: vector of probabilities for words count
5147
N: number of words to sample
52-
'''
48+
"""
5349
return np.random.multinomial(N, probs, size=1)[0]
5450

5551
def unif_reject_sample(self, probs):
56-
'''
52+
"""
5753
function to sample topics (bernoulli distributed)
5854
given a vector of probabilities.
5955
It samples from a uniform distribution U(0,1)
6056
to get the thresholds for each topic.
61-
'''
57+
"""
6258
h_unif = np.random.rand(*probs.shape)
6359
h_sample = np.array(h_unif < probs, dtype=int)
6460
return h_sample
6561

6662
def deterministic_sample(self, probs):
67-
'''
63+
"""
6864
function to sample topics (bernoulli distributed)
6965
given a vector of probabilities.
7066
It uses the >0.5 rule to assign 1 to each topic.
71-
'''
67+
"""
7268
return (probs > 0.5).astype(int)
7369

74-
7570
################### gradient utils
7671

7772
def interaction_penalty(self, vel_vh, w_vh):
78-
'''
79-
function to adjust the gradient of the
73+
"""
74+
function to adjust the gradient of the
8075
topic-word interaction weights during a training iteration
8176
of a RS model by a penalty factor.
8277
The model shoud have the attributes:
@@ -87,7 +82,7 @@ def interaction_penalty(self, vel_vh, w_vh):
8782
This function also requires two numpy arrays as arguments:
8883
- the interaction weigths matrix w_vh, that connects topics to words
8984
- the respective gradients vel_vh (also a matrix)
90-
'''
85+
"""
9186
if self.penalty:
9287
if self.penL1: # L1 penalty
9388
if self.local_penalty:
@@ -103,14 +98,13 @@ def interaction_penalty(self, vel_vh, w_vh):
10398
vel_vh = vel_vh - penal
10499
return vel_vh
105100

106-
107101
############### likelihood utils
108102

109103
def neg_free_energy(self, v):
110-
'''
111-
given an array v similar to the dtm, computes the
104+
"""
105+
given an array v similar to the dtm, computes the
112106
log pdf under the replicated softmax
113-
'''
107+
"""
114108
w_vh, w_v, w_h = self.W
115109
T = self.hidden
116110
D = v.sum(axis=1)
@@ -121,12 +115,11 @@ def neg_free_energy(self, v):
121115
fren += np.log(1 + np.exp(D * a_j + np.dot(v, w_j)))
122116
return fren
123117

124-
125118
def neg_free_energy_single_doc(self, v):
126-
'''
119+
"""
127120
given a one dimensional Bow vector v representing a single document,
128121
computes the log pdf under the replicated softmax
129-
'''
122+
"""
130123
w_vh, w_v, w_h = self.W
131124
T = self.hidden
132125
D = v.sum()
@@ -137,22 +130,20 @@ def neg_free_energy_single_doc(self, v):
137130
fren += np.log(1 + np.exp(D * a_j + np.dot(v, w_j)))
138131
return fren
139132

140-
141133
def marginal_pdf(self, v):
142134
return np.exp(self.neg_free_energy(v))
143135

144136
def marginal_pdf_single_doc(self, v):
145137
return np.exp(self.neg_free_energy_single_doc(v))
146138

147-
148139
############ octis output functions
149140

150141
def topic_words(self, topk, id2word=None):
151-
'''
142+
"""
152143
Given a gensim dictionary id2word,
153144
returns the topk most important words for each topic
154145
inside a list of T lists, where T is the number of topics
155-
'''
146+
"""
156147
w_vh, w_v, w_h = self.W
157148
T = self.hidden
158149
if id2word is None:
@@ -168,22 +159,20 @@ def topic_words(self, topk, id2word=None):
168159
return toplist
169160

170161
def _get_topics(self, topk):
171-
'''
162+
"""
172163
Given a gensim dictionary id2word,
173164
returns the topk most important words for each topic
174165
inside a list of T lists, where T is the number of topics
175166
(this function is a wrapper of topic_words, used by octis class)
176-
'''
167+
"""
177168
return self.topic_words(topk, self.id2word)
178169

179-
180-
181170
def _get_topic_word_matrix(self):
182171
"""
183172
Returns the topic representation of the words.
184173
Uses min-max normalization by topic of the interaction weigths
185174
matrix w_vh. The ranking of the words using this matrix
186-
is equivalent to the ranking obtained from the unnormalized
175+
is equivalent to the ranking obtained from the unnormalized
187176
matrix of weigths w_vh.
188177
"""
189178
w_vh, w_v, w_h = self.W
@@ -196,16 +185,14 @@ def _get_topic_word_matrix(self):
196185
topic_word_matrix = np.array(normalized)
197186
return topic_word_matrix
198187

199-
200188
def _get_topic_doc(self, dtm):
201-
'''
202-
given a bidimensional array dtm like, returns
189+
"""
190+
given a bidimensional array dtm like, returns
203191
the probabilities of each topic for each document
204192
(as an array of probabilities).
205-
'''
193+
"""
206194
return self.visible2hidden(dtm).T
207195

208-
209196
####################### train utils
210197

211198
def set_structure_from_dtm(
@@ -221,9 +208,8 @@ def set_structure_from_dtm(
221208
monitor_loglik=False,
222209
logdtm=False,
223210
):
224-
225-
'''function to initialize the weigths matrices
226-
given the dtm and the number of topics'''
211+
"""function to initialize the weigths matrices
212+
given the dtm and the number of topics"""
227213
doval = val_dtm is not None
228214

229215
if logdtm:
@@ -235,9 +221,8 @@ def set_structure_from_dtm(
235221
if doval:
236222
self.val_dtm = val_dtm
237223

238-
239224
D = self.dtm.sum(axis=1)
240-
assert not np.any(D==0), 'all the documents should have positive length'
225+
assert not np.any(D == 0), "all the documents should have positive length"
241226

242227
self.hidden = num_topics
243228
self.F = num_topics
@@ -276,7 +261,3 @@ def set_structure_from_dtm(
276261
self.train_loglik = np.empty(epochs)
277262
if doval:
278263
self.val_loglik = np.empty(epochs)
279-
280-
281-
282-

octis/models/oRSM.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
100100
from octis.dataset.dataset import Dataset
101101
from octis.models.oRSM import oRSM
102-
102+
103103
dataset_20ng = Dataset()
104104
dataset_20ng.fetch_dataset("20NewsGroup")
105105
@@ -305,7 +305,6 @@ def v_to_mf_h1(self, v):
305305
def visible2hidden(self, v):
306306
return self.v_to_mf_h1(v)
307307

308-
309308
def visible_to_hiddens_gibbs(self, v):
310309
"""
311310
main function to compute the hidden states given visible states
@@ -330,13 +329,11 @@ def visible_to_hiddens_gibbs(self, v):
330329

331330
return mu1, mu2
332331

333-
334332
def sample_hidden(self, v):
335333
h1_probs = self.v_to_mf_h1(v)
336334
h1_sample = self.unif_reject_sample(h1_probs)
337335
return h1_sample
338336

339-
340337
##################################### leapfrog trainsition operators
341338

342339
def gibbs_transition(self, v):
@@ -367,7 +364,6 @@ def gibbs_transition_lowcost(self, v):
367364
visible_sample[i] = self.multinomial_sample(visible_probs[i], D[i])
368365
return visible_sample
369366

370-
371367
######################## gradient descent optimization
372368

373369
def gradient_simple(self, v1, v2, h11, h12, h21, h22):
@@ -607,9 +603,7 @@ def pretrain_kcd_step(self, ids):
607603

608604
h1 = self.v_to_mf_h1(v)
609605
D = v.sum(axis=1)
610-
h2 = (
611-
v * self.M / D.reshape(-1, 1)
612-
)
606+
h2 = v * self.M / D.reshape(-1, 1)
613607

614608
for k in range(self.tK):
615609
v_model = self.sample_visible(h1, D)
@@ -786,7 +780,6 @@ def train_epoch(self):
786780

787781
self.t += 1
788782

789-
790783
def set_train_hyper(
791784
self,
792785
epochs=3,
@@ -960,4 +953,3 @@ def ppl_approx(self, testmatrix):
960953
"""
961954
ppl = np.exp(self.log_ppl_approx(testmatrix))
962955
return ppl
963-

0 commit comments

Comments
 (0)