Skip to content

Commit f728ada

Browse files
ggml : address integer overflows in binary ops CUDA implementation (ggml-org#24706)
* ggml : address integer overflows in binary ops CUDA implementation * ggml : add size_t casts to avoid integer overflows * ggml : add more asserts checking integer overflows in binary ops CUDA implementation --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
1 parent 3e61ea0 commit f728ada

1 file changed

Lines changed: 90 additions & 46 deletions

File tree

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 90 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -34,60 +34,67 @@ template <float (*bin_op)(const float, const float),
3434
static __global__ void k_bin_bcast(const src0_t * src0,
3535
const src1_t * src1,
3636
dst_t * dst,
37-
const int ne0,
38-
const int ne1,
39-
const int ne2,
37+
const uint32_t ne0,
38+
const uint32_t ne1,
39+
const uint32_t ne2,
4040
const uint3 ne3,
4141
const uint3 ne10,
4242
const uint3 ne11,
4343
const uint3 ne12,
4444
const uint3 ne13,
45-
/*const int s0,*/
46-
const int s1,
47-
const int s2,
48-
const int s3,
49-
const int s00,
50-
const int s01,
51-
const int s02,
52-
const int s03,
53-
const int s10,
54-
const int s11,
55-
const int s12,
56-
const int s13,
45+
/*const uint32_t s0,*/
46+
const uint32_t s1,
47+
const uint32_t s2,
48+
const uint32_t s3,
49+
const uint32_t s00,
50+
const uint32_t s01,
51+
const uint32_t s02,
52+
const uint32_t s03,
53+
const uint32_t s10,
54+
const uint32_t s11,
55+
const uint32_t s12,
56+
const uint32_t s13,
5757
src1_ptrs... src1s) {
5858
ggml_cuda_pdl_lc();
5959
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
6060
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
6161
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
6262
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
6363

64-
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
64+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
6565
return;
6666
}
6767

6868
const uint32_t i11 = fastmodulo(i1, ne11);
6969
const uint32_t i12 = fastmodulo(i2, ne12);
7070
const uint32_t i13 = fastmodulo(i3, ne13);
7171

72-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
73-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
74-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
72+
const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01;
73+
const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11;
74+
const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1;
7575

7676
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
7777
dst_t * dst_row = dst + i_dst;
7878

79+
const uint32_t s0 = blockDim.x * gridDim.x;
80+
7981
ggml_cuda_pdl_sync();
80-
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
82+
for (uint32_t i0 = i0s; i0 < ne0; i0 += s0) {
8183
const uint32_t i10 = fastmodulo(i0, ne10);
8284

83-
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
85+
float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f;
8486
if constexpr (sizeof...(src1_ptrs) > 0) {
85-
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
87+
result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10])));
8688
} else {
87-
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
89+
result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]);
8890
}
8991

9092
dst_row[i0] = (dst_t) result;
93+
94+
// protect i0 from overflow
95+
if (ne0 - i0 <= s0) {
96+
break;
97+
}
9198
}
9299
}
93100

@@ -110,19 +117,19 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
110117
const uint3 ne12,
111118
const uint3 ne13,
112119
/*const int s0,*/
113-
const int s1,
114-
const int s2,
115-
const int s3,
116-
const int s00,
117-
const int s01,
118-
const int s02,
119-
const int s03,
120-
const int s10,
121-
const int s11,
122-
const int s12,
123-
const int s13,
120+
const uint32_t s1,
121+
const uint32_t s2,
122+
const uint32_t s3,
123+
const uint32_t s00,
124+
const uint32_t s01,
125+
const uint32_t s02,
126+
const uint32_t s03,
127+
const uint32_t s10,
128+
const uint32_t s11,
129+
const uint32_t s12,
130+
const uint32_t s13,
124131
src1_ptrs... src1s) {
125-
const int i = blockDim.x*blockIdx.x + threadIdx.x;
132+
const uint32_t i = blockDim.x*blockIdx.x + threadIdx.x;
126133

127134
const uint32_t i3 = fastdiv(i, prod_012);
128135
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
@@ -133,25 +140,25 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
133140
return;
134141
}
135142

136-
const int i11 = fastmodulo(i1, ne11);
137-
const int i12 = fastmodulo(i2, ne12);
138-
const int i13 = fastmodulo(i3, ne13);
143+
const uint32_t i11 = fastmodulo(i1, ne11);
144+
const uint32_t i12 = fastmodulo(i2, ne12);
145+
const uint32_t i13 = fastmodulo(i3, ne13);
139146

140-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
141-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
142-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
147+
const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01;
148+
const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11;
149+
const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1;
143150

144151
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
145152
dst_t * dst_row = dst + i_dst;
146153

147-
const int i10 = fastmodulo(i0, ne10);
154+
const uint32_t i10 = fastmodulo(i0, ne10);
148155

149156
ggml_cuda_pdl_sync();
150-
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
157+
float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f;
151158
if constexpr (sizeof...(src1_ptrs) > 0) {
152-
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
159+
result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10])));
153160
} else {
154-
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
161+
result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]);
155162
}
156163

157164
dst_row[i0] = (dst_t) result;
@@ -248,6 +255,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
248255
size_t s02 = nb02 / sizeof(src0_t);
249256
size_t s03 = nb03 / sizeof(src0_t);
250257

258+
GGML_ASSERT(ne0 <= std::numeric_limits<uint32_t>::max());
259+
GGML_ASSERT(ne1 <= std::numeric_limits<uint32_t>::max());
260+
GGML_ASSERT(ne2 <= std::numeric_limits<uint32_t>::max());
261+
GGML_ASSERT(ne3 <= std::numeric_limits<uint32_t>::max());
262+
263+
//GGML_ASSERT(s0 <= std::numeric_limits<uint32_t>::max());
264+
GGML_ASSERT(s1 <= std::numeric_limits<uint32_t>::max());
265+
GGML_ASSERT(s2 <= std::numeric_limits<uint32_t>::max());
266+
GGML_ASSERT(s3 <= std::numeric_limits<uint32_t>::max());
267+
268+
GGML_ASSERT(s00 <= std::numeric_limits<uint32_t>::max());
269+
GGML_ASSERT(s01 <= std::numeric_limits<uint32_t>::max());
270+
GGML_ASSERT(s02 <= std::numeric_limits<uint32_t>::max());
271+
GGML_ASSERT(s03 <= std::numeric_limits<uint32_t>::max());
272+
273+
GGML_ASSERT(s10 <= std::numeric_limits<uint32_t>::max());
274+
GGML_ASSERT(s11 <= std::numeric_limits<uint32_t>::max());
275+
GGML_ASSERT(s12 <= std::numeric_limits<uint32_t>::max());
276+
GGML_ASSERT(s13 <= std::numeric_limits<uint32_t>::max());
277+
278+
GGML_ASSERT(cne1[0] <= std::numeric_limits<uint32_t>::max());
279+
GGML_ASSERT(cne1[1] <= std::numeric_limits<uint32_t>::max());
280+
GGML_ASSERT(cne1[2] <= std::numeric_limits<uint32_t>::max());
281+
GGML_ASSERT(cne1[3] <= std::numeric_limits<uint32_t>::max());
282+
251283
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
252284
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
253285
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
@@ -263,6 +295,8 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
263295
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
264296
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
265297

298+
GGML_ASSERT(ne2 * ne3 <= std::numeric_limits<unsigned int>::max());
299+
266300
const int block_size = 128;
267301

268302
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@@ -281,7 +315,13 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
281315
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
282316

283317
if (block_nums.z > 65535 || block_nums.y > 65535) {
284-
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
318+
int64_t block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
319+
320+
GGML_ASSERT(block_num <= std::numeric_limits<uint32_t>::max());
321+
GGML_ASSERT(block_num * block_size <= std::numeric_limits<uint32_t>::max());
322+
GGML_ASSERT(ne0 * ne1 <= std::numeric_limits<uint32_t>::max());
323+
GGML_ASSERT(ne0 * ne1 * ne2 <= std::numeric_limits<uint32_t>::max());
324+
285325
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
286326
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
287327
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
@@ -298,6 +338,10 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
298338
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
299339
}
300340
} else {
341+
GGML_ASSERT(int64_t(block_nums.x) * block_dims.x <= std::numeric_limits<uint32_t>::max());
342+
GGML_ASSERT(int64_t(block_nums.y) * block_dims.y <= std::numeric_limits<uint32_t>::max());
343+
GGML_ASSERT(int64_t(block_nums.z) * block_dims.z <= std::numeric_limits<uint32_t>::max());
344+
301345
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
302346
{
303347
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);

0 commit comments

Comments
 (0)