Skip to content

Commit ada8a5c

Browse files
Merge include-dpctl-tensor into move_tensor_accumulation_impl
2 parents d8c3680 + 585f2e5 commit ada8a5c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+223
-198
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
# *****************************************************************************
2828

2929

30-
from dpctl.tensor._search_functions import where
31-
32-
from dpctl_ext.tensor._copy_utils import (
30+
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
31+
from ._clip import clip
32+
from ._copy_utils import (
3333
asnumpy,
3434
astype,
3535
copy,
3636
from_numpy,
3737
to_numpy,
3838
)
39-
from dpctl_ext.tensor._ctors import (
39+
from ._ctors import (
4040
arange,
4141
asarray,
4242
empty,
@@ -53,7 +53,7 @@
5353
zeros,
5454
zeros_like,
5555
)
56-
from dpctl_ext.tensor._indexing_functions import (
56+
from ._indexing_functions import (
5757
extract,
5858
nonzero,
5959
place,
@@ -62,7 +62,7 @@
6262
take,
6363
take_along_axis,
6464
)
65-
from dpctl_ext.tensor._manipulation_functions import (
65+
from ._manipulation_functions import (
6666
broadcast_arrays,
6767
broadcast_to,
6868
concat,
@@ -78,10 +78,8 @@
7878
tile,
7979
unstack,
8080
)
81-
from dpctl_ext.tensor._reshape import reshape
82-
83-
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
84-
from ._clip import clip
81+
from ._reshape import reshape
82+
from ._search_functions import where
8583
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
8684

8785
__all__ = [

dpctl_ext/tensor/_clip.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@
3535
# when dpnp fully migrates dpctl/tensor
3636
import dpctl_ext.tensor as dpt_ext
3737
import dpctl_ext.tensor._tensor_impl as ti
38-
from dpctl_ext.tensor._copy_utils import (
38+
39+
from ._copy_utils import (
3940
_empty_like_orderK,
4041
_empty_like_pair_orderK,
4142
_empty_like_triple_orderK,
4243
)
43-
from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl
44-
from dpctl_ext.tensor._type_utils import _can_cast
45-
44+
from ._manipulation_functions import _broadcast_shape_impl
4645
from ._scalar_utils import (
4746
_get_dtype,
4847
_get_queue_usm_type,
4948
_get_shape,
5049
_validate_dtype,
5150
)
5251
from ._type_utils import (
52+
_can_cast,
5353
_resolve_one_strong_one_weak_types,
5454
_resolve_one_strong_two_weak_types,
5555
)

dpctl_ext/tensor/_copy_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
# when dpnp fully migrates dpctl/tensor
4343
import dpctl_ext.tensor as dpt_ext
4444
import dpctl_ext.tensor._tensor_impl as ti
45-
from dpctl_ext.tensor._type_utils import _dtype_supported_by_device_impl
4645

4746
from ._numpy_helper import normalize_axis_index
47+
from ._type_utils import _dtype_supported_by_device_impl
4848

4949
__doc__ = (
5050
"Implementation module for copy- and cast- operations on "
@@ -299,7 +299,7 @@ def _prepare_indices_arrays(inds, q, usm_type):
299299
inds = tuple(
300300
map(
301301
lambda ind: (
302-
ind if ind.dtype == ind_dt else dpt.astype(ind, ind_dt)
302+
ind if ind.dtype == ind_dt else dpt_ext.astype(ind, ind_dt)
303303
),
304304
inds,
305305
)

dpctl_ext/tensor/_ctors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
# when dpnp fully migrates dpctl/tensor
4343
import dpctl_ext.tensor as dpt_ext
4444
import dpctl_ext.tensor._tensor_impl as ti
45-
from dpctl_ext.tensor._copy_utils import (
45+
46+
from ._copy_utils import (
4647
_empty_like_orderK,
4748
_from_numpy_empty_like_orderK,
4849
)
@@ -1440,7 +1441,7 @@ def linspace(
14401441
)
14411442
_manager.add_event_pair(hev, la_ev)
14421443

1443-
return res if int_dt is None else dpt.astype(res, int_dt)
1444+
return res if int_dt is None else dpt_ext.astype(res, int_dt)
14441445

14451446

14461447
def meshgrid(*arrays, indexing="xy"):

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def place(arr, mask, vals):
190190
if vals.dtype == arr.dtype:
191191
rhs = vals
192192
else:
193-
rhs = dpt.astype(vals, arr.dtype)
193+
rhs = dpt_ext.astype(vals, arr.dtype)
194194
hev, pl_ev = ti._place(
195195
dst=arr,
196196
cumsum=cumsum,

dpctl_ext/tensor/_reshape.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,17 @@
3131
import dpctl.tensor as dpt
3232
import dpctl.utils
3333
import numpy as np
34-
from dpctl.tensor._tensor_impl import (
35-
_copy_usm_ndarray_for_reshape,
36-
_ravel_multi_index,
37-
_unravel_index,
38-
)
3934

4035
# TODO: revert to `import dpctl.tensor...`
4136
# when dpnp fully migrates dpctl/tensor
4237
import dpctl_ext.tensor as dpt_ext
4338

39+
from ._tensor_impl import (
40+
_copy_usm_ndarray_for_reshape,
41+
_ravel_multi_index,
42+
_unravel_index,
43+
)
44+
4445
__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
4546

4647

dpctl_ext/tensor/_search_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
# when dpnp fully migrates dpctl/tensor
3535
import dpctl_ext.tensor as dpt_ext
3636
import dpctl_ext.tensor._tensor_impl as ti
37-
from dpctl_ext.tensor._manipulation_functions import _broadcast_shape_impl
3837

3938
from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
39+
from ._manipulation_functions import _broadcast_shape_impl
4040
from ._scalar_utils import (
4141
_get_dtype,
4242
_get_queue_usm_type,

dpctl_ext/tensor/libtensor/include/kernels/clip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
#pragma once
3636
#include <algorithm>
3737
#include <cmath>
38-
#include <complex>
3938
#include <cstddef>
4039
#include <cstdint>
4140
#include <type_traits>
41+
#include <vector>
4242

4343
#include <sycl/sycl.hpp>
4444

dpctl_ext/tensor/libtensor/include/kernels/constructors.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
//===----------------------------------------------------------------------===//
3434

3535
#pragma once
36+
3637
#include <array>
37-
#include <complex>
3838
#include <cstddef>
39+
#include <type_traits>
3940
#include <vector>
4041

4142
#include <sycl/sycl.hpp>
@@ -130,7 +131,7 @@ sycl::event lin_space_step_impl(sycl::queue &exec_q,
130131
}
131132

132133
// Constructor to populate tensor with linear sequence defined by
133-
// start and and data
134+
// start and data
134135

135136
template <typename Ty, typename wTy>
136137
class LinearSequenceAffineFunctor
@@ -191,7 +192,7 @@ class LinearSequenceAffineFunctor
191192
*
192193
* @param exec_q Sycl queue to which kernel is submitted for execution.
193194
* @param nelems Length of the sequence.
194-
* @param start_v Stating value of the sequence.
195+
* @param start_v Starting value of the sequence.
195196
* @param end_v End-value of the sequence.
196197
* @param include_endpoint Whether the end-value is included in the sequence.
197198
* @param array_data Kernel accessible USM pointer to the start of array to be

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@
3333
#pragma once
3434

3535
#include <algorithm>
36-
#include <cmath>
37-
#include <complex>
3836
#include <cstddef>
3937
#include <cstdint>
40-
#include <limits>
4138
#include <type_traits>
4239
#include <vector>
4340

0 commit comments

Comments
 (0)