@@ -84,9 +84,11 @@ def __init__(
8484 debug (bool): Make additional tests at extra computational cost
8585 """
8686 solver_args = {} if solver_args is None else solver_args
87+
8788 preconditioner_args = {} if preconditioner_args is None else preconditioner_args
8889 preconditioner_args ['drop_tol' ] = preconditioner_args .get ('drop_tol' , 1e-3 )
8990 preconditioner_args ['fill_factor' ] = preconditioner_args .get ('fill_factor' , 100 )
91+
9092 self ._makeAttributeAndRegister (
9193 'max_cached_factorizations' ,
9294 'useGPU' ,
@@ -209,7 +211,7 @@ def setup_preconditioner(self, Dirichlet_recombination=True, left_preconditioner
209211
210212 self .Pl = self .spectral .sparse_lib .csc_matrix (R )
211213
212- if Dirichlet_recombination and type (self .axes [- 1 ]).__name__ in ['ChebychevHelper, Ultraspherical ' ]:
214+ if Dirichlet_recombination and type (self .axes [- 1 ]).__name__ in ['ChebychevHelper' , 'UltrasphericalHelper ' ]:
213215 _Pr = self .spectral .get_Dirichlet_recombination_matrix (axis = - 1 )
214216 else :
215217 _Pr = Id
@@ -234,18 +236,21 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
234236 if self .spectral_space :
235237 rhs_hat = rhs .copy ()
236238 if u0 is not None :
237- u0_hat = self . Pr . T @ u0 .copy ().flatten ()
239+ u0_hat = u0 .copy ().flatten ()
238240 else :
239241 u0_hat = None
240242 else :
241243 rhs_hat = self .spectral .transform (rhs )
242244 if u0 is not None :
243- u0_hat = self .Pr . T @ self . spectral .transform (u0 ).flatten ()
245+ u0_hat = self .spectral .transform (u0 ).flatten ()
244246 else :
245247 u0_hat = None
246248
247- if self .useGPU :
248- self .xp .cuda .Device ().synchronize ()
249+ # apply inverse right preconditioner to initial guess
250+ if u0_hat is not None and 'direct' not in self .solver_type :
251+ if not hasattr (self , '_Pr_inv' ):
252+ self ._PR_inv = self .linalg .splu (self .Pr .astype (complex )).solve
253+ u0_hat [...] = self ._PR_inv (u0_hat )
249254
250255 rhs_hat = (self .M @ rhs_hat .flatten ()).reshape (rhs_hat .shape )
251256 rhs_hat = self .spectral .put_BCs_in_rhs_hat (rhs_hat )
@@ -255,17 +260,13 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
255260 A = self .M + dt * self .L
256261 A = self .Pl @ self .spectral .put_BCs_in_matrix (A ) @ self .Pr
257262
258- # import numpy as np
259- # if A.shape[0] < 200:
260- # import matplotlib.pyplot as plt
263+ # if A.shape[0] < 200e20:
264+ # import matplotlib.pyplot as plt
261265
262- # # M = self.spectral.put_BCs_in_matrix(self.L.copy())
263- # M = A # self.L
264- # im = plt.imshow((M / abs(M)).real)
265- # # im = plt.imshow(np.log10(abs(A.toarray())).real)
266- # # im = plt.imshow(((A.toarray())).real)
267- # plt.colorbar(im)
268- # plt.show()
266+ # # M = self.spectral.put_BCs_in_matrix(self.L.copy())
267+ # M = A # self.L
268+ # im = plt.spy(M)
269+ # plt.show()
269270
270271 if 'ilu' in self .solver_type .lower ():
271272 if dt not in self .cached_factorizations .keys ():
@@ -330,9 +331,6 @@ def solve_system(self, rhs, dt, u0=None, *args, skip_itransform=False, **kwargs)
330331 sol_hat = self .spectral .u_init_forward
331332 sol_hat [...] = (self .Pr @ _sol_hat ).reshape (sol_hat .shape )
332333
333- if self .useGPU :
334- self .xp .cuda .Device ().synchronize ()
335-
336334 if self .spectral_space :
337335 return sol_hat
338336 else :
0 commit comments