|
| 1 | +# ***************************************************************************** |
| 2 | +# Copyright (c) 2026, Intel Corporation |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# Redistribution and use in source and binary forms, with or without |
| 6 | +# modification, are permitted provided that the following conditions are met: |
| 7 | +# - Redistributions of source code must retain the above copyright notice, |
| 8 | +# this list of conditions and the following disclaimer. |
| 9 | +# - Redistributions in binary form must reproduce the above copyright notice, |
| 10 | +# this list of conditions and the following disclaimer in the documentation |
| 11 | +# and/or other materials provided with the distribution. |
| 12 | +# - Neither the name of the copyright holder nor the names of its contributors |
| 13 | +# may be used to endorse or promote products derived from this software |
| 14 | +# without specific prior written permission. |
| 15 | +# |
| 16 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 17 | +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 18 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| 19 | +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE |
| 20 | +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| 21 | +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| 22 | +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| 23 | +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| 24 | +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| 25 | +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF |
| 26 | +# THE POSSIBILITY OF SUCH DAMAGE. |
| 27 | +# ***************************************************************************** |
| 28 | + |
| 29 | + |
| 30 | +from typing import Literal, Union |
| 31 | + |
| 32 | +import dpctl |
| 33 | +import dpctl.utils as du |
| 34 | + |
| 35 | +# TODO: revert to `from ._usmarray import...` |
| 36 | +# when dpnp fully migrates dpctl/tensor |
| 37 | +from dpctl.tensor._usmarray import usm_ndarray |
| 38 | + |
| 39 | +from ._copy_utils import _empty_like_orderK |
| 40 | +from ._ctors import empty |
| 41 | +from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy |
| 42 | +from ._tensor_impl import _take as ti_take |
| 43 | +from ._tensor_impl import ( |
| 44 | + default_device_index_type as ti_default_device_index_type, |
| 45 | +) |
| 46 | +from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right |
| 47 | +from ._type_utils import isdtype, result_type |
| 48 | + |
| 49 | + |
| 50 | +def searchsorted( |
| 51 | + x1: usm_ndarray, |
| 52 | + x2: usm_ndarray, |
| 53 | + /, |
| 54 | + *, |
| 55 | + side: Literal["left", "right"] = "left", |
| 56 | + sorter: Union[usm_ndarray, None] = None, |
| 57 | +) -> usm_ndarray: |
| 58 | + """searchsorted(x1, x2, side='left', sorter=None) |
| 59 | +
|
| 60 | + Finds the indices into `x1` such that, if the corresponding elements |
| 61 | + in `x2` were inserted before the indices, the order of `x1`, when sorted |
| 62 | + in ascending order, would be preserved. |
| 63 | +
|
| 64 | + Args: |
| 65 | + x1 (usm_ndarray): |
| 66 | + input array. Must be a one-dimensional array. If `sorter` is |
| 67 | + `None`, must be sorted in ascending order; otherwise, `sorter` must |
| 68 | + be an array of indices that sort `x1` in ascending order. |
| 69 | + x2 (usm_ndarray): |
| 70 | + array containing search values. |
| 71 | + side (Literal["left", "right]): |
| 72 | + argument controlling which index is returned if a value lands |
| 73 | + exactly on an edge. If `x2` is an array of rank `N` where |
| 74 | + `v = x2[n, m, ..., j]`, the element `ret[n, m, ..., j]` in the |
| 75 | + return array `ret` contains the position `i` such that |
| 76 | + if `side="left"`, it is the first index such that |
| 77 | + `x1[i-1] < v <= x1[i]`, `0` if `v <= x1[0]`, and `x1.size` |
| 78 | + if `v > x1[-1]`; |
| 79 | + and if `side="right"`, it is the first position `i` such that |
| 80 | + `x1[i-1] <= v < x1[i]`, `0` if `v < x1[0]`, and `x1.size` |
| 81 | + if `v >= x1[-1]`. Default: `"left"`. |
| 82 | + sorter (Optional[usm_ndarray]): |
| 83 | + array of indices that sort `x1` in ascending order. The array must |
| 84 | + have the same shape as `x1` and have an integral data type. |
| 85 | + Out of bound index values of `sorter` array are treated using |
| 86 | + `"wrap"` mode documented in :py:func:`dpctl.tensor.take`. |
| 87 | + Default: `None`. |
| 88 | + """ |
| 89 | + if not isinstance(x1, usm_ndarray): |
| 90 | + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") |
| 91 | + if not isinstance(x2, usm_ndarray): |
| 92 | + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") |
| 93 | + if sorter is not None and not isinstance(sorter, usm_ndarray): |
| 94 | + raise TypeError( |
| 95 | + f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}" |
| 96 | + ) |
| 97 | + |
| 98 | + if side not in ["left", "right"]: |
| 99 | + raise ValueError( |
| 100 | + "Unrecognized value of 'side' keyword argument. " |
| 101 | + "Expected either 'left' or 'right'" |
| 102 | + ) |
| 103 | + |
| 104 | + if sorter is None: |
| 105 | + q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue]) |
| 106 | + else: |
| 107 | + q = du.get_execution_queue( |
| 108 | + [x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue] |
| 109 | + ) |
| 110 | + if q is None: |
| 111 | + raise du.ExecutionPlacementError( |
| 112 | + "Execution placement can not be unambiguously " |
| 113 | + "inferred from input arguments." |
| 114 | + ) |
| 115 | + |
| 116 | + if x1.ndim != 1: |
| 117 | + raise ValueError("First argument array must be one-dimensional") |
| 118 | + |
| 119 | + x1_dt = x1.dtype |
| 120 | + x2_dt = x2.dtype |
| 121 | + |
| 122 | + _manager = du.SequentialOrderManager[q] |
| 123 | + dep_evs = _manager.submitted_events |
| 124 | + ev = dpctl.SyclEvent() |
| 125 | + if sorter is not None: |
| 126 | + if not isdtype(sorter.dtype, "integral"): |
| 127 | + raise ValueError( |
| 128 | + f"Sorter array must have integral data type, got {sorter.dtype}" |
| 129 | + ) |
| 130 | + if x1.shape != sorter.shape: |
| 131 | + raise ValueError( |
| 132 | + "Sorter array must be one-dimension with the same " |
| 133 | + "shape as the first argument array" |
| 134 | + ) |
| 135 | + res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q) |
| 136 | + ind = (sorter,) |
| 137 | + axis = 0 |
| 138 | + wrap_out_of_bound_indices_mode = 0 |
| 139 | + ht_ev, ev = ti_take( |
| 140 | + x1, |
| 141 | + ind, |
| 142 | + res, |
| 143 | + axis, |
| 144 | + wrap_out_of_bound_indices_mode, |
| 145 | + sycl_queue=q, |
| 146 | + depends=dep_evs, |
| 147 | + ) |
| 148 | + x1 = res |
| 149 | + _manager.add_event_pair(ht_ev, ev) |
| 150 | + |
| 151 | + if x1_dt != x2_dt: |
| 152 | + dt = result_type(x1, x2) |
| 153 | + if x1_dt != dt: |
| 154 | + x1_buf = _empty_like_orderK(x1, dt) |
| 155 | + dep_evs = _manager.submitted_events |
| 156 | + ht_ev, ev = ti_copy( |
| 157 | + src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs |
| 158 | + ) |
| 159 | + _manager.add_event_pair(ht_ev, ev) |
| 160 | + x1 = x1_buf |
| 161 | + if x2_dt != dt: |
| 162 | + x2_buf = _empty_like_orderK(x2, dt) |
| 163 | + dep_evs = _manager.submitted_events |
| 164 | + ht_ev, ev = ti_copy( |
| 165 | + src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs |
| 166 | + ) |
| 167 | + _manager.add_event_pair(ht_ev, ev) |
| 168 | + x2 = x2_buf |
| 169 | + |
| 170 | + dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type]) |
| 171 | + index_dt = ti_default_device_index_type(q) |
| 172 | + |
| 173 | + dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type) |
| 174 | + |
| 175 | + dep_evs = _manager.submitted_events |
| 176 | + if side == "left": |
| 177 | + ht_ev, s_ev = _searchsorted_left( |
| 178 | + hay=x1, |
| 179 | + needles=x2, |
| 180 | + positions=dst, |
| 181 | + sycl_queue=q, |
| 182 | + depends=dep_evs, |
| 183 | + ) |
| 184 | + else: |
| 185 | + ht_ev, s_ev = _searchsorted_right( |
| 186 | + hay=x1, needles=x2, positions=dst, sycl_queue=q, depends=dep_evs |
| 187 | + ) |
| 188 | + _manager.add_event_pair(ht_ev, s_ev) |
| 189 | + return dst |
0 commit comments