Skip to content

Commit 3e39872

Browse files
Merge move_tensor_sorting_impl into move_tensor_reductions_impl_ext
2 parents e0bba1d + 3c26859 commit 3e39872

File tree

17 files changed

+43
-65
lines changed

17 files changed

+43
-65
lines changed

dpctl_ext/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,7 @@
3636

3737
#include <cstddef>
3838

39-
namespace dpctl::tensor::kernels
40-
{
41-
42-
namespace search_sorted_detail
39+
namespace dpctl::tensor::kernels::search_sorted_detail
4340
{
4441

4542
template <typename T>
@@ -119,6 +116,4 @@ std::size_t upper_bound_indexed_impl(const Acc acc,
119116
acc_indexer);
120117
}
121118

122-
} // namespace search_sorted_detail
123-
124-
} // namespace dpctl::tensor::kernels
119+
} // namespace dpctl::tensor::kernels::search_sorted_detail

dpctl_ext/tensor/libtensor/include/kernels/sorting/searchsorted.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
#pragma once
3636

3737
#include <cstddef>
38-
#include <cstdint>
3938
#include <vector>
4039

4140
#include <sycl/sycl.hpp>

dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_utils.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
#include <array>
3838
#include <cstddef>
3939
#include <cstdint>
40-
#include <span>
4140
#include <vector>
4241

4342
#include <sycl/sycl.hpp>
@@ -78,10 +77,8 @@ sycl::event iota_impl(sycl::queue &exec_q,
7877
}
7978

8079
if (offset + n_wi * max_sgSize < nelems) {
81-
static constexpr auto group_ls_props = syclexp::properties{
82-
syclexp::data_placement_striped
83-
// , syclexp::full_group
84-
};
80+
static constexpr auto group_ls_props =
81+
syclexp::properties{syclexp::data_placement_striped};
8582

8683
auto out_multi_ptr = sycl::address_space_cast<
8784
sycl::access::address_space::global_space,

dpctl_ext/tensor/libtensor/include/kernels/sorting/topk.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@
3434

3535
#pragma once
3636

37+
#include <algorithm>
3738
#include <cstddef>
3839
#include <cstdint>
40+
#include <functional>
3941
#include <stdexcept>
4042
#include <vector>
4143

42-
#include <sycl/ext/oneapi/sub_group_mask.hpp>
4344
#include <sycl/sycl.hpp>
4445

4546
#include "kernels/sorting/merge_sort.hpp"

dpctl_ext/tensor/libtensor/source/accumulators/cumulative_logsumexp.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,6 @@ void init_cumulative_logsumexp(py::module_ m)
337337

338338
auto cumlogsumexp_dtype_supported = [&](const py::dtype &input_dtype,
339339
const py::dtype &output_dtype) {
340-
using dpctl::tensor::py_internal::py_accumulate_dtype_supported;
341340
return py_accumulate_dtype_supported(
342341
input_dtype, output_dtype, cumlogsumexp_strided_dispatch_table);
343342
};

dpctl_ext/tensor/libtensor/source/accumulators/cumulative_prod.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ void init_cumulative_prod(py::module_ m)
346346

347347
auto cumprod_dtype_supported = [&](const py::dtype &input_dtype,
348348
const py::dtype &output_dtype) {
349-
using dpctl::tensor::py_internal::py_accumulate_dtype_supported;
350349
return py_accumulate_dtype_supported(input_dtype, output_dtype,
351350
cumprod_strided_dispatch_table);
352351
};

dpctl_ext/tensor/libtensor/source/accumulators/cumulative_sum.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,6 @@ void init_cumulative_sum(py::module_ m)
344344

345345
auto cumsum_dtype_supported = [&](const py::dtype &input_dtype,
346346
const py::dtype &output_dtype) {
347-
using dpctl::tensor::py_internal::py_accumulate_dtype_supported;
348347
return py_accumulate_dtype_supported(input_dtype, output_dtype,
349348
cumsum_strided_dispatch_table);
350349
};

dpctl_ext/tensor/libtensor/source/sorting/isin.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
#include <cstddef>
3737
#include <stdexcept>
38+
#include <tuple>
3839
#include <utility>
3940
#include <vector>
4041

@@ -46,6 +47,7 @@
4647

4748
#include "kernels/sorting/isin.hpp"
4849
#include "utils/memory_overlap.hpp"
50+
#include "utils/offset_utils.hpp"
4951
#include "utils/output_validation.hpp"
5052
#include "utils/sycl_alloc_utils.hpp"
5153
#include "utils/type_dispatch.hpp"
@@ -254,7 +256,7 @@ std::pair<sycl::event, sycl::event>
254256
simplified_dst_strides.push_back(0);
255257
}
256258
else {
257-
dpctl::tensor::py_internal::simplify_iteration_space(
259+
simplify_iteration_space(
258260
// modified by reference
259261
simplified_nd,
260262
// read-only inputs
@@ -313,9 +315,8 @@ std::pair<sycl::event, sycl::event>
313315

314316
void init_isin_functions(py::module_ m)
315317
{
316-
dpctl::tensor::py_internal::detail::init_isin_dispatch_vector();
318+
detail::init_isin_dispatch_vector();
317319

318-
using dpctl::tensor::py_internal::py_isin;
319320
m.def("_isin", &py_isin, py::arg("needles"), py::arg("hay"), py::arg("dst"),
320321
py::arg("sycl_queue"), py::arg("invert"),
321322
py::arg("depends") = py::list());

dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636
#include <cstdint>
3737
#include <type_traits>
38+
#include <utility>
39+
#include <vector>
3840

3941
#include <sycl/sycl.hpp>
4042

@@ -121,18 +123,16 @@ void init_merge_argsort_dispatch_tables(void)
121123

122124
void init_merge_argsort_functions(py::module_ m)
123125
{
124-
dpctl::tensor::py_internal::init_merge_argsort_dispatch_tables();
126+
init_merge_argsort_dispatch_tables();
125127

126128
auto py_argsort_ascending = [](const dpctl::tensor::usm_ndarray &src,
127129
const int trailing_dims_to_sort,
128130
const dpctl::tensor::usm_ndarray &dst,
129131
sycl::queue &exec_q,
130132
const std::vector<sycl::event> &depends)
131133
-> std::pair<sycl::event, sycl::event> {
132-
return dpctl::tensor::py_internal::py_argsort(
133-
src, trailing_dims_to_sort, dst, exec_q, depends,
134-
dpctl::tensor::py_internal::
135-
ascending_argsort_contig_dispatch_table);
134+
return py_argsort(src, trailing_dims_to_sort, dst, exec_q, depends,
135+
ascending_argsort_contig_dispatch_table);
136136
};
137137
m.def("_argsort_ascending", py_argsort_ascending, py::arg("src"),
138138
py::arg("trailing_dims_to_sort"), py::arg("dst"),
@@ -144,10 +144,8 @@ void init_merge_argsort_functions(py::module_ m)
144144
sycl::queue &exec_q,
145145
const std::vector<sycl::event> &depends)
146146
-> std::pair<sycl::event, sycl::event> {
147-
return dpctl::tensor::py_internal::py_argsort(
148-
src, trailing_dims_to_sort, dst, exec_q, depends,
149-
dpctl::tensor::py_internal::
150-
descending_argsort_contig_dispatch_table);
147+
return py_argsort(src, trailing_dims_to_sort, dst, exec_q, depends,
148+
descending_argsort_contig_dispatch_table);
151149
};
152150
m.def("_argsort_descending", py_argsort_descending, py::arg("src"),
153151
py::arg("trailing_dims_to_sort"), py::arg("dst"),

dpctl_ext/tensor/libtensor/source/sorting/merge_sort.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
/// extension.
3434
//===----------------------------------------------------------------------===//
3535

36+
#include <utility>
37+
#include <vector>
38+
3639
#include <sycl/sycl.hpp>
3740

3841
#include "dpnp4pybind11.hpp"
@@ -102,17 +105,16 @@ void init_merge_sort_dispatch_vectors(void)
102105

103106
void init_merge_sort_functions(py::module_ m)
104107
{
105-
dpctl::tensor::py_internal::init_merge_sort_dispatch_vectors();
108+
init_merge_sort_dispatch_vectors();
106109

107110
auto py_sort_ascending = [](const dpctl::tensor::usm_ndarray &src,
108111
const int trailing_dims_to_sort,
109112
const dpctl::tensor::usm_ndarray &dst,
110113
sycl::queue &exec_q,
111114
const std::vector<sycl::event> &depends)
112115
-> std::pair<sycl::event, sycl::event> {
113-
return dpctl::tensor::py_internal::py_sort(
114-
src, trailing_dims_to_sort, dst, exec_q, depends,
115-
dpctl::tensor::py_internal::ascending_sort_contig_dispatch_vector);
116+
return py_sort(src, trailing_dims_to_sort, dst, exec_q, depends,
117+
ascending_sort_contig_dispatch_vector);
116118
};
117119
m.def("_sort_ascending", py_sort_ascending, py::arg("src"),
118120
py::arg("trailing_dims_to_sort"), py::arg("dst"),
@@ -124,9 +126,8 @@ void init_merge_sort_functions(py::module_ m)
124126
sycl::queue &exec_q,
125127
const std::vector<sycl::event> &depends)
126128
-> std::pair<sycl::event, sycl::event> {
127-
return dpctl::tensor::py_internal::py_sort(
128-
src, trailing_dims_to_sort, dst, exec_q, depends,
129-
dpctl::tensor::py_internal::descending_sort_contig_dispatch_vector);
129+
return py_sort(src, trailing_dims_to_sort, dst, exec_q, depends,
130+
descending_sort_contig_dispatch_vector);
130131
};
131132
m.def("_sort_descending", py_sort_descending, py::arg("src"),
132133
py::arg("trailing_dims_to_sort"), py::arg("dst"),

0 commit comments

Comments
 (0)