Skip to content

Commit d80285e

Browse files
Merge move_elementwise_binary_impl into move_elementwise_binary_impl_part_2
2 parents ee6ba17 + c9efd4d commit d80285e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+259
-240
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
# *****************************************************************************
2828

2929

30-
from dpctl.tensor._search_functions import where
31-
32-
from dpctl_ext.tensor._copy_utils import (
30+
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
31+
from ._clip import clip
32+
from ._copy_utils import (
3333
asnumpy,
3434
astype,
3535
copy,
3636
from_numpy,
3737
to_numpy,
3838
)
39-
from dpctl_ext.tensor._ctors import (
39+
from ._ctors import (
4040
arange,
4141
asarray,
4242
empty,
@@ -53,42 +53,6 @@
5353
zeros,
5454
zeros_like,
5555
)
56-
from dpctl_ext.tensor._indexing_functions import (
57-
extract,
58-
nonzero,
59-
place,
60-
put,
61-
put_along_axis,
62-
take,
63-
take_along_axis,
64-
)
65-
from dpctl_ext.tensor._linear_algebra_functions import (
66-
matmul,
67-
matrix_transpose,
68-
tensordot,
69-
vecdot,
70-
)
71-
from dpctl_ext.tensor._manipulation_functions import (
72-
broadcast_arrays,
73-
broadcast_to,
74-
concat,
75-
expand_dims,
76-
flip,
77-
moveaxis,
78-
permute_dims,
79-
repeat,
80-
roll,
81-
squeeze,
82-
stack,
83-
swapaxes,
84-
tile,
85-
unstack,
86-
)
87-
from dpctl_ext.tensor._reshape import reshape
88-
from dpctl_ext.tensor._utility_functions import all, any, diff
89-
90-
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
91-
from ._clip import clip
9256
from ._elementwise_funcs import (
9357
abs,
9458
acos,
@@ -150,6 +114,37 @@
150114
tanh,
151115
trunc,
152116
)
117+
from ._indexing_functions import (
118+
extract,
119+
nonzero,
120+
place,
121+
put,
122+
put_along_axis,
123+
take,
124+
take_along_axis,
125+
)
126+
from ._linear_algebra_functions import (
127+
matmul,
128+
matrix_transpose,
129+
tensordot,
130+
vecdot,
131+
)
132+
from ._manipulation_functions import (
133+
broadcast_arrays,
134+
broadcast_to,
135+
concat,
136+
expand_dims,
137+
flip,
138+
moveaxis,
139+
permute_dims,
140+
repeat,
141+
roll,
142+
squeeze,
143+
stack,
144+
swapaxes,
145+
tile,
146+
unstack,
147+
)
153148
from ._reduction import (
154149
argmax,
155150
argmin,
@@ -161,6 +156,8 @@
161156
reduce_hypot,
162157
sum,
163158
)
159+
from ._reshape import reshape
160+
from ._search_functions import where
164161
from ._searchsorted import searchsorted
165162
from ._set_functions import (
166163
isin,
@@ -171,6 +168,7 @@
171168
)
172169
from ._sorting import argsort, sort, top_k
173170
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
171+
from ._utility_functions import all, any, diff
174172

175173
__all__ = [
176174
"abs",

dpctl_ext/tensor/_accumulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
import dpctl_ext.tensor as dpt_ext
3636
import dpctl_ext.tensor._tensor_accumulation_impl as tai
3737
import dpctl_ext.tensor._tensor_impl as ti
38-
from dpctl_ext.tensor._type_utils import (
38+
39+
from ._numpy_helper import normalize_axis_index
40+
from ._type_utils import (
3941
_default_accumulation_dtype,
4042
_default_accumulation_dtype_fp_types,
4143
_to_device_supported_dtype,
4244
)
4345

44-
from ._numpy_helper import normalize_axis_index
45-
4646

4747
def _accumulate_common(
4848
x,

dpctl_ext/tensor/_clip.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@
3535
# when dpnp fully migrates dpctl/tensor
3636
import dpctl_ext.tensor as dpt_ext
3737
import dpctl_ext.tensor._tensor_impl as ti
38-
from dpctl_ext.tensor._copy_utils import (
38+
39+
from ._copy_utils import (
3940
_empty_like_orderK,
4041
_empty_like_pair_orderK,
4142
_empty_like_triple_orderK,
4243
)
43-
from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl
44-
from dpctl_ext.tensor._type_utils import _can_cast
45-
44+
from ._manipulation_functions import _broadcast_shape_impl
4645
from ._scalar_utils import (
4746
_get_dtype,
4847
_get_queue_usm_type,
4948
_get_shape,
5049
_validate_dtype,
5150
)
5251
from ._type_utils import (
52+
_can_cast,
5353
_resolve_one_strong_one_weak_types,
5454
_resolve_one_strong_two_weak_types,
5555
)

dpctl_ext/tensor/_copy_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
# when dpnp fully migrates dpctl/tensor
4343
import dpctl_ext.tensor as dpt_ext
4444
import dpctl_ext.tensor._tensor_impl as ti
45-
from dpctl_ext.tensor._type_utils import _dtype_supported_by_device_impl
4645

4746
from ._numpy_helper import normalize_axis_index
47+
from ._type_utils import _dtype_supported_by_device_impl
4848

4949
__doc__ = (
5050
"Implementation module for copy- and cast- operations on "
@@ -299,7 +299,7 @@ def _prepare_indices_arrays(inds, q, usm_type):
299299
inds = tuple(
300300
map(
301301
lambda ind: (
302-
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
302+
ind if ind.dtype == ind_dt else dpt_ext.astype(ind, ind_dt)
303303
),
304304
inds,
305305
)

dpctl_ext/tensor/_ctors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
# when dpnp fully migrates dpctl/tensor
4343
import dpctl_ext.tensor as dpt_ext
4444
import dpctl_ext.tensor._tensor_impl as ti
45-
from dpctl_ext.tensor._copy_utils import (
45+
46+
from ._copy_utils import (
4647
_empty_like_orderK,
4748
_from_numpy_empty_like_orderK,
4849
)
@@ -1440,7 +1441,7 @@ def linspace(
14401441
)
14411442
_manager.add_event_pair(hev, la_ev)
14421443

1443-
return res if int_dt is None else dpt.astype(res, int_dt)
1444+
return res if int_dt is None else dpt_ext.astype(res, int_dt)
14441445

14451446

14461447
def meshgrid(*arrays, indexing="xy"):

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def place(arr, mask, vals):
190190
if vals.dtype == arr.dtype:
191191
rhs = vals
192192
else:
193-
rhs = dpt.astype(vals, arr.dtype)
193+
rhs = dpt_ext.astype(vals, arr.dtype)
194194
hev, pl_ev = ti._place(
195195
dst=arr,
196196
cumsum=cumsum,

dpctl_ext/tensor/_reshape.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,17 @@
3131
import dpctl.tensor as dpt
3232
import dpctl.utils
3333
import numpy as np
34-
from dpctl.tensor._tensor_impl import (
35-
_copy_usm_ndarray_for_reshape,
36-
_ravel_multi_index,
37-
_unravel_index,
38-
)
3934

4035
# TODO: revert to `import dpctl.tensor...`
4136
# when dpnp fully migrates dpctl/tensor
4237
import dpctl_ext.tensor as dpt_ext
4338

39+
from ._tensor_impl import (
40+
_copy_usm_ndarray_for_reshape,
41+
_ravel_multi_index,
42+
_unravel_index,
43+
)
44+
4445
__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
4546

4647

dpctl_ext/tensor/_search_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
# when dpnp fully migrates dpctl/tensor
3535
import dpctl_ext.tensor as dpt_ext
3636
import dpctl_ext.tensor._tensor_impl as ti
37-
from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl
3837

3938
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
39+
from ._manipulation_functions import _broadcast_shape_impl
4040
from ._scalar_utils import (
4141
_get_dtype,
4242
_get_queue_usm_type,

dpctl_ext/tensor/libtensor/include/kernels/clip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
#pragma once
3636
#include <algorithm>
3737
#include <cmath>
38-
#include <complex>
3938
#include <cstddef>
4039
#include <cstdint>
4140
#include <type_traits>
41+
#include <vector>
4242

4343
#include <sycl/sycl.hpp>
4444

dpctl_ext/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
//===----------------------------------------------------------------------===//
3434

3535
#pragma once
36+
3637
#include <array>
37-
#include <complex>
3838
#include <cstddef>
39+
#include <type_traits>
3940
#include <vector>
4041

4142
#include <sycl/sycl.hpp>
@@ -130,7 +131,7 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q,
130131
}
131132

132133
// Constructor to populate tensor with linear sequence defined by
133-
// start and and data
134+
// start and data
134135

135136
template <typename Ty, typename wTy>
136137
class LinearSequenceAffineFunctor
@@ -191,7 +192,7 @@ class LinearSequenceAffineFunctor
191192
*
192193
* @param exec_q Sycl queue to which kernel is submitted for execution.
193194
* @param nelems Length of the sequence.
194-
* @param start_v Stating value of the sequence.
195+
* @param start_v Starting value of the sequence.
195196
* @param end_v End-value of the sequence.
196197
* @param include_endpoint Whether the end-value is included in the sequence.
197198
* @param array_data Kernel accessible USM pointer to the start of array to be

0 commit comments

Comments
 (0)