Skip to content

Commit 9de86fc

Browse files
committed
fix bug
1 parent b46a846 commit 9de86fc

2 files changed

Lines changed: 9 additions & 14 deletions

File tree

scopen/MF.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,18 @@ def _compute_regularization(alpha, l1_ratio, regularization):
198198
return l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H
199199

200200

201-
def _loss(X, W, H, square_root=False):
201+
def _loss(X, W, H):
202202
"""Compute the Frobenius *squared* norm of X - dot(W, H).
203203
Parameters
204204
----------
205205
X : float or array-like, shape (n_samples, n_features)
206206
Numpy masked arrays or arrays containing NaN are accepted.
207207
W : float or dense array-like, shape (n_samples, n_components)
208208
H : float or dense array-like, shape (n_components, n_features)
209-
square_root : boolean, default False
210-
If True, return np.sqrt(2 * res)
211-
For beta == 2, it corresponds to the Frobenius norm.
212209
Returns
213210
-------
214211
res : float
215-
Beta divergence of X and np.dot(X, H)
212+
Frobenius norm of X and np.dot(X, H)
216213
"""
217214
# The method can be called with scalars
218215
if not sp.issparse(X):
@@ -232,10 +229,7 @@ def _loss(X, W, H, square_root=False):
232229

233230
assert not np.isnan(res)
234231
assert res >= 0
235-
if square_root:
236-
return np.sqrt(res * 2)
237-
else:
238-
return res
232+
return np.sqrt(res * 2)
239233

240234

241235
def _initialize_nmf(X, n_components, init=None, eps=1e-6,
@@ -522,7 +516,7 @@ def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0,
522516
f"violation: {_violation: .8f}")
523517

524518
elif verbose == 2:
525-
err = _loss(X, W, Ht.T, square_root=True)
519+
err = _loss(X, W, Ht.T)
526520
print(f"{datetime.now().strftime('%m/%d/%Y %H:%M:%S')}, iteration: {n_iter: }, "
527521
f"violation: {_violation: .8f}, error: {err: .8f}")
528522

@@ -688,7 +682,6 @@ def non_negative_factorization(X, W=None, H=None, n_components=None,
688682
W, H, n_iter = _fit_coordinate_descent(X, W, H, tol, max_iter,
689683
l1_reg_W, l1_reg_H,
690684
l2_reg_W, l2_reg_H,
691-
update_H=True,
692685
verbose=verbose,
693686
shuffle=shuffle,
694687
random_state=random_state)
@@ -876,7 +869,7 @@ def fit_transform(self, X, y=None, W=None, H=None):
876869
random_state=self.random_state, verbose=self.verbose,
877870
shuffle=self.shuffle)
878871

879-
self.reconstruction_err_ = _loss(X, W, H, square_root=True)
872+
self.reconstruction_err_ = _loss(X, W, H)
880873

881874
self.n_components_ = H.shape[0]
882875
self.components_ = H

scopen/Main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def estimate_rank(data, args):
129129
for n_components in n_components_list:
130130
arguments = (data, n_components, args.alpha,
131131
args.max_iter, args.verbose,
132-
args.random_state. args.init)
132+
args.random_state, args.init)
133133

134134
res = run_nmf(arguments)
135135
w_hat_dict[n_components] = res[0]
@@ -139,7 +139,9 @@ def estimate_rank(data, args):
139139
elif args.nc > 1:
140140
arguments_list = list()
141141
for n_components in n_components_list:
142-
arguments = (data, n_components, args.alpha, args.max_iter, args.verbose, args.random_state, args.init)
142+
arguments = (data, n_components, args.alpha,
143+
args.max_iter, args.verbose,
144+
args.random_state, args.init)
143145
arguments_list.append(arguments)
144146

145147
with Pool(processes=args.nc) as pool:

0 commit comments

Comments
 (0)