diff --git a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh index a977c5547..a6f451c7a 100644 --- a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh +++ b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -7,6 +7,13 @@ namespace deep_gemm { +/// Pack exponent bytes from 4 fp32 bit-patterns into one uint32 (UE8M0 format). +__device__ __forceinline__ uint32_t pack_4_fp32_exponents( + uint32_t v0, uint32_t v1, uint32_t v2, uint32_t v3) { + return ((v0 >> 23u) & 0xFFu) | (((v1 >> 23u) & 0xFFu) << 8u) + | (((v2 >> 23u) & 0xFFu) << 16u) | (((v3 >> 23u) & 0xFFu) << 24u); +} + template CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) { @@ -98,11 +105,7 @@ CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, } // Pack and store - uint32_t packed = 0; - packed |= (values[0] >> 23u); - packed |= (values[1] >> 15u); - packed |= (values[2] >> 7u); - packed |= (values[3] << 1u); + uint32_t packed = pack_4_fp32_exponents(values[0], values[1], values[2], values[3]); if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn) out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed; } @@ -178,10 +181,10 @@ CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, // Pack and store uint4 packed; - packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u); - packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u); - packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u); - packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u); + packed.x = pack_4_fp32_exponents(values[0].x, values[1].x, values[2].x, values[3].x); + packed.y = pack_4_fp32_exponents(values[0].y, values[1].y, values[2].y, values[3].y); + packed.z = pack_4_fp32_exponents(values[0].z, values[1].z, values[2].z, values[3].z); + packed.w = pack_4_fp32_exponents(values[0].w, values[1].w, values[2].w, values[3].w); reinterpret_cast(out + packed_sf_k_idx * mn)[mn_idx] = packed; } }