|
9 | 9 |
|
10 | 10 | import time |
11 | 11 | import warnings |
12 | | -import sklearn |
13 | | -import scipy as sp |
| 12 | + |
14 | 13 | import numpy as np |
15 | | -from tqdm import trange |
16 | | -from sklearn.base import clone |
| 14 | +import scipy as sp |
17 | 15 | import scipy.linalg as linalg |
18 | 16 | import scipy.optimize as optimize |
19 | | -from sklearn.utils.extmath import fast_logdet |
| 17 | +import sklearn |
| 18 | +from sklearn.base import clone |
20 | 19 | from sklearn.decomposition import FactorAnalysis |
21 | | -from sklearn.gaussian_process.kernels import Kernel |
22 | | -from sklearn.gaussian_process.kernels import RBF, WhiteKernel |
23 | | -from sklearn.gaussian_process.kernels import ConstantKernel |
24 | | -from concurrent.futures import ThreadPoolExecutor, as_completed |
| 20 | +from sklearn.gaussian_process.kernels import ( |
| 21 | + RBF, |
| 22 | + ConstantKernel, |
| 23 | + Kernel, |
| 24 | + WhiteKernel |
| 25 | +) |
| 26 | +from tqdm import trange |
25 | 27 |
|
26 | 28 | __all__ = [ |
27 | 29 | "GPFA" |
@@ -354,7 +356,7 @@ def fit(self, X, use_cut_trials=False): |
354 | 356 |
|
355 | 357 | 4. **EM iteration**: Steps 2 and 3 are repeated until either the change |
356 | 358 | in complete data likelihood drops below the set threshold, or if the |
357 | | - maximum number of interations is reached. |
| 359 | + maximum number of interactions is reached. |
358 | 360 |
|
359 | 361 | **Orthonormalization:** Finally, this function computes an |
360 | 362 | orthonormalization transform to the loading matrix |
@@ -839,9 +841,11 @@ def _infer_latents(self, X, get_ll=True): |
839 | 841 | for t in unique_Ts: |
840 | 842 | if t == unique_Ts[0]: |
841 | 843 | K_big_inv = linalg.inv(K_big[:t * self.z_dim, :t * self.z_dim]) |
| 844 | + K_big_inv = (K_big_inv + K_big_inv.T) / 2 |
842 | 845 | logdet_k_big = self._logdet(K_big[:t * self.z_dim, :t * self.z_dim]) |
843 | 846 | M = K_big_inv + C_rinv_c_big[:t * self.z_dim,:t * self.z_dim] |
844 | 847 | M_inv = linalg.inv(M) |
| 848 | + M_inv = (M_inv + M_inv.T) / 2 |
845 | 849 | logdet_M = self._logdet(M) |
846 | 850 | else: |
847 | 851 | # Here, we compute the inverse of K for the current t from its |
@@ -934,6 +938,7 @@ def _learn_gp_params(self, latent_seqs, precomp): |
934 | 938 | jac=True |
935 | 939 | ) |
936 | 940 | self.gp_kernel[i].theta = res_opt.x |
| 941 | + |
937 | 942 | for j in range(len(precomp['Tu'])): |
938 | 943 | precomp['Tu'][j]['PautoSUM'][i, :, :].fill(0) |
939 | 944 |
|
@@ -1108,18 +1113,22 @@ def _sym_block_inversion( |
1108 | 1113 | [MAinv, MCinv.T], |
1109 | 1114 | [MCinv, MDinv] |
1110 | 1115 | ]) |
| 1116 | + Minv = (Minv + Minv.T) / 2 # Ensure symmetry |
1111 | 1117 | # Check if MD is positive definite |
1112 | 1118 | try: |
1113 | 1119 | logdet_MD = self._logdet(MD) # Use Cholesky decomposition if possible |
1114 | 1120 | except np.linalg.LinAlgError: |
1115 | | - logdet_MD = fast_logdet(MD) # Fallback to fast_logdet for non-PD matrices |
| 1121 | + logdet_MD = np.linalg.slogdet(MD)[1] # Fallback to slogdet for non-PD matrices |
| 1122 | + warnings.warn("Cholesky decomposition failed for MD; using slogdet instead. " |
| 1123 | + "MD may not be positive definite.", |
| 1124 | + UserWarning) |
1116 | 1125 | logdet_M = -logdet_Ainv + logdet_MD |
1117 | 1126 | if X is not None: |
1118 | 1127 | if logdet_X is None: |
1119 | 1128 | logdet_X = self._logdet(X) |
1120 | 1129 | MDpAinvBXAinvB = MD + AinvB.T @ X @ AinvB |
1121 | 1130 | MAinv = X - X @ AinvB @ linalg.inv(MDpAinvBXAinvB) @ AinvB.T @ X |
1122 | | - logdet_MAinv = logdet_MD + logdet_X - fast_logdet(MDpAinvBXAinvB) |
| 1131 | + logdet_MAinv = logdet_MD + logdet_X - np.linalg.slogdet(MDpAinvBXAinvB)[1] |
1123 | 1132 | return Minv, logdet_M, MAinv, logdet_MAinv |
1124 | 1133 | return Minv, logdet_M |
1125 | 1134 |
|
@@ -1233,6 +1242,7 @@ def _grad_bet_theta(self, theta, gp_kernel_i, precomp, i): |
1233 | 1242 | T = precomp['Tu'][j]['T'] |
1234 | 1243 | if j == 0: |
1235 | 1244 | Kinv = linalg.inv(Kmax[:T, :T]) |
| 1245 | + Kinv = (Kinv + Kinv.T) / 2 |
1236 | 1246 | logdet_K = self._logdet(Kmax[:T, :T]) |
1237 | 1247 | else: |
1238 | 1248 | # Here, we compute the inverse of K for the current |
|
0 commit comments