Skip to content

Commit 4492ec9

Browse files
Added comments for the helper functions
1 parent 0cf075b commit 4492ec9

1 file changed

Lines changed: 47 additions & 8 deletions

File tree

  • clang/runtime/dpct-rt/include/dpct

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,60 +2048,99 @@ class joint_matrix {
20482048
const size_t num_elements;
20492049
};
20502050

2051+
/// Loads 1 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2052+
/// Requires the sub-group size of kernel calling this function to be 32
2053+
/// \tparam [in] T The type of result variable
2054+
/// \param [in] addr The address of the matrix in shared memory
2055+
/// \param [in] m The local memory to store the matrix
2056+
/// \param [in] item_ct1 The sycl::nd_item object
2057+
/// \param [in] trans Indicates whether the matrix to be loaded transposed
2058+
/// \param [in] mat The matrix index to be loaded
20512059
template <typename T>
20522060
void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
20532061
bool trans = false, unsigned mat = 0) {
2054-
int lane = item_ct1.get_local_id(2);
2062+
int lane = item_ct1.get_local_id(2) % 32;
20552063

2056-
int group = lane / 8;
2057-
int sub = lane % 8;
2058-
int src_base = group * 2;
2064+
int lane_group8_row = lane / 8;
2065+
int lane_group8_col = lane % 8;
20592066

20602067
if (!trans) {
20612068
// calculate the source lane
2062-
int src_lane = (sub / 4) ? (src_base + 1) : src_base;
2069+
int src_lane = 2 * lane_group8_row;
2070+
if (lane_group8_col >= 4)
2071+
src_lane += 1;
20632072

20642073
// Broadcast the address from the source lane
20652074
auto recv_addr_uintp = dpct::select_from_sub_group(
20662075
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2076+
2077+
// Cast the received address from uintptr_t to the type of 'm'
20672078
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
20682079

20692080
// Non-transposed load
2070-
*m = recv_addr[sub % 4];
2081+
*m = recv_addr[lane_group8_col % 4];
20712082
} else {
20722083
// calculate the source lane
20732084
int src_lane = (lane % 4) * 2;
20742085

2075-
// Broadcast the address from the source lane:
2086+
// Broadcast the address from the source lane
20762087
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
20772088
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
20782089
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
20792090
item_ct1.get_sub_group(), addr, mat * 8 + src_lane + 1);
2091+
2092+
// Cast the received address from uintptr_t to 'half *'
20802093
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
20812094
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
20822095

20832096
// Transposed load
2084-
int index = (lane / 4);
2097+
int index = lane / 4;
20852098
sycl::half val0 = recv_addr_1[index];
20862099
sycl::half val1 = recv_addr_2[index];
2100+
2101+
// Combine the two 16-bits into one 32-bit value
20872102
sycl::half2 val = sycl::half2(val0, val1);
20882103
*m = *reinterpret_cast<T *>(&val);
20892104
}
20902105
}
20912106

2107+
/// Loads 2 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2108+
/// Requires the sub-group size of kernel calling this function to be 32
2109+
/// \tparam [in] T The type of result variable
2110+
/// \param [in] addr The address of the matrix in shared memory
2111+
/// \param [in] m1 The local memory to store data of 1st matrix
2112+
/// \param [in] m2 The local memory to store data of 2nd matrix
2113+
/// \param [in] item_ct1 The sycl::nd_item object
2114+
/// \param [in] trans Indicates whether the matrix to be loaded transposed
20922115
template <typename T>
20932116
void ldmatrix(uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3> &item_ct1,
20942117
bool trans = false) {
2118+
// Load 1st matrix
20952119
ldmatrix(addr, m1, item_ct1, trans, 0);
2120+
// Load 2nd matrix
20962121
ldmatrix(addr, m2, item_ct1, trans, 1);
20972122
}
20982123

2124+
/// Loads 4 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2125+
/// Requires the sub-group size of kernel calling this function to be 32
2126+
/// \tparam [in] T The type of result variable
2127+
/// \param [in] addr The address of the matrix in shared memory
2128+
/// \param [in] m1 The local memory to store data of 1st matrix
2129+
/// \param [in] m2 The local memory to store data of 2nd matrix
2130+
/// \param [in] m3 The local memory to store data of 3rd matrix
2131+
/// \param [in] m4 The local memory to store data of 4th matrix
2132+
/// \param [in] item_ct1 The sycl::nd_item object
2133+
/// \param [in] trans Indicates whether the matrix to be loaded transposed
20992134
template <typename T>
21002135
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
21012136
const sycl::nd_item<3> &item_ct1, bool trans = false) {
2137+
// Load 1st matrix
21022138
ldmatrix(addr, m1, item_ct1, trans, 0);
2139+
// Load 2nd matrix
21032140
ldmatrix(addr, m2, item_ct1, trans, 1);
2141+
// Load 3rd matrix
21042142
ldmatrix(addr, m3, item_ct1, trans, 2);
2143+
// Load 4th matrix
21052144
ldmatrix(addr, m4, item_ct1, trans, 3);
21062145
}
21072146

0 commit comments

Comments
 (0)