22#include " ggml-sycl/common.hpp"
33#include " ggml-sycl/presets.hpp"
44
5- static void norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
6- const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
5+ static void norm_f32 (const float * x, float * dst, const int ncols,
6+ const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
7+ const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
8+ const float eps, const sycl::nd_item<3 >& item_ct1, sycl::float2* s_sum, int block_size) {
79
810 const int nrows = item_ct1.get_group_range (2 );
911 const int nchannels = item_ct1.get_group_range (1 );
@@ -16,16 +18,16 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
1618 const int tid = item_ct1.get_local_id (2 );
1719 const int nwarps = nthreads / WARP_SIZE ;
1820
19- const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row }, {sample, channel, row});
20- const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols }, {sample, channel, row});
21+ const auto src_offset = calculate_offset<3 >({src_stride_sample, src_stride_channel, src_stride_row }, {sample, channel, row});
22+ const auto dst_offset = calculate_offset<3 >({dst_stride_sample, dst_stride_channel, dst_stride_row }, {sample, channel, row});
2123
22- x += strided_offset ;
23- dst += packed_offset ;
24+ x += src_offset ;
25+ dst += dst_offset ;
2426
2527 sycl::float2 mean_var = sycl::float2 (0 .f , 0 .f );
2628
2729 for (int col = tid; col < ncols; col += block_size) {
28- const float xi = x[col];
30+ const float xi = x[col * src_stride_col ];
2931 mean_var.x () += xi;
3032 mean_var.y () += xi * xi;
3133 }
@@ -54,7 +56,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
5456 const float inv_std = sycl::rsqrt (var + eps);
5557
5658 for (int col = tid; col < ncols; col += block_size) {
57- dst[col] = (x[col] - mean) * inv_std;
59+ dst[col * dst_stride_col ] = (x[col * src_stride_col ] - mean) * inv_std;
5860 }
5961}
6062
@@ -145,8 +147,10 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
145147 }
146148}
147149
148- static void rms_norm_f32 (const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
149- const int64_t stride_sample, const float eps, const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
150+ static void rms_norm_f32 (const float * x, float * dst, const int ncols,
151+ const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
152+ const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
153+ const float eps, const sycl::nd_item<3 >& item_ct1, float * s_sum, int block_size) {
150154
151155 const int nrows = item_ct1.get_group_range (2 );
152156 const int nchannels = item_ct1.get_group_range (1 );
@@ -160,17 +164,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
160164 const int tid = item_ct1.get_local_id (2 );
161165 const int nwarps = nthreads / WARP_SIZE ;
162166
163- const auto strided_offset = calculate_offset<3 >({stride_sample, stride_channel, stride_row }, {sample, channel, row});
164- const auto packed_offset = calculate_offset<3 >({nchannels * nrows * ncols, nrows * ncols, ncols }, {sample, channel, row});
167+ const auto src_offset = calculate_offset<3 >({src_stride_sample, src_stride_channel, src_stride_row }, {sample, channel, row});
168+ const auto dst_offset = calculate_offset<3 >({dst_stride_sample, dst_stride_channel, dst_stride_row }, {sample, channel, row});
165169
166- x += strided_offset ;
167- dst += packed_offset ;
170+ x += src_offset ;
171+ dst += dst_offset ;
168172
169173
170174 float tmp = 0 .0f ; // partial sum for thread in warp
171175
172176 for (int col = tid; col < ncols; col += block_size) {
173- const float xi = x[col];
177+ const float xi = x[col * src_stride_col ];
174178 tmp += xi * xi;
175179 }
176180
@@ -198,14 +202,15 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
198202 const float scale = sycl::rsqrt (mean + eps);
199203
200204 for (int col = tid; col < ncols; col += block_size) {
201- dst[col] = scale * x[col];
205+ dst[col * dst_stride_col ] = scale * x[col * src_stride_col ];
202206 }
203207}
204208
205209template <int warp_size>
206210static void l2_norm_f32 (const float * x, float * dst, const int ncols,
207- const int64_t stride_row, const int64_t stride_channel,
208- const int64_t stride_sample, const float eps,
211+ const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel,
212+ const int64_t src_stride_sample, const int64_t dst_stride_col, const int64_t dst_stride_row,
213+ const int64_t dst_stride_channel, const int64_t dst_stride_sample, const float eps,
209214 const sycl::nd_item<3 >& item_ct1, float * s_sum, const int block_size) {
210215 const int nrows = item_ct1.get_group_range (2 );
211216 const int nchannels = item_ct1.get_group_range (1 );
@@ -215,26 +220,27 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
215220 const int sample = item_ct1.get_group (0 );
216221 const int tid = item_ct1.get_local_id (2 );
217222
218- x += sample*stride_sample + channel*stride_channel + row*stride_row ;
219- dst += (( sample*nchannels + channel)*nrows + row)*ncols ;
223+ x += sample*src_stride_sample + channel*src_stride_channel + row*src_stride_row ;
224+ dst += sample*dst_stride_sample + channel*dst_stride_channel + row*dst_stride_row ;
220225
221226 float tmp = 0 .0f ; // partial sum for thread in warp
222227
223228 for (int col = tid; col < ncols; col += block_size) {
224- const float xi = x[col];
229+ const float xi = x[col * src_stride_col ];
225230 tmp += xi * xi;
226231 }
227232
228233 tmp = block_reduce<block_reduce_method::SUM , warp_size>(tmp, s_sum, block_size);
229234 const float scale = sycl::rsqrt (sycl::fmax (tmp, eps * eps));
230235
231236 for (int col = tid; col < ncols; col += block_size) {
232- dst[col] = scale * x[col];
237+ dst[col * dst_stride_col ] = scale * x[col * src_stride_col ];
233238 }
234239}
235240
236241static void norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
237- const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
242+ const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
243+ const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
238244 const float eps, queue_ptr stream, int device) {
239245
240246 const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
@@ -245,7 +251,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
245251 sycl::nd_range<3 >(global_dims * block_dims, block_dims),
246252 [=](sycl::nd_item<3 > item_ct1)
247253 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
248- norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE );
254+ norm_f32 (x, dst, ncols,
255+ src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
256+ dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
257+ eps, item_ct1, nullptr , WARP_SIZE );
249258 });
250259 });
251260 }
@@ -265,7 +274,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
265274 sycl::nd_range<3 >(global_dims * block_dims, block_dims),
266275 [=](sycl::nd_item<3 > item_ct1)
267276 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
268- norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
277+ norm_f32 (x, dst, ncols,
278+ src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
279+ dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
280+ eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
269281 });
270282 });
271283 }
@@ -319,7 +331,9 @@ static void group_norm_f32_sycl(const float* x, float* dst,
319331}
320332
321333static void rms_norm_f32_sycl (const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
322- const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
334+ const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
335+ const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
336+ const float eps, queue_ptr stream, int device) {
323337 // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
324338
325339 const sycl::range<3 > global_dims (nsamples, nchannels, nrows);
@@ -330,7 +344,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
330344 sycl::nd_range<3 >(global_dims * block_dims, block_dims),
331345 [=](sycl::nd_item<3 > item_ct1)
332346 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
333- rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr , WARP_SIZE );
347+ rms_norm_f32 (x, dst, ncols,
348+ src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
349+ dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
350+ eps, item_ct1, nullptr , WARP_SIZE );
334351 });
335352 });
336353 }
@@ -350,7 +367,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
350367 sycl::nd_range<3 >(global_dims * block_dims, block_dims),
351368 [=](sycl::nd_item<3 > item_ct1)
352369 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
353- rms_norm_f32 (x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
370+ rms_norm_f32 (x, dst, ncols,
371+ src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
372+ dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
373+ eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
354374 });
355375 });
356376 }
@@ -363,9 +383,14 @@ static void l2_norm_f32_sycl(const float * x,
363383 const int nrows,
364384 const int nchannels,
365385 const int nsamples,
366- const int64_t stride_row,
367- const int64_t stride_channel,
368- const int64_t stride_sample,
386+ const int64_t src_stride_col,
387+ const int64_t src_stride_row,
388+ const int64_t src_stride_channel,
389+ const int64_t src_stride_sample,
390+ const int64_t dst_stride_col,
391+ const int64_t dst_stride_row,
392+ const int64_t dst_stride_channel,
393+ const int64_t dst_stride_sample,
369394 const float eps,
370395 queue_ptr stream,
371396 int device) {
@@ -379,7 +404,10 @@ static void l2_norm_f32_sycl(const float * x,
379404 block_dims),
380405 [=](sycl::nd_item<3 > item_ct1)
381406 [[sycl::reqd_sub_group_size(warp_size)]] {
382- l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
407+ l2_norm_f32<warp_size>(x, dst, ncols,
408+ src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
409+ dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
410+ eps, item_ct1,
383411 nullptr , warp_size);
384412 });
385413 });
@@ -398,7 +426,9 @@ static void l2_norm_f32_sycl(const float * x,
398426 block_dims),
399427 [=](sycl::nd_item<3 > item_ct1)
400428 [[sycl::reqd_sub_group_size(warp_size)]] {
401- l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
429+ l2_norm_f32<warp_size>(x, dst, ncols,
430+ src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
431+ dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
402432 eps, item_ct1, get_pointer (s_sum_acc_ct1), work_group_size);
403433 });
404434 });
@@ -421,12 +451,20 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
421451 memcpy (&eps, dst->op_params , sizeof (float ));
422452 GGML_ASSERT (eps >= 0 .0f );
423453 const size_t ts0 = ggml_type_size (src0->type );
424- GGML_ASSERT (nb00 == ts0);
425- const int64_t s01 = nb01 / ts0;
426- const int64_t s02 = nb02 / ts0;
427- const int64_t s03 = nb03 / ts0;
428-
429- norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
454+ const size_t tdst = ggml_type_size (dst->type );
455+ GGML_ASSERT (nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0 );
456+ GGML_ASSERT (nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0 );
457+ const int64_t ss0 = nb00 / ts0;
458+ const int64_t ss1 = nb01 / ts0;
459+ const int64_t ss2 = nb02 / ts0;
460+ const int64_t ss3 = nb03 / ts0;
461+ const int64_t ds0 = nb0 / tdst;
462+ const int64_t ds1 = nb1 / tdst;
463+ const int64_t ds2 = nb2 / tdst;
464+ const int64_t ds3 = nb3 / tdst;
465+
466+ norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03,
467+ ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device );
430468}
431469
432470void ggml_sycl_op_group_norm (ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
@@ -465,11 +503,19 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
465503
466504 GGML_TENSOR_UNARY_OP_LOCALS
467505 const size_t ts0 = ggml_type_size (src0->type );
468- GGML_ASSERT (nb00 == ts0);
469- const int64_t s01 = nb01 / ts0;
470- const int64_t s02 = nb02 / ts0;
471- const int64_t s03 = nb03 / ts0;
472- rms_norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device );
506+ const size_t tdst = ggml_type_size (dst->type );
507+ GGML_ASSERT (nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0 );
508+ GGML_ASSERT (nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0 );
509+ const int64_t ss0 = nb00 / ts0;
510+ const int64_t ss1 = nb01 / ts0;
511+ const int64_t ss2 = nb02 / ts0;
512+ const int64_t ss3 = nb03 / ts0;
513+ const int64_t ds0 = nb0 / tdst;
514+ const int64_t ds1 = nb1 / tdst;
515+ const int64_t ds2 = nb2 / tdst;
516+ const int64_t ds3 = nb3 / tdst;
517+ rms_norm_f32_sycl (src0_dd, dst_dd, ne00, ne01, ne02, ne03,
518+ ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device );
473519}
474520
475521void ggml_sycl_op_rms_norm_back (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -644,13 +690,21 @@ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
644690 GGML_ASSERT (eps >= 0 .0f );
645691
646692 const size_t ts0 = ggml_type_size (src0->type );
647- GGML_ASSERT (nb00 == ts0);
648- const int64_t s01 = nb01 / ts0;
649- const int64_t s02 = nb02 / ts0;
650- const int64_t s03 = nb03 / ts0;
693+ const size_t tdst = ggml_type_size (dst->type );
694+ GGML_ASSERT (nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0 );
695+ GGML_ASSERT (nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0 );
696+ const int64_t ss0 = nb00 / ts0;
697+ const int64_t ss1 = nb01 / ts0;
698+ const int64_t ss2 = nb02 / ts0;
699+ const int64_t ss3 = nb03 / ts0;
700+ const int64_t ds0 = nb0 / tdst;
701+ const int64_t ds1 = nb1 / tdst;
702+ const int64_t ds2 = nb2 / tdst;
703+ const int64_t ds3 = nb3 / tdst;
651704
652705 /* support both WARP_SIZE or WARP_32_SIZE in code
653706 choose by hardware for better performance
654707 */
655- l2_norm_f32_sycl<WARP_SIZE >(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device );
708+ l2_norm_f32_sycl<WARP_SIZE >(src0_d, dst_d, ne00, ne01, ne02, ne03,
709+ ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, stream, ctx.device );
656710}
0 commit comments