@@ -2057,9 +2057,11 @@ void ldmatrix(uintptr_t addr, T *m,
20572057 int group = lane / 8 ;
20582058 int sub = lane % 8 ;
20592059 int src_base = group * 2 ;
2060- int src_lane = (sub / 4 ) ? (src_base + 1 ) : src_base;
20612060
20622061 if (!trans) {
2062+ // calculate the source lane
2063+ int src_lane = (sub / 4 ) ? (src_base + 1 ) : src_base;
2064+
20632065 // Broadcast the address from the source lane
20642066 auto recv_addr_uintp = dpct::select_from_sub_group (
20652067 item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
@@ -2073,16 +2075,21 @@ void ldmatrix(uintptr_t addr, T *m,
20732075 uint16_t bits1 = sycl::bit_cast<unsigned short , sycl::half>(val1);
20742076 *m = ((uint32_t )bits1 << 16 ) | bits0;
20752077 } else {
2078+ // calculate the source lane
2079+ int src_lane = (lane % 4 ) * 2 ;
2080+
20762081 // Broadcast the address from the source lane:
2077- auto recv_addr_uintp = dpct::select_from_sub_group (
2078- item_ct1.get_sub_group (), addr, mat * 8 );
2079- auto recv_addr = reinterpret_cast <sycl::half *>(recv_addr_uintp);
2080- recv_addr += src_lane;
2082+ auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2083+ item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
2084+ auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2085+ item_ct1.get_sub_group (), addr, mat * 8 + src_lane + 1 );
2086+ auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
2087+ auto recv_addr_2 = reinterpret_cast <sycl::half *>(recv_addr_uintp_2);
20812088
20822089 // Transposed load
2083- int index = (lane % 4 ) * 8 * 2 ;
2084- sycl::half val0 = recv_addr [index];
2085- sycl::half val1 = recv_addr [index + 8 ];
2090+ int index = (lane / 4 );
2091+ sycl::half val0 = recv_addr_1 [index];
2092+ sycl::half val1 = recv_addr_2 [index];
20862093 uint16_t bits0 = sycl::bit_cast<unsigned short , sycl::half>(val0);
20872094 uint16_t bits1 = sycl::bit_cast<unsigned short , sycl::half>(val1);
20882095 *m = ((uint32_t )bits1 << 16 ) | bits0;
0 commit comments