Skip to content

Commit 8790132

Browse files
committed
path: fix _halfthreshold for complex valued cupy arrays
1 parent 269a077 commit 8790132

1 file changed

Lines changed: 11 additions & 7 deletions

File tree

pylops/optimization/cls_sparsity.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,17 @@ def _halfthreshold(x: NDArray, thresh: float) -> NDArray:
125125
Since version 1.17.0 does not produce ``np.nan`` on bad input.
126126
127127
"""
128-
arg = np.ones_like(x)
129-
arg[x != 0] = (thresh / 8.0) * (np.abs(x[x != 0]) / 3.0) ** (-1.5)
130-
arg = np.clip(arg, -1, 1)
131-
phi = 2.0 / 3.0 * np.arccos(arg)
132-
x1 = 2.0 / 3.0 * x * (1 + np.cos(2.0 * np.pi / 3.0 - phi))
133-
# x1[np.abs(x) <= 1.5 * thresh ** (2. / 3.)] = 0
134-
x1[np.abs(x) <= (54 ** (1.0 / 3.0) / 4.0) * thresh ** (2.0 / 3.0)] = 0
128+
ncp = get_array_module(x)
129+
arg = ncp.ones_like(x)
130+
arg[x != 0] = (thresh / 8.0) * (ncp.abs(x[x != 0.0]) / 3.0) ** (-1.5)
131+
if ncp.iscomplexobj(arg):
132+
arg.real = ncp.clip(arg.real, -1.0, 1.0)
133+
arg.imag = ncp.clip(arg.imag, -1.0, 1.0)
134+
else:
135+
arg = ncp.clip(arg, -1.0, 1.0)
136+
phi = 2.0 / 3.0 * ncp.arccos(arg)
137+
x1 = 2.0 / 3.0 * x * (1.0 + ncp.cos(2.0 * np.pi / 3.0 - phi))
138+
x1[ncp.abs(x) <= (54.0 ** (1.0 / 3.0) / 4.0) * thresh ** (2.0 / 3.0)] = 0
135139
return x1
136140

137141

0 commit comments

Comments
 (0)