55#include < thrust/copy.h>
66#include < thrust/device_vector.h>
77#include < thrust/execution_policy.h>
8+ #include < thrust/for_each.h>
89#include < thrust/functional.h>
910#include < thrust/iterator/transform_output_iterator.h>
1011#include < thrust/pair.h>
@@ -27,14 +28,6 @@ struct recur_binary_op {
2728 }
2829};
2930
30- template <typename T>
31- struct input_unary_op {
32- __host__ __device__ cuda::std::pair<T, T> operator ()(
33- const T &decay, const T &impulse) const {
34- return cuda::std::make_pair (decay, impulse);
35- }
36- };
37-
3831template <typename T>
3932struct output_unary_op {
4033 __host__ __device__ T
@@ -43,54 +36,38 @@ struct output_unary_op {
4336 }
4437};
4538
46- template <typename T>
47- struct scan_functor {
48- thrust::device_ptr<cuda::std::pair<T, T>> data;
49- int n_steps;
50- __host__ __device__ void operator ()(int i) const {
51- thrust::inclusive_scan (thrust::device, data + i * n_steps,
52- data + (i + 1 ) * n_steps, data + i * n_steps,
53- recur_binary_op<T>());
54- }
55- };
56-
5739template <typename scalar_t >
58- void compute_linear_recurrence (const scalar_t *decays, const scalar_t *impulses,
59- scalar_t *out, int n_steps) {
40+ __host__ __device__ void compute_linear_recurrence (const scalar_t *decays,
41+ const scalar_t *impulses,
42+ scalar_t *out, int n_steps) {
6043 thrust::inclusive_scan (
6144 thrust::device, thrust::make_zip_iterator (decays, impulses),
6245 thrust::make_zip_iterator (decays + n_steps, impulses + n_steps),
6346 thrust::make_transform_output_iterator (out,
64- // thrust::get<1>),
6547 output_unary_op<scalar_t >()),
6648 recur_binary_op<scalar_t >());
6749}
6850
51+ template <typename T>
52+ struct scan_functor {
53+ const T *decays, *impulses;
54+ T *out;
55+ int n_steps;
56+ __host__ __device__ void operator ()(int i) {
57+ compute_linear_recurrence<T>(decays + i * n_steps,
58+ impulses + i * n_steps, out + i * n_steps,
59+ n_steps);
60+ }
61+ };
62+
6963template <typename scalar_t >
7064void compute_linear_recurrence2 (const scalar_t *decays,
7165 const scalar_t *impulses,
7266 // const scalar_t *initials,
7367 scalar_t *out, int n_dims, int n_steps) {
74- thrust::device_vector<cuda::std::pair<scalar_t , scalar_t >> pairs (n_steps *
75- n_dims);
76- thrust::transform (
77- thrust::device, decays, decays + n_steps * n_dims, impulses,
78- pairs.begin (),
79- [] __host__ __device__ (const scalar_t &decay, const scalar_t &impulse) {
80- return cuda::std::make_pair (decay, impulse);
81- });
82-
83- recur_binary_op<scalar_t > binary_op;
8468 thrust::counting_iterator<int > it (0 );
85- scan_functor<scalar_t > scan_op{pairs.data (), n_steps};
86-
69+ scan_functor<scalar_t > scan_op{decays, impulses, out, n_steps};
8770 thrust::for_each (thrust::device, it, it + n_dims, scan_op);
88-
89- thrust::transform (thrust::device, pairs.begin (), pairs.end (), out,
90- [] __host__ __device__ (
91- const cuda::std::pair<scalar_t , scalar_t > &state) {
92- return state.second ;
93- });
9471}
9572
9673at::Tensor scan_cuda_wrapper (const at::Tensor &input, const at::Tensor &weights,
0 commit comments