Skip to content

Commit 9bebfcb

Browse files
authored
sycl : fix failed ut cases of norm (ggml-org#25044)
1 parent 0b6529d commit 9bebfcb

1 file changed

Lines changed: 103 additions & 49 deletions

File tree

ggml/src/ggml-sycl/norm.cpp

Lines changed: 103 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
#include "ggml-sycl/common.hpp"
33
#include "ggml-sycl/presets.hpp"
44

5-
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
6-
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
5+
static void norm_f32(const float* x, float* dst, const int ncols,
6+
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
7+
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
8+
const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
79

810
const int nrows = item_ct1.get_group_range(2);
911
const int nchannels = item_ct1.get_group_range(1);
@@ -16,16 +18,16 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
1618
const int tid = item_ct1.get_local_id(2);
1719
const int nwarps = nthreads / WARP_SIZE;
1820

19-
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
20-
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21+
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
22+
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
2123

22-
x += strided_offset;
23-
dst += packed_offset;
24+
x += src_offset;
25+
dst += dst_offset;
2426

2527
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
2628

2729
for (int col = tid; col < ncols; col += block_size) {
28-
const float xi = x[col];
30+
const float xi = x[col * src_stride_col];
2931
mean_var.x() += xi;
3032
mean_var.y() += xi * xi;
3133
}
@@ -54,7 +56,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
5456
const float inv_std = sycl::rsqrt(var + eps);
5557

5658
for (int col = tid; col < ncols; col += block_size) {
57-
dst[col] = (x[col] - mean) * inv_std;
59+
dst[col * dst_stride_col] = (x[col * src_stride_col] - mean) * inv_std;
5860
}
5961
}
6062

@@ -145,8 +147,10 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
145147
}
146148
}
147149

148-
static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
149-
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
150+
static void rms_norm_f32(const float* x, float* dst, const int ncols,
151+
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
152+
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
153+
const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
150154

151155
const int nrows = item_ct1.get_group_range(2);
152156
const int nchannels = item_ct1.get_group_range(1);
@@ -160,17 +164,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
160164
const int tid = item_ct1.get_local_id(2);
161165
const int nwarps = nthreads / WARP_SIZE;
162166

163-
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
164-
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
167+
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
168+
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
165169

166-
x += strided_offset;
167-
dst += packed_offset;
170+
x += src_offset;
171+
dst += dst_offset;
168172

169173

170174
float tmp = 0.0f; // partial sum for thread in warp
171175

172176
for (int col = tid; col < ncols; col += block_size) {
173-
const float xi = x[col];
177+
const float xi = x[col * src_stride_col];
174178
tmp += xi * xi;
175179
}
176180

@@ -198,14 +202,15 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
198202
const float scale = sycl::rsqrt(mean + eps);
199203

200204
for (int col = tid; col < ncols; col += block_size) {
201-
dst[col] = scale * x[col];
205+
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
202206
}
203207
}
204208

205209
template<int warp_size>
206210
static void l2_norm_f32(const float * x, float * dst, const int ncols,
207-
const int64_t stride_row, const int64_t stride_channel,
208-
const int64_t stride_sample, const float eps,
211+
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel,
212+
const int64_t src_stride_sample, const int64_t dst_stride_col, const int64_t dst_stride_row,
213+
const int64_t dst_stride_channel, const int64_t dst_stride_sample, const float eps,
209214
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
210215
const int nrows = item_ct1.get_group_range(2);
211216
const int nchannels = item_ct1.get_group_range(1);
@@ -215,26 +220,27 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
215220
const int sample = item_ct1.get_group(0);
216221
const int tid = item_ct1.get_local_id(2);
217222

218-
x += sample*stride_sample + channel*stride_channel + row*stride_row;
219-
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
223+
x += sample*src_stride_sample + channel*src_stride_channel + row*src_stride_row;
224+
dst += sample*dst_stride_sample + channel*dst_stride_channel + row*dst_stride_row;
220225

221226
float tmp = 0.0f; // partial sum for thread in warp
222227

223228
for (int col = tid; col < ncols; col += block_size) {
224-
const float xi = x[col];
229+
const float xi = x[col * src_stride_col];
225230
tmp += xi * xi;
226231
}
227232

228233
tmp = block_reduce<block_reduce_method::SUM, warp_size>(tmp, s_sum, block_size);
229234
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
230235

231236
for (int col = tid; col < ncols; col += block_size) {
232-
dst[col] = scale * x[col];
237+
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
233238
}
234239
}
235240

236241
static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
237-
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
242+
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
243+
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
238244
const float eps, queue_ptr stream, int device) {
239245

240246
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -245,7 +251,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
245251
sycl::nd_range<3>(global_dims * block_dims, block_dims),
246252
[=](sycl::nd_item<3> item_ct1)
247253
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
248-
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
254+
norm_f32(x, dst, ncols,
255+
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
256+
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
257+
eps, item_ct1, nullptr, WARP_SIZE);
249258
});
250259
});
251260
}
@@ -265,7 +274,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
265274
sycl::nd_range<3>(global_dims * block_dims, block_dims),
266275
[=](sycl::nd_item<3> item_ct1)
267276
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
268-
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
277+
norm_f32(x, dst, ncols,
278+
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
279+
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
280+
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
269281
});
270282
});
271283
}
@@ -319,7 +331,9 @@ static void group_norm_f32_sycl(const float* x, float* dst,
319331
}
320332

321333
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
322-
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
334+
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
335+
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
336+
const float eps, queue_ptr stream, int device) {
323337
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
324338

325339
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -330,7 +344,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
330344
sycl::nd_range<3>(global_dims * block_dims, block_dims),
331345
[=](sycl::nd_item<3> item_ct1)
332346
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
333-
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
347+
rms_norm_f32(x, dst, ncols,
348+
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
349+
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
350+
eps, item_ct1, nullptr, WARP_SIZE);
334351
});
335352
});
336353
}
@@ -350,7 +367,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
350367
sycl::nd_range<3>(global_dims * block_dims, block_dims),
351368
[=](sycl::nd_item<3> item_ct1)
352369
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
353-
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
370+
rms_norm_f32(x, dst, ncols,
371+
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
372+
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
373+
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
354374
});
355375
});
356376
}
@@ -363,9 +383,14 @@ static void l2_norm_f32_sycl(const float * x,
363383
const int nrows,
364384
const int nchannels,
365385
const int nsamples,
366-
const int64_t stride_row,
367-
const int64_t stride_channel,
368-
const int64_t stride_sample,
386+
const int64_t src_stride_col,
387+
const int64_t src_stride_row,
388+
const int64_t src_stride_channel,
389+
const int64_t src_stride_sample,
390+
const int64_t dst_stride_col,
391+
const int64_t dst_stride_row,
392+
const int64_t dst_stride_channel,
393+
const int64_t dst_stride_sample,
369394
const float eps,
370395
queue_ptr stream,
371396
int device) {
@@ -379,7 +404,10 @@ static void l2_norm_f32_sycl(const float * x,
379404
block_dims),
380405
[=](sycl::nd_item<3> item_ct1)
381406
[[sycl::reqd_sub_group_size(warp_size)]] {
382-
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
407+
l2_norm_f32<warp_size>(x, dst, ncols,
408+
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
409+
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
410+
eps, item_ct1,
383411
nullptr, warp_size);
384412
});
385413
});
@@ -398,7 +426,9 @@ static void l2_norm_f32_sycl(const float * x,
398426
block_dims),
399427
[=](sycl::nd_item<3> item_ct1)
400428
[[sycl::reqd_sub_group_size(warp_size)]] {
401-
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
429+
l2_norm_f32<warp_size>(x, dst, ncols,
430+
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
431+
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
402432
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
403433
});
404434
});
@@ -421,12 +451,20 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
421451
memcpy(&eps, dst->op_params, sizeof(float));
422452
GGML_ASSERT(eps >= 0.0f);
423453
const size_t ts0 = ggml_type_size(src0->type);
424-
GGML_ASSERT(nb00 == ts0);
425-
const int64_t s01 = nb01 / ts0;
426-
const int64_t s02 = nb02 / ts0;
427-
const int64_t s03 = nb03 / ts0;
428-
429-
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
454+
const size_t tdst = ggml_type_size(dst->type);
455+
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
456+
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
457+
const int64_t ss0 = nb00 / ts0;
458+
const int64_t ss1 = nb01 / ts0;
459+
const int64_t ss2 = nb02 / ts0;
460+
const int64_t ss3 = nb03 / ts0;
461+
const int64_t ds0 = nb0 / tdst;
462+
const int64_t ds1 = nb1 / tdst;
463+
const int64_t ds2 = nb2 / tdst;
464+
const int64_t ds3 = nb3 / tdst;
465+
466+
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
467+
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
430468
}
431469

432470
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
@@ -465,11 +503,19 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
465503

466504
GGML_TENSOR_UNARY_OP_LOCALS
467505
const size_t ts0 = ggml_type_size(src0->type);
468-
GGML_ASSERT(nb00 == ts0);
469-
const int64_t s01 = nb01 / ts0;
470-
const int64_t s02 = nb02 / ts0;
471-
const int64_t s03 = nb03 / ts0;
472-
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
506+
const size_t tdst = ggml_type_size(dst->type);
507+
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
508+
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
509+
const int64_t ss0 = nb00 / ts0;
510+
const int64_t ss1 = nb01 / ts0;
511+
const int64_t ss2 = nb02 / ts0;
512+
const int64_t ss3 = nb03 / ts0;
513+
const int64_t ds0 = nb0 / tdst;
514+
const int64_t ds1 = nb1 / tdst;
515+
const int64_t ds2 = nb2 / tdst;
516+
const int64_t ds3 = nb3 / tdst;
517+
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
518+
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
473519
}
474520

475521
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -644,13 +690,21 @@ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
644690
GGML_ASSERT(eps >= 0.0f);
645691

646692
const size_t ts0 = ggml_type_size(src0->type);
647-
GGML_ASSERT(nb00 == ts0);
648-
const int64_t s01 = nb01 / ts0;
649-
const int64_t s02 = nb02 / ts0;
650-
const int64_t s03 = nb03 / ts0;
693+
const size_t tdst = ggml_type_size(dst->type);
694+
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
695+
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
696+
const int64_t ss0 = nb00 / ts0;
697+
const int64_t ss1 = nb01 / ts0;
698+
const int64_t ss2 = nb02 / ts0;
699+
const int64_t ss3 = nb03 / ts0;
700+
const int64_t ds0 = nb0 / tdst;
701+
const int64_t ds1 = nb1 / tdst;
702+
const int64_t ds2 = nb2 / tdst;
703+
const int64_t ds3 = nb3 / tdst;
651704

652705
/*support both WARP_SIZE or WARP_32_SIZE in code
653706
choose by hardware for better performance
654707
*/
655-
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
708+
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03,
709+
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, stream, ctx.device);
656710
}

0 commit comments

Comments
 (0)