Skip to content

Commit 56d397d

Browse files
Move dpt.top_k()
1 parent 62d19f1 commit 56d397d

File tree

2 files changed

+169
-4
lines changed

2 files changed

+169
-4
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
unique_inverse,
9191
unique_values,
9292
)
93-
from ._sorting import argsort, sort
93+
from ._sorting import argsort, sort, top_k
9494
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
9595

9696
__all__ = [
@@ -143,6 +143,7 @@
143143
"take",
144144
"take_along_axis",
145145
"tile",
146+
"top_k",
146147
"to_numpy",
147148
"tril",
148149
"triu",

dpctl_ext/tensor/_sorting.py

Lines changed: 167 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
# THE POSSIBILITY OF SUCH DAMAGE.
2727
# *****************************************************************************
2828

29-
# import operator
30-
# from typing import NamedTuple
29+
import operator
30+
from typing import NamedTuple
3131

3232
import dpctl.tensor as dpt
3333
import dpctl.utils as du
@@ -48,9 +48,10 @@
4848
_radix_sort_dtype_supported,
4949
_sort_ascending,
5050
_sort_descending,
51+
_topk,
5152
)
5253

53-
__all__ = ["sort", "argsort"]
54+
__all__ = ["sort", "argsort", "top_k"]
5455

5556

5657
def _get_mergesort_impl_fn(descending):
@@ -284,3 +285,166 @@ def argsort(x, axis=-1, descending=False, stable=True, kind=None):
284285
inv_perm = sorted(range(nd), key=lambda d: perm[d])
285286
res = dpt_ext.permute_dims(res, inv_perm)
286287
return res
288+
289+
290+
def _get_top_k_largest(mode):
291+
modes = {"largest": True, "smallest": False}
292+
try:
293+
return modes[mode]
294+
except KeyError:
295+
raise ValueError(
296+
f"`mode` must be `largest` or `smallest`. Got `{mode}`."
297+
)
298+
299+
300+
class TopKResult(NamedTuple):
301+
values: dpt.usm_ndarray
302+
indices: dpt.usm_ndarray
303+
304+
305+
def top_k(x, k, /, *, axis=None, mode="largest"):
306+
"""top_k(x, k, axis=None, mode="largest")
307+
308+
Returns the `k` largest or smallest values and their indices in the input
309+
array `x` along the specified axis `axis`.
310+
311+
Args:
312+
x (usm_ndarray):
313+
input array.
314+
k (int):
315+
number of elements to find. Must be a positive integer value.
316+
axis (Optional[int]):
317+
axis along which to search. If `None`, the search will be performed
318+
over the flattened array. Default: ``None``.
319+
mode (Literal["largest", "smallest"]):
320+
search mode. Must be one of the following modes:
321+
322+
- `"largest"`: return the `k` largest elements.
323+
- `"smallest"`: return the `k` smallest elements.
324+
325+
Default: `"largest"`.
326+
327+
Returns:
328+
tuple[usm_ndarray, usm_ndarray]
329+
a namedtuple `(values, indices)` whose
330+
331+
* first element `values` will be an array containing the `k`
332+
largest or smallest elements of `x`. The array has the same data
333+
type as `x`. If `axis` was `None`, `values` will be a
334+
one-dimensional array with shape `(k,)` and otherwise, `values`
335+
will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]`
336+
* second element `indices` will be an array containing indices of
337+
`x` that result in `values`. The array will have the same shape
338+
as `values` and will have the default array index data type.
339+
"""
340+
largest = _get_top_k_largest(mode)
341+
if not isinstance(x, dpt.usm_ndarray):
342+
raise TypeError(
343+
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
344+
)
345+
346+
k = operator.index(k)
347+
if k < 0:
348+
raise ValueError("`k` must be a positive integer value")
349+
350+
nd = x.ndim
351+
if axis is None:
352+
sz = x.size
353+
if nd == 0:
354+
if k > 1:
355+
raise ValueError(f"`k`={k} is out of bounds 1")
356+
return TopKResult(
357+
dpt_ext.copy(x, order="C"),
358+
dpt_ext.zeros_like(
359+
x, dtype=ti.default_device_index_type(x.sycl_queue)
360+
),
361+
)
362+
arr = x
363+
n_search_dims = None
364+
res_sh = k
365+
else:
366+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
367+
sz = x.shape[axis]
368+
a1 = axis + 1
369+
if a1 == nd:
370+
perm = list(range(nd))
371+
arr = x
372+
else:
373+
perm = [i for i in range(nd) if i != axis] + [
374+
axis,
375+
]
376+
arr = dpt_ext.permute_dims(x, perm)
377+
n_search_dims = 1
378+
res_sh = arr.shape[: nd - 1] + (k,)
379+
380+
if k > sz:
381+
raise ValueError(f"`k`={k} is out of bounds {sz}")
382+
383+
exec_q = x.sycl_queue
384+
_manager = du.SequentialOrderManager[exec_q]
385+
dep_evs = _manager.submitted_events
386+
387+
res_usm_type = arr.usm_type
388+
if arr.flags.c_contiguous:
389+
vals = dpt_ext.empty(
390+
res_sh,
391+
dtype=arr.dtype,
392+
usm_type=res_usm_type,
393+
order="C",
394+
sycl_queue=exec_q,
395+
)
396+
inds = dpt_ext.empty(
397+
res_sh,
398+
dtype=ti.default_device_index_type(exec_q),
399+
usm_type=res_usm_type,
400+
order="C",
401+
sycl_queue=exec_q,
402+
)
403+
ht_ev, impl_ev = _topk(
404+
src=arr,
405+
trailing_dims_to_search=n_search_dims,
406+
k=k,
407+
largest=largest,
408+
vals=vals,
409+
inds=inds,
410+
sycl_queue=exec_q,
411+
depends=dep_evs,
412+
)
413+
_manager.add_event_pair(ht_ev, impl_ev)
414+
else:
415+
tmp = dpt_ext.empty_like(arr, order="C")
416+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
417+
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
418+
)
419+
_manager.add_event_pair(ht_ev, copy_ev)
420+
vals = dpt_ext.empty(
421+
res_sh,
422+
dtype=arr.dtype,
423+
usm_type=res_usm_type,
424+
order="C",
425+
sycl_queue=exec_q,
426+
)
427+
inds = dpt_ext.empty(
428+
res_sh,
429+
dtype=ti.default_device_index_type(exec_q),
430+
usm_type=res_usm_type,
431+
order="C",
432+
sycl_queue=exec_q,
433+
)
434+
ht_ev, impl_ev = _topk(
435+
src=tmp,
436+
trailing_dims_to_search=n_search_dims,
437+
k=k,
438+
largest=largest,
439+
vals=vals,
440+
inds=inds,
441+
sycl_queue=exec_q,
442+
depends=[copy_ev],
443+
)
444+
_manager.add_event_pair(ht_ev, impl_ev)
445+
if axis is not None and a1 != nd:
446+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
447+
vals = dpt_ext.permute_dims(vals, inv_perm)
448+
inds = dpt_ext.permute_dims(inds, inv_perm)
449+
450+
return TopKResult(vals, inds)

0 commit comments

Comments
 (0)