Skip to content

Commit dd62b6e

Browse files
Move _tensor_sorting_impl extension and use it for dpnp (#2793)
This PR completely moves `_tensor_sorting_impl` pybind11 extension into `dpctl_ext.tensor` and extends dpctl_ext.tensor Python API with the functions `searchsorted isin, unique_all, unique_counts, unique_inverse, unique_values, argsort, sort and top_k ` reusing them in dpnp
1 parent e96405c commit dd62b6e

37 files changed

+8362
-14
lines changed

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,23 @@ set(_accumulator_sources
6969
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
7070
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
7171
)
72+
set(_sorting_sources
73+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
74+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
75+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
76+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
77+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp
78+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
79+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
80+
)
7281
set(_tensor_accumulation_impl_sources
7382
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
7483
${_accumulator_sources}
7584
)
85+
set(_tensor_sorting_impl_sources
86+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
87+
${_sorting_sources}
88+
)
7689

7790
set(_static_lib_trgt simplify_iteration_space)
7891

@@ -101,6 +114,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_accumulation_i
101114
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
102115
list(APPEND _py_trgts ${python_module_name})
103116

117+
set(python_module_name _tensor_sorting_impl)
118+
pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources})
119+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources})
120+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
121+
list(APPEND _py_trgts ${python_module_name})
122+
104123
set(_clang_prefix "")
105124
if(WIN32)
106125
set(_clang_prefix "/clang:")
@@ -117,7 +136,7 @@ list(
117136
APPEND _no_fast_math_sources
118137
# ${_elementwise_sources}
119138
# ${_reduction_sources}
120-
# ${_sorting_sources}
139+
${_sorting_sources}
121140
# ${_linalg_sources}
122141
${_accumulator_sources}
123142
)

dpctl_ext/tensor/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,20 @@
8080
)
8181
from ._reshape import reshape
8282
from ._search_functions import where
83+
from ._searchsorted import searchsorted
84+
from ._set_functions import (
85+
isin,
86+
unique_all,
87+
unique_counts,
88+
unique_inverse,
89+
unique_values,
90+
)
91+
from ._sorting import argsort, sort, top_k
8392
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
8493

8594
__all__ = [
8695
"arange",
96+
"argsort",
8797
"asarray",
8898
"asnumpy",
8999
"astype",
@@ -108,6 +118,7 @@
108118
"full_like",
109119
"iinfo",
110120
"isdtype",
121+
"isin",
111122
"linspace",
112123
"meshgrid",
113124
"moveaxis",
@@ -122,15 +133,22 @@
122133
"reshape",
123134
"result_type",
124135
"roll",
136+
"searchsorted",
137+
"sort",
125138
"squeeze",
126139
"stack",
127140
"swapaxes",
128141
"take",
129142
"take_along_axis",
130143
"tile",
144+
"top_k",
131145
"to_numpy",
132146
"tril",
133147
"triu",
148+
"unique_all",
149+
"unique_counts",
150+
"unique_inverse",
151+
"unique_values",
134152
"unstack",
135153
"where",
136154
"zeros",

dpctl_ext/tensor/_searchsorted.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
30+
from typing import Literal, Union
31+
32+
import dpctl
33+
import dpctl.utils as du
34+
35+
# TODO: revert to `from ._usmarray import...`
36+
# when dpnp fully migrates dpctl/tensor
37+
from dpctl.tensor._usmarray import usm_ndarray
38+
39+
from ._copy_utils import _empty_like_orderK
40+
from ._ctors import empty
41+
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
42+
from ._tensor_impl import _take as ti_take
43+
from ._tensor_impl import (
44+
default_device_index_type as ti_default_device_index_type,
45+
)
46+
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
47+
from ._type_utils import isdtype, result_type
48+
49+
50+
def searchsorted(
51+
x1: usm_ndarray,
52+
x2: usm_ndarray,
53+
/,
54+
*,
55+
side: Literal["left", "right"] = "left",
56+
sorter: Union[usm_ndarray, None] = None,
57+
) -> usm_ndarray:
58+
"""searchsorted(x1, x2, side='left', sorter=None)
59+
60+
Finds the indices into `x1` such that, if the corresponding elements
61+
in `x2` were inserted before the indices, the order of `x1`, when sorted
62+
in ascending order, would be preserved.
63+
64+
Args:
65+
x1 (usm_ndarray):
66+
input array. Must be a one-dimensional array. If `sorter` is
67+
`None`, must be sorted in ascending order; otherwise, `sorter` must
68+
be an array of indices that sort `x1` in ascending order.
69+
x2 (usm_ndarray):
70+
array containing search values.
71+
side (Literal["left", "right]):
72+
argument controlling which index is returned if a value lands
73+
exactly on an edge. If `x2` is an array of rank `N` where
74+
`v = x2[n, m, ..., j]`, the element `ret[n, m, ..., j]` in the
75+
return array `ret` contains the position `i` such that
76+
if `side="left"`, it is the first index such that
77+
`x1[i-1] < v <= x1[i]`, `0` if `v <= x1[0]`, and `x1.size`
78+
if `v > x1[-1]`;
79+
and if `side="right"`, it is the first position `i` such that
80+
`x1[i-1] <= v < x1[i]`, `0` if `v < x1[0]`, and `x1.size`
81+
if `v >= x1[-1]`. Default: `"left"`.
82+
sorter (Optional[usm_ndarray]):
83+
array of indices that sort `x1` in ascending order. The array must
84+
have the same shape as `x1` and have an integral data type.
85+
Out of bound index values of `sorter` array are treated using
86+
`"wrap"` mode documented in :py:func:`dpctl.tensor.take`.
87+
Default: `None`.
88+
"""
89+
if not isinstance(x1, usm_ndarray):
90+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}")
91+
if not isinstance(x2, usm_ndarray):
92+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}")
93+
if sorter is not None and not isinstance(sorter, usm_ndarray):
94+
raise TypeError(
95+
f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}"
96+
)
97+
98+
if side not in ["left", "right"]:
99+
raise ValueError(
100+
"Unrecognized value of 'side' keyword argument. "
101+
"Expected either 'left' or 'right'"
102+
)
103+
104+
if sorter is None:
105+
q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue])
106+
else:
107+
q = du.get_execution_queue(
108+
[x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue]
109+
)
110+
if q is None:
111+
raise du.ExecutionPlacementError(
112+
"Execution placement can not be unambiguously "
113+
"inferred from input arguments."
114+
)
115+
116+
if x1.ndim != 1:
117+
raise ValueError("First argument array must be one-dimensional")
118+
119+
x1_dt = x1.dtype
120+
x2_dt = x2.dtype
121+
122+
_manager = du.SequentialOrderManager[q]
123+
dep_evs = _manager.submitted_events
124+
ev = dpctl.SyclEvent()
125+
if sorter is not None:
126+
if not isdtype(sorter.dtype, "integral"):
127+
raise ValueError(
128+
f"Sorter array must have integral data type, got {sorter.dtype}"
129+
)
130+
if x1.shape != sorter.shape:
131+
raise ValueError(
132+
"Sorter array must be one-dimension with the same "
133+
"shape as the first argument array"
134+
)
135+
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
136+
ind = (sorter,)
137+
axis = 0
138+
wrap_out_of_bound_indices_mode = 0
139+
ht_ev, ev = ti_take(
140+
x1,
141+
ind,
142+
res,
143+
axis,
144+
wrap_out_of_bound_indices_mode,
145+
sycl_queue=q,
146+
depends=dep_evs,
147+
)
148+
x1 = res
149+
_manager.add_event_pair(ht_ev, ev)
150+
151+
if x1_dt != x2_dt:
152+
dt = result_type(x1, x2)
153+
if x1_dt != dt:
154+
x1_buf = _empty_like_orderK(x1, dt)
155+
dep_evs = _manager.submitted_events
156+
ht_ev, ev = ti_copy(
157+
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
158+
)
159+
_manager.add_event_pair(ht_ev, ev)
160+
x1 = x1_buf
161+
if x2_dt != dt:
162+
x2_buf = _empty_like_orderK(x2, dt)
163+
dep_evs = _manager.submitted_events
164+
ht_ev, ev = ti_copy(
165+
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
166+
)
167+
_manager.add_event_pair(ht_ev, ev)
168+
x2 = x2_buf
169+
170+
dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
171+
index_dt = ti_default_device_index_type(q)
172+
173+
dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)
174+
175+
dep_evs = _manager.submitted_events
176+
if side == "left":
177+
ht_ev, s_ev = _searchsorted_left(
178+
hay=x1,
179+
needles=x2,
180+
positions=dst,
181+
sycl_queue=q,
182+
depends=dep_evs,
183+
)
184+
else:
185+
ht_ev, s_ev = _searchsorted_right(
186+
hay=x1, needles=x2, positions=dst, sycl_queue=q, depends=dep_evs
187+
)
188+
_manager.add_event_pair(ht_ev, s_ev)
189+
return dst

0 commit comments

Comments
 (0)