Skip to content

Commit 6295ca3

Browse files
committed
Add dedicated SlidingWindow1dFunctor and SlidingWindow1dSmallFunctor kernels
1 parent 20d444e commit 6295ca3

File tree

3 files changed

+297
-187
lines changed

3 files changed

+297
-187
lines changed

dpnp/backend/extensions/statistics/histogram_common.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,23 +35,19 @@
3535

3636
#include <sycl/sycl.hpp>
3737

38+
#include "dpctl4pybind11.hpp"
39+
3840
#include "ext/common.hpp"
3941
#include "kernels/statistics/histogram.hpp"
4042

41-
namespace dpctl::tensor
43+
namespace statistics::histogram
4244
{
43-
class usm_ndarray;
44-
}
45-
4645
using dpctl::tensor::usm_ndarray;
4746

4847
using ext::common::AtomicOp;
4948
using ext::common::IsNan;
5049
using ext::common::Less;
5150

52-
namespace statistics::histogram
53-
{
54-
5551
template <typename T, int Dims>
5652
struct CachedData
5753
{

dpnp/backend/extensions/statistics/sliding_window1d.hpp

Lines changed: 18 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,19 @@
2828

2929
#pragma once
3030

31-
#include <algorithm>
32-
33-
#include "utils/math_utils.hpp"
34-
#include <sycl/sycl.hpp>
31+
#include <cstddef>
32+
#include <cstdint>
3533
#include <type_traits>
3634

37-
#include <stdio.h>
38-
39-
#include "ext/common.hpp"
35+
#include <sycl/sycl.hpp>
4036

41-
using dpctl::tensor::usm_ndarray;
37+
#include "dpctl4pybind11.hpp"
4238

43-
using ext::common::Align;
44-
using ext::common::CeilDiv;
39+
#include "kernels/statistics/sliding_window1d.hpp"
4540

4641
namespace statistics::sliding_window1d
4742
{
43+
using dpctl::tensor::usm_ndarray;
4844

4945
template <typename T, uint32_t Size>
5046
class _RegistryDataStorage
@@ -464,60 +460,6 @@ PaddedSpan<T, SizeT>
464460
return PaddedSpan<T, SizeT>(data, size, offset);
465461
}
466462

467-
template <typename Results,
468-
typename AData,
469-
typename VData,
470-
typename Op,
471-
typename Red>
472-
void process_block(Results &results,
473-
uint32_t r_size,
474-
AData &a_data,
475-
VData &v_data,
476-
uint32_t block_size,
477-
Op op,
478-
Red red)
479-
{
480-
for (uint32_t i = 0; i < block_size; ++i) {
481-
auto v_val = v_data.broadcast(i);
482-
for (uint32_t r = 0; r < r_size; ++r) {
483-
results[r] = red(results[r], op(a_data[r], v_val));
484-
}
485-
a_data.advance_left();
486-
}
487-
}
488-
489-
template <typename SizeT>
490-
SizeT get_global_linear_id(const uint32_t wpi, const sycl::nd_item<1> &item)
491-
{
492-
auto sbgroup = item.get_sub_group();
493-
const auto sg_loc_id = sbgroup.get_local_linear_id();
494-
495-
const SizeT sg_base_id = wpi * (item.get_global_linear_id() - sg_loc_id);
496-
const SizeT id = sg_base_id + sg_loc_id;
497-
498-
return id;
499-
}
500-
501-
template <typename SizeT>
502-
uint32_t get_results_num(const uint32_t wpi,
503-
const SizeT size,
504-
const SizeT global_id,
505-
const sycl::nd_item<1> &item)
506-
{
507-
auto sbgroup = item.get_sub_group();
508-
509-
const auto sbg_size = sbgroup.get_max_local_range()[0];
510-
const auto size_ = sycl::sub_sat(size, global_id);
511-
return std::min(SizeT(wpi), CeilDiv(size_, sbg_size));
512-
}
513-
514-
template <uint32_t WorkPI,
515-
typename T,
516-
typename SizeT,
517-
typename Op,
518-
typename Red>
519-
class sliding_window1d_kernel;
520-
521463
template <uint32_t WorkPI,
522464
typename T,
523465
typename SizeT,
@@ -531,76 +473,15 @@ void submit_sliding_window1d(const PaddedSpan<const T, SizeT> &a,
531473
sycl::nd_range<1> nd_range,
532474
sycl::handler &cgh)
533475
{
534-
cgh.parallel_for<sliding_window1d_kernel<WorkPI, T, SizeT, Op, Red>>(
535-
nd_range, [=](sycl::nd_item<1> item) {
536-
auto glid = get_global_linear_id<SizeT>(WorkPI, item);
537-
538-
auto results = RegistryData<T, WorkPI>(item);
539-
results.fill(0);
540-
541-
auto results_num = get_results_num(WorkPI, out.size(), glid, item);
542-
543-
const auto *a_begin = a.begin();
544-
const auto *a_end = a.end();
476+
using SlidingWindow1dKernel =
477+
dpnp::kernels::sliding_window1d::SlidingWindow1dFunctor<
478+
WorkPI, PaddedSpan<const T, SizeT>, Span<const T, SizeT>, Op, Red,
479+
Span<T, SizeT>, RegistryData, RegistryWindow>;
545480

546-
auto sbgroup = item.get_sub_group();
547-
548-
const auto chunks_count =
549-
CeilDiv(v.size(), sbgroup.get_max_local_range()[0]);
550-
551-
const auto *a_ptr = &a.padded_begin()[glid];
552-
553-
auto _a_load_cond = [a_begin, a_end](auto &&ptr) {
554-
return ptr >= a_begin && ptr < a_end;
555-
};
556-
557-
auto a_data = RegistryWindow<const T, WorkPI + 1>(item);
558-
a_ptr = a_data.load(a_ptr, _a_load_cond, 0);
559-
560-
const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
561-
auto v_size = v.size();
562-
563-
for (uint32_t b = 0; b < chunks_count; ++b) {
564-
auto v_data = RegistryData<const T>(item);
565-
v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
566-
567-
uint32_t chunk_size_ =
568-
std::min(v_size, SizeT(v_data.total_size()));
569-
process_block(results, results_num, a_data, v_data, chunk_size_,
570-
op, red);
571-
572-
if (b != chunks_count - 1) {
573-
a_ptr = a_data.load_lane(a_data.size_y() - 1, a_ptr,
574-
_a_load_cond, 0);
575-
v_size -= v_data.total_size();
576-
}
577-
}
578-
579-
auto *const out_ptr = out.begin();
580-
// auto *const out_end = out.end();
581-
582-
auto y_start = glid;
583-
auto y_stop =
584-
std::min(y_start + WorkPI * results.size_x(), out.size());
585-
uint32_t i = 0;
586-
for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
587-
out_ptr[y] = results[i++];
588-
}
589-
// while the code itself seems to be valid, inside correlate
590-
// kernel it results in memory corruption. Further investigation
591-
// is needed. SAT-7693
592-
// corruption results.store(&out_ptr[glid],
593-
// [out_end](auto &&ptr) { return ptr < out_end; });
594-
});
481+
cgh.parallel_for<SlidingWindow1dKernel>(
482+
nd_range, SlidingWindow1dKernel(a, v, op, red, out));
595483
}
596484

597-
template <uint32_t WorkPI,
598-
typename T,
599-
typename SizeT,
600-
typename Op,
601-
typename Red>
602-
class sliding_window1d_small_kernel;
603-
604485
template <uint32_t WorkPI,
605486
typename T,
606487
typename SizeT,
@@ -614,56 +495,13 @@ void submit_sliding_window1d_small_kernel(const PaddedSpan<const T, SizeT> &a,
614495
sycl::nd_range<1> nd_range,
615496
sycl::handler &cgh)
616497
{
617-
cgh.parallel_for<sliding_window1d_small_kernel<WorkPI, T, SizeT, Op, Red>>(
618-
nd_range, [=](sycl::nd_item<1> item) {
619-
auto glid = get_global_linear_id<SizeT>(WorkPI, item);
620-
621-
auto results = RegistryData<T, WorkPI>(item);
622-
results.fill(0);
623-
624-
auto sbgroup = item.get_sub_group();
625-
auto sg_size = sbgroup.get_max_local_range()[0];
626-
627-
const uint32_t to_read = WorkPI * sg_size + v.size();
628-
const auto *a_begin = a.begin();
629-
630-
const auto *a_ptr = &a.padded_begin()[glid];
631-
const auto *a_end = std::min(a_ptr + to_read, a.end());
632-
633-
auto _a_load_cond = [a_begin, a_end](auto &&ptr) {
634-
return ptr >= a_begin && ptr < a_end;
635-
};
498+
using SlidingWindow1dSmallKernel =
499+
dpnp::kernels::sliding_window1d::SlidingWindow1dSmallFunctor<
500+
WorkPI, PaddedSpan<const T, SizeT>, Span<const T, SizeT>, Op, Red,
501+
Span<T, SizeT>, RegistryData, RegistryWindow>;
636502

637-
auto a_data = RegistryWindow<const T, WorkPI + 1>(item);
638-
a_data.load(a_ptr, _a_load_cond, 0);
639-
640-
const auto *v_ptr = &v.begin()[sbgroup.get_local_linear_id()];
641-
auto v_size = v.size();
642-
643-
auto v_data = RegistryData<const T>(item);
644-
v_ptr = v_data.load(v_ptr, v_data.x() < v_size, 0);
645-
646-
auto results_num = get_results_num(WorkPI, out.size(), glid, item);
647-
648-
process_block(results, results_num, a_data, v_data, v_size, op,
649-
red);
650-
651-
auto *const out_ptr = out.begin();
652-
// auto *const out_end = out.end();
653-
654-
auto y_start = glid;
655-
auto y_stop =
656-
std::min(y_start + WorkPI * results.size_x(), out.size());
657-
uint32_t i = 0;
658-
for (uint32_t y = y_start; y < y_stop; y += results.size_x()) {
659-
out_ptr[y] = results[i++];
660-
}
661-
// while the code itself seems to be valid, inside correlate
662-
// kernel it results in memory corruption. Further investigation
663-
// is needed. SAT-7693
664-
// corruption results.store(&out_ptr[glid],
665-
// [out_end](auto &&ptr) { return ptr < out_end; });
666-
});
503+
cgh.parallel_for<SlidingWindow1dSmallKernel>(
504+
nd_range, SlidingWindow1dSmallKernel(a, v, op, red, out));
667505
}
668506

669507
void validate(const usm_ndarray &a,

0 commit comments

Comments
 (0)