@@ -50,7 +50,7 @@ namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
5050typedef sycl::event (*kaiser_fn_ptr_t )(sycl::queue &,
5151 char *,
5252 const std::size_t ,
53- const float ,
53+ const py::object & ,
5454 const std::vector<sycl::event> &);
5555
5656static kaiser_fn_ptr_t kaiser_dispatch_vector[dpctl_td_ns::num_types];
@@ -61,10 +61,10 @@ class KaiserFunctor
6161private:
6262 T *data = nullptr ;
6363 const std::size_t N;
64- const float beta;
64+ const T beta;
6565
6666public:
67- KaiserFunctor (T *data, const std::size_t N, const float beta)
67+ KaiserFunctor (T *data, const std::size_t N, const T beta)
6868 : data(data), N(N), beta(beta)
6969 {
7070 }
@@ -89,12 +89,13 @@ template <typename T, template <typename> class Functor>
8989sycl::event kaiser_impl (sycl::queue &q,
9090 char *result,
9191 const std::size_t nelems,
92- const float beta ,
92+ const py::object &py_beta ,
9393 const std::vector<sycl::event> &depends)
9494{
9595 dpctl::tensor::type_utils::validate_type_for_device<T>(q);
9696
9797 T *res = reinterpret_cast <T *>(result);
98+ const T beta = py::cast<const T>(py_beta);
9899
99100 sycl::event kaiser_ev = q.submit ([&](sycl::handler &cgh) {
100101 cgh.depends_on (depends);
@@ -123,7 +124,7 @@ struct KaiserFactory
123124
124125std::pair<sycl::event, sycl::event>
125126 py_kaiser (sycl::queue &exec_q,
126- const float beta ,
127+ const py::object &py_beta ,
127128 const dpctl::tensor::usm_ndarray &result,
128129 const std::vector<sycl::event> &depends)
129130{
@@ -160,7 +161,7 @@ std::pair<sycl::event, sycl::event>
160161
161162 char *result_typeless_ptr = result.get_data ();
162163 sycl::event kaiser_ev =
163- fn (exec_q, result_typeless_ptr, nelems, beta , depends);
164+ fn (exec_q, result_typeless_ptr, nelems, py_beta , depends);
164165 sycl::event args_ev =
165166 dpctl::utils::keep_args_alive (exec_q, {result}, {kaiser_ev});
166167
0 commit comments