Skip to content

Commit b1a5bd4

Browse files
authored
CUDA: better coalesce data-access for contiguous concat (#22330)
Also, distribute all elements across CTAs evenly instead of launching one CTA per dim
1 parent 0c6ee1c commit b1a5bd4

1 file changed

Lines changed: 62 additions & 79 deletions

File tree

ggml/src/ggml-cuda/concat.cu

Lines changed: 62 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,79 @@
11
#include "concat.cuh"
22

33
// contiguous kernels
4-
static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
5-
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
6-
if (nidx >= ne0) {
7-
return;
8-
}
9-
10-
int offset_dst =
11-
nidx +
12-
blockIdx.y * ne0 +
13-
blockIdx.z * ne0 * gridDim.y;
14-
15-
if (nidx < ne00) { // src0
16-
int offset_src =
17-
nidx +
18-
blockIdx.y * ne00 +
19-
blockIdx.z * ne00 * gridDim.y;
20-
dst[offset_dst] = x[offset_src];
21-
} else {
22-
int offset_src =
23-
(nidx - ne00) +
24-
blockIdx.y * (ne0 - ne00) +
25-
blockIdx.z * (ne0 - ne00) * gridDim.y;
26-
dst[offset_dst] = y[offset_src];
27-
}
28-
}
29-
30-
static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
31-
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
32-
if (nidx >= ne0) {
33-
return;
34-
}
4+
template <int dim>
5+
static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x,
6+
const float * y,
7+
float * dst,
8+
int64_t ne00,
9+
int64_t ne01,
10+
int64_t ne02,
11+
int64_t ne0,
12+
int64_t ne1,
13+
int64_t ne2) {
14+
static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]");
15+
16+
const int64_t n = ne0 * ne1 * ne2;
17+
18+
for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) {
19+
if constexpr (dim == 0) {
20+
const int64_t row = i / ne0;
21+
const int64_t i0 = i - row * ne0;
22+
23+
if (i0 < ne00) {
24+
dst[i] = x[row * ne00 + i0];
25+
} else {
26+
dst[i] = y[row * (ne0 - ne00) + (i0 - ne00)];
27+
}
28+
} else if constexpr (dim == 1) {
29+
const int64_t dst_plane = ne0 * ne1;
30+
const int64_t src0_plane = ne0 * ne01;
31+
const int64_t src1_plane = dst_plane - src0_plane;
32+
const int64_t i2 = i / dst_plane;
33+
const int64_t i01 = i - i2 * dst_plane;
34+
35+
if (i01 < src0_plane) {
36+
dst[i] = x[i2 * src0_plane + i01];
37+
} else {
38+
dst[i] = y[i2 * src1_plane + (i01 - src0_plane)];
39+
}
40+
} else {
41+
const int64_t src0_size = ne0 * ne1 * ne02;
3542

36-
int offset_dst =
37-
nidx +
38-
blockIdx.y * ne0 +
39-
blockIdx.z * ne0 * gridDim.y;
40-
41-
if (blockIdx.y < (unsigned)ne01) { // src0
42-
int offset_src =
43-
nidx +
44-
blockIdx.y * ne0 +
45-
blockIdx.z * ne0 * ne01;
46-
dst[offset_dst] = x[offset_src];
47-
} else {
48-
int offset_src =
49-
nidx +
50-
(blockIdx.y - ne01) * ne0 +
51-
blockIdx.z * ne0 * (gridDim.y - ne01);
52-
dst[offset_dst] = y[offset_src];
43+
if (i < src0_size) {
44+
dst[i] = x[i];
45+
} else {
46+
dst[i] = y[i - src0_size];
47+
}
48+
}
5349
}
5450
}
5551

56-
static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
57-
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
58-
if (nidx >= ne0) {
59-
return;
60-
}
61-
62-
int offset_dst =
63-
nidx +
64-
blockIdx.y * ne0 +
65-
blockIdx.z * ne0 * gridDim.y;
66-
67-
if (blockIdx.z < (unsigned)ne02) { // src0
68-
int offset_src =
69-
nidx +
70-
blockIdx.y * ne0 +
71-
blockIdx.z * ne0 * gridDim.y;
72-
dst[offset_dst] = x[offset_src];
73-
} else {
74-
int offset_src =
75-
nidx +
76-
blockIdx.y * ne0 +
77-
(blockIdx.z - ne02) * ne0 * gridDim.y;
78-
dst[offset_dst] = y[offset_src];
79-
}
80-
}
52+
static void concat_f32_cuda(const float * x,
53+
const float * y,
54+
float * dst,
55+
int64_t ne00,
56+
int64_t ne01,
57+
int64_t ne02,
58+
int64_t ne0,
59+
int64_t ne1,
60+
int64_t ne2,
61+
int dim,
62+
cudaStream_t stream) {
63+
const int64_t n = ne0 * ne1 * ne2;
64+
const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
8165

82-
static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
83-
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
84-
dim3 gridDim(num_blocks, ne1, ne2);
8566
if (dim == 0) {
86-
concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
67+
concat_f32_cont<0>
68+
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
8769
return;
8870
}
8971
if (dim == 1) {
90-
concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
72+
concat_f32_cont<1>
73+
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
9174
return;
9275
}
93-
concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
76+
concat_f32_cont<2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
9477
}
9578

9679
// non-contiguous kernel (slow)

0 commit comments

Comments
 (0)