@@ -34,60 +34,67 @@ template <float (*bin_op)(const float, const float),
3434static __global__ void k_bin_bcast (const src0_t * src0,
3535 const src1_t * src1,
3636 dst_t * dst,
37- const int ne0,
38- const int ne1,
39- const int ne2,
37+ const uint32_t ne0,
38+ const uint32_t ne1,
39+ const uint32_t ne2,
4040 const uint3 ne3,
4141 const uint3 ne10,
4242 const uint3 ne11,
4343 const uint3 ne12,
4444 const uint3 ne13,
45- /* const int s0,*/
46- const int s1,
47- const int s2,
48- const int s3,
49- const int s00,
50- const int s01,
51- const int s02,
52- const int s03,
53- const int s10,
54- const int s11,
55- const int s12,
56- const int s13,
45+ /* const uint32_t s0,*/
46+ const uint32_t s1,
47+ const uint32_t s2,
48+ const uint32_t s3,
49+ const uint32_t s00,
50+ const uint32_t s01,
51+ const uint32_t s02,
52+ const uint32_t s03,
53+ const uint32_t s10,
54+ const uint32_t s11,
55+ const uint32_t s12,
56+ const uint32_t s13,
5757 src1_ptrs... src1s) {
5858 ggml_cuda_pdl_lc ();
5959 const uint32_t i0s = blockDim .x * blockIdx .x + threadIdx .x ;
6060 const uint32_t i1 = (blockDim .y * blockIdx .y + threadIdx .y );
6161 const uint32_t i2 = fastdiv ((blockDim .z * blockIdx .z + threadIdx .z ), ne3);
6262 const uint32_t i3 = (blockDim .z * blockIdx .z + threadIdx .z ) - (i2 * ne3.z );
6363
64- if (i0s >= ( uint32_t ) ne0 || i1 >= ( uint32_t ) ne1 || i2 >= ( uint32_t ) ne2 || i3 >= ne3.z ) {
64+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z ) {
6565 return ;
6666 }
6767
6868 const uint32_t i11 = fastmodulo (i1, ne11);
6969 const uint32_t i12 = fastmodulo (i2, ne12);
7070 const uint32_t i13 = fastmodulo (i3, ne13);
7171
72- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
73- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
74- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
72+ const size_t i_src0 = size_t ( i3) *s03 + size_t ( i2) *s02 + size_t ( i1) *s01;
73+ const size_t i_src1 = size_t ( i13) *s13 + size_t ( i12) *s12 + size_t ( i11) *s11;
74+ const size_t i_dst = size_t ( i3) *s3 + size_t ( i2) *s2 + size_t ( i1) *s1;
7575
7676 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr ;
7777 dst_t * dst_row = dst + i_dst;
7878
79+ const uint32_t s0 = blockDim .x * gridDim .x ;
80+
7981 ggml_cuda_pdl_sync ();
80- for (int i0 = i0s; i0 < ne0; i0 += blockDim . x * gridDim . x ) {
82+ for (uint32_t i0 = i0s; i0 < ne0; i0 += s0 ) {
8183 const uint32_t i10 = fastmodulo (i0, ne10);
8284
83- float result = src0_row ? (float ) src0_row[i0 *s00] : 0 .0f ;
85+ float result = src0_row ? (float ) src0_row[size_t (i0) *s00] : 0 .0f ;
8486 if constexpr (sizeof ...(src1_ptrs) > 0 ) {
85- result = (..., (result = bin_op (result, (float )src1s[i_src1 + i10*s10])));
87+ result = (..., (result = bin_op (result, (float )src1s[i_src1 + size_t ( i10) *s10])));
8688 } else {
87- result = bin_op (result, (float )src1[i_src1 + i10*s10]);
89+ result = bin_op (result, (float )src1[i_src1 + size_t ( i10) *s10]);
8890 }
8991
9092 dst_row[i0] = (dst_t ) result;
93+
94+ // protect i0 from overflow
95+ if (ne0 - i0 <= s0) {
96+ break ;
97+ }
9198 }
9299}
93100
@@ -110,19 +117,19 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
110117 const uint3 ne12,
111118 const uint3 ne13,
112119 /* const int s0,*/
113- const int s1,
114- const int s2,
115- const int s3,
116- const int s00,
117- const int s01,
118- const int s02,
119- const int s03,
120- const int s10,
121- const int s11,
122- const int s12,
123- const int s13,
120+ const uint32_t s1,
121+ const uint32_t s2,
122+ const uint32_t s3,
123+ const uint32_t s00,
124+ const uint32_t s01,
125+ const uint32_t s02,
126+ const uint32_t s03,
127+ const uint32_t s10,
128+ const uint32_t s11,
129+ const uint32_t s12,
130+ const uint32_t s13,
124131 src1_ptrs... src1s) {
125- const int i = blockDim .x *blockIdx .x + threadIdx .x ;
132+ const uint32_t i = blockDim .x *blockIdx .x + threadIdx .x ;
126133
127134 const uint32_t i3 = fastdiv (i, prod_012);
128135 const uint32_t i2 = fastdiv (i - i3 * prod_012.z , prod_01);
@@ -133,25 +140,25 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
133140 return ;
134141 }
135142
136- const int i11 = fastmodulo (i1, ne11);
137- const int i12 = fastmodulo (i2, ne12);
138- const int i13 = fastmodulo (i3, ne13);
143+ const uint32_t i11 = fastmodulo (i1, ne11);
144+ const uint32_t i12 = fastmodulo (i2, ne12);
145+ const uint32_t i13 = fastmodulo (i3, ne13);
139146
140- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
141- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
142- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
147+ const size_t i_src0 = size_t ( i3) *s03 + size_t ( i2) *s02 + size_t ( i1) *s01;
148+ const size_t i_src1 = size_t ( i13) *s13 + size_t ( i12) *s12 + size_t ( i11) *s11;
149+ const size_t i_dst = size_t ( i3) *s3 + size_t ( i2) *s2 + size_t ( i1) *s1;
143150
144151 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr ;
145152 dst_t * dst_row = dst + i_dst;
146153
147- const int i10 = fastmodulo (i0, ne10);
154+ const uint32_t i10 = fastmodulo (i0, ne10);
148155
149156 ggml_cuda_pdl_sync ();
150- float result = src0_row ? (float ) src0_row[i0 *s00] : 0 .0f ;
157+ float result = src0_row ? (float ) src0_row[size_t (i0) *s00] : 0 .0f ;
151158 if constexpr (sizeof ...(src1_ptrs) > 0 ) {
152- result = (..., (result = bin_op (result, (float )src1s[i_src1 + i10*s10])));
159+ result = (..., (result = bin_op (result, (float )src1s[i_src1 + size_t ( i10) *s10])));
153160 } else {
154- result = bin_op (result, (float )src1[i_src1 + i10*s10]);
161+ result = bin_op (result, (float )src1[i_src1 + size_t ( i10) *s10]);
155162 }
156163
157164 dst_row[i0] = (dst_t ) result;
@@ -248,6 +255,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
248255 size_t s02 = nb02 / sizeof (src0_t );
249256 size_t s03 = nb03 / sizeof (src0_t );
250257
258+ GGML_ASSERT (ne0 <= std::numeric_limits<uint32_t >::max ());
259+ GGML_ASSERT (ne1 <= std::numeric_limits<uint32_t >::max ());
260+ GGML_ASSERT (ne2 <= std::numeric_limits<uint32_t >::max ());
261+ GGML_ASSERT (ne3 <= std::numeric_limits<uint32_t >::max ());
262+
263+ // GGML_ASSERT(s0 <= std::numeric_limits<uint32_t>::max());
264+ GGML_ASSERT (s1 <= std::numeric_limits<uint32_t >::max ());
265+ GGML_ASSERT (s2 <= std::numeric_limits<uint32_t >::max ());
266+ GGML_ASSERT (s3 <= std::numeric_limits<uint32_t >::max ());
267+
268+ GGML_ASSERT (s00 <= std::numeric_limits<uint32_t >::max ());
269+ GGML_ASSERT (s01 <= std::numeric_limits<uint32_t >::max ());
270+ GGML_ASSERT (s02 <= std::numeric_limits<uint32_t >::max ());
271+ GGML_ASSERT (s03 <= std::numeric_limits<uint32_t >::max ());
272+
273+ GGML_ASSERT (s10 <= std::numeric_limits<uint32_t >::max ());
274+ GGML_ASSERT (s11 <= std::numeric_limits<uint32_t >::max ());
275+ GGML_ASSERT (s12 <= std::numeric_limits<uint32_t >::max ());
276+ GGML_ASSERT (s13 <= std::numeric_limits<uint32_t >::max ());
277+
278+ GGML_ASSERT (cne1[0 ] <= std::numeric_limits<uint32_t >::max ());
279+ GGML_ASSERT (cne1[1 ] <= std::numeric_limits<uint32_t >::max ());
280+ GGML_ASSERT (cne1[2 ] <= std::numeric_limits<uint32_t >::max ());
281+ GGML_ASSERT (cne1[3 ] <= std::numeric_limits<uint32_t >::max ());
282+
251283 GGML_ASSERT (nb0 % sizeof (dst_t ) == 0 );
252284 GGML_ASSERT (nb1 % sizeof (dst_t ) == 0 );
253285 GGML_ASSERT (nb2 % sizeof (dst_t ) == 0 );
@@ -263,6 +295,8 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
263295 GGML_ASSERT (nb12 % sizeof (src1_t ) == 0 );
264296 GGML_ASSERT (nb13 % sizeof (src1_t ) == 0 );
265297
298+ GGML_ASSERT (ne2 * ne3 <= std::numeric_limits<unsigned int >::max ());
299+
266300 const int block_size = 128 ;
267301
268302 int64_t hne0 = std::max (ne0 / 2LL , 1LL );
@@ -281,7 +315,13 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
281315 const uint3 ne13 = init_fastdiv_values ((uint32_t ) cne1[3 ]);
282316
283317 if (block_nums.z > 65535 || block_nums.y > 65535 ) {
284- int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
318+ int64_t block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1 ) / block_size;
319+
320+ GGML_ASSERT (block_num <= std::numeric_limits<uint32_t >::max ());
321+ GGML_ASSERT (block_num * block_size <= std::numeric_limits<uint32_t >::max ());
322+ GGML_ASSERT (ne0 * ne1 <= std::numeric_limits<uint32_t >::max ());
323+ GGML_ASSERT (ne0 * ne1 * ne2 <= std::numeric_limits<uint32_t >::max ());
324+
285325 const uint3 prod_012 = init_fastdiv_values ((uint32_t ) (ne0 * ne1 * ne2));
286326 const uint3 prod_01 = init_fastdiv_values ((uint32_t ) (ne0 * ne1));
287327 const uint3 ne0_fastdiv = init_fastdiv_values ((uint32_t ) ne0);
@@ -298,6 +338,10 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
298338 s10, s11, s12, s13, (const src1_t *) dst->src [I + 1 ]->data ...);
299339 }
300340 } else {
341+ GGML_ASSERT (int64_t (block_nums.x ) * block_dims.x <= std::numeric_limits<uint32_t >::max ());
342+ GGML_ASSERT (int64_t (block_nums.y ) * block_dims.y <= std::numeric_limits<uint32_t >::max ());
343+ GGML_ASSERT (int64_t (block_nums.z ) * block_dims.z <= std::numeric_limits<uint32_t >::max ());
344+
301345 const uint3 ne3_fastdiv = init_fastdiv_values ((uint32_t ) ne3);
302346 {
303347 const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params (block_nums, block_dims, 0 , stream);
0 commit comments