Skip to content

Commit 5e6aa6e

Browse files
committed
Add dedicated SlidingWindow1dFunctor and SlidingWindow1dSmallFunctor kernels
1 parent 0a7d38c commit 5e6aa6e

File tree

3 files changed

+297
-187
lines changed

3 files changed

+297
-187
lines changed

dpnp/backend/extensions/statistics/histogram_common.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,19 @@
3535

3636
#include <sycl/sycl.hpp>
3737

38+
#include "dpctl4pybind11.hpp"
39+
3840
#include "ext/common.hpp"
3941
#include "kernels/statistics/histogram.hpp"
4042

41-
namespace dpctl::tensor
43+
namespace statistics::histogram
4244
{
43-
class usm_ndarray;
44-
}
45-
4645
using dpctl::tensor::usm_ndarray;
4746

4847
using ext::common::AtomicOp;
4948
using ext::common::IsNan;
5049
using ext::common::Less;
5150

52-
namespace statistics::histogram
53-
{
54-
5551
template <typename T, int Dims>
5652
struct CachedData
5753
{

dpnp/backend/extensions/statistics/sliding_window1d.hpp

Lines changed: 18 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,19 @@
2828

2929
#pragma once
3030

31-
#include <algorithm>
32-
33-
#include "utils/math_utils.hpp"
34-
#include <sycl/sycl.hpp>
31+
#include <cstddef>
32+
#include <cstdint>
3533
#include <type_traits>
3634

37-
#include <stdio.h>
38-
39-
#include "ext/common.hpp"
35+
#include <sycl/sycl.hpp>
4036

41-
using dpctl::tensor::usm_ndarray;
37+
#include "dpctl4pybind11.hpp"
4238

43-
using ext::common::Align;
44-
using ext::common::CeilDiv;
39+
#include "kernels/statistics/sliding_window1d.hpp"
4540

4641
namespace statistics::sliding_window1d
4742
{
43+
using dpctl::tensor::usm_ndarray;
4844

4945
template <typename T, uint32_t Size>
5046
class _RegistryDataStorage
@@ -424,60 +420,6 @@ PaddedSpan<T, SizeT>
424420
return PaddedSpan<T, SizeT>(data, size, offset);
425421
}
426422

427-
template <typename Results,
428-
typename AData,
429-
typename VData,
430-
typename Op,
431-
typename Red>
432-
void process_block(Results &results,
433-
uint32_t r_size,
434-
AData &a_data,
435-
VData &v_data,
436-
uint32_t block_size,
437-
Op op,
438-
Red red)
439-
{
440-
for (uint32_t i = 0; i < block_size; ++i) {
441-
auto v_val = v_data.broadcast(i);
442-
for (uint32_t r = 0; r < r_size; ++r) {
443-
results[r] = red(results[r], op(a_data[r], v_val));
444-
}
445-
a_data.advance_left();
446-
}
447-
}
448-
449-
template <typename SizeT>
450-
SizeT get_global_linear_id(const uint32_t wpi, const sycl::nd_item<1> &item)
451-
{
452-
auto sbgroup = item.get_sub_group();
453-
const auto sg_loc_id = sbgroup.get_local_linear_id();
454-
455-
const SizeT sg_base_id = wpi * (item.get_global_linear_id() - sg_loc_id);
456-
const SizeT id = sg_base_id + sg_loc_id;
457-
458-
return id;
459-
}
460-
461-
template <typename SizeT>
462-
uint32_t get_results_num(const uint32_t wpi,
463-
const SizeT size,
464-
const SizeT global_id,
465-
const sycl::nd_item<1> &item)
466-
{
467-
auto sbgroup = item.get_sub_group();
468-
469-
const auto sbg_size = sbgroup.get_max_local_range()[0];
470-
const auto size_ = sycl::sub_sat(size, global_id);
471-
return std::min(SizeT(wpi), CeilDiv(size_, sbg_size));
472-
}
473-
474-
template <uint32_t WorkPI,
475-
typename T,
476-
typename SizeT,
477-
typename Op,
478-
typename Red>
479-
class sliding_window1d_kernel;
480-
481423
template <uint32_t WorkPI,
482424
typename T,
483425
typename SizeT,
@@ -491,76 +433,15 @@ void submit_sliding_window1d(const PaddedSpan<const T, SizeT> &a,
491433
sycl::nd_range<1> nd_range,
492434
sycl::handler &cgh)
493435
{
494-
cgh.parallel_for<sliding_window1d_kernel<WorkPI, T, SizeT, Op, Red>>(
495-
nd_range, [=](sycl::nd_item<1> item) {
496-
auto glid = get_global_linear_id<SizeT>(WorkPI, item);
497-
498-
auto results = RegistryData<T, WorkPI>(item);
499-
results.fill(0);
500-
501-
auto results_num = get_results_num(WorkPI, out.size(), glid, item);
502-
503-
const auto *a_begin = a.begin();
504-
const auto *a_end = a.end();
436+
using SlidingWindow1dKernel =
437+
dpnp::kernels::sliding_window1d::SlidingWindow1dFunctor<
438+
WorkPI, PaddedSpan<const T, SizeT>, Span<const T, SizeT>, Op, Red,
439+
Span<T, SizeT>, RegistryData, RegistryWindow>;
505440

506-
auto sbgroup = item.get_sub_group();
507-
508-
const auto chunks_count =
509-
CeilDiv(v.size(), sbgroup.get_max_local_range()[0]);
510-
511-
const auto *a_ptr = &a.padded_begin()[glid];
512-
513-
auto _a_load_cond = [a_begin, a_end](auto &&ptr) {
514-
return ptr >= a_begin && ptr < a_end;
515-
};
516-
517-
auto a_data = RegistryWindow<const T, WorkPI + 1>(item);
518-
a_ptr = a_data.load(a_ptr, _a_load_cond, 0);
519-
520-
const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
521-
auto v_size = v.size();
522-
523-
for (uint32_t b = 0; b < chunks_count; ++b) {
524-
auto v_data = RegistryData<const T>(item);
525-
v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
526-
527-
uint32_t chunk_size_ =
528-
std::min(v_size, SizeT(v_data.total_size()));
529-
process_block(results, results_num, a_data, v_data, chunk_size_,
530-
op, red);
531-
532-
if (b != chunks_count - 1) {
533-
a_ptr = a_data.load_lane(a_data.size_y() - 1, a_ptr,
534-
_a_load_cond, 0);
535-
v_size -= v_data.total_size();
536-
}
537-
}
538-
539-
auto *const out_ptr = out.begin();
540-
// auto *const out_end = out.end();
541-
542-
auto y_start = glid;
543-
auto y_stop =
544-
std::min(y_start + WorkPI * results.size_x(), out.size());
545-
uint32_t i = 0;
546-
for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
547-
out_ptr[y] = results[i++];
548-
}
549-
// while the code itself seems to be valid, inside correlate
550-
// kernel it results in memory corruption. Further investigation
551-
// is needed. SAT-7693
552-
// corruption results.store(&out_ptr[glid],
553-
// [out_end](auto &&ptr) { return ptr < out_end; });
554-
});
441+
cgh.parallel_for<SlidingWindow1dKernel>(
442+
nd_range, SlidingWindow1dKernel(a, v, op, red, out));
555443
}
556444

557-
template <uint32_t WorkPI,
558-
typename T,
559-
typename SizeT,
560-
typename Op,
561-
typename Red>
562-
class sliding_window1d_small_kernel;
563-
564445
template <uint32_t WorkPI,
565446
typename T,
566447
typename SizeT,
@@ -574,56 +455,13 @@ void submit_sliding_window1d_small_kernel(const PaddedSpan<const T, SizeT> &a,
574455
sycl::nd_range<1> nd_range,
575456
sycl::handler &cgh)
576457
{
577-
cgh.parallel_for<sliding_window1d_small_kernel<WorkPI, T, SizeT, Op, Red>>(
578-
nd_range, [=](sycl::nd_item<1> item) {
579-
auto glid = get_global_linear_id<SizeT>(WorkPI, item);
580-
581-
auto results = RegistryData<T, WorkPI>(item);
582-
results.fill(0);
583-
584-
auto sbgroup = item.get_sub_group();
585-
auto sg_size = sbgroup.get_max_local_range()[0];
586-
587-
const uint32_t to_read = WorkPI * sg_size + v.size();
588-
const auto *a_begin = a.begin();
589-
590-
const auto *a_ptr = &a.padded_begin()[glid];
591-
const auto *a_end = std::min(a_ptr + to_read, a.end());
592-
593-
auto _a_load_cond = [a_begin, a_end](auto &&ptr) {
594-
return ptr >= a_begin && ptr < a_end;
595-
};
458+
using SlidingWindow1dSmallKernel =
459+
dpnp::kernels::sliding_window1d::SlidingWindow1dSmallFunctor<
460+
WorkPI, PaddedSpan<const T, SizeT>, Span<const T, SizeT>, Op, Red,
461+
Span<T, SizeT>, RegistryData, RegistryWindow>;
596462

597-
auto a_data = RegistryWindow<const T, WorkPI + 1>(item);
598-
a_data.load(a_ptr, _a_load_cond, 0);
599-
600-
const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
601-
auto v_size = v.size();
602-
603-
auto v_data = RegistryData<const T>(item);
604-
v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
605-
606-
auto results_num = get_results_num(WorkPI, out.size(), glid, item);
607-
608-
process_block(results, results_num, a_data, v_data, v_size, op,
609-
red);
610-
611-
auto *const out_ptr = out.begin();
612-
// auto *const out_end = out.end();
613-
614-
auto y_start = glid;
615-
auto y_stop =
616-
std::min(y_start + WorkPI * results.size_x(), out.size());
617-
uint32_t i = 0;
618-
for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
619-
out_ptr[y] = results[i++];
620-
}
621-
// while the code itself seems to be valid, inside correlate
622-
// kernel it results in memory corruption. Further investigation
623-
// is needed. SAT-7693
624-
// corruption results.store(&out_ptr[glid],
625-
// [out_end](auto &&ptr) { return ptr < out_end; });
626-
});
463+
cgh.parallel_for<SlidingWindow1dSmallKernel>(
464+
nd_range, SlidingWindow1dSmallKernel(a, v, op, red, out));
627465
}
628466

629467
void validate(const usm_ndarray &a,

0 commit comments

Comments
 (0)