@@ -2048,42 +2048,56 @@ class joint_matrix {
20482048 const size_t num_elements;
20492049};
20502050
2051+ // / Loads an 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2052+ // / \tparam [in] T The type of result variable
2053+ // / \param [in] addr The address of the matrix in shared memory
2054+ // / \param [in] m The local memory to store the matrix
2055+ // / \param [in] item_ct1 The sycl::nd_item object
2056+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
2057+ // / \param [in] mat The matrix index to be loaded
20512058template <typename T>
20522059void ldmatrix (uintptr_t addr, T *m, const sycl::nd_item<3 > &item_ct1,
20532060 bool trans = false , unsigned mat = 0 ) {
2054- int lane = item_ct1.get_local_id (2 );
2061+ int lane = item_ct1.get_local_id (2 ) % 32 ;
20552062
2056- int group = lane / 8 ;
2057- int sub = lane % 8 ;
2058- int src_base = group * 2 ;
2063+ int lane_group8_row = lane / 8 ;
2064+ int lane_group8_col = lane % 8 ;
20592065
20602066 if (!trans) {
20612067 // calculate the source lane
2062- int src_lane = (sub / 4 ) ? (src_base + 1 ) : src_base;
2068+ int src_lane = 2 * lane_group8_row;
2069+ if (lane_group8_col >= 4 )
2070+ src_lane += 1 ;
20632071
20642072 // Broadcast the address from the source lane
20652073 auto recv_addr_uintp = dpct::select_from_sub_group (
20662074 item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
2075+
2076+ // Cast the received address from uintptr_t to the type of 'm'
20672077 auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
20682078
20692079 // Non-transposed load
2070- *m = recv_addr[sub % 4 ];
2080+ *m = recv_addr[lane_group8_col % 4 ];
20712081 } else {
20722082 // calculate the source lane
20732083 int src_lane = (lane % 4 ) * 2 ;
20742084
2075- // Broadcast the address from the source lane:
2085+ // Broadcast the address from the source lane
20762086 auto recv_addr_uintp_1 = dpct::select_from_sub_group (
20772087 item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
20782088 auto recv_addr_uintp_2 = dpct::select_from_sub_group (
20792089 item_ct1.get_sub_group (), addr, mat * 8 + src_lane + 1 );
2090+
2091+ // Cast the received address from uintptr_t to 'half *'
20802092 auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
20812093 auto recv_addr_2 = reinterpret_cast <sycl::half *>(recv_addr_uintp_2);
20822094
20832095 // Transposed load
2084- int index = ( lane / 4 ) ;
2096+ int index = lane / 4 ;
20852097 sycl::half val0 = recv_addr_1[index];
20862098 sycl::half val1 = recv_addr_2[index];
2099+
2100+ // Combine the two 16-bits into one 32-bit value
20872101 sycl::half2 val = sycl::half2 (val0, val1);
20882102 *m = *reinterpret_cast <T *>(&val);
20892103 }
0 commit comments