2323from pylops .optimization .eigs import power_iteration
2424from pylops .optimization .leastsquares import regularized_inversion
2525from pylops .utils import deps
26- from pylops .utils .backend import get_array_module , get_module_name , inplace_set
26+ from pylops .utils .backend import (
27+ get_array_module ,
28+ get_module_name ,
29+ get_real_dtype ,
30+ inplace_set ,
31+ )
2732from pylops .utils .typing import InputDimsLike , NDArray , SamplingLike
2833
2934spgl1_message = deps .spgl1_import ("the spgl1 solver" )
@@ -58,7 +63,7 @@ def _hardthreshold(x: NDArray, thresh: float) -> NDArray:
5863
5964 """
6065 x1 = x .copy ()
61- x1 [np .abs (x ) <= sqrt (2 * thresh )] = 0
66+ x1 [np .abs (x ) <= sqrt (2 * thresh )] = 0.0
6267 return x1
6368
6469
@@ -120,13 +125,17 @@ def _halfthreshold(x: NDArray, thresh: float) -> NDArray:
120125 Since version 1.17.0 does not produce ``np.nan`` on bad input.
121126
122127 """
123- arg = np .ones_like (x )
124- arg [x != 0 ] = (thresh / 8.0 ) * (np .abs (x [x != 0 ]) / 3.0 ) ** (- 1.5 )
125- arg = np .clip (arg , - 1 , 1 )
126- phi = 2.0 / 3.0 * np .arccos (arg )
127- x1 = 2.0 / 3.0 * x * (1 + np .cos (2.0 * np .pi / 3.0 - phi ))
128- # x1[np.abs(x) <= 1.5 * thresh ** (2. / 3.)] = 0
129- x1 [np .abs (x ) <= (54 ** (1.0 / 3.0 ) / 4.0 ) * thresh ** (2.0 / 3.0 )] = 0
128+ ncp = get_array_module (x )
129+ arg = ncp .ones_like (x )
130+ arg [x != 0 ] = (thresh / 8.0 ) * (ncp .abs (x [x != 0.0 ]) / 3.0 ) ** (- 1.5 )
131+ if ncp .iscomplexobj (arg ):
132+ arg .real = ncp .clip (arg .real , - 1.0 , 1.0 )
133+ arg .imag = ncp .clip (arg .imag , - 1.0 , 1.0 )
134+ else :
135+ arg = ncp .clip (arg , - 1.0 , 1.0 )
136+ phi = 2.0 / 3.0 * ncp .arccos (arg )
137+ x1 = 2.0 / 3.0 * x * (1.0 + ncp .cos (2.0 * np .pi / 3.0 - phi ))
138+ x1 [ncp .abs (x ) <= (54.0 ** (1.0 / 3.0 ) / 4.0 ) * thresh ** (2.0 / 3.0 )] = 0
130139 return x1
131140
132141
@@ -1609,7 +1618,7 @@ def setup(
16091618
16101619 # prepare decay (if not passed)
16111620 if perc is None and decay is None :
1612- self .decay = self .ncp .ones (niter , dtype = self .Op )
1621+ self .decay = self .ncp .ones (niter , dtype = get_real_dtype ( self .Op . dtype ) )
16131622
16141623 # step size
16151624 if alpha is not None :
@@ -1751,6 +1760,7 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17511760 self .SOpx_unthesh [:] = self .SOprmatvec (self .x_unthesh )
17521761 else :
17531762 SOpx_unthesh = self .SOprmatvec (x_unthesh )
1763+
17541764 # threshold current solution or current solution projected onto SOp.H space
17551765 if self .SOp is None :
17561766 x_unthesh_or_SOpx_unthesh = (
@@ -1767,6 +1777,7 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
17671777 )
17681778 else :
17691779 x = self .threshf (x_unthesh_or_SOpx_unthesh , 100 - self .perc )
1780+
17701781 # apply SOp to thresholded x
17711782 if self .SOp is not None :
17721783 x = self .SOpmatvec (x )
0 commit comments