@@ -2048,60 +2048,99 @@ class joint_matrix {
20482048 const size_t num_elements;
20492049};
20502050
2051+ // / Loads 1 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2052+ // / Requires the sub-group size of kernel calling this function to be 32
2053+ // / \tparam [in] T The type of result variable
2054+ // / \param [in] addr The address of the matrix in shared memory
2055+ // / \param [in] m The local memory to store the matrix
2056+ // / \param [in] item_ct1 The sycl::nd_item object
2057+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
2058+ // / \param [in] mat The matrix index to be loaded
20512059template <typename T>
20522060void ldmatrix (uintptr_t addr, T *m, const sycl::nd_item<3 > &item_ct1,
20532061 bool trans = false , unsigned mat = 0 ) {
2054- int lane = item_ct1.get_local_id (2 );
2062+ int lane = item_ct1.get_local_id (2 ) % 32 ;
20552063
2056- int group = lane / 8 ;
2057- int sub = lane % 8 ;
2058- int src_base = group * 2 ;
2064+ int lane_group8_row = lane / 8 ;
2065+ int lane_group8_col = lane % 8 ;
20592066
20602067 if (!trans) {
20612068 // calculate the source lane
2062- int src_lane = (sub / 4 ) ? (src_base + 1 ) : src_base;
2069+ int src_lane = 2 * lane_group8_row;
2070+ if (lane_group8_col >= 4 )
2071+ src_lane += 1 ;
20632072
20642073 // Broadcast the address from the source lane
20652074 auto recv_addr_uintp = dpct::select_from_sub_group (
20662075 item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
2076+
2077+ // Cast the received address from uintptr_t to the type of 'm'
20672078 auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
20682079
20692080 // Non-transposed load
2070- *m = recv_addr[sub % 4 ];
2081+ *m = recv_addr[lane_group8_col % 4 ];
20712082 } else {
20722083 // calculate the source lane
20732084 int src_lane = (lane % 4 ) * 2 ;
20742085
2075- // Broadcast the address from the source lane:
2086+ // Broadcast the address from the source lane
20762087 auto recv_addr_uintp_1 = dpct::select_from_sub_group (
20772088 item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
20782089 auto recv_addr_uintp_2 = dpct::select_from_sub_group (
20792090 item_ct1.get_sub_group (), addr, mat * 8 + src_lane + 1 );
2091+
2092+ // Cast the received address from uintptr_t to 'half *'
20802093 auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
20812094 auto recv_addr_2 = reinterpret_cast <sycl::half *>(recv_addr_uintp_2);
20822095
20832096 // Transposed load
2084- int index = ( lane / 4 ) ;
2097+ int index = lane / 4 ;
20852098 sycl::half val0 = recv_addr_1[index];
20862099 sycl::half val1 = recv_addr_2[index];
2100+
2101+ // Combine the two 16-bits into one 32-bit value
20872102 sycl::half2 val = sycl::half2 (val0, val1);
20882103 *m = *reinterpret_cast <T *>(&val);
20892104 }
20902105}
20912106
2107+ // / Loads 2 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2108+ // / Requires the sub-group size of kernel calling this function to be 32
2109+ // / \tparam [in] T The type of result variable
2110+ // / \param [in] addr The address of the matrix in shared memory
2111+ // / \param [in] m1 The local memory to store data of 1st matrix
2112+ // / \param [in] m2 The local memory to store data of 2nd matrix
2113+ // / \param [in] item_ct1 The sycl::nd_item object
2114+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
20922115template <typename T>
20932116void ldmatrix (uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3 > &item_ct1,
20942117 bool trans = false ) {
2118+ // Load 1st matrix
20952119 ldmatrix (addr, m1, item_ct1, trans, 0 );
2120+ // Load 2nd matrix
20962121 ldmatrix (addr, m2, item_ct1, trans, 1 );
20972122}
20982123
2124+ // / Loads 4 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2125+ // / Requires the sub-group size of kernel calling this function to be 32
2126+ // / \tparam [in] T The type of result variable
2127+ // / \param [in] addr The address of the matrix in shared memory
2128+ // / \param [in] m1 The local memory to store data of 1st matrix
2129+ // / \param [in] m2 The local memory to store data of 2nd matrix
2130+ // / \param [in] m3 The local memory to store data of 3rd matrix
2131+ // / \param [in] m4 The local memory to store data of 4th matrix
2132+ // / \param [in] item_ct1 The sycl::nd_item object
2133+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
20992134template <typename T>
21002135void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
21012136 const sycl::nd_item<3 > &item_ct1, bool trans = false ) {
2137+ // Load 1st matrix
21022138 ldmatrix (addr, m1, item_ct1, trans, 0 );
2139+ // Load 2nd matrix
21032140 ldmatrix (addr, m2, item_ct1, trans, 1 );
2141+ // Load 3rd matrix
21042142 ldmatrix (addr, m3, item_ct1, trans, 2 );
2143+ // Load 4th matrix
21052144 ldmatrix (addr, m4, item_ct1, trans, 3 );
21062145}
21072146
0 commit comments