@@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma {
10151015#endif // AMD_MFMA_AVAILABLE
10161016 }
10171017
1018- static __device__ __forceinline__ void mma_block_scaled (tile<16 , 8 , float > & D,
1019- const tile<16 , 8 , int > & A,
1020- const tile<8 , 8 , int > & B,
1021- uint32_t a_scale,
1022- uint32_t b_scale) {
1018+ template <ggml_type type>
1019+ static __device__ __forceinline__ void mma_block_scaled_fp4 (tile<16 , 8 , float > & D,
1020+ const tile<16 , 8 , int > & A,
1021+ const tile<8 , 8 , int > & B,
1022+ uint32_t a_scale,
1023+ uint32_t b_scale) {
10231024#ifdef BLACKWELL_MMA_AVAILABLE
10241025 const int * Axi = (const int *) A.x ;
10251026 const int * Bxi = (const int *) B.x ;
10261027 float * Dxi = (float *) D.x ;
10271028
1028- asm volatile (
1029- " mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
1030- " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1031- " %10, {0, 0}, %11, {0, 0};"
1032- : " +f" (Dxi[0 ]), " +f" (Dxi[1 ]), " +f" (Dxi[2 ]), " +f" (Dxi[3 ])
1033- : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[0 ]), " r" (Bxi[1 ]), " r" (a_scale), " r" (b_scale));
1029+ if constexpr (type == GGML_TYPE_MXFP4) {
1030+ asm volatile (
1031+ " mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
1032+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1033+ " %10, {0, 0}, %11, {0, 0};"
1034+ : " +f" (Dxi[0 ]), " +f" (Dxi[1 ]), " +f" (Dxi[2 ]), " +f" (Dxi[3 ])
1035+ : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[0 ]), " r" (Bxi[1 ]), " r" (a_scale), " r" (b_scale));
1036+ } else {
1037+ asm volatile (
1038+ " mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 "
1039+ " {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
1040+ " %10, {0, 0}, %11, {0, 0};"
1041+ : " +f" (Dxi[0 ]), " +f" (Dxi[1 ]), " +f" (Dxi[2 ]), " +f" (Dxi[3 ])
1042+ : " r" (Axi[0 ]), " r" (Axi[1 ]), " r" (Axi[2 ]), " r" (Axi[3 ]), " r" (Bxi[0 ]), " r" (Bxi[1 ]), " r" (a_scale), " r" (b_scale));
1043+ }
10341044#else
10351045 GGML_UNUSED_VARS (D, A, B, a_scale, b_scale);
1036- #endif // BLACKWELL_MMA_AVAILABLE
1046+ #endif // BLACKWELL_MMA_AVAILABLE
10371047 }
10381048
10391049 static __device__ __forceinline__ void mma (
0 commit comments