|
27 | 27 | # ***************************************************************************** |
28 | 28 |
|
29 | 29 | import builtins |
| 30 | +import operator |
| 31 | +from numbers import Integral |
30 | 32 |
|
31 | 33 | import dpctl |
32 | 34 | import dpctl.memory as dpm |
|
39 | 41 |
|
40 | 42 | # TODO: revert to `import dpctl.tensor...` |
41 | 43 | # when dpnp fully migrates dpctl/tensor |
| 44 | +import dpctl_ext.tensor as dpt_ext |
42 | 45 | import dpctl_ext.tensor._tensor_impl as ti |
43 | 46 |
|
| 47 | +from ._numpy_helper import normalize_axis_index |
| 48 | + |
44 | 49 | __doc__ = ( |
45 | 50 | "Implementation module for copy- and cast- operations on " |
46 | 51 | ":class:`dpctl.tensor.usm_ndarray`." |
@@ -130,6 +135,307 @@ def _copy_from_numpy_into(dst, np_ary): |
130 | 135 | ) |
131 | 136 |
|
132 | 137 |
|
| 138 | +def _extract_impl(ary, ary_mask, axis=0): |
| 139 | + """ |
| 140 | + Extract elements of ary by applying mask starting from slot |
| 141 | + dimension axis |
| 142 | + """ |
| 143 | + if not isinstance(ary, dpt.usm_ndarray): |
| 144 | + raise TypeError( |
| 145 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" |
| 146 | + ) |
| 147 | + if isinstance(ary_mask, dpt.usm_ndarray): |
| 148 | + dst_usm_type = dpctl.utils.get_coerced_usm_type( |
| 149 | + (ary.usm_type, ary_mask.usm_type) |
| 150 | + ) |
| 151 | + exec_q = dpctl.utils.get_execution_queue( |
| 152 | + (ary.sycl_queue, ary_mask.sycl_queue) |
| 153 | + ) |
| 154 | + if exec_q is None: |
| 155 | + raise dpctl.utils.ExecutionPlacementError( |
| 156 | + "arrays have different associated queues. " |
| 157 | + "Use `y.to_device(x.device)` to migrate." |
| 158 | + ) |
| 159 | + elif isinstance(ary_mask, np.ndarray): |
| 160 | + dst_usm_type = ary.usm_type |
| 161 | + exec_q = ary.sycl_queue |
| 162 | + ary_mask = dpt.asarray( |
| 163 | + ary_mask, usm_type=dst_usm_type, sycl_queue=exec_q |
| 164 | + ) |
| 165 | + else: |
| 166 | + raise TypeError( |
| 167 | + "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got " |
| 168 | + f"{type(ary_mask)}" |
| 169 | + ) |
| 170 | + ary_nd = ary.ndim |
| 171 | + pp = normalize_axis_index(operator.index(axis), ary_nd) |
| 172 | + mask_nd = ary_mask.ndim |
| 173 | + if pp < 0 or pp + mask_nd > ary_nd: |
| 174 | + raise ValueError( |
| 175 | + "Parameter p is inconsistent with input array dimensions" |
| 176 | + ) |
| 177 | + mask_nelems = ary_mask.size |
| 178 | + cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64 |
| 179 | + cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device) |
| 180 | + exec_q = cumsum.sycl_queue |
| 181 | + _manager = dpctl.utils.SequentialOrderManager[exec_q] |
| 182 | + dep_evs = _manager.submitted_events |
| 183 | + mask_count = ti.mask_positions( |
| 184 | + ary_mask, cumsum, sycl_queue=exec_q, depends=dep_evs |
| 185 | + ) |
| 186 | + dst_shape = ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :] |
| 187 | + dst = dpt.empty( |
| 188 | + dst_shape, dtype=ary.dtype, usm_type=dst_usm_type, device=ary.device |
| 189 | + ) |
| 190 | + if dst.size == 0: |
| 191 | + return dst |
| 192 | + hev, ev = ti._extract( |
| 193 | + src=ary, |
| 194 | + cumsum=cumsum, |
| 195 | + axis_start=pp, |
| 196 | + axis_end=pp + mask_nd, |
| 197 | + dst=dst, |
| 198 | + sycl_queue=exec_q, |
| 199 | + depends=dep_evs, |
| 200 | + ) |
| 201 | + _manager.add_event_pair(hev, ev) |
| 202 | + return dst |
| 203 | + |
| 204 | + |
| 205 | +def _get_indices_queue_usm_type(inds, queue, usm_type): |
| 206 | + """ |
| 207 | + Utility for validating indices are NumPy ndarray or usm_ndarray of integral |
| 208 | + dtype or Python integers. At least one must be an array. |
| 209 | +
|
| 210 | + For each array, the queue and usm type are appended to `queue_list` and |
| 211 | + `usm_type_list`, respectively. |
| 212 | + """ |
| 213 | + queues = [queue] |
| 214 | + usm_types = [usm_type] |
| 215 | + any_array = False |
| 216 | + for ind in inds: |
| 217 | + if isinstance(ind, (np.ndarray, dpt.usm_ndarray)): |
| 218 | + any_array = True |
| 219 | + if ind.dtype.kind not in "ui": |
| 220 | + raise IndexError( |
| 221 | + "arrays used as indices must be of integer (or boolean) " |
| 222 | + "type" |
| 223 | + ) |
| 224 | + if isinstance(ind, dpt.usm_ndarray): |
| 225 | + queues.append(ind.sycl_queue) |
| 226 | + usm_types.append(ind.usm_type) |
| 227 | + elif not isinstance(ind, Integral): |
| 228 | + raise TypeError( |
| 229 | + "all elements of `ind` expected to be usm_ndarrays, " |
| 230 | + f"NumPy arrays, or integers, found {type(ind)}" |
| 231 | + ) |
| 232 | + if not any_array: |
| 233 | + raise TypeError( |
| 234 | + "at least one element of `inds` expected to be an array" |
| 235 | + ) |
| 236 | + usm_type = dpctl.utils.get_coerced_usm_type(usm_types) |
| 237 | + q = dpctl.utils.get_execution_queue(queues) |
| 238 | + return q, usm_type |
| 239 | + |
| 240 | + |
| 241 | +def _nonzero_impl(ary): |
| 242 | + if not isinstance(ary, dpt.usm_ndarray): |
| 243 | + raise TypeError( |
| 244 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" |
| 245 | + ) |
| 246 | + exec_q = ary.sycl_queue |
| 247 | + usm_type = ary.usm_type |
| 248 | + mask_nelems = ary.size |
| 249 | + cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64 |
| 250 | + cumsum = dpt.empty( |
| 251 | + mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C" |
| 252 | + ) |
| 253 | + _manager = dpctl.utils.SequentialOrderManager[exec_q] |
| 254 | + dep_evs = _manager.submitted_events |
| 255 | + mask_count = ti.mask_positions( |
| 256 | + ary, cumsum, sycl_queue=exec_q, depends=dep_evs |
| 257 | + ) |
| 258 | + indexes_dt = ti.default_device_index_type(exec_q.sycl_device) |
| 259 | + indexes = dpt.empty( |
| 260 | + (ary.ndim, mask_count), |
| 261 | + dtype=indexes_dt, |
| 262 | + usm_type=usm_type, |
| 263 | + sycl_queue=exec_q, |
| 264 | + order="C", |
| 265 | + ) |
| 266 | + hev, nz_ev = ti._nonzero(cumsum, indexes, ary.shape, exec_q) |
| 267 | + res = tuple(indexes[i, :] for i in range(ary.ndim)) |
| 268 | + _manager.add_event_pair(hev, nz_ev) |
| 269 | + return res |
| 270 | + |
| 271 | + |
| 272 | +def _prepare_indices_arrays(inds, q, usm_type): |
| 273 | + """ |
| 274 | + Utility taking a mix of usm_ndarray and possibly Python int scalar indices, |
| 275 | + a queue (assumed to be common to arrays in inds), and a usm type. |
| 276 | +
|
| 277 | + Python scalar integers are promoted to arrays on the provided queue and |
| 278 | + with the provided usm type. All arrays are then promoted to a common |
| 279 | + integral type (if possible) before being broadcast to a common shape. |
| 280 | + """ |
| 281 | + # scalar integers -> arrays |
| 282 | + inds = tuple( |
| 283 | + map( |
| 284 | + lambda ind: ( |
| 285 | + ind |
| 286 | + if isinstance(ind, dpt.usm_ndarray) |
| 287 | + else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q) |
| 288 | + ), |
| 289 | + inds, |
| 290 | + ) |
| 291 | + ) |
| 292 | + |
| 293 | + # promote to a common integral type if possible |
| 294 | + ind_dt = dpt.result_type(*inds) |
| 295 | + if ind_dt.kind not in "ui": |
| 296 | + raise ValueError( |
| 297 | + "cannot safely promote indices to an integer data type" |
| 298 | + ) |
| 299 | + inds = tuple( |
| 300 | + map( |
| 301 | + lambda ind: ( |
| 302 | + ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt) |
| 303 | + ), |
| 304 | + inds, |
| 305 | + ) |
| 306 | + ) |
| 307 | + |
| 308 | + # broadcast |
| 309 | + inds = dpt.broadcast_arrays(*inds) |
| 310 | + |
| 311 | + return inds |
| 312 | + |
| 313 | + |
| 314 | +def _put_multi_index(ary, inds, p, vals, mode=0): |
| 315 | + if not isinstance(ary, dpt.usm_ndarray): |
| 316 | + raise TypeError( |
| 317 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" |
| 318 | + ) |
| 319 | + ary_nd = ary.ndim |
| 320 | + p = normalize_axis_index(operator.index(p), ary_nd) |
| 321 | + mode = operator.index(mode) |
| 322 | + if mode not in [0, 1]: |
| 323 | + raise ValueError( |
| 324 | + "Invalid value for mode keyword, only 0 or 1 is supported" |
| 325 | + ) |
| 326 | + if not isinstance(inds, (list, tuple)): |
| 327 | + inds = (inds,) |
| 328 | + |
| 329 | + exec_q, coerced_usm_type = _get_indices_queue_usm_type( |
| 330 | + inds, ary.sycl_queue, ary.usm_type |
| 331 | + ) |
| 332 | + |
| 333 | + if exec_q is not None: |
| 334 | + if not isinstance(vals, dpt.usm_ndarray): |
| 335 | + vals = dpt.asarray( |
| 336 | + vals, |
| 337 | + dtype=ary.dtype, |
| 338 | + usm_type=coerced_usm_type, |
| 339 | + sycl_queue=exec_q, |
| 340 | + ) |
| 341 | + else: |
| 342 | + exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue)) |
| 343 | + coerced_usm_type = dpctl.utils.get_coerced_usm_type( |
| 344 | + ( |
| 345 | + coerced_usm_type, |
| 346 | + vals.usm_type, |
| 347 | + ) |
| 348 | + ) |
| 349 | + if exec_q is None: |
| 350 | + raise dpctl.utils.ExecutionPlacementError( |
| 351 | + "Can not automatically determine where to allocate the " |
| 352 | + "result or performance execution. " |
| 353 | + "Use `usm_ndarray.to_device` method to migrate data to " |
| 354 | + "be associated with the same queue." |
| 355 | + ) |
| 356 | + |
| 357 | + inds = _prepare_indices_arrays(inds, exec_q, coerced_usm_type) |
| 358 | + |
| 359 | + ind0 = inds[0] |
| 360 | + ary_sh = ary.shape |
| 361 | + p_end = p + len(inds) |
| 362 | + if 0 in ary_sh[p:p_end] and ind0.size != 0: |
| 363 | + raise IndexError( |
| 364 | + "cannot put into non-empty indices along an empty axis" |
| 365 | + ) |
| 366 | + expected_vals_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:] |
| 367 | + if vals.dtype == ary.dtype: |
| 368 | + rhs = vals |
| 369 | + else: |
| 370 | + rhs = dpt_ext.astype(vals, ary.dtype) |
| 371 | + rhs = dpt.broadcast_to(rhs, expected_vals_shape) |
| 372 | + _manager = dpctl.utils.SequentialOrderManager[exec_q] |
| 373 | + dep_ev = _manager.submitted_events |
| 374 | + hev, put_ev = ti._put( |
| 375 | + dst=ary, |
| 376 | + ind=inds, |
| 377 | + val=rhs, |
| 378 | + axis_start=p, |
| 379 | + mode=mode, |
| 380 | + sycl_queue=exec_q, |
| 381 | + depends=dep_ev, |
| 382 | + ) |
| 383 | + _manager.add_event_pair(hev, put_ev) |
| 384 | + return |
| 385 | + |
| 386 | + |
| 387 | +def _take_multi_index(ary, inds, p, mode=0): |
| 388 | + if not isinstance(ary, dpt.usm_ndarray): |
| 389 | + raise TypeError( |
| 390 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" |
| 391 | + ) |
| 392 | + ary_nd = ary.ndim |
| 393 | + p = normalize_axis_index(operator.index(p), ary_nd) |
| 394 | + mode = operator.index(mode) |
| 395 | + if mode not in [0, 1]: |
| 396 | + raise ValueError( |
| 397 | + "Invalid value for mode keyword, only 0 or 1 is supported" |
| 398 | + ) |
| 399 | + if not isinstance(inds, (list, tuple)): |
| 400 | + inds = (inds,) |
| 401 | + |
| 402 | + exec_q, res_usm_type = _get_indices_queue_usm_type( |
| 403 | + inds, ary.sycl_queue, ary.usm_type |
| 404 | + ) |
| 405 | + if exec_q is None: |
| 406 | + raise dpctl.utils.ExecutionPlacementError( |
| 407 | + "Can not automatically determine where to allocate the " |
| 408 | + "result or performance execution. " |
| 409 | + "Use `usm_ndarray.to_device` method to migrate data to " |
| 410 | + "be associated with the same queue." |
| 411 | + ) |
| 412 | + |
| 413 | + inds = _prepare_indices_arrays(inds, exec_q, res_usm_type) |
| 414 | + |
| 415 | + ind0 = inds[0] |
| 416 | + ary_sh = ary.shape |
| 417 | + p_end = p + len(inds) |
| 418 | + if 0 in ary_sh[p:p_end] and ind0.size != 0: |
| 419 | + raise IndexError("cannot take non-empty indices from an empty axis") |
| 420 | + res_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:] |
| 421 | + res = dpt.empty( |
| 422 | + res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q |
| 423 | + ) |
| 424 | + _manager = dpctl.utils.SequentialOrderManager[exec_q] |
| 425 | + dep_ev = _manager.submitted_events |
| 426 | + hev, take_ev = ti._take( |
| 427 | + src=ary, |
| 428 | + ind=inds, |
| 429 | + dst=res, |
| 430 | + axis_start=p, |
| 431 | + mode=mode, |
| 432 | + sycl_queue=exec_q, |
| 433 | + depends=dep_ev, |
| 434 | + ) |
| 435 | + _manager.add_event_pair(hev, take_ev) |
| 436 | + return res |
| 437 | + |
| 438 | + |
133 | 439 | def from_numpy(np_ary, /, *, device=None, usm_type="device", sycl_queue=None): |
134 | 440 | """ |
135 | 441 | from_numpy(arg, device=None, usm_type="device", sycl_queue=None) |
|
0 commit comments