1919from pylops .basicoperators import Diagonal , Identity , VStack
2020from pylops .optimization .basesolver import Solver , _units
2121from pylops .optimization .basic import cgls
22+ from pylops .optimization .callback import _callback_stop
2223from pylops .optimization .eigs import power_iteration
2324from pylops .optimization .leastsquares import regularized_inversion
2425from pylops .utils import deps
@@ -399,14 +400,14 @@ def setup(
399400 epsR : :obj:`float`, optional
400401 Damping to be applied to residuals for weighting term
401402 epsI : :obj:`float`, optional
402- Tikhonov damping (for ``kind="data"``) or L1 model damping
403- (for ``kind="datamodel"``)
403+ Tikhonov damping
404404 tolIRLS : :obj:`float`, optional
405405 Tolerance. Stop outer iterations if difference between inverted model
406406 at subsequent iterations is smaller than ``tolIRLS``
407407 warm : :obj:`bool`, optional
408- Warm start each inversion inner step with previous estimate (``True``) or not (``False``).
409- This only applies to ``kind="data"`` and ``kind="datamodel"``
408+ Warm start each inversion inner step with previous estimate (``True``)
409+ or not (``False``). This only applies to ``kind="data"`` and
410+ ``kind="datamodel"``
410411 kind : :obj:`str`, optional
411412 Kind of solver (``model``, ``data`` or ``datamodel``)
412413 preallocate : :obj:`bool`, optional
@@ -432,9 +433,6 @@ def setup(
432433 self .isjax = get_module_name (self .ncp ) == "jax"
433434 self ._setpreallocate (preallocate )
434435
435- # initiate outer iteration counter
436- self .iiter = 0
437-
438436 # choose step to use
439437 if self .kind == "data" :
440438 self ._step = self ._step_data
@@ -456,6 +454,13 @@ def setup(
456454 self .rw = self .ncp .empty_like (self .y )
457455 else :
458456 self .rw = self .ncp .empty (self .Op .shape [1 ], dtype = self .Op .dtype )
457+
458+ # create variables to track the residual norm and iterations
459+ self .cost = [
460+ float (np .linalg .norm (self .y )),
461+ ]
462+ self .iiter = 0
463+
459464 # print setup
460465 if show :
461466 self ._print_setup ()
@@ -619,6 +624,7 @@ def step(
619624 self .rnorm = self .ncp .linalg .norm (self .r )
620625
621626 self .iiter += 1
627+ self .cost .append (float (self .rnorm ))
622628 if show :
623629 self ._print_step (x )
624630 return x
@@ -687,6 +693,10 @@ def run(
687693 xold = x .copy ()
688694 x = self .step (x , engine , showstep , ** kwargs_solver )
689695 self .callback (x )
696+ # check if any callback has raised a stop flag
697+ stop = _callback_stop (self .callbacks )
698+ if stop :
699+ break
690700
691701 # adding initial guess
692702 if hasattr (self , "x0" ):
@@ -1134,7 +1144,7 @@ def step(
11341144 self .ncp .subtract (self .res , self .y , out = self .res )
11351145
11361146 self .iiter += 1
1137- self .cost .append (float (np .linalg .norm (self .res )))
1147+ self .cost .append (float (self . ncp .linalg .norm (self .res )))
11381148 if show :
11391149 self ._print_step (x )
11401150 return x , cols
@@ -1187,6 +1197,10 @@ def run(
11871197 )
11881198 x , cols = self .step (x , cols , engine , showstep )
11891199 self .callback (x , cols )
1200+ # check if any callback has raised a stop flag
1201+ stop = _callback_stop (self .callbacks )
1202+ if stop :
1203+ break
11901204 return x , cols
11911205
11921206 def finalize (
@@ -1824,6 +1838,10 @@ def run(
18241838 )
18251839 x , xupdate = self .step (x , showstep )
18261840 self .callback (x )
1841+ # check if any callback has raised a stop flag
1842+ stop = _callback_stop (self .callbacks )
1843+ if stop :
1844+ break
18271845 if xupdate <= self .tol :
18281846 logger .info ("Update smaller that tolerance for iteration %d" , self .iiter )
18291847 return x
@@ -2205,6 +2223,10 @@ def run(
22052223 )
22062224 x , z , xupdate = self .step (x , z , showstep )
22072225 self .callback (x )
2226+ # check if any callback has raised a stop flag
2227+ stop = _callback_stop (self .callbacks )
2228+ if stop :
2229+ break
22082230 if xupdate <= self .tol :
22092231 logger .warning (
22102232 "Update smaller that tolerance for " "iteration %d" , self .iiter
@@ -2943,7 +2965,10 @@ def run(
29432965 )
29442966 x = self .step (x , engine , showstep , show_inner , ** kwargs_lsqr )
29452967 self .callback (x )
2946-
2968+ # check if any callback has raised a stop flag
2969+ stop = _callback_stop (self .callbacks )
2970+ if stop :
2971+ break
29472972 return x
29482973
29492974 def finalize (self , show : bool = False ) -> NDArray :
0 commit comments