@@ -5761,222 +5761,8 @@ INSTANTIATE_VQ_SCALAR_GEMV_F32(3, 8)
57615761INSTANTIATE_VQ_SCALAR_GEMV_F32(3 , 10 )
57625762INSTANTIATE_VQ_SCALAR_GEMV_F32(4 , 8 )
57635763
5764- // ============================================================================
5765- // Training Kernels (from QLORA-2 branch)
5766- // ============================================================================
5767-
5768- }
5769- }
5770-
5771- // ---- Grouped scalar GEMV launcher ----
5772- template <int K, typename scalar_t >
5773- void kbitGroupedScalarGemv (
5774- const scalar_t * A_concat, const unsigned int * B_packed_all, const unsigned char * B_absmax_all,
5775- const float * codebook, scalar_t * C_concat, const int * expert_offsets, int K_dim, int N, int num_experts
5776- ) {
5777- constexpr int COLS_PER_BLOCK = 4 ;
5778- constexpr int BLOCK_SIZE = 128 ;
5779- int n_groups = (N + COLS_PER_BLOCK - 1 ) / COLS_PER_BLOCK;
5780- dim3 grid (n_groups, num_experts);
5781-
5782- kbit_grouped_scalar_gemv<K, 4 , scalar_t ><<<grid, BLOCK_SIZE>>> (
5783- A_concat, B_packed_all, B_absmax_all, codebook, C_concat, expert_offsets, K_dim, N, num_experts
5784- );
5785- CUDA_CHECK_RETURN (cudaPeekAtLastError ());
5786- }
5787-
5788- // ---- Debug: Simple MMA test kernel ----
5789- // Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
5790- __global__ void test_mma_kernel (const half* __restrict__ A, const half* __restrict__ B, float * __restrict__ C) {
5791- int lane_id = threadIdx .x % 32 ;
5792- int gid = lane_id / 4 ;
5793- int tid = lane_id % 4 ;
5794-
5795- // Load A fragment: A is [16,16] row-major
5796- // m16n8k16 register order (from Turing m16n8k8 decomposition):
5797- // a[0]: row_lo (gid), k_lo (tid*2..tid*2+1)
5798- // a[1]: row_hi (gid+8), k_lo (tid*2..tid*2+1)
5799- // a[2]: row_lo (gid), k_hi (tid*2+8..tid*2+9)
5800- // a[3]: row_hi (gid+8), k_hi (tid*2+8..tid*2+9)
5801- uint32_t frag_a[4 ];
5802- {
5803- half2 h_rlo_klo = __halves2half2 (A[gid * 16 + tid * 2 ], A[gid * 16 + tid * 2 + 1 ]);
5804- half2 h_rhi_klo = __halves2half2 (A[(gid + 8 ) * 16 + tid * 2 ], A[(gid + 8 ) * 16 + tid * 2 + 1 ]);
5805- half2 h_rlo_khi = __halves2half2 (A[gid * 16 + tid * 2 + 8 ], A[gid * 16 + tid * 2 + 9 ]);
5806- half2 h_rhi_khi = __halves2half2 (A[(gid + 8 ) * 16 + tid * 2 + 8 ], A[(gid + 8 ) * 16 + tid * 2 + 9 ]);
5807- frag_a[0 ] = *reinterpret_cast <uint32_t *>(&h_rlo_klo);
5808- frag_a[1 ] = *reinterpret_cast <uint32_t *>(&h_rhi_klo);
5809- frag_a[2 ] = *reinterpret_cast <uint32_t *>(&h_rlo_khi);
5810- frag_a[3 ] = *reinterpret_cast <uint32_t *>(&h_rhi_khi);
5811- }
5812-
5813- // Load B fragment: B is [16,8] row-major. MMA B is col-major, so B_col[k,n] = B_row[k,n].
5814- uint32_t frag_b[2 ];
5815- {
5816- half2 b0 = __halves2half2 (B[(tid * 2 ) * 8 + gid], B[(tid * 2 + 1 ) * 8 + gid]);
5817- half2 b1 = __halves2half2 (B[(tid * 2 + 8 ) * 8 + gid], B[(tid * 2 + 9 ) * 8 + gid]);
5818- frag_b[0 ] = *reinterpret_cast <uint32_t *>(&b0);
5819- frag_b[1 ] = *reinterpret_cast <uint32_t *>(&b1);
5820- }
5821-
5822- float c[4 ] = {0 , 0 , 0 , 0 };
5823- asm volatile (" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
5824- " {%0, %1, %2, %3}, "
5825- " {%4, %5, %6, %7}, "
5826- " {%8, %9}, "
5827- " {%10, %11, %12, %13};\n "
5828- : " =f" (c[0 ]), " =f" (c[1 ]), " =f" (c[2 ]), " =f" (c[3 ])
5829- : " r" (frag_a[0 ]), " r" (frag_a[1 ]), " r" (frag_a[2 ]), " r" (frag_a[3 ]), " r" (frag_b[0 ]), " r" (frag_b[1 ]),
5830- " f" (c[0 ]), " f" (c[1 ]), " f" (c[2 ]), " f" (c[3 ]));
5831-
5832- // Write C[16,8] row-major
5833- C[gid * 8 + tid * 2 ] = c[0 ];
5834- C[gid * 8 + tid * 2 + 1 ] = c[1 ];
5835- C[(gid + 8 ) * 8 + tid * 2 ] = c[2 ];
5836- C[(gid + 8 ) * 8 + tid * 2 + 1 ] = c[3 ];
5837- }
5838-
5839- void testMMA (const half* A, const half* B, float * C) {
5840- test_mma_kernel<<<1 , 32 >>> (A, B, C);
5841- CUDA_CHECK_RETURN (cudaPeekAtLastError ());
5842- }
5843-
5844- // ---- Template instantiations ----
5845-
5846- #define INSTANTIATE_KBIT_QUANT (T, K ) \
5847- template void quantizeBlockwise_kbit<T, K>(const float *, const T*, float *, unsigned int *, int );
5848-
5849- INSTANTIATE_KBIT_QUANT (half, 2 )
5850- INSTANTIATE_KBIT_QUANT(half, 3 )
5851- INSTANTIATE_KBIT_QUANT(half, 4 )
5852- INSTANTIATE_KBIT_QUANT(half, 5 )
5853- INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 2 )
5854- INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 3 )
5855- INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 4 )
5856- INSTANTIATE_KBIT_QUANT(__nv_bfloat16, 5 )
5857- INSTANTIATE_KBIT_QUANT(float , 2 )
5858- INSTANTIATE_KBIT_QUANT(float , 3 )
5859- INSTANTIATE_KBIT_QUANT(float , 4 )
5860- INSTANTIATE_KBIT_QUANT(float , 5 )
5861-
5862- // Dequant instantiations: all output types × absmax types × K values
5863- #define INSTANTIATE_KBIT_DEQUANT (T, K, ABSMAX_T ) \
5864- template void dequantizeBlockwise_kbit<T, K, ABSMAX_T>( \
5865- const unsigned int *, const float *, const ABSMAX_T*, T*, int , cudaStream_t \
5866- );
5867-
5868- // uint8 E4M4 absmax (default)
5869- INSTANTIATE_KBIT_DEQUANT (half, 2 , unsigned char )
5870- INSTANTIATE_KBIT_DEQUANT(half, 3 , unsigned char )
5871- INSTANTIATE_KBIT_DEQUANT(half, 4 , unsigned char )
5872- INSTANTIATE_KBIT_DEQUANT(half, 5 , unsigned char )
5873- INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 2 , unsigned char )
5874- INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 3 , unsigned char )
5875- INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 4 , unsigned char )
5876- INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 5 , unsigned char )
5877- INSTANTIATE_KBIT_DEQUANT(float , 2 , unsigned char )
5878- INSTANTIATE_KBIT_DEQUANT(float , 3 , unsigned char )
5879- INSTANTIATE_KBIT_DEQUANT(float , 4 , unsigned char )
5880- INSTANTIATE_KBIT_DEQUANT(float , 5 , unsigned char )
5881-
5882- // fp16 absmax (option)
5883- INSTANTIATE_KBIT_DEQUANT(half, 2 , half)
5884- INSTANTIATE_KBIT_DEQUANT(half, 3 , half)
5885- INSTANTIATE_KBIT_DEQUANT(half, 4 , half)
5886- INSTANTIATE_KBIT_DEQUANT(half, 5 , half)
5887- INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 2 , half)
5888- INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 3 , half)
5889- INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 4 , half)
5890- INSTANTIATE_KBIT_DEQUANT(__nv_bfloat16, 5 , half)
5891- INSTANTIATE_KBIT_DEQUANT(float , 2 , half)
5892- INSTANTIATE_KBIT_DEQUANT(float , 3 , half)
5893- INSTANTIATE_KBIT_DEQUANT(float , 4 , half)
5894- INSTANTIATE_KBIT_DEQUANT(float , 5 , half)
5895-
5896- // Repack instantiations: one per K value
5897- #define INSTANTIATE_KBIT_REPACK (K ) \
5898- template void repackKbit<K>(const unsigned int *, const float *, unsigned int *, unsigned char *, int , int );
5899-
5900- INSTANTIATE_KBIT_REPACK (2 )
5901- INSTANTIATE_KBIT_REPACK(3 )
5902- INSTANTIATE_KBIT_REPACK(4 )
5903- INSTANTIATE_KBIT_REPACK(5 )
5904-
5905- // GEMM instantiations: one per K value (fp16 only)
5906- #define INSTANTIATE_KBIT_GEMM (K ) \
5907- template void kbitGemmMinimal<K>( \
5908- const half*, const unsigned int *, const unsigned char *, const float *, half*, int , int , int \
5909- ); \
5910- template void kbitGemmPipelined<K>( \
5911- const half*, const unsigned int *, const unsigned char *, const float *, half*, int , int , int \
5912- ); \
5913- template void kbitGemmSplitK<K>( \
5914- const half*, const unsigned int *, const unsigned char *, const float *, half*, float *, int *, int , int , int , int \
5915- );
5916-
5917- INSTANTIATE_KBIT_GEMM (2 )
5918- INSTANTIATE_KBIT_GEMM(3 )
5919- INSTANTIATE_KBIT_GEMM(4 )
5920- INSTANTIATE_KBIT_GEMM(5 )
5921-
5922- // Production kernel instantiations (fp16 and bf16)
5923- #define INSTANTIATE_KBIT_GEMM_PROD (K ) \
5924- template void kbitGemmProd<K, half>( \
5925- const half*, const unsigned int *, const unsigned char *, const float *, half*, float *, int *, int , int , int , int \
5926- ); \
5927- template void kbitGemmProd<K, __nv_bfloat16>( \
5928- const __nv_bfloat16*, const unsigned int *, const unsigned char *, const float *, __nv_bfloat16*, float *, int *, \
5929- int , int , int , int \
5930- );
5931-
5932- INSTANTIATE_KBIT_GEMM_PROD (2 )
5933- INSTANTIATE_KBIT_GEMM_PROD(3 )
5934- INSTANTIATE_KBIT_GEMM_PROD(4 )
5935- INSTANTIATE_KBIT_GEMM_PROD(5 )
5936-
5937- // Grouped expert GEMM instantiations (fp16 and bf16)
5938- #define INSTANTIATE_KBIT_GROUPED_GEMM_PROD (K ) \
5939- template void kbitGroupedGemmProd<K, half>( \
5940- const half*, const unsigned int *, const unsigned char *, const float *, half*, const int *, int , int , int \
5941- ); \
5942- template void kbitGroupedGemmProd<K, __nv_bfloat16>( \
5943- const __nv_bfloat16*, const unsigned int *, const unsigned char *, const float *, __nv_bfloat16*, const int *, \
5944- int , int , int \
5945- );
5946-
5947- INSTANTIATE_KBIT_GROUPED_GEMM_PROD (2 )
5948- INSTANTIATE_KBIT_GROUPED_GEMM_PROD(3 )
5949- INSTANTIATE_KBIT_GROUPED_GEMM_PROD(4 )
5950- INSTANTIATE_KBIT_GROUPED_GEMM_PROD(5 )
5951-
5952- // Scalar GEMV instantiations (fp16 and bf16) — flat layout, float32 absmax, C=1
5953- #define INSTANTIATE_KBIT_SCALAR_GEMV (K ) \
5954- template void kbitScalarGemv<K, half>( \
5955- const half*, const unsigned int *, const float *, const float *, half*, int , int , int \
5956- ); \
5957- template void kbitScalarGemv<K, __nv_bfloat16>( \
5958- const __nv_bfloat16*, const unsigned int *, const float *, const float *, __nv_bfloat16*, int , int , int \
5959- );
5960-
5961- INSTANTIATE_KBIT_SCALAR_GEMV (2 )
5962- INSTANTIATE_KBIT_SCALAR_GEMV(3 )
5963- INSTANTIATE_KBIT_SCALAR_GEMV(4 )
5964- INSTANTIATE_KBIT_SCALAR_GEMV(5 )
5965-
5966- // Grouped scalar GEMV instantiations (fp16 and bf16)
5967- #define INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV (K ) \
5968- template void kbitGroupedScalarGemv<K, half>( \
5969- const half*, const unsigned int *, const unsigned char *, const float *, half*, const int *, int , int , int \
5970- ); \
5971- template void kbitGroupedScalarGemv<K, __nv_bfloat16>( \
5972- const __nv_bfloat16*, const unsigned int *, const unsigned char *, const float *, __nv_bfloat16*, const int *, \
5973- int , int , int \
5974- );
5975-
5976- INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV (2 )
5977- INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV(3 )
5978- INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV(4 )
5979- INSTANTIATE_KBIT_GROUPED_SCALAR_GEMV(5 )
5764+ // NOTE: kbitGroupedScalarGemv was removed (grouped MMA covers all MoE shapes).
5765+ // See commit ac7d6ff.
59805766
59815767// ============================================================================
59825768// Training Kernels: SwiGLU, RMSNorm, RoPE
0 commit comments