Skip to content

Commit 63edabc

Browse files
committed
feat: added finalize and callback to basesolver
1 parent 21949ba commit 63edabc

2 files changed

Lines changed: 439 additions & 84 deletions

File tree

pyproximal/optimization/basesolver.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pylops.optimization.basesolver import Solver as pSolver
1010
from pylops.optimization.callback import Callbacks
11+
from pylops.utils.typing import NDArray
1112

1213
if TYPE_CHECKING:
1314
from pyproximal.ProxOperator import ProxOperator
@@ -134,3 +135,61 @@ def solve(
134135
135136
"""
136137
pass
138+
139+
def finalize(self, nbar: int = 60, show: bool = False) -> None:
140+
r"""Finalize solver
141+
142+
Parameters
143+
----------
144+
nbar : :obj:`int`, optional
145+
Number of ``-`` in the bar dividing iterations
146+
from finalize messages in the print message of
147+
the solver
148+
show : :obj:`bool`, optional
149+
Display finalize log
150+
151+
"""
152+
self.tend = time.time()
153+
self.telapsed = self.tend - self.tstart
154+
155+
if show:
156+
self._print_finalize(nbar=nbar)
157+
158+
def callback( # noqa: B027
159+
self,
160+
x: NDArray,
161+
z: NDArray | None = None,
162+
*args,
163+
**kwargs,
164+
) -> None:
165+
"""Callback routine
166+
167+
This routine must be passed by the user. Its function signature must contain
168+
either a single input that contains the current solution or two inputs
169+
that contain the current solutions for methods that apply splitting
170+
(when using the `solve` method it will be automatically invoked after
171+
each step of the solve)
172+
173+
Parameters
174+
----------
175+
x : :obj:`numpy.ndarray`
176+
Current solution
177+
z : :obj:`numpy.ndarray`
178+
Current additional solution
179+
180+
Examples
181+
--------
182+
>>> import numpy as np
183+
>>> from pyproximal.optimization.cls_primal import ADMM
184+
>>> def callback(x, z):
185+
... print(f"Running callback, current solutions {x} - {z}")
186+
...
187+
>>> admmsolve.callback = callback
188+
189+
>>> x = np.ones(2)
190+
>>> z = np.zeros(2)
191+
>>> admmsolve.callback(x, z)
192+
Running callback, current solutions [1. 1.] - [0. 0.]
193+
194+
"""
195+
pass

0 commit comments

Comments
 (0)