Skip to content

Commit 9339cb8

Browse files
Merge move_tensor_elementwise_impl_unary into move_tensor_elementwise_impl_unary_par_2
2 parents 7f444bd + 6a2c31b commit 9339cb8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+253
-234
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
# *****************************************************************************
2828

2929

30-
from dpctl.tensor._search_functions import where
31-
32-
from dpctl_ext.tensor._copy_utils import (
30+
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
31+
from ._clip import clip
32+
from ._copy_utils import (
3333
asnumpy,
3434
astype,
3535
copy,
3636
from_numpy,
3737
to_numpy,
3838
)
39-
from dpctl_ext.tensor._ctors import (
39+
from ._ctors import (
4040
arange,
4141
asarray,
4242
empty,
@@ -53,36 +53,6 @@
5353
zeros,
5454
zeros_like,
5555
)
56-
from dpctl_ext.tensor._indexing_functions import (
57-
extract,
58-
nonzero,
59-
place,
60-
put,
61-
put_along_axis,
62-
take,
63-
take_along_axis,
64-
)
65-
from dpctl_ext.tensor._manipulation_functions import (
66-
broadcast_arrays,
67-
broadcast_to,
68-
concat,
69-
expand_dims,
70-
flip,
71-
moveaxis,
72-
permute_dims,
73-
repeat,
74-
roll,
75-
squeeze,
76-
stack,
77-
swapaxes,
78-
tile,
79-
unstack,
80-
)
81-
from dpctl_ext.tensor._reshape import reshape
82-
from dpctl_ext.tensor._utility_functions import all, any, diff
83-
84-
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
85-
from ._clip import clip
8656
from ._elementwise_funcs import (
8757
abs,
8858
acos,
@@ -112,6 +82,31 @@
11282
negative,
11383
positive,
11484
)
85+
from ._indexing_functions import (
86+
extract,
87+
nonzero,
88+
place,
89+
put,
90+
put_along_axis,
91+
take,
92+
take_along_axis,
93+
)
94+
from ._manipulation_functions import (
95+
broadcast_arrays,
96+
broadcast_to,
97+
concat,
98+
expand_dims,
99+
flip,
100+
moveaxis,
101+
permute_dims,
102+
repeat,
103+
roll,
104+
squeeze,
105+
stack,
106+
swapaxes,
107+
tile,
108+
unstack,
109+
)
115110
from ._reduction import (
116111
argmax,
117112
argmin,
@@ -123,6 +118,8 @@
123118
reduce_hypot,
124119
sum,
125120
)
121+
from ._reshape import reshape
122+
from ._search_functions import where
126123
from ._searchsorted import searchsorted
127124
from ._set_functions import (
128125
isin,
@@ -133,6 +130,7 @@
133130
)
134131
from ._sorting import argsort, sort, top_k
135132
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
133+
from ._utility_functions import all, any, diff
136134

137135
__all__ = [
138136
"abs",

dpctl_ext/tensor/_accumulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
import dpctl_ext.tensor as dpt_ext
3636
import dpctl_ext.tensor._tensor_accumulation_impl as tai
3737
import dpctl_ext.tensor._tensor_impl as ti
38-
from dpctl_ext.tensor._type_utils import (
38+
39+
from ._numpy_helper import normalize_axis_index
40+
from ._type_utils import (
3941
_default_accumulation_dtype,
4042
_default_accumulation_dtype_fp_types,
4143
_to_device_supported_dtype,
4244
)
4345

44-
from ._numpy_helper import normalize_axis_index
45-
4646

4747
def _accumulate_common(
4848
x,

dpctl_ext/tensor/_clip.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@
3535
# when dpnp fully migrates dpctl/tensor
3636
import dpctl_ext.tensor as dpt_ext
3737
import dpctl_ext.tensor._tensor_impl as ti
38-
from dpctl_ext.tensor._copy_utils import (
38+
39+
from ._copy_utils import (
3940
_empty_like_orderK,
4041
_empty_like_pair_orderK,
4142
_empty_like_triple_orderK,
4243
)
43-
from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl
44-
from dpctl_ext.tensor._type_utils import _can_cast
45-
44+
from ._manipulation_functions import _broadcast_shape_impl
4645
from ._scalar_utils import (
4746
_get_dtype,
4847
_get_queue_usm_type,
4948
_get_shape,
5049
_validate_dtype,
5150
)
5251
from ._type_utils import (
52+
_can_cast,
5353
_resolve_one_strong_one_weak_types,
5454
_resolve_one_strong_two_weak_types,
5555
)

dpctl_ext/tensor/_copy_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
# when dpnp fully migrates dpctl/tensor
4343
import dpctl_ext.tensor as dpt_ext
4444
import dpctl_ext.tensor._tensor_impl as ti
45-
from dpctl_ext.tensor._type_utils import _dtype_supported_by_device_impl
4645

4746
from ._numpy_helper import normalize_axis_index
47+
from ._type_utils import _dtype_supported_by_device_impl
4848

4949
__doc__ = (
5050
"Implementation module for copy- and cast- operations on "
@@ -299,7 +299,7 @@ def _prepare_indices_arrays(inds, q, usm_type):
299299
inds = tuple(
300300
map(
301301
lambda ind: (
302-
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
302+
ind if ind.dtype == ind_dt else dpt_ext.astype(ind, ind_dt)
303303
),
304304
inds,
305305
)

dpctl_ext/tensor/_ctors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
# when dpnp fully migrates dpctl/tensor
4343
import dpctl_ext.tensor as dpt_ext
4444
import dpctl_ext.tensor._tensor_impl as ti
45-
from dpctl_ext.tensor._copy_utils import (
45+
46+
from ._copy_utils import (
4647
_empty_like_orderK,
4748
_from_numpy_empty_like_orderK,
4849
)
@@ -1440,7 +1441,7 @@ def linspace(
14401441
)
14411442
_manager.add_event_pair(hev, la_ev)
14421443

1443-
return res if int_dt is None else dpt.astype(res, int_dt)
1444+
return res if int_dt is None else dpt_ext.astype(res, int_dt)
14441445

14451446

14461447
def meshgrid(*arrays, indexing="xy"):

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def place(arr, mask, vals):
190190
if vals.dtype == arr.dtype:
191191
rhs = vals
192192
else:
193-
rhs = dpt.astype(vals, arr.dtype)
193+
rhs = dpt_ext.astype(vals, arr.dtype)
194194
hev, pl_ev = ti._place(
195195
dst=arr,
196196
cumsum=cumsum,

dpctl_ext/tensor/_reshape.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,17 @@
3131
import dpctl.tensor as dpt
3232
import dpctl.utils
3333
import numpy as np
34-
from dpctl.tensor._tensor_impl import (
35-
_copy_usm_ndarray_for_reshape,
36-
_ravel_multi_index,
37-
_unravel_index,
38-
)
3934

4035
# TODO: revert to `import dpctl.tensor...`
4136
# when dpnp fully migrates dpctl/tensor
4237
import dpctl_ext.tensor as dpt_ext
4338

39+
from ._tensor_impl import (
40+
_copy_usm_ndarray_for_reshape,
41+
_ravel_multi_index,
42+
_unravel_index,
43+
)
44+
4445
__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
4546

4647

dpctl_ext/tensor/_search_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
# when dpnp fully migrates dpctl/tensor
3535
import dpctl_ext.tensor as dpt_ext
3636
import dpctl_ext.tensor._tensor_impl as ti
37-
from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl
3837

3938
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
39+
from ._manipulation_functions import _broadcast_shape_impl
4040
from ._scalar_utils import (
4141
_get_dtype,
4242
_get_queue_usm_type,

dpctl_ext/tensor/libtensor/include/kernels/clip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
#pragma once
3636
#include <algorithm>
3737
#include <cmath>
38-
#include <complex>
3938
#include <cstddef>
4039
#include <cstdint>
4140
#include <type_traits>
41+
#include <vector>
4242

4343
#include <sycl/sycl.hpp>
4444

dpctl_ext/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
//===----------------------------------------------------------------------===//
3434

3535
#pragma once
36+
3637
#include <array>
37-
#include <complex>
3838
#include <cstddef>
39+
#include <type_traits>
3940
#include <vector>
4041

4142
#include <sycl/sycl.hpp>
@@ -130,7 +131,7 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q,
130131
}
131132

132133
// Constructor to populate tensor with linear sequence defined by
133-
// start and and data
134+
// start and data
134135

135136
template <typename Ty, typename wTy>
136137
class LinearSequenceAffineFunctor
@@ -191,7 +192,7 @@ class LinearSequenceAffineFunctor
191192
*
192193
* @param exec_q Sycl queue to which kernel is submitted for execution.
193194
* @param nelems Length of the sequence.
194-
* @param start_v Stating value of the sequence.
195+
* @param start_v Starting value of the sequence.
195196
* @param end_v End-value of the sequence.
196197
* @param include_endpoint Whether the end-value is included in the sequence.
197198
* @param array_data Kernel accessible USM pointer to the start of array to be

0 commit comments

Comments
 (0)