Skip to content

Commit 6a2c31b

Browse files
Merge move_tensor_reductions_impl_ext into move_tensor_elementwise_impl_unary
2 parents e24b129 + c9644ac commit 6a2c31b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+243
-224
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
# *****************************************************************************
2828

2929

30-
from dpctl.tensor._search_functions import where
31-
32-
from dpctl_ext.tensor._copy_utils import (
30+
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
31+
from ._clip import clip
32+
from ._copy_utils import (
3333
asnumpy,
3434
astype,
3535
copy,
3636
from_numpy,
3737
to_numpy,
3838
)
39-
from dpctl_ext.tensor._ctors import (
39+
from ._ctors import (
4040
arange,
4141
asarray,
4242
empty,
@@ -53,7 +53,20 @@
5353
zeros,
5454
zeros_like,
5555
)
56-
from dpctl_ext.tensor._indexing_functions import (
56+
from ._elementwise_funcs import (
57+
abs,
58+
acos,
59+
acosh,
60+
angle,
61+
asin,
62+
asinh,
63+
atan,
64+
atanh,
65+
bitwise_invert,
66+
ceil,
67+
conj,
68+
)
69+
from ._indexing_functions import (
5770
extract,
5871
nonzero,
5972
place,
@@ -62,7 +75,7 @@
6275
take,
6376
take_along_axis,
6477
)
65-
from dpctl_ext.tensor._manipulation_functions import (
78+
from ._manipulation_functions import (
6679
broadcast_arrays,
6780
broadcast_to,
6881
concat,
@@ -78,24 +91,6 @@
7891
tile,
7992
unstack,
8093
)
81-
from dpctl_ext.tensor._reshape import reshape
82-
from dpctl_ext.tensor._utility_functions import all, any, diff
83-
84-
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
85-
from ._clip import clip
86-
from ._elementwise_funcs import (
87-
abs,
88-
acos,
89-
acosh,
90-
angle,
91-
asin,
92-
asinh,
93-
atan,
94-
atanh,
95-
bitwise_invert,
96-
ceil,
97-
conj,
98-
)
9994
from ._reduction import (
10095
argmax,
10196
argmin,
@@ -107,6 +102,8 @@
107102
reduce_hypot,
108103
sum,
109104
)
105+
from ._reshape import reshape
106+
from ._search_functions import where
110107
from ._searchsorted import searchsorted
111108
from ._set_functions import (
112109
isin,
@@ -117,6 +114,7 @@
117114
)
118115
from ._sorting import argsort, sort, top_k
119116
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
117+
from ._utility_functions import all, any, diff
120118

121119
__all__ = [
122120
"abs",

dpctl_ext/tensor/_accumulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535
import dpctl_ext.tensor as dpt_ext
3636
import dpctl_ext.tensor._tensor_accumulation_impl as tai
3737
import dpctl_ext.tensor._tensor_impl as ti
38-
from dpctl_ext.tensor._type_utils import (
38+
39+
from ._numpy_helper import normalize_axis_index
40+
from ._type_utils import (
3941
_default_accumulation_dtype,
4042
_default_accumulation_dtype_fp_types,
4143
_to_device_supported_dtype,
4244
)
4345

44-
from ._numpy_helper import normalize_axis_index
45-
4646

4747
def _accumulate_common(
4848
x,

dpctl_ext/tensor/_clip.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@
3535
# when dpnp fully migrates dpctl/tensor
3636
import dpctl_ext.tensor as dpt_ext
3737
import dpctl_ext.tensor._tensor_impl as ti
38-
from dpctl_ext.tensor._copy_utils import (
38+
39+
from ._copy_utils import (
3940
_empty_like_orderK,
4041
_empty_like_pair_orderK,
4142
_empty_like_triple_orderK,
4243
)
43-
from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl
44-
from dpctl_ext.tensor._type_utils import _can_cast
45-
44+
from ._manipulation_functions import _broadcast_shape_impl
4645
from ._scalar_utils import (
4746
_get_dtype,
4847
_get_queue_usm_type,
4948
_get_shape,
5049
_validate_dtype,
5150
)
5251
from ._type_utils import (
52+
_can_cast,
5353
_resolve_one_strong_one_weak_types,
5454
_resolve_one_strong_two_weak_types,
5555
)

dpctl_ext/tensor/_copy_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
# when dpnp fully migrates dpctl/tensor
4343
import dpctl_ext.tensor as dpt_ext
4444
import dpctl_ext.tensor._tensor_impl as ti
45-
from dpctl_ext.tensor._type_utils import _dtype_supported_by_device_impl
4645

4746
from ._numpy_helper import normalize_axis_index
47+
from ._type_utils import _dtype_supported_by_device_impl
4848

4949
__doc__ = (
5050
"Implementation module for copy- and cast- operations on "
@@ -299,7 +299,7 @@ def _prepare_indices_arrays(inds, q, usm_type):
299299
inds = tuple(
300300
map(
301301
lambda ind: (
302-
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
302+
ind if ind.dtype == ind_dt else dpt_ext.astype(ind, ind_dt)
303303
),
304304
inds,
305305
)

dpctl_ext/tensor/_ctors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
# when dpnp fully migrates dpctl/tensor
4343
import dpctl_ext.tensor as dpt_ext
4444
import dpctl_ext.tensor._tensor_impl as ti
45-
from dpctl_ext.tensor._copy_utils import (
45+
46+
from ._copy_utils import (
4647
_empty_like_orderK,
4748
_from_numpy_empty_like_orderK,
4849
)
@@ -1440,7 +1441,7 @@ def linspace(
14401441
)
14411442
_manager.add_event_pair(hev, la_ev)
14421443

1443-
return res if int_dt is None else dpt.astype(res, int_dt)
1444+
return res if int_dt is None else dpt_ext.astype(res, int_dt)
14441445

14451446

14461447
def meshgrid(*arrays, indexing="xy"):

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def place(arr, mask, vals):
190190
if vals.dtype == arr.dtype:
191191
rhs = vals
192192
else:
193-
rhs = dpt.astype(vals, arr.dtype)
193+
rhs = dpt_ext.astype(vals, arr.dtype)
194194
hev, pl_ev = ti._place(
195195
dst=arr,
196196
cumsum=cumsum,

dpctl_ext/tensor/_reshape.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,17 @@
3131
import dpctl.tensor as dpt
3232
import dpctl.utils
3333
import numpy as np
34-
from dpctl.tensor._tensor_impl import (
35-
_copy_usm_ndarray_for_reshape,
36-
_ravel_multi_index,
37-
_unravel_index,
38-
)
3934

4035
# TODO: revert to `import dpctl.tensor...`
4136
# when dpnp fully migrates dpctl/tensor
4237
import dpctl_ext.tensor as dpt_ext
4338

39+
from ._tensor_impl import (
40+
_copy_usm_ndarray_for_reshape,
41+
_ravel_multi_index,
42+
_unravel_index,
43+
)
44+
4445
__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
4546

4647

dpctl_ext/tensor/_search_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
# when dpnp fully migrates dpctl/tensor
3535
import dpctl_ext.tensor as dpt_ext
3636
import dpctl_ext.tensor._tensor_impl as ti
37-
from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl
3837

3938
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
39+
from ._manipulation_functions import _broadcast_shape_impl
4040
from ._scalar_utils import (
4141
_get_dtype,
4242
_get_queue_usm_type,

dpctl_ext/tensor/libtensor/include/kernels/clip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
#pragma once
3636
#include <algorithm>
3737
#include <cmath>
38-
#include <complex>
3938
#include <cstddef>
4039
#include <cstdint>
4140
#include <type_traits>
41+
#include <vector>
4242

4343
#include <sycl/sycl.hpp>
4444

dpctl_ext/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
//===----------------------------------------------------------------------===//
3434

3535
#pragma once
36+
3637
#include <array>
37-
#include <complex>
3838
#include <cstddef>
39+
#include <type_traits>
3940
#include <vector>
4041

4142
#include <sycl/sycl.hpp>
@@ -130,7 +131,7 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q,
130131
}
131132

132133
// Constructor to populate tensor with linear sequence defined by
133-
// start and and data
134+
// start and data
134135

135136
template <typename Ty, typename wTy>
136137
class LinearSequenceAffineFunctor
@@ -191,7 +192,7 @@ class LinearSequenceAffineFunctor
191192
*
192193
* @param exec_q Sycl queue to which kernel is submitted for execution.
193194
* @param nelems Length of the sequence.
194-
* @param start_v Stating value of the sequence.
195+
* @param start_v Starting value of the sequence.
195196
* @param end_v End-value of the sequence.
196197
* @param include_endpoint Whether the end-value is included in the sequence.
197198
* @param array_data Kernel accessible USM pointer to the start of array to be

0 commit comments

Comments
 (0)