Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions pylops/optimization/cls_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from pylops.optimization.eigs import power_iteration
from pylops.optimization.leastsquares import regularized_inversion
from pylops.utils import deps
from pylops.utils.backend import get_array_module, get_module_name, inplace_set
from pylops.utils.backend import (
get_array_module,
get_module_name,
get_real_dtype,
inplace_set,
)
from pylops.utils.typing import InputDimsLike, NDArray, SamplingLike

spgl1_message = deps.spgl1_import("the spgl1 solver")
Expand Down Expand Up @@ -58,7 +63,7 @@ def _hardthreshold(x: NDArray, thresh: float) -> NDArray:

"""
x1 = x.copy()
x1[np.abs(x) <= sqrt(2 * thresh)] = 0
x1[np.abs(x) <= sqrt(2 * thresh)] = 0.0
return x1


Expand Down Expand Up @@ -120,13 +125,17 @@ def _halfthreshold(x: NDArray, thresh: float) -> NDArray:
Since version 1.17.0 does not produce ``np.nan`` on bad input.

"""
arg = np.ones_like(x)
arg[x != 0] = (thresh / 8.0) * (np.abs(x[x != 0]) / 3.0) ** (-1.5)
arg = np.clip(arg, -1, 1)
phi = 2.0 / 3.0 * np.arccos(arg)
x1 = 2.0 / 3.0 * x * (1 + np.cos(2.0 * np.pi / 3.0 - phi))
# x1[np.abs(x) <= 1.5 * thresh ** (2. / 3.)] = 0
x1[np.abs(x) <= (54 ** (1.0 / 3.0) / 4.0) * thresh ** (2.0 / 3.0)] = 0
ncp = get_array_module(x)
arg = ncp.ones_like(x)
arg[x != 0] = (thresh / 8.0) * (ncp.abs(x[x != 0.0]) / 3.0) ** (-1.5)
if ncp.iscomplexobj(arg):
arg.real = ncp.clip(arg.real, -1.0, 1.0)
arg.imag = ncp.clip(arg.imag, -1.0, 1.0)
else:
arg = ncp.clip(arg, -1.0, 1.0)
phi = 2.0 / 3.0 * ncp.arccos(arg)
x1 = 2.0 / 3.0 * x * (1.0 + ncp.cos(2.0 * np.pi / 3.0 - phi))
x1[ncp.abs(x) <= (54.0 ** (1.0 / 3.0) / 4.0) * thresh ** (2.0 / 3.0)] = 0
return x1


Expand Down Expand Up @@ -1609,7 +1618,7 @@ def setup(

# prepare decay (if not passed)
if perc is None and decay is None:
self.decay = self.ncp.ones(niter, dtype=self.Op)
self.decay = self.ncp.ones(niter, dtype=get_real_dtype(self.Op.dtype))

# step size
if alpha is not None:
Expand Down Expand Up @@ -1751,6 +1760,7 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
self.SOpx_unthesh[:] = self.SOprmatvec(self.x_unthesh)
else:
SOpx_unthesh = self.SOprmatvec(x_unthesh)

# threshold current solution or current solution projected onto SOp.H space
if self.SOp is None:
x_unthesh_or_SOpx_unthesh = (
Expand All @@ -1767,6 +1777,7 @@ def step(self, x: NDArray, show: bool = False) -> Tuple[NDArray, float]:
)
else:
x = self.threshf(x_unthesh_or_SOpx_unthesh, 100 - self.perc)

# apply SOp to thresholded x
if self.SOp is not None:
x = self.SOpmatvec(x)
Expand Down
Loading