Skip to content

Commit 4c4842f

Browse files
committed
Fix typos and ensure symmetry in GPFA calculations; improve logdet fallback method.
1 parent 275732a commit 4c4842f

2 files changed

Lines changed: 23 additions & 13 deletions

File tree

gpfa/gpfa.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,21 @@
99

1010
import time
1111
import warnings
12-
import sklearn
13-
import scipy as sp
12+
1413
import numpy as np
15-
from tqdm import trange
16-
from sklearn.base import clone
14+
import scipy as sp
1715
import scipy.linalg as linalg
1816
import scipy.optimize as optimize
19-
from sklearn.utils.extmath import fast_logdet
17+
import sklearn
18+
from sklearn.base import clone
2019
from sklearn.decomposition import FactorAnalysis
21-
from sklearn.gaussian_process.kernels import Kernel
22-
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
23-
from sklearn.gaussian_process.kernels import ConstantKernel
24-
from concurrent.futures import ThreadPoolExecutor, as_completed
20+
from sklearn.gaussian_process.kernels import (
21+
RBF,
22+
ConstantKernel,
23+
Kernel,
24+
WhiteKernel
25+
)
26+
from tqdm import trange
2527

2628
__all__ = [
2729
"GPFA"
@@ -354,7 +356,7 @@ def fit(self, X, use_cut_trials=False):
354356
355357
4. **EM iteration**: Steps 2 and 3 are repeated until either the change
356358
in complete data likelihood drops below the set threshold, or if the
357-
maximum number of interations is reached.
359+
maximum number of interactions is reached.
358360
359361
**Orthonormalization:** Finally, this function computes an
360362
orthonormalization transform to the loading matrix
@@ -839,9 +841,11 @@ def _infer_latents(self, X, get_ll=True):
839841
for t in unique_Ts:
840842
if t == unique_Ts[0]:
841843
K_big_inv = linalg.inv(K_big[:t * self.z_dim, :t * self.z_dim])
844+
K_big_inv = (K_big_inv + K_big_inv.T) / 2
842845
logdet_k_big = self._logdet(K_big[:t * self.z_dim, :t * self.z_dim])
843846
M = K_big_inv + C_rinv_c_big[:t * self.z_dim,:t * self.z_dim]
844847
M_inv = linalg.inv(M)
848+
M_inv = (M_inv + M_inv.T) / 2
845849
logdet_M = self._logdet(M)
846850
else:
847851
# Here, we compute the inverse of K for the current t from its
@@ -934,6 +938,7 @@ def _learn_gp_params(self, latent_seqs, precomp):
934938
jac=True
935939
)
936940
self.gp_kernel[i].theta = res_opt.x
941+
937942
for j in range(len(precomp['Tu'])):
938943
precomp['Tu'][j]['PautoSUM'][i, :, :].fill(0)
939944

@@ -1108,18 +1113,22 @@ def _sym_block_inversion(
11081113
[MAinv, MCinv.T],
11091114
[MCinv, MDinv]
11101115
])
1116+
Minv = (Minv + Minv.T) / 2 # Ensure symmetry
11111117
# Check if MD is positive definite
11121118
try:
11131119
logdet_MD = self._logdet(MD) # Use Cholesky decomposition if possible
11141120
except np.linalg.LinAlgError:
1115-
logdet_MD = fast_logdet(MD) # Fallback to fast_logdet for non-PD matrices
1121+
logdet_MD = np.linalg.slogdet(MD)[1] # Fallback to slogdet for non-PD matrices
1122+
warnings.warn("Cholesky decomposition failed for MD; using slogdet instead. "
1123+
"MD may not be positive definite.",
1124+
UserWarning)
11161125
logdet_M = -logdet_Ainv + logdet_MD
11171126
if X is not None:
11181127
if logdet_X is None:
11191128
logdet_X = self._logdet(X)
11201129
MDpAinvBXAinvB = MD + AinvB.T @ X @ AinvB
11211130
MAinv = X - X @ AinvB @ linalg.inv(MDpAinvBXAinvB) @ AinvB.T @ X
1122-
logdet_MAinv = logdet_MD + logdet_X - fast_logdet(MDpAinvBXAinvB)
1131+
logdet_MAinv = logdet_MD + logdet_X - np.linalg.slogdet(MDpAinvBXAinvB)[1]
11231132
return Minv, logdet_M, MAinv, logdet_MAinv
11241133
return Minv, logdet_M
11251134

@@ -1233,6 +1242,7 @@ def _grad_bet_theta(self, theta, gp_kernel_i, precomp, i):
12331242
T = precomp['Tu'][j]['T']
12341243
if j == 0:
12351244
Kinv = linalg.inv(Kmax[:T, :T])
1245+
Kinv = (Kinv + Kinv.T) / 2
12361246
logdet_K = self._logdet(Kmax[:T, :T])
12371247
else:
12381248
# Here, we compute the inverse of K for the current

gpfa/preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class EventTimesToCounts(sklearn.base.TransformerMixin):
1717
"""
18-
Bins sequence of event times into event counts whithin evenly spaced
18+
Bins sequence of event times into event counts within evenly spaced
1919
time bins.
2020
2121
This class supports binning sequences of event times (e.g., spike trains)

0 commit comments

Comments
 (0)