Skip to content

Commit d1d8732

Browse files
committed
refactor recur2 version
1 parent a60cfe7 commit d1d8732

1 file changed

Lines changed: 17 additions & 40 deletions

File tree

torchlpc/csrc/cuda/scan.cu

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <thrust/copy.h>
66
#include <thrust/device_vector.h>
77
#include <thrust/execution_policy.h>
8+
#include <thrust/for_each.h>
89
#include <thrust/functional.h>
910
#include <thrust/iterator/transform_output_iterator.h>
1011
#include <thrust/pair.h>
@@ -27,14 +28,6 @@ struct recur_binary_op {
2728
}
2829
};
2930

30-
template <typename T>
31-
struct input_unary_op {
32-
__host__ __device__ cuda::std::pair<T, T> operator()(
33-
const T &decay, const T &impulse) const {
34-
return cuda::std::make_pair(decay, impulse);
35-
}
36-
};
37-
3831
template <typename T>
3932
struct output_unary_op {
4033
__host__ __device__ T
@@ -43,54 +36,38 @@ struct output_unary_op {
4336
}
4437
};
4538

46-
template <typename T>
47-
struct scan_functor {
48-
thrust::device_ptr<cuda::std::pair<T, T>> data;
49-
int n_steps;
50-
__host__ __device__ void operator()(int i) const {
51-
thrust::inclusive_scan(thrust::device, data + i * n_steps,
52-
data + (i + 1) * n_steps, data + i * n_steps,
53-
recur_binary_op<T>());
54-
}
55-
};
56-
5739
template <typename scalar_t>
58-
void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses,
59-
scalar_t *out, int n_steps) {
40+
__host__ __device__ void compute_linear_recurrence(const scalar_t *decays,
41+
const scalar_t *impulses,
42+
scalar_t *out, int n_steps) {
6043
thrust::inclusive_scan(
6144
thrust::device, thrust::make_zip_iterator(decays, impulses),
6245
thrust::make_zip_iterator(decays + n_steps, impulses + n_steps),
6346
thrust::make_transform_output_iterator(out,
64-
// thrust::get<1>),
6547
output_unary_op<scalar_t>()),
6648
recur_binary_op<scalar_t>());
6749
}
6850

51+
template <typename T>
52+
struct scan_functor {
53+
const T *decays, *impulses;
54+
T *out;
55+
int n_steps;
56+
__host__ __device__ void operator()(int i) {
57+
compute_linear_recurrence<T>(decays + i * n_steps,
58+
impulses + i * n_steps, out + i * n_steps,
59+
n_steps);
60+
}
61+
};
62+
6963
template <typename scalar_t>
7064
void compute_linear_recurrence2(const scalar_t *decays,
7165
const scalar_t *impulses,
7266
// const scalar_t *initials,
7367
scalar_t *out, int n_dims, int n_steps) {
74-
thrust::device_vector<cuda::std::pair<scalar_t, scalar_t>> pairs(n_steps *
75-
n_dims);
76-
thrust::transform(
77-
thrust::device, decays, decays + n_steps * n_dims, impulses,
78-
pairs.begin(),
79-
[] __host__ __device__(const scalar_t &decay, const scalar_t &impulse) {
80-
return cuda::std::make_pair(decay, impulse);
81-
});
82-
83-
recur_binary_op<scalar_t> binary_op;
8468
thrust::counting_iterator<int> it(0);
85-
scan_functor<scalar_t> scan_op{pairs.data(), n_steps};
86-
69+
scan_functor<scalar_t> scan_op{decays, impulses, out, n_steps};
8770
thrust::for_each(thrust::device, it, it + n_dims, scan_op);
88-
89-
thrust::transform(thrust::device, pairs.begin(), pairs.end(), out,
90-
[] __host__ __device__(
91-
const cuda::std::pair<scalar_t, scalar_t> &state) {
92-
return state.second;
93-
});
9471
}
9572

9673
at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights,

0 commit comments

Comments
 (0)