Skip to content

Commit ead551a

Browse files
Updated trans logic
1 parent b03c620 commit ead551a

1 file changed

Lines changed: 15 additions & 8 deletions

File tree

  • clang/runtime/dpct-rt/include/dpct

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,9 +2057,11 @@ void ldmatrix(uintptr_t addr, T *m,
20572057
int group = lane / 8;
20582058
int sub = lane % 8;
20592059
int src_base = group * 2;
2060-
int src_lane = (sub / 4) ? (src_base + 1) : src_base;
20612060

20622061
if (!trans) {
2062+
// calculate the source lane
2063+
int src_lane = (sub / 4) ? (src_base + 1) : src_base;
2064+
20632065
// Broadcast the address from the source lane
20642066
auto recv_addr_uintp = dpct::select_from_sub_group(
20652067
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
@@ -2073,16 +2075,21 @@ void ldmatrix(uintptr_t addr, T *m,
20732075
uint16_t bits1 = sycl::bit_cast<unsigned short, sycl::half>(val1);
20742076
*m = ((uint32_t)bits1 << 16) | bits0;
20752077
} else {
2078+
// calculate the source lane
2079+
int src_lane = (lane % 4) * 2;
2080+
20762081
// Broadcast the address from the source lane:
2077-
auto recv_addr_uintp = dpct::select_from_sub_group(
2078-
item_ct1.get_sub_group(), addr, mat * 8);
2079-
auto recv_addr = reinterpret_cast<sycl::half *>(recv_addr_uintp);
2080-
recv_addr += src_lane;
2082+
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2083+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2084+
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2085+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane + 1);
2086+
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
2087+
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
20812088

20822089
// Transposed load
2083-
int index = (lane % 4) * 8 * 2;
2084-
sycl::half val0 = recv_addr[index];
2085-
sycl::half val1 = recv_addr[index + 8];
2090+
int index = (lane / 4);
2091+
sycl::half val0 = recv_addr_1[index];
2092+
sycl::half val1 = recv_addr_2[index];
20862093
uint16_t bits0 = sycl::bit_cast<unsigned short, sycl::half>(val0);
20872094
uint16_t bits1 = sycl::bit_cast<unsigned short, sycl::half>(val1);
20882095
*m = ((uint32_t)bits1 << 16) | bits0;

0 commit comments

Comments
 (0)