|
28 | 28 |
|
29 | 29 | import builtins |
30 | 30 | import operator |
| 31 | +from numbers import Integral |
31 | 32 |
|
32 | 33 | import dpctl |
33 | 34 | import dpctl.memory as dpm |
|
40 | 41 |
|
41 | 42 | # TODO: revert to `import dpctl.tensor...` |
42 | 43 | # when dpnp fully migrates dpctl/tensor |
| 44 | +import dpctl_ext.tensor as dpt_ext |
43 | 45 | import dpctl_ext.tensor._tensor_impl as ti |
44 | 46 |
|
45 | 47 | from ._numpy_helper import normalize_axis_index |
@@ -200,6 +202,42 @@ def _extract_impl(ary, ary_mask, axis=0): |
200 | 202 | return dst |
201 | 203 |
|
202 | 204 |
|
| 205 | +def _get_indices_queue_usm_type(inds, queue, usm_type): |
| 206 | + """ |
| 207 | + Utility for validating indices are NumPy ndarray or usm_ndarray of integral |
| 208 | + dtype or Python integers. At least one must be an array. |
| 209 | +
|
| 210 | + For each array, the queue and usm type are appended to `queue_list` and |
| 211 | + `usm_type_list`, respectively. |
| 212 | + """ |
| 213 | + queues = [queue] |
| 214 | + usm_types = [usm_type] |
| 215 | + any_array = False |
| 216 | + for ind in inds: |
| 217 | + if isinstance(ind, (np.ndarray, dpt.usm_ndarray)): |
| 218 | + any_array = True |
| 219 | + if ind.dtype.kind not in "ui": |
| 220 | + raise IndexError( |
| 221 | + "arrays used as indices must be of integer (or boolean) " |
| 222 | + "type" |
| 223 | + ) |
| 224 | + if isinstance(ind, dpt.usm_ndarray): |
| 225 | + queues.append(ind.sycl_queue) |
| 226 | + usm_types.append(ind.usm_type) |
| 227 | + elif not isinstance(ind, Integral): |
| 228 | + raise TypeError( |
| 229 | + "all elements of `ind` expected to be usm_ndarrays, " |
| 230 | + f"NumPy arrays, or integers, found {type(ind)}" |
| 231 | + ) |
| 232 | + if not any_array: |
| 233 | + raise TypeError( |
| 234 | + "at least one element of `inds` expected to be an array" |
| 235 | + ) |
| 236 | + usm_type = dpctl.utils.get_coerced_usm_type(usm_types) |
| 237 | + q = dpctl.utils.get_execution_queue(queues) |
| 238 | + return q, usm_type |
| 239 | + |
| 240 | + |
203 | 241 | def _nonzero_impl(ary): |
204 | 242 | if not isinstance(ary, dpt.usm_ndarray): |
205 | 243 | raise TypeError( |
@@ -231,6 +269,121 @@ def _nonzero_impl(ary): |
231 | 269 | return res |
232 | 270 |
|
233 | 271 |
|
| 272 | +def _prepare_indices_arrays(inds, q, usm_type): |
| 273 | + """ |
| 274 | + Utility taking a mix of usm_ndarray and possibly Python int scalar indices, |
| 275 | + a queue (assumed to be common to arrays in inds), and a usm type. |
| 276 | +
|
| 277 | + Python scalar integers are promoted to arrays on the provided queue and |
| 278 | + with the provided usm type. All arrays are then promoted to a common |
| 279 | + integral type (if possible) before being broadcast to a common shape. |
| 280 | + """ |
| 281 | + # scalar integers -> arrays |
| 282 | + inds = tuple( |
| 283 | + map( |
| 284 | + lambda ind: ( |
| 285 | + ind |
| 286 | + if isinstance(ind, dpt.usm_ndarray) |
| 287 | + else dpt.asarray(ind, usm_type=usm_type, sycl_queue=q) |
| 288 | + ), |
| 289 | + inds, |
| 290 | + ) |
| 291 | + ) |
| 292 | + |
| 293 | + # promote to a common integral type if possible |
| 294 | + ind_dt = dpt.result_type(*inds) |
| 295 | + if ind_dt.kind not in "ui": |
| 296 | + raise ValueError( |
| 297 | + "cannot safely promote indices to an integer data type" |
| 298 | + ) |
| 299 | + inds = tuple( |
| 300 | + map( |
| 301 | + lambda ind: ( |
| 302 | + ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt) |
| 303 | + ), |
| 304 | + inds, |
| 305 | + ) |
| 306 | + ) |
| 307 | + |
| 308 | + # broadcast |
| 309 | + inds = dpt.broadcast_arrays(*inds) |
| 310 | + |
| 311 | + return inds |
| 312 | + |
| 313 | + |
| 314 | +def _put_multi_index(ary, inds, p, vals, mode=0): |
| 315 | + if not isinstance(ary, dpt.usm_ndarray): |
| 316 | + raise TypeError( |
| 317 | + f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" |
| 318 | + ) |
| 319 | + ary_nd = ary.ndim |
| 320 | + p = normalize_axis_index(operator.index(p), ary_nd) |
| 321 | + mode = operator.index(mode) |
| 322 | + if mode not in [0, 1]: |
| 323 | + raise ValueError( |
| 324 | + "Invalid value for mode keyword, only 0 or 1 is supported" |
| 325 | + ) |
| 326 | + if not isinstance(inds, (list, tuple)): |
| 327 | + inds = (inds,) |
| 328 | + |
| 329 | + exec_q, coerced_usm_type = _get_indices_queue_usm_type( |
| 330 | + inds, ary.sycl_queue, ary.usm_type |
| 331 | + ) |
| 332 | + |
| 333 | + if exec_q is not None: |
| 334 | + if not isinstance(vals, dpt.usm_ndarray): |
| 335 | + vals = dpt.asarray( |
| 336 | + vals, |
| 337 | + dtype=ary.dtype, |
| 338 | + usm_type=coerced_usm_type, |
| 339 | + sycl_queue=exec_q, |
| 340 | + ) |
| 341 | + else: |
| 342 | + exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue)) |
| 343 | + coerced_usm_type = dpctl.utils.get_coerced_usm_type( |
| 344 | + ( |
| 345 | + coerced_usm_type, |
| 346 | + vals.usm_type, |
| 347 | + ) |
| 348 | + ) |
| 349 | + if exec_q is None: |
| 350 | + raise dpctl.utils.ExecutionPlacementError( |
| 351 | + "Can not automatically determine where to allocate the " |
| 352 | + "result or performance execution. " |
| 353 | + "Use `usm_ndarray.to_device` method to migrate data to " |
| 354 | + "be associated with the same queue." |
| 355 | + ) |
| 356 | + |
| 357 | + inds = _prepare_indices_arrays(inds, exec_q, coerced_usm_type) |
| 358 | + |
| 359 | + ind0 = inds[0] |
| 360 | + ary_sh = ary.shape |
| 361 | + p_end = p + len(inds) |
| 362 | + if 0 in ary_sh[p:p_end] and ind0.size != 0: |
| 363 | + raise IndexError( |
| 364 | + "cannot put into non-empty indices along an empty axis" |
| 365 | + ) |
| 366 | + expected_vals_shape = ary_sh[:p] + ind0.shape + ary_sh[p_end:] |
| 367 | + if vals.dtype == ary.dtype: |
| 368 | + rhs = vals |
| 369 | + else: |
| 370 | + rhs = dpt_ext.astype(vals, ary.dtype) |
| 371 | + rhs = dpt.broadcast_to(rhs, expected_vals_shape) |
| 372 | + _manager = dpctl.utils.SequentialOrderManager[exec_q] |
| 373 | + dep_ev = _manager.submitted_events |
| 374 | + hev, put_ev = ti._put( |
| 375 | + dst=ary, |
| 376 | + ind=inds, |
| 377 | + val=rhs, |
| 378 | + axis_start=p, |
| 379 | + mode=mode, |
| 380 | + sycl_queue=exec_q, |
| 381 | + depends=dep_ev, |
| 382 | + ) |
| 383 | + _manager.add_event_pair(hev, put_ev) |
| 384 | + return |
| 385 | + |
| 386 | + |
234 | 387 | def from_numpy(np_ary, /, *, device=None, usm_type="device", sycl_queue=None): |
235 | 388 | """ |
236 | 389 | from_numpy(arg, device=None, usm_type="device", sycl_queue=None) |
|
0 commit comments