Skip to content

Commit 67ff814

Browse files
authored
Merge pull request #692 from mrava87/patch-solverscupy
Patch: fix ISTA/FISTA with complex numbers
2 parents 0ae0af3 + 7c50d07 commit 67ff814

2 files changed

Lines changed: 145 additions & 97 deletions

File tree

pylops/optimization/cls_sparsity.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
from pylops.optimization.eigs import power_iteration
2424
from pylops.optimization.leastsquares import regularized_inversion
2525
from pylops.utils import deps
26-
from pylops.utils.backend import get_array_module, get_module_name, inplace_set
26+
from pylops.utils.backend import (
27+
get_array_module,
28+
get_module_name,
29+
get_real_dtype,
30+
inplace_set,
31+
)
2732
from pylops.utils.typing import InputDimsLike, NDArray, SamplingLike
2833

2934
spgl1_message = deps.spgl1_import("the spgl1 solver")
@@ -58,7 +63,7 @@ def _hardthreshold(x: NDArray, thresh: float) -> NDArray:
5863
5964
"""
6065
x1 = x.copy()
61-
x1[np.abs(x) <= sqrt(2 * thresh)] = 0
66+
x1[np.abs(x) <= sqrt(2 * thresh)] = 0.0
6267
return x1
6368

6469

@@ -120,13 +125,17 @@ def _halfthreshold(x: NDArray, thresh: float) -> NDArray:
120125
Since version 1.17.0 does not produce ``np.nan`` on bad input.
121126
122127
"""
123-
arg = np.ones_like(x)
124-
arg[x != 0] = (thresh / 8.0) * (np.abs(x[x != 0]) / 3.0) ** (-1.5)
125-
arg = np.clip(arg, -1, 1)
126-
phi = 2.0 / 3.0 * np.arccos(arg)
127-
x1 = 2.0 / 3.0 * x * (1 + np.cos(2.0 * np.pi / 3.0 - phi))
128-
# x1[np.abs(x) <= 1.5 * thresh ** (2. / 3.)] = 0
129-
x1[np.abs(x) <= (54 ** (1.0 / 3.0) / 4.0) * thresh ** (2.0 / 3.0)] = 0
128+
ncp = get_array_module(x)
129+
arg = ncp.ones_like(x)
130+
arg[x != 0] = (thresh / 8.0) * (ncp.abs(x[x != 0.0]) / 3.0) ** (-1.5)
131+
if ncp.iscomplexobj(arg):
132+
arg.real = ncp.clip(arg.real, -1.0, 1.0)
133+
arg.imag = ncp.clip(arg.imag, -1.0, 1.0)
134+
else:
135+
arg = ncp.clip(arg, -1.0, 1.0)
136+
phi = 2.0 / 3.0 * ncp.arccos(arg)
137+
x1 = 2.0 / 3.0 * x * (1.0 + ncp.cos(2.0 * np.pi / 3.0 - phi))
138+
x1[ncp.abs(x) <= (54.0 ** (1.0 / 3.0) / 4.0) * thresh ** (2.0 / 3.0)] = 0
130139
return x1
131140

132141

@@ -1609,7 +1618,7 @@ def setup(
16091618

16101619
# prepare decay (if not passed)
16111620
if perc is None and decay is None:
1612-
self.decay = self.ncp.ones(niter, dtype=self.Op)
1621+
self.decay = self.ncp.ones(niter, dtype=get_real_dtype(self.Op.dtype))
16131622

16141623
# step size
16151624
if alpha is not None:
@@ -1751,6 +1760,7 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17511760
self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh)
17521761
else:
17531762
SOpx_unthesh = self.SOprmatvec(x_unthesh)
1763+
17541764
# threshold current solution or current solution projected onto SOp.H space
17551765
if self.SOp is None:
17561766
x_unthesh_or_SOpx_unthesh = (
@@ -1767,6 +1777,7 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17671777
)
17681778
else:
17691779
x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)
1780+
17701781
# apply SOp to thresholded x
17711782
if self.SOp is not None:
17721783
x = self.SOpmatvec(x)

0 commit comments

Comments
 (0)