Skip to content

Commit 7feb4ee

Browse files
Move nonzero() to dpctl_ext/tensor and reuse it in dpnp
1 parent afa5411 commit 7feb4ee

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from dpctl_ext.tensor._indexing_functions import (
4343
extract,
44+
nonzero,
4445
place,
4546
put,
4647
take,
@@ -57,6 +58,7 @@
5758
"extract",
5859
"from_numpy",
5960
"full",
61+
"nonzero",
6062
"place",
6163
"put",
6264
"reshape",

dpctl_ext/tensor/_copy_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,37 @@ def _extract_impl(ary, ary_mask, axis=0):
200200
return dst
201201

202202

203+
def _nonzero_impl(ary):
204+
if not isinstance(ary, dpt.usm_ndarray):
205+
raise TypeError(
206+
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
207+
)
208+
exec_q = ary.sycl_queue
209+
usm_type = ary.usm_type
210+
mask_nelems = ary.size
211+
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
212+
cumsum = dpt.empty(
213+
mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C"
214+
)
215+
_manager = dpctl.utils.SequentialOrderManager[exec_q]
216+
dep_evs = _manager.submitted_events
217+
mask_count = ti.mask_positions(
218+
ary, cumsum, sycl_queue=exec_q, depends=dep_evs
219+
)
220+
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
221+
indexes = dpt.empty(
222+
(ary.ndim, mask_count),
223+
dtype=indexes_dt,
224+
usm_type=usm_type,
225+
sycl_queue=exec_q,
226+
order="C",
227+
)
228+
hev, nz_ev = ti._nonzero(cumsum, indexes, ary.shape, exec_q)
229+
res = tuple(indexes[i, :] for i in range(ary.ndim))
230+
_manager.add_event_pair(hev, nz_ev)
231+
return res
232+
233+
203234
def from_numpy(np_ary, /, *, device=None, usm_type="device", sycl_queue=None):
204235
"""
205236
from_numpy(arg, device=None, usm_type="device", sycl_queue=None)

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from ._copy_utils import (
4141
_extract_impl,
42+
_nonzero_impl,
4243
)
4344
from ._numpy_helper import normalize_axis_index
4445

@@ -98,6 +99,33 @@ def extract(condition, arr):
9899
return _extract_impl(arr, condition)
99100

100101

102+
def nonzero(arr):
103+
"""nonzero(arr)
104+
105+
Return the indices of non-zero elements.
106+
107+
Returns a tuple of usm_ndarrays, one for each dimension
108+
of ``arr``, containing the indices of the non-zero elements
109+
in that dimension. The values of ``arr`` are always tested in
110+
row-major, C-style order.
111+
112+
Args:
113+
arr (usm_ndarray):
114+
Input array, which has non-zero array rank.
115+
116+
Returns:
117+
Tuple[usm_ndarray, ...]:
118+
Indices of non-zero array elements.
119+
"""
120+
if not isinstance(arr, dpt.usm_ndarray):
121+
raise TypeError(
122+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(arr)}"
123+
)
124+
if arr.ndim == 0:
125+
raise ValueError("Array of positive rank is expected")
126+
return _nonzero_impl(arr)
127+
128+
101129
def place(arr, mask, vals):
102130
"""place(arr, mask, vals)
103131

dpnp/dpnp_iface_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def extract(condition, a):
817817
usm_a = dpt_ext.reshape(usm_a, -1)
818818
usm_cond = dpt_ext.reshape(usm_cond, -1)
819819

820-
usm_res = dpt_ext.take(usm_a, dpt.nonzero(usm_cond)[0])
820+
usm_res = dpt_ext.take(usm_a, dpt_ext.nonzero(usm_cond)[0])
821821
else:
822822
if usm_cond.shape != usm_a.shape:
823823
usm_a = dpt_ext.reshape(usm_a, -1)
@@ -1546,7 +1546,7 @@ def nonzero(a):
15461546

15471547
usm_a = dpnp.get_usm_ndarray(a)
15481548
return tuple(
1549-
dpnp_array._create_from_usm_ndarray(y) for y in dpt.nonzero(usm_a)
1549+
dpnp_array._create_from_usm_ndarray(y) for y in dpt_ext.nonzero(usm_a)
15501550
)
15511551

15521552

0 commit comments

Comments
 (0)