11#pragma once
22
3+ #include " ../utils/cuda_utils.cuh"
34#include " bank_conflict_free_sgemm.cuh"
45#include " tiled_sgemm.cuh"
5- #include " ../utils/cuda_utils.cuh"
66#include < cuda_fp16.h>
77#include < cuda_runtime.h>
8+
9+ // WMMA is only available on sm_70+
10+ // When compiling for host (__CUDA_ARCH__ not defined), always include WMMA
11+ // When compiling for device, only include for sm_70+
12+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
813#include < mma.h>
14+ #endif
915
1016namespace tensor_core {
1117inline constexpr int WMMA_M = 16 ;
@@ -18,8 +24,7 @@ using tensor_core::WMMA_M;
1824using tensor_core::WMMA_N ;
1925
2026inline bool tensorCoreDimensionsSupported (int M, int K, int N) {
21- return M > 0 && K > 0 && N > 0 && M % WMMA_M == 0 && K % WMMA_K == 0 &&
22- N % WMMA_N == 0 ;
27+ return M > 0 && K > 0 && N > 0 && M % WMMA_M == 0 && K % WMMA_K == 0 && N % WMMA_N == 0 ;
2328}
2429
2530inline bool tensorCoresAvailable () {
@@ -34,14 +39,16 @@ inline bool tensorCoresAvailable() {
3439/* *
3540 * Kernel to convert FP32 to FP16
3641 */
37- __global__ void float_to_half_kernel (const float *__restrict__ input,
38- half * __restrict__ output, int size) {
42+ __global__ void float_to_half_kernel (const float *__restrict__ input, half * __restrict__ output,
43+ int size) {
3944 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
4045 if (idx < size) {
4146 output[idx] = __float2half (input[idx]);
4247 }
4348}
4449
50+ // WMMA kernel is only available on sm_70+
51+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
4552/* *
4653 * Basic Tensor Core SGEMM Kernel
4754 *
@@ -52,9 +59,8 @@ __global__ void float_to_half_kernel(const float *__restrict__ input,
5259 * Callers must validate dimensions before launching it.
5360 */
5461__global__ void tensor_core_sgemm_kernel_fp16 (const half *__restrict__ A,
55- const half *__restrict__ B,
56- float *__restrict__ C, int M,
57- int K, int N) {
62+ const half *__restrict__ B, float *__restrict__ C,
63+ int M, int K, int N) {
5864 int warpM = blockIdx .y ;
5965 int warpN = blockIdx .x ;
6066
@@ -71,9 +77,7 @@ __global__ void tensor_core_sgemm_kernel_fp16(const half *__restrict__ A,
7177 nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMA_M , WMMA_N , WMMA_K , half,
7278 nvcuda::wmma::row_major>
7379 b_frag;
74- nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M , WMMA_N , WMMA_K ,
75- float >
76- c_frag;
80+ nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M , WMMA_N , WMMA_K , float > c_frag;
7781
7882 nvcuda::wmma::fill_fragment (c_frag, 0 .0f );
7983
@@ -83,30 +87,32 @@ __global__ void tensor_core_sgemm_kernel_fp16(const half *__restrict__ A,
8387 nvcuda::wmma::mma_sync (c_frag, a_frag, b_frag, c_frag);
8488 }
8589
86- nvcuda::wmma::store_matrix_sync (C + aRow * N + bCol, c_frag, N,
87- nvcuda::wmma::mem_row_major);
90+ nvcuda::wmma::store_matrix_sync (C + aRow * N + bCol, c_frag, N, nvcuda::wmma::mem_row_major);
8891}
8992
90- inline void launch_tensor_core_sgemm_fp16_fast_path (const half *A, const half *B,
91- float *C, int M, int K,
92- int N,
93- cudaStream_t stream = 0 ) {
93+ inline void launch_tensor_core_sgemm_fp16_fast_path (const half *A, const half *B, float *C, int M,
94+ int K, int N, cudaStream_t stream = 0 ) {
9495 dim3 blockDim (32 , 1 );
9596 dim3 gridDim ((N + WMMA_N - 1 ) / WMMA_N , (M + WMMA_M - 1 ) / WMMA_M );
9697
97- tensor_core_sgemm_kernel_fp16<<<gridDim , blockDim , 0 , stream>>> (A, B, C, M, K,
98- N);
98+ tensor_core_sgemm_kernel_fp16<<<gridDim , blockDim , 0 , stream>>> (A, B, C, M, K, N);
9999
100100 CUDA_CHECK (cudaGetLastError ());
101101}
102+ #else
103+ // Stub implementations for older architectures (will not be called)
104+ inline void launch_tensor_core_sgemm_fp16_fast_path (const half *, const half *, float *, int , int ,
105+ int , cudaStream_t) {
106+ // This function should never be called on pre-sm_70 GPUs
107+ }
108+ #endif
102109
103110/* *
104111 * Launch wrapper for Tensor Core SGEMM
105112 * Handles FP32 to FP16 conversion internally and safely falls back when WMMA
106113 * constraints are not met.
107114 */
108- inline void launch_tensor_core_sgemm (const float *A, const float *B, float *C,
109- int M, int K, int N,
115+ inline void launch_tensor_core_sgemm (const float *A, const float *B, float *C, int M, int K, int N,
110116 cudaStream_t stream = 0 ) {
111117 if (M <= 0 || K <= 0 || N <= 0 ) {
112118 return ;
@@ -124,31 +130,26 @@ inline void launch_tensor_core_sgemm(const float *A, const float *B, float *C,
124130 int gridSizeA = (M * K + blockSize - 1 ) / blockSize;
125131 int gridSizeB = (K * N + blockSize - 1 ) / blockSize;
126132
127- float_to_half_kernel<<<gridSizeA, blockSize, 0 , stream>>> (A, d_A_fp16.get (),
128- M * K);
129- float_to_half_kernel<<<gridSizeB, blockSize, 0 , stream>>> (B, d_B_fp16.get (),
130- K * N);
133+ float_to_half_kernel<<<gridSizeA, blockSize, 0 , stream>>> (A, d_A_fp16.get (), M * K);
134+ float_to_half_kernel<<<gridSizeB, blockSize, 0 , stream>>> (B, d_B_fp16.get (), K * N);
131135 CUDA_CHECK (cudaGetLastError ());
132136
133- launch_tensor_core_sgemm_fp16_fast_path (d_A_fp16.get (), d_B_fp16.get (), C, M,
134- K, N, stream);
137+ launch_tensor_core_sgemm_fp16_fast_path (d_A_fp16.get (), d_B_fp16.get (), C, M, K, N, stream);
135138}
136139
137140/* *
138141 * Tensor Core SGEMM with pre-converted FP16 inputs.
139142 * Falls back to a safe FP32 kernel when the WMMA fast path is not applicable.
140143 */
141- inline void launch_tensor_core_sgemm_fp16 (const half *A, const half *B, float *C,
142- int M, int K, int N,
143- cudaStream_t stream = 0 ) {
144+ inline void launch_tensor_core_sgemm_fp16 (const half *A, const half *B, float *C, int M, int K,
145+ int N, cudaStream_t stream = 0 ) {
144146 if (M <= 0 || K <= 0 || N <= 0 ) {
145147 return ;
146148 }
147149
148150 if (!tensorCoresAvailable () || !tensorCoreDimensionsSupported (M, K, N)) {
149- throw CudaError (
150- " launch_tensor_core_sgemm_fp16 requires sm_70+ and dimensions aligned "
151- " to 16" );
151+ throw CudaError (" launch_tensor_core_sgemm_fp16 requires sm_70+ and dimensions aligned "
152+ " to 16" );
152153 }
153154
154155 launch_tensor_core_sgemm_fp16_fast_path (A, B, C, M, K, N, stream);
0 commit comments