@@ -67,11 +67,11 @@ sycl::event window_impl(sycl::queue &q,
6767 return window_ev;
6868}
6969
70- inline std::pair<sycl::event, sycl::event>
71- py_window (sycl::queue &exec_q,
70+ template <typename funcPtrT>
71+ std::tuple<size_t , char *, funcPtrT>
72+ window_fn (sycl::queue &exec_q,
7273 const dpctl::tensor::usm_ndarray &result,
73- const std::vector<sycl::event> &depends,
74- const window_fn_ptr_t *window_dispatch_vector)
74+ const funcPtrT *window_dispatch_vector)
7575{
7676 dpctl::tensor::validation::CheckWritable::throw_if_not_writable (result);
7777
@@ -92,19 +92,35 @@ inline std::pair<sycl::event, sycl::event>
9292
9393 size_t nelems = result.get_size ();
9494 if (nelems == 0 ) {
95- return std::make_pair (sycl::event{}, sycl::event{} );
95+ return std::make_tuple (nelems, nullptr , nullptr );
9696 }
9797
9898 int result_typenum = result.get_typenum ();
9999 auto array_types = dpctl_td_ns::usm_ndarray_types ();
100100 int result_type_id = array_types.typenum_to_lookup_id (result_typenum);
101- auto fn = window_dispatch_vector[result_type_id];
101+ funcPtrT fn = window_dispatch_vector[result_type_id];
102102
103103 if (fn == nullptr ) {
104104 throw std::runtime_error (" Type of given array is not supported" );
105105 }
106106
107107 char *result_typeless_ptr = result.get_data ();
108+ return std::make_tuple (nelems, result_typeless_ptr, fn);
109+ }
110+
111+ inline std::pair<sycl::event, sycl::event>
112+ py_window (sycl::queue &exec_q,
113+ const dpctl::tensor::usm_ndarray &result,
114+ const std::vector<sycl::event> &depends,
115+ const window_fn_ptr_t *window_dispatch_vector)
116+ {
117+ auto [nelems, result_typeless_ptr, fn] =
118+ window_fn<window_fn_ptr_t >(exec_q, result, window_dispatch_vector);
119+
120+ if (nelems == 0 ) {
121+ return std::make_pair (sycl::event{}, sycl::event{});
122+ }
123+
108124 sycl::event window_ev = fn (exec_q, result_typeless_ptr, nelems, depends);
109125 sycl::event args_ev =
110126 dpctl::utils::keep_args_alive (exec_q, {result}, {window_ev});
0 commit comments