2828
2929#pragma once
3030
31- #include < algorithm>
32-
33- #include " utils/math_utils.hpp"
34- #include < sycl/sycl.hpp>
31+ #include < cstddef>
32+ #include < cstdint>
3533#include < type_traits>
3634
37- #include < stdio.h>
38-
39- #include " ext/common.hpp"
35+ #include < sycl/sycl.hpp>
4036
41- using dpctl::tensor::usm_ndarray;
37+ # include " dpctl4pybind11.hpp "
4238
43- using ext::common::Align;
44- using ext::common::CeilDiv;
39+ #include " kernels/statistics/sliding_window1d.hpp"
4540
4641namespace statistics ::sliding_window1d
4742{
43+ using dpctl::tensor::usm_ndarray;
4844
4945template <typename T, uint32_t Size>
5046class _RegistryDataStorage
@@ -424,60 +420,6 @@ PaddedSpan<T, SizeT>
424420 return PaddedSpan<T, SizeT>(data, size, offset);
425421}
426422
427- template <typename Results,
428- typename AData,
429- typename VData,
430- typename Op,
431- typename Red>
432- void process_block (Results &results,
433- uint32_t r_size,
434- AData &a_data,
435- VData &v_data,
436- uint32_t block_size,
437- Op op,
438- Red red)
439- {
440- for (uint32_t i = 0 ; i < block_size; ++i) {
441- auto v_val = v_data.broadcast (i);
442- for (uint32_t r = 0 ; r < r_size; ++r) {
443- results[r] = red (results[r], op (a_data[r], v_val));
444- }
445- a_data.advance_left ();
446- }
447- }
448-
449- template <typename SizeT>
450- SizeT get_global_linear_id (const uint32_t wpi, const sycl::nd_item<1 > &item)
451- {
452- auto sbgroup = item.get_sub_group ();
453- const auto sg_loc_id = sbgroup.get_local_linear_id ();
454-
455- const SizeT sg_base_id = wpi * (item.get_global_linear_id () - sg_loc_id);
456- const SizeT id = sg_base_id + sg_loc_id;
457-
458- return id;
459- }
460-
461- template <typename SizeT>
462- uint32_t get_results_num (const uint32_t wpi,
463- const SizeT size,
464- const SizeT global_id,
465- const sycl::nd_item<1 > &item)
466- {
467- auto sbgroup = item.get_sub_group ();
468-
469- const auto sbg_size = sbgroup.get_max_local_range ()[0 ];
470- const auto size_ = sycl::sub_sat (size, global_id);
471- return std::min (SizeT (wpi), CeilDiv (size_, sbg_size));
472- }
473-
474- template <uint32_t WorkPI,
475- typename T,
476- typename SizeT,
477- typename Op,
478- typename Red>
479- class sliding_window1d_kernel ;
480-
481423template <uint32_t WorkPI,
482424 typename T,
483425 typename SizeT,
@@ -491,76 +433,15 @@ void submit_sliding_window1d(const PaddedSpan<const T, SizeT> &a,
491433 sycl::nd_range<1 > nd_range,
492434 sycl::handler &cgh)
493435{
494- cgh.parallel_for <sliding_window1d_kernel<WorkPI, T, SizeT, Op, Red>>(
495- nd_range, [=](sycl::nd_item<1 > item) {
496- auto glid = get_global_linear_id<SizeT>(WorkPI, item);
497-
498- auto results = RegistryData<T, WorkPI>(item);
499- results.fill (0 );
500-
501- auto results_num = get_results_num (WorkPI, out.size (), glid, item);
502-
503- const auto *a_begin = a.begin ();
504- const auto *a_end = a.end ();
436+ using SlidingWindow1dKernel =
437+ dpnp::kernels::sliding_window1d::SlidingWindow1dFunctor<
438+ WorkPI, PaddedSpan<const T, SizeT>, Span<const T, SizeT>, Op, Red,
439+ Span<T, SizeT>, RegistryData, RegistryWindow>;
505440
506- auto sbgroup = item.get_sub_group ();
507-
508- const auto chunks_count =
509- CeilDiv (v.size (), sbgroup.get_max_local_range ()[0 ]);
510-
511- const auto *a_ptr = &a.padded_begin ()[glid];
512-
513- auto _a_load_cond = [a_begin, a_end](auto &&ptr) {
514- return ptr >= a_begin && ptr < a_end;
515- };
516-
517- auto a_data = RegistryWindow<const T, WorkPI + 1 >(item);
518- a_ptr = a_data.load (a_ptr, _a_load_cond, 0 );
519-
520- const auto *v_ptr = &v.begin ()[sbgroup.get_local_linear_id ()];
521- auto v_size = v.size ();
522-
523- for (uint32_t b = 0 ; b < chunks_count; ++b) {
524- auto v_data = RegistryData<const T>(item);
525- v_ptr = v_data.load (v_ptr, v_data.x () < v_size, 0 );
526-
527- uint32_t chunk_size_ =
528- std::min (v_size, SizeT (v_data.total_size ()));
529- process_block (results, results_num, a_data, v_data, chunk_size_,
530- op, red);
531-
532- if (b != chunks_count - 1 ) {
533- a_ptr = a_data.load_lane (a_data.size_y () - 1 , a_ptr,
534- _a_load_cond, 0 );
535- v_size -= v_data.total_size ();
536- }
537- }
538-
539- auto *const out_ptr = out.begin ();
540- // auto *const out_end = out.end();
541-
542- auto y_start = glid;
543- auto y_stop =
544- std::min (y_start + WorkPI * results.size_x (), out.size ());
545- uint32_t i = 0 ;
546- for (uint32_t y = y_start; y < y_stop; y += results.size_x ()) {
547- out_ptr[y] = results[i++];
548- }
549- // while the code itself seems to be valid, inside correlate
550- // kernel it results in memory corruption. Further investigation
551- // is needed. SAT-7693
552- // corruption results.store(&out_ptr[glid],
553- // [out_end](auto &&ptr) { return ptr < out_end; });
554- });
441+ cgh.parallel_for <SlidingWindow1dKernel>(
442+ nd_range, SlidingWindow1dKernel (a, v, op, red, out));
555443}
556444
557- template <uint32_t WorkPI,
558- typename T,
559- typename SizeT,
560- typename Op,
561- typename Red>
562- class sliding_window1d_small_kernel ;
563-
564445template <uint32_t WorkPI,
565446 typename T,
566447 typename SizeT,
@@ -574,56 +455,13 @@ void submit_sliding_window1d_small_kernel(const PaddedSpan<const T, SizeT> &a,
574455 sycl::nd_range<1 > nd_range,
575456 sycl::handler &cgh)
576457{
577- cgh.parallel_for <sliding_window1d_small_kernel<WorkPI, T, SizeT, Op, Red>>(
578- nd_range, [=](sycl::nd_item<1 > item) {
579- auto glid = get_global_linear_id<SizeT>(WorkPI, item);
580-
581- auto results = RegistryData<T, WorkPI>(item);
582- results.fill (0 );
583-
584- auto sbgroup = item.get_sub_group ();
585- auto sg_size = sbgroup.get_max_local_range ()[0 ];
586-
587- const uint32_t to_read = WorkPI * sg_size + v.size ();
588- const auto *a_begin = a.begin ();
589-
590- const auto *a_ptr = &a.padded_begin ()[glid];
591- const auto *a_end = std::min (a_ptr + to_read, a.end ());
592-
593- auto _a_load_cond = [a_begin, a_end](auto &&ptr) {
594- return ptr >= a_begin && ptr < a_end;
595- };
458+ using SlidingWindow1dSmallKernel =
459+ dpnp::kernels::sliding_window1d::SlidingWindow1dSmallFunctor<
460+ WorkPI, PaddedSpan<const T, SizeT>, Span<const T, SizeT>, Op, Red,
461+ Span<T, SizeT>, RegistryData, RegistryWindow>;
596462
597- auto a_data = RegistryWindow<const T, WorkPI + 1 >(item);
598- a_data.load (a_ptr, _a_load_cond, 0 );
599-
600- const auto *v_ptr = &v.begin ()[sbgroup.get_local_linear_id ()];
601- auto v_size = v.size ();
602-
603- auto v_data = RegistryData<const T>(item);
604- v_ptr = v_data.load (v_ptr, v_data.x () < v_size, 0 );
605-
606- auto results_num = get_results_num (WorkPI, out.size (), glid, item);
607-
608- process_block (results, results_num, a_data, v_data, v_size, op,
609- red);
610-
611- auto *const out_ptr = out.begin ();
612- // auto *const out_end = out.end();
613-
614- auto y_start = glid;
615- auto y_stop =
616- std::min (y_start + WorkPI * results.size_x (), out.size ());
617- uint32_t i = 0 ;
618- for (uint32_t y = y_start; y < y_stop; y += results.size_x ()) {
619- out_ptr[y] = results[i++];
620- }
621- // while the code itself seems to be valid, inside correlate
622- // kernel it results in memory corruption. Further investigation
623- // is needed. SAT-7693
624- // corruption results.store(&out_ptr[glid],
625- // [out_end](auto &&ptr) { return ptr < out_end; });
626- });
463+ cgh.parallel_for <SlidingWindow1dSmallKernel>(
464+ nd_range, SlidingWindow1dSmallKernel (a, v, op, red, out));
627465}
628466
629467void validate (const usm_ndarray &a,
0 commit comments