@@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
1111 int ne0, int ne1, int ne2, int ne3,
1212 int ne10, int ne11, int ne12, int ne13,
1313 /* int s0, */ int s1, int s2, int s3,
14- /* int s00,*/ int s01, int s02, int s03,
15- /* int s10,*/ int s11, int s12, int s13,
14+ int s00, int s01, int s02, int s03,
15+ int s10, int s11, int s12, int s13,
1616 const sycl::nd_item<3 > &item_ct1) {
1717 const int i0s = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
1818 item_ct1.get_local_id (2 );
@@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
4444 for (int i0 = i0s; i0 < ne0;
4545 i0 += item_ct1.get_local_range (2 ) * item_ct1.get_group_range (2 )) {
4646 const int i10 = i0 % ne10;
47- dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
47+ dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0*s00 ] : 0 .0f , (float )src1_row[i10*s10 ]);
4848 }
4949}
5050
@@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
5353 int ne0, int ne1, int ne2, int ne3,
5454 int ne10, int ne11, int ne12, int ne13,
5555 /* int s0, */ int s1, int s2, int s3,
56- /* int s00,*/ int s01, int s02, int s03,
57- /* int s10,*/ int s11, int s12, int s13,
56+ int s00, int s01, int s02, int s03,
57+ int s10, int s11, int s12, int s13,
5858 const sycl::nd_item<3 > &item_ct1) {
5959
6060 const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
@@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
8282 dst_t * dst_row = dst + i_dst;
8383
8484 const int i10 = i0 % ne10;
85- dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0] : 0 .0f , (float )src1_row[i10]);
85+ dst_row[i0] = (dst_t )bin_op (src0 ? (float )src0_row[i0*s00 ] : 0 .0f , (float )src1_row[i10*s10 ]);
8686}
8787
8888
@@ -95,7 +95,8 @@ struct bin_bcast_sycl {
9595 const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
9696 const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
9797 const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
98- const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
98+ const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted,
99+ queue_ptr stream) {
99100 int nr0 = ne10 / ne0;
100101 int nr1 = ne11/ne1;
101102 int nr2 = ne12/ne2;
@@ -123,7 +124,7 @@ struct bin_bcast_sycl {
123124 cnb[3 ] *= cne[3 ];
124125 };
125126
126- if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous ) {
127+ if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted ) {
127128 for (int i = 0 ; i < 4 ; i++) {
128129 if (nr[i] != 1 ) {
129130 break ;
@@ -164,7 +165,7 @@ struct bin_bcast_sycl {
164165 size_t nb12 = cnb1[2 ];
165166 size_t nb13 = cnb1[3 ];
166167
167- size_t s0 = nb0 / sizeof (dst_t );
168+ // size_t s0 = nb0 / sizeof(dst_t);
168169 size_t s1 = nb1 / sizeof (dst_t );
169170 size_t s2 = nb2 / sizeof (dst_t );
170171 size_t s3 = nb3 / sizeof (dst_t );
@@ -196,9 +197,6 @@ struct bin_bcast_sycl {
196197 GGML_ASSERT (nb12 % sizeof (src1_t ) == 0 );
197198 GGML_ASSERT (nb13 % sizeof (src1_t ) == 0 );
198199
199- GGML_ASSERT (s0 == 1 );
200- GGML_ASSERT (s10 == 1 );
201-
202200 const int block_size = 128 ;
203201
204202 int64_t hne0 = std::max (ne0/2LL , 1LL );
@@ -232,8 +230,8 @@ struct bin_bcast_sycl {
232230 [=](sycl::nd_item<3 > item_ct1) {
233231 k_bin_bcast_unravel<bin_op>(
234232 src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
235- ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
236- s03, s11, s12, s13, item_ct1);
233+ ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02,
234+ s03, s10, s11, s12, s13, item_ct1);
237235 });
238236 }
239237 } else {
@@ -251,7 +249,7 @@ struct bin_bcast_sycl {
251249 [=](sycl::nd_item<3 > item_ct1) {
252250 k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253251 ne2, ne3, ne10, ne11, ne12, ne13,
254- s1, s2, s3, s01, s02, s03, s11, s12, s13,
252+ s1, s2, s3, s00, s01, s02, s03, s10 , s11, s12, s13,
255253 item_ct1);
256254 });
257255 }
@@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
268266 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
269267 op ()((const float *) src0->data , (const float *) src1->data , (float *) dst->data , ne00, ne01, ne02, ne03, ne10,
270268 ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
271- ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst ), main_stream);
269+ ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_permuted (src0), ggml_is_permuted (src1 ), main_stream);
272270 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
273271 op ()((const sycl::half *) src0->data , (const sycl::half *) src1->data , (sycl::half *) dst->data , ne00, ne01,
274272 ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
275- nb0, nb1, nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst ),
273+ nb0, nb1, nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_permuted (src0), ggml_is_permuted (src1 ),
276274 main_stream);
277275 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
278276 op ()((const sycl::half *) src0->data , (const float *) src1->data , (sycl::half *) dst->data , ne00, ne01, ne02,
279277 ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
280- nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
278+ nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_permuted (src0), ggml_is_permuted (src1),
279+ main_stream);
281280 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
282281 op ()((const int32_t *) src0->data , (const int32_t *) src1->data , (int32_t *) dst->data , ne00, ne01, ne02, ne03,
283282 ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
284- nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
283+ nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_permuted (src0), ggml_is_permuted (src1),
284+ main_stream);
285285 } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
286286 op ()((const int16_t *) src0->data , (const int16_t *) src1->data , (int16_t *) dst->data , ne00, ne01, ne02, ne03,
287287 ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
288- nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
288+ nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_permuted (src0), ggml_is_permuted (src1),
289+ main_stream);
289290 } else {
290291 fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__, ggml_type_name (dst->type ),
291292 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments