Skip to content

Commit bd88679

Browse files
Added comments for the helper function
1 parent 0cf075b commit bd88679

1 file changed

Lines changed: 22 additions & 8 deletions

File tree

  • clang/runtime/dpct-rt/include/dpct

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,42 +2048,56 @@ class joint_matrix {
20482048
const size_t num_elements;
20492049
};
20502050

2051+
/// Loads an 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2052+
/// \tparam [in] T The type of result variable
2053+
/// \param [in] addr The address of the matrix in shared memory
2054+
/// \param [in] m The local memory to store the matrix
2055+
/// \param [in] item_ct1 The sycl::nd_item object
2056+
/// \param [in] trans Indicates whether the matrix to be loaded transposed
2057+
/// \param [in] mat The matrix index to be loaded
20512058
template <typename T>
20522059
void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
20532060
bool trans = false, unsigned mat = 0) {
2054-
int lane = item_ct1.get_local_id(2);
2061+
int lane = item_ct1.get_local_id(2) % 32;
20552062

2056-
int group = lane / 8;
2057-
int sub = lane % 8;
2058-
int src_base = group * 2;
2063+
int lane_group8_row = lane / 8;
2064+
int lane_group8_col = lane % 8;
20592065

20602066
if (!trans) {
20612067
// calculate the source lane
2062-
int src_lane = (sub / 4) ? (src_base + 1) : src_base;
2068+
int src_lane = 2 * lane_group8_row;
2069+
if (lane_group8_col >= 4)
2070+
src_lane += 1;
20632071

20642072
// Broadcast the address from the source lane
20652073
auto recv_addr_uintp = dpct::select_from_sub_group(
20662074
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2075+
2076+
// Cast the received address from uintptr_t to the type of 'm'
20672077
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
20682078

20692079
// Non-transposed load
2070-
*m = recv_addr[sub % 4];
2080+
*m = recv_addr[lane_group8_col % 4];
20712081
} else {
20722082
// calculate the source lane
20732083
int src_lane = (lane % 4) * 2;
20742084

2075-
// Broadcast the address from the source lane:
2085+
// Broadcast the address from the source lane
20762086
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
20772087
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
20782088
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
20792089
item_ct1.get_sub_group(), addr, mat * 8 + src_lane + 1);
2090+
2091+
// Cast the received address from uintptr_t to 'half *'
20802092
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
20812093
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
20822094

20832095
// Transposed load
2084-
int index = (lane / 4);
2096+
int index = lane / 4;
20852097
sycl::half val0 = recv_addr_1[index];
20862098
sycl::half val1 = recv_addr_2[index];
2099+
2100+
// Combine the two 16-bits into one 32-bit value
20872101
sycl::half2 val = sycl::half2(val0, val1);
20882102
*m = *reinterpret_cast<T *>(&val);
20892103
}

0 commit comments

Comments
 (0)