Skip to content

Commit 029fb66

Browse files
authored
feat: GPU dyn dispatch patches support (#7563)
Integrates the structural plumbing as well as applying the patches in source and scalar ops within the GPU dynamic dispatch kernel. None of the CUDA dynamic dispatch benchmarks regressed by adding patches support. As part of this change, the fastlanes lane count for a given type is now determined at compile time via `FL_LANES<type>`. --------- Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 452a4a3 commit 029fb66

19 files changed

Lines changed: 1190 additions & 564 deletions

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ struct BenchRunner {
123123
}
124124

125125
impl BenchRunner {
126-
fn new(array: &vortex::array::ArrayRef, len: usize, cuda_ctx: &CudaExecutionCtx) -> Self {
126+
fn new(array: &vortex::array::ArrayRef, len: usize, cuda_ctx: &mut CudaExecutionCtx) -> Self {
127127
let plan = match DispatchPlan::new(array).vortex_expect("build_dyn_dispatch_plan") {
128128
DispatchPlan::Fused(plan) => plan,
129129
_ => unreachable!("encoding not fusable"),
@@ -201,7 +201,7 @@ fn bench_for_bitpacked(c: &mut Criterion) {
201201
let mut cuda_ctx =
202202
CudaSession::create_execution_ctx(&VortexSession::empty()).vortex_expect("ctx");
203203

204-
let bench_runner = BenchRunner::new(&array, n, &cuda_ctx);
204+
let bench_runner = BenchRunner::new(&array, n, &mut cuda_ctx);
205205

206206
b.iter_custom(|iters| {
207207
let mut total_time = Duration::ZERO;
@@ -246,7 +246,7 @@ fn bench_dict_bp_codes(c: &mut Criterion) {
246246
let mut cuda_ctx =
247247
CudaSession::create_execution_ctx(&VortexSession::empty()).vortex_expect("ctx");
248248

249-
let bench_runner = BenchRunner::new(&array, n, &cuda_ctx);
249+
let bench_runner = BenchRunner::new(&array, n, &mut cuda_ctx);
250250

251251
b.iter_custom(|iters| {
252252
let mut total_time = Duration::ZERO;
@@ -290,7 +290,7 @@ fn bench_runend(c: &mut Criterion) {
290290
let mut cuda_ctx =
291291
CudaSession::create_execution_ctx(&VortexSession::empty()).vortex_expect("ctx");
292292

293-
let bench_runner = BenchRunner::new(&array, n, &cuda_ctx);
293+
let bench_runner = BenchRunner::new(&array, n, &mut cuda_ctx);
294294

295295
b.iter_custom(|iters| {
296296
let mut total_time = Duration::ZERO;
@@ -344,7 +344,7 @@ fn bench_dict_bp_codes_bp_for_values(c: &mut Criterion) {
344344
let mut cuda_ctx =
345345
CudaSession::create_execution_ctx(&VortexSession::empty()).vortex_expect("ctx");
346346

347-
let bench_runner = BenchRunner::new(&array, n, &cuda_ctx);
347+
let bench_runner = BenchRunner::new(&array, n, &mut cuda_ctx);
348348

349349
b.iter_custom(|iters| {
350350
let mut total_time = Duration::ZERO;
@@ -409,7 +409,7 @@ fn bench_alp_for_bitpacked(c: &mut Criterion) {
409409
let mut cuda_ctx =
410410
CudaSession::create_execution_ctx(&VortexSession::empty()).vortex_expect("ctx");
411411

412-
let bench_runner = BenchRunner::new(&array, n, &cuda_ctx);
412+
let bench_runner = BenchRunner::new(&array, n, &mut cuda_ctx);
413413

414414
b.iter_custom(|iters| {
415415
let mut total_time = Duration::ZERO;

vortex-cuda/kernels/src/bit_unpack_16.cu

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,148 +4,148 @@
44

55
template <int BW>
66
__device__ void _bit_unpack_16_device(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, int thread_idx, GPUPatches& patches) {
7-
__shared__ uint16_t shared_out[1024];
7+
__shared__ uint16_t shared_out[FL_CHUNK];
88

99
// Step 1: Unpack into shared memory
1010
#pragma unroll
11-
for (int i = 0; i < 2; i++) {
12-
_bit_unpack_16_lane<BW>(in, shared_out, reference, thread_idx * 2 + i);
11+
for (int i = 0; i < FL_LANES<uint16_t> / 32; i++) {
12+
_bit_unpack_16_lane<BW>(in, shared_out, reference, thread_idx * (FL_LANES<uint16_t> / 32) + i);
1313
}
1414
__syncwarp();
1515

1616
// Step 2: Apply patches to shared memory in parallel
1717
PatchesCursor<uint16_t> cursor(patches, blockIdx.x, thread_idx, 32);
1818
auto patch = cursor.next();
19-
while (patch.index != 1024) {
19+
while (patch.index != FL_CHUNK) {
2020
shared_out[patch.index] = patch.value;
2121
patch = cursor.next();
2222
}
2323
__syncwarp();
2424

2525
// Step 3: Copy to global memory
2626
#pragma unroll
27-
for (int i = 0; i < 32; i++) {
27+
for (int i = 0; i < FL_CHUNK / 32; i++) {
2828
auto idx = i * 32 + thread_idx;
2929
out[idx] = shared_out[idx];
3030
}
3131
}
3232

3333
extern "C" __global__ void bit_unpack_16_0bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
3434
int thread_idx = threadIdx.x;
35-
auto in = full_in + (blockIdx.x * (128 * 0 / sizeof(uint16_t)));
36-
auto out = full_out + (blockIdx.x * 1024);
35+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 0));
36+
auto out = full_out + (blockIdx.x * FL_CHUNK);
3737
_bit_unpack_16_device<0>(in, out, reference, thread_idx, patches);
3838
}
3939

4040
extern "C" __global__ void bit_unpack_16_1bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
4141
int thread_idx = threadIdx.x;
42-
auto in = full_in + (blockIdx.x * (128 * 1 / sizeof(uint16_t)));
43-
auto out = full_out + (blockIdx.x * 1024);
42+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 1));
43+
auto out = full_out + (blockIdx.x * FL_CHUNK);
4444
_bit_unpack_16_device<1>(in, out, reference, thread_idx, patches);
4545
}
4646

4747
extern "C" __global__ void bit_unpack_16_2bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
4848
int thread_idx = threadIdx.x;
49-
auto in = full_in + (blockIdx.x * (128 * 2 / sizeof(uint16_t)));
50-
auto out = full_out + (blockIdx.x * 1024);
49+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 2));
50+
auto out = full_out + (blockIdx.x * FL_CHUNK);
5151
_bit_unpack_16_device<2>(in, out, reference, thread_idx, patches);
5252
}
5353

5454
extern "C" __global__ void bit_unpack_16_3bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
5555
int thread_idx = threadIdx.x;
56-
auto in = full_in + (blockIdx.x * (128 * 3 / sizeof(uint16_t)));
57-
auto out = full_out + (blockIdx.x * 1024);
56+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 3));
57+
auto out = full_out + (blockIdx.x * FL_CHUNK);
5858
_bit_unpack_16_device<3>(in, out, reference, thread_idx, patches);
5959
}
6060

6161
extern "C" __global__ void bit_unpack_16_4bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
6262
int thread_idx = threadIdx.x;
63-
auto in = full_in + (blockIdx.x * (128 * 4 / sizeof(uint16_t)));
64-
auto out = full_out + (blockIdx.x * 1024);
63+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 4));
64+
auto out = full_out + (blockIdx.x * FL_CHUNK);
6565
_bit_unpack_16_device<4>(in, out, reference, thread_idx, patches);
6666
}
6767

6868
extern "C" __global__ void bit_unpack_16_5bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
6969
int thread_idx = threadIdx.x;
70-
auto in = full_in + (blockIdx.x * (128 * 5 / sizeof(uint16_t)));
71-
auto out = full_out + (blockIdx.x * 1024);
70+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 5));
71+
auto out = full_out + (blockIdx.x * FL_CHUNK);
7272
_bit_unpack_16_device<5>(in, out, reference, thread_idx, patches);
7373
}
7474

7575
extern "C" __global__ void bit_unpack_16_6bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
7676
int thread_idx = threadIdx.x;
77-
auto in = full_in + (blockIdx.x * (128 * 6 / sizeof(uint16_t)));
78-
auto out = full_out + (blockIdx.x * 1024);
77+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 6));
78+
auto out = full_out + (blockIdx.x * FL_CHUNK);
7979
_bit_unpack_16_device<6>(in, out, reference, thread_idx, patches);
8080
}
8181

8282
extern "C" __global__ void bit_unpack_16_7bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
8383
int thread_idx = threadIdx.x;
84-
auto in = full_in + (blockIdx.x * (128 * 7 / sizeof(uint16_t)));
85-
auto out = full_out + (blockIdx.x * 1024);
84+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 7));
85+
auto out = full_out + (blockIdx.x * FL_CHUNK);
8686
_bit_unpack_16_device<7>(in, out, reference, thread_idx, patches);
8787
}
8888

8989
extern "C" __global__ void bit_unpack_16_8bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
9090
int thread_idx = threadIdx.x;
91-
auto in = full_in + (blockIdx.x * (128 * 8 / sizeof(uint16_t)));
92-
auto out = full_out + (blockIdx.x * 1024);
91+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 8));
92+
auto out = full_out + (blockIdx.x * FL_CHUNK);
9393
_bit_unpack_16_device<8>(in, out, reference, thread_idx, patches);
9494
}
9595

9696
extern "C" __global__ void bit_unpack_16_9bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
9797
int thread_idx = threadIdx.x;
98-
auto in = full_in + (blockIdx.x * (128 * 9 / sizeof(uint16_t)));
99-
auto out = full_out + (blockIdx.x * 1024);
98+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 9));
99+
auto out = full_out + (blockIdx.x * FL_CHUNK);
100100
_bit_unpack_16_device<9>(in, out, reference, thread_idx, patches);
101101
}
102102

103103
extern "C" __global__ void bit_unpack_16_10bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
104104
int thread_idx = threadIdx.x;
105-
auto in = full_in + (blockIdx.x * (128 * 10 / sizeof(uint16_t)));
106-
auto out = full_out + (blockIdx.x * 1024);
105+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 10));
106+
auto out = full_out + (blockIdx.x * FL_CHUNK);
107107
_bit_unpack_16_device<10>(in, out, reference, thread_idx, patches);
108108
}
109109

110110
extern "C" __global__ void bit_unpack_16_11bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
111111
int thread_idx = threadIdx.x;
112-
auto in = full_in + (blockIdx.x * (128 * 11 / sizeof(uint16_t)));
113-
auto out = full_out + (blockIdx.x * 1024);
112+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 11));
113+
auto out = full_out + (blockIdx.x * FL_CHUNK);
114114
_bit_unpack_16_device<11>(in, out, reference, thread_idx, patches);
115115
}
116116

117117
extern "C" __global__ void bit_unpack_16_12bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
118118
int thread_idx = threadIdx.x;
119-
auto in = full_in + (blockIdx.x * (128 * 12 / sizeof(uint16_t)));
120-
auto out = full_out + (blockIdx.x * 1024);
119+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 12));
120+
auto out = full_out + (blockIdx.x * FL_CHUNK);
121121
_bit_unpack_16_device<12>(in, out, reference, thread_idx, patches);
122122
}
123123

124124
extern "C" __global__ void bit_unpack_16_13bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
125125
int thread_idx = threadIdx.x;
126-
auto in = full_in + (blockIdx.x * (128 * 13 / sizeof(uint16_t)));
127-
auto out = full_out + (blockIdx.x * 1024);
126+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 13));
127+
auto out = full_out + (blockIdx.x * FL_CHUNK);
128128
_bit_unpack_16_device<13>(in, out, reference, thread_idx, patches);
129129
}
130130

131131
extern "C" __global__ void bit_unpack_16_14bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
132132
int thread_idx = threadIdx.x;
133-
auto in = full_in + (blockIdx.x * (128 * 14 / sizeof(uint16_t)));
134-
auto out = full_out + (blockIdx.x * 1024);
133+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 14));
134+
auto out = full_out + (blockIdx.x * FL_CHUNK);
135135
_bit_unpack_16_device<14>(in, out, reference, thread_idx, patches);
136136
}
137137

138138
extern "C" __global__ void bit_unpack_16_15bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
139139
int thread_idx = threadIdx.x;
140-
auto in = full_in + (blockIdx.x * (128 * 15 / sizeof(uint16_t)));
141-
auto out = full_out + (blockIdx.x * 1024);
140+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 15));
141+
auto out = full_out + (blockIdx.x * FL_CHUNK);
142142
_bit_unpack_16_device<15>(in, out, reference, thread_idx, patches);
143143
}
144144

145145
extern "C" __global__ void bit_unpack_16_16bw_32t(const uint16_t *__restrict full_in, uint16_t *__restrict full_out, uint16_t reference, GPUPatches patches) {
146146
int thread_idx = threadIdx.x;
147-
auto in = full_in + (blockIdx.x * (128 * 16 / sizeof(uint16_t)));
148-
auto out = full_out + (blockIdx.x * 1024);
147+
auto in = full_in + (blockIdx.x * (FL_LANES<uint16_t> * 16));
148+
auto out = full_out + (blockIdx.x * FL_CHUNK);
149149
_bit_unpack_16_device<16>(in, out, reference, thread_idx, patches);
150150
}
151151

vortex-cuda/kernels/src/bit_unpack_16_lanes.cuh

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ __device__ void _bit_unpack_16_lane<0>(const uint16_t *__restrict in, uint16_t *
1919

2020
template <>
2121
__device__ void _bit_unpack_16_lane<1>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
22-
unsigned int LANE_COUNT = 64;
22+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
2323
uint16_t src;
2424
uint16_t tmp;
2525
src = in[lane];
@@ -59,7 +59,7 @@ __device__ void _bit_unpack_16_lane<1>(const uint16_t *__restrict in, uint16_t *
5959

6060
template <>
6161
__device__ void _bit_unpack_16_lane<2>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
62-
unsigned int LANE_COUNT = 64;
62+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
6363
uint16_t src;
6464
uint16_t tmp;
6565
src = in[lane];
@@ -101,7 +101,7 @@ __device__ void _bit_unpack_16_lane<2>(const uint16_t *__restrict in, uint16_t *
101101

102102
template <>
103103
__device__ void _bit_unpack_16_lane<3>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
104-
unsigned int LANE_COUNT = 64;
104+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
105105
uint16_t src;
106106
uint16_t tmp;
107107
src = in[lane];
@@ -145,7 +145,7 @@ __device__ void _bit_unpack_16_lane<3>(const uint16_t *__restrict in, uint16_t *
145145

146146
template <>
147147
__device__ void _bit_unpack_16_lane<4>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
148-
unsigned int LANE_COUNT = 64;
148+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
149149
uint16_t src;
150150
uint16_t tmp;
151151
src = in[lane];
@@ -191,7 +191,7 @@ __device__ void _bit_unpack_16_lane<4>(const uint16_t *__restrict in, uint16_t *
191191

192192
template <>
193193
__device__ void _bit_unpack_16_lane<5>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
194-
unsigned int LANE_COUNT = 64;
194+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
195195
uint16_t src;
196196
uint16_t tmp;
197197
src = in[lane];
@@ -239,7 +239,7 @@ __device__ void _bit_unpack_16_lane<5>(const uint16_t *__restrict in, uint16_t *
239239

240240
template <>
241241
__device__ void _bit_unpack_16_lane<6>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
242-
unsigned int LANE_COUNT = 64;
242+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
243243
uint16_t src;
244244
uint16_t tmp;
245245
src = in[lane];
@@ -289,7 +289,7 @@ __device__ void _bit_unpack_16_lane<6>(const uint16_t *__restrict in, uint16_t *
289289

290290
template <>
291291
__device__ void _bit_unpack_16_lane<7>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
292-
unsigned int LANE_COUNT = 64;
292+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
293293
uint16_t src;
294294
uint16_t tmp;
295295
src = in[lane];
@@ -341,7 +341,7 @@ __device__ void _bit_unpack_16_lane<7>(const uint16_t *__restrict in, uint16_t *
341341

342342
template <>
343343
__device__ void _bit_unpack_16_lane<8>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
344-
unsigned int LANE_COUNT = 64;
344+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
345345
uint16_t src;
346346
uint16_t tmp;
347347
src = in[lane];
@@ -395,7 +395,7 @@ __device__ void _bit_unpack_16_lane<8>(const uint16_t *__restrict in, uint16_t *
395395

396396
template <>
397397
__device__ void _bit_unpack_16_lane<9>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
398-
unsigned int LANE_COUNT = 64;
398+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
399399
uint16_t src;
400400
uint16_t tmp;
401401
src = in[lane];
@@ -451,7 +451,7 @@ __device__ void _bit_unpack_16_lane<9>(const uint16_t *__restrict in, uint16_t *
451451

452452
template <>
453453
__device__ void _bit_unpack_16_lane<10>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
454-
unsigned int LANE_COUNT = 64;
454+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
455455
uint16_t src;
456456
uint16_t tmp;
457457
src = in[lane];
@@ -509,7 +509,7 @@ __device__ void _bit_unpack_16_lane<10>(const uint16_t *__restrict in, uint16_t
509509

510510
template <>
511511
__device__ void _bit_unpack_16_lane<11>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
512-
unsigned int LANE_COUNT = 64;
512+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
513513
uint16_t src;
514514
uint16_t tmp;
515515
src = in[lane];
@@ -569,7 +569,7 @@ __device__ void _bit_unpack_16_lane<11>(const uint16_t *__restrict in, uint16_t
569569

570570
template <>
571571
__device__ void _bit_unpack_16_lane<12>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
572-
unsigned int LANE_COUNT = 64;
572+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
573573
uint16_t src;
574574
uint16_t tmp;
575575
src = in[lane];
@@ -631,7 +631,7 @@ __device__ void _bit_unpack_16_lane<12>(const uint16_t *__restrict in, uint16_t
631631

632632
template <>
633633
__device__ void _bit_unpack_16_lane<13>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
634-
unsigned int LANE_COUNT = 64;
634+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
635635
uint16_t src;
636636
uint16_t tmp;
637637
src = in[lane];
@@ -695,7 +695,7 @@ __device__ void _bit_unpack_16_lane<13>(const uint16_t *__restrict in, uint16_t
695695

696696
template <>
697697
__device__ void _bit_unpack_16_lane<14>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
698-
unsigned int LANE_COUNT = 64;
698+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
699699
uint16_t src;
700700
uint16_t tmp;
701701
src = in[lane];
@@ -761,7 +761,7 @@ __device__ void _bit_unpack_16_lane<14>(const uint16_t *__restrict in, uint16_t
761761

762762
template <>
763763
__device__ void _bit_unpack_16_lane<15>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
764-
unsigned int LANE_COUNT = 64;
764+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
765765
uint16_t src;
766766
uint16_t tmp;
767767
src = in[lane];
@@ -829,7 +829,7 @@ __device__ void _bit_unpack_16_lane<15>(const uint16_t *__restrict in, uint16_t
829829

830830
template <>
831831
__device__ void _bit_unpack_16_lane<16>(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, unsigned int lane) {
832-
unsigned int LANE_COUNT = 64;
832+
constexpr unsigned int LANE_COUNT = FL_LANES<uint16_t>;
833833
#pragma unroll
834834
for (int row = 0; row < 16; row++) {
835835
out[INDEX(row, lane)] = in[LANE_COUNT * row + lane] + reference;

0 commit comments

Comments
 (0)