@@ -88,11 +88,11 @@ def memory_usage(
8888 # Get number of bytes of dtype used in the solver
8989 nbytes = np .dtype (self .Op .dtype ).itemsize
9090
91- # Setup: x0, y, self.r, self.c
91+ # Setup: x0 - y, self.r, self.c
9292 memuse = (self .Op .shape [1 ] + 3 * self .Op .shape [0 ]) * nbytes
9393
94- # Step (additional variables to those in setup): Opc, c1
95- memuse += (2 * self .Op .shape [0 ]) * nbytes
94+ # Step (additional variables to those in setup): c1 - Opc
95+ memuse += (self . Op . shape [ 1 ] + self .Op .shape [0 ]) * nbytes
9696
9797 if show :
9898 print (f"CG predicted memory usage: { memuse / _units [unit ]:.2f} { unit } " )
@@ -125,7 +125,10 @@ def setup(
125125 preallocate : :obj:`bool`, optional
126126 .. versionadded:: 2.5.0
127127
128- Pre-allocate all variables used by the solver
128+ Pre-allocate all variables used by the solver. Note that if ``y``
129+ is a JAX array, this option is ignored and variables are not
130+ pre-allocated since JAX does not support in-place operations.
131+
129132 show : :obj:`bool`, optional
130133 Display setup log
131134
@@ -138,17 +141,18 @@ def setup(
138141 self .y = y
139142 self .niter = niter
140143 self .tol = tol
141- self . preallocate = preallocate
144+
142145 self .ncp = get_array_module (y )
143146 self .isjax = get_module_name (self .ncp ) == "jax"
147+ self ._setpreallocate (preallocate )
144148
145149 # initialize solver
146150 if x0 is None :
147151 x = self .ncp .zeros (self .Op .shape [1 ], dtype = self .y .dtype )
148152 self .r = self .y .copy ()
149153 else :
150154 x = x0
151- if self .isjax :
155+ if not self .preallocate :
152156 self .r = self .y - self .Op .matvec (x )
153157 else :
154158 self .r = self .ncp .empty_like (self .y )
@@ -186,23 +190,20 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
186190 Updated model vector
187191
188192 """
189- if not self .preallocate :
190- c1 = self .ncp .empty_like (self .c )
191-
192193 Opc = self .Op .matvec (self .c )
193194 cOpc = self .ncp .abs (self .c .dot (Opc .conj ()))
194195 a = self .kold / cOpc
195- if self .isjax :
196+ if not self .preallocate :
196197 x += a * self .c
197198 self .r -= a * Opc
198199 else :
199- self .ncp .multiply (self .c , a , out = self .c1 if self . preallocate else c1 )
200- self .ncp .add (x , self .c1 if self . preallocate else c1 , out = x )
200+ self .ncp .multiply (self .c , a , out = self .c1 )
201+ self .ncp .add (x , self .c1 , out = x )
201202 self .ncp .multiply (Opc , a , out = Opc )
202203 self .ncp .subtract (self .r , Opc , out = self .r )
203204 k = self .ncp .abs (self .r .dot (self .r .conj ()))
204205 b = k / self .kold
205- if self .isjax :
206+ if not self .preallocate :
206207 self .c = self .r + b * self .c
207208 else :
208209 self .ncp .multiply (self .c , b , out = self .c )
@@ -401,11 +402,11 @@ def memory_usage(
401402 # Get number of bytes of dtype used in the solver
402403 nbytes = np .dtype (self .Op .dtype ).itemsize
403404
404- # Setup: x0, y, self.s , self.c , self.q
405+ # Setup: x0, self.c - y , self.s , self.q
405406 memuse = (2 * self .Op .shape [1 ] + 3 * self .Op .shape [0 ]) * nbytes
406407
407408 # Step (additional variables to those in setup): r, x1, c1
408- memuse += (self . Op . shape [ 1 ] + 2 * self .Op .shape [0 ]) * nbytes
409+ memuse += (3 * self .Op .shape [1 ]) * nbytes
409410
410411 if show :
411412 print (f"CGLS predicted memory usage: { memuse / _units [unit ]:.2f} { unit } " )
@@ -455,9 +456,10 @@ def setup(
455456 self .damp = damp ** 2
456457 self .tol = tol
457458 self .niter = niter
458- self . preallocate = preallocate
459+
459460 self .ncp = get_array_module (y )
460461 self .isjax = get_module_name (self .ncp ) == "jax"
462+ self ._setpreallocate (preallocate )
461463
462464 # initialize solver
463465 if x0 is None :
@@ -466,7 +468,7 @@ def setup(
466468 self .c = self .Op .rmatvec (self .s )
467469 else :
468470 x = x0 .copy ()
469- if self .isjax :
471+ if not self .preallocate :
470472 self .s = self .y - self .Op .matvec (x )
471473 self .c = self .Op .rmatvec (self .s ) - damp * x
472474 else :
@@ -512,40 +514,35 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
512514 Display iteration log
513515
514516 """
515- if not self .preallocate :
516- c1 = self .ncp .empty_like (self .c )
517- x1 = self .ncp .empty_like (x )
518- r = self .ncp .empty_like (x )
519-
520517 a = self .kold / (
521518 self .q .dot (self .q .conj ()) + self .damp * self .c .dot (self .c .conj ())
522519 )
523- if self .isjax :
520+ if not self .preallocate :
524521 x += a * self .c
525522 self .s = self .s - a * self .q
526523 r = self .Op .rmatvec (self .s ) - self .damp * x
527524 else :
528- self .ncp .multiply (self .c , a , out = self .c1 if self . preallocate else c1 )
529- self .ncp .add (x , self .c1 if self . preallocate else c1 , out = x )
525+ self .ncp .multiply (self .c , a , out = self .c1 )
526+ self .ncp .add (x , self .c1 , out = x )
530527
531528 self .ncp .multiply (self .q , a , out = self .q )
532529 self .ncp .subtract (self .s , self .q , out = self .s )
533530
534- self .ncp .multiply (x , self .damp , out = self .x1 if self . preallocate else x1 )
531+ self .ncp .multiply (x , self .damp , out = self .x1 )
535532 self .ncp .subtract (
536533 self .Op .rmatvec (self .s ),
537- self .x1 if self . preallocate else x1 ,
538- out = self .r if self . preallocate else r ,
534+ self .x1 ,
535+ out = self .r ,
539536 )
540537 k = self .ncp .abs (
541538 self .r .dot (self .r .conj ()) if self .preallocate else r .dot (r .conj ())
542539 )
543540 b = k / self .kold
544- if self .isjax :
541+ if not self .preallocate :
545542 self .c = r + b * self .c
546543 else :
547544 self .ncp .multiply (self .c , b , out = self .c )
548- self .ncp .add (self .c , self .r if self . preallocate else r , out = self .c )
545+ self .ncp .add (self .c , self .r , out = self .c )
549546 self .q = self .Op .matvec (self .c )
550547 self .kold = k
551548 self .iiter += 1
@@ -818,7 +815,7 @@ def memory_usage(
818815 # Get number of bytes of dtype used in the solver
819816 nbytes = np .dtype (self .Op .dtype ).itemsize
820817
821- # Setup: x0, y, self.u , self.v , self.w , self.dk
818+ # Setup: x0, self.v , self.w , self.dk - y , self.u
822819 memuse = (4 * self .Op .shape [1 ] + 2 * self .Op .shape [0 ]) * nbytes
823820
824821 # Step (additional variables to those in setup): w1
@@ -890,9 +887,10 @@ def setup(
890887 self .conlim = conlim
891888 self .niter = niter
892889 self .calc_var = calc_var
893- self . preallocate = preallocate
890+
894891 self .ncp = get_array_module (y )
895892 self .isjax = get_module_name (self .ncp ) == "jax"
893+ self ._setpreallocate (preallocate )
896894
897895 m , n = self .Op .shape
898896
@@ -924,22 +922,22 @@ def setup(
924922 self .u = y .copy ()
925923 else :
926924 x = x0 .copy ()
927- if self .isjax :
925+ if self .preallocate :
928926 self .u = self .y - self .Op .matvec (x0 )
929927 else :
930928 self .u = self .ncp .empty_like (self .y )
931929 self .ncp .subtract (self .y , self .Op .matvec (x0 ), out = self .u )
932930 self .alfa = 0.0
933931 self .beta = self .ncp .linalg .norm (self .u )
934932 if self .beta > 0.0 :
935- if self .isjax :
933+ if self .preallocate :
936934 self .u = self .u / self .beta
937935 else :
938936 self .ncp .divide (self .u , self .beta , out = self .u )
939937 self .v = self .Op .rmatvec (self .u )
940938 self .alfa = self .ncp .linalg .norm (self .v )
941939 if self .alfa > 0 :
942- if self .isjax :
940+ if self .preallocate :
943941 self .v = self .v / self .alfa
944942 else :
945943 self .ncp .divide (self .v , self .alfa , out = self .v )
@@ -994,35 +992,32 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
994992 Estimated model of size :math:`[M \times 1]`
995993
996994 """
997- if not self .preallocate :
998- w1 = self .ncp .empty_like (self .w )
999-
1000995 # perform the next step of the bidiagonalization to obtain the
1001996 # next beta, u, alfa, v. These satisfy the relations
1002997 # beta*u = Op*v - alfa*u,
1003998 # alfa*v = Op'*u - beta*v'
1004- if self .isjax :
999+ if not self .preallocate :
10051000 self .u = self .Op .matvec (self .v ) - self .alfa * self .u
10061001 else :
10071002 self .ncp .multiply (self .u , self .alfa , out = self .u )
10081003 self .ncp .subtract (self .Op .matvec (self .v ), self .u , out = self .u )
10091004 self .beta = self .ncp .linalg .norm (self .u )
10101005 if self .beta > 0 :
1011- if self .isjax :
1006+ if not self .preallocate :
10121007 self .u = self .u / self .beta
10131008 else :
10141009 self .ncp .divide (self .u , self .beta , out = self .u )
10151010 self .anorm = np .linalg .norm (
10161011 [self .anorm , to_numpy (self .alfa ), to_numpy (self .beta ), self .damp ]
10171012 )
1018- if self .isjax :
1013+ if not self .preallocate :
10191014 self .v = self .Op .rmatvec (self .u ) - self .beta * self .v
10201015 else :
10211016 self .ncp .multiply (self .v , self .beta , out = self .v )
10221017 self .ncp .subtract (self .Op .rmatvec (self .u ), self .v , out = self .v )
10231018 self .alfa = self .ncp .linalg .norm (self .v )
10241019 if self .alfa > 0 :
1025- if self .isjax :
1020+ if not self .preallocate :
10261021 self .v = self .v / self .alfa
10271022 else :
10281023 self .ncp .divide (self .v , self .alfa , out = self .v )
@@ -1049,14 +1044,14 @@ def step(self, x: NDArray, show: bool = False) -> NDArray:
10491044 # update x and w.
10501045 self .t1 = self .phi / self .rho
10511046 self .t2 = - self .theta / self .rho
1052- if self .isjax :
1047+ if not self .preallocate :
10531048 self .dk = self .w / self .rho
10541049 x = x + self .t1 * self .w
10551050 self .w = self .v + self .t2 * self .w
10561051 else :
10571052 self .ncp .divide (self .w , self .rho , out = self .dk )
1058- self .ncp .multiply (self .w , self .t1 , out = self .w1 if self . preallocate else w1 )
1059- self .ncp .add (x , self .w1 if self . preallocate else w1 , out = x )
1053+ self .ncp .multiply (self .w , self .t1 , out = self .w1 )
1054+ self .ncp .add (x , self .w1 , out = x )
10601055 self .ncp .multiply (self .w , self .t2 , out = self .w )
10611056 self .ncp .add (self .v , self .w , out = self .w )
10621057 self .ddnorm = self .ddnorm + self .ncp .linalg .norm (self .dk ) ** 2
0 commit comments