2121
2222#include " ../utils/cuda_utils.cuh"
2323#include " ../utils/verify.cuh"
24+ #include < climits>
2425#include < cuda_fp16.h>
2526#include < cuda_runtime.h>
2627#include < functional>
@@ -46,14 +47,7 @@ using tensor_core::WMMA_N;
4647/* *
4748 * 检查当前设备是否支持 Tensor Core (sm_70+)
4849 */
49- inline bool tensorCoresAvailable () {
50- int device;
51- CUDA_CHECK (cudaGetDevice (&device));
52-
53- cudaDeviceProp prop;
54- CUDA_CHECK (cudaGetDeviceProperties (&prop, device));
55- return prop.major >= 7 ;
56- }
50+ inline bool tensorCoresAvailable () { return DeviceInfoCache::instance ().hasTensorCores (); }
5751
5852/* *
5953 * 检查给定维度是否适合 Tensor Core 加速
@@ -67,11 +61,7 @@ inline bool tensorCoreDimensionsSupported(int M, int K, int N) {
6761 * 获取当前设备的 Tensor Core 信息字符串
6862 */
6963inline const char *getTensorCoreArchName () {
70- int device;
71- CUDA_CHECK (cudaGetDevice (&device));
72-
73- cudaDeviceProp prop;
74- CUDA_CHECK (cudaGetDeviceProperties (&prop, device));
64+ const cudaDeviceProp &prop = DeviceInfoCache::instance ().prop ();
7565
7666 if (prop.major == 7 ) {
7767 return (prop.minor == 0 ) ? " Volta" : (prop.minor == 5 ) ? " Turing" : " Unknown sm_7x" ;
@@ -101,7 +91,8 @@ using FallbackKernel =
10191 *
10292 * 提供一个空的 fallback(用于测试或显式配置场景)
10393 */
104- inline void nullFallback (const float *, const float *, float *, int , int , int , cudaStream_t = 0 ) {
94+ [[maybe_unused]] inline void
95+ nullFallback (const float *, const float *, float *, int , int , int , cudaStream_t = 0 ) {
10596 // 空实现 - 用于测试
10697}
10798
@@ -179,7 +170,7 @@ __global__ void tensor_core_sgemm_kernel_fp16(const half *__restrict__ A,
179170 */
180171inline void launch_tensor_core_sgemm_fp16_fast_path (const half *A, const half *B, float *C, int M,
181172 int K, int N, cudaStream_t stream = 0 ) {
182- dim3 blockDim (32 , 1 );
173+ dim3 blockDim (kDefaultTileSize , 1 );
183174 dim3 gridDim ((N + WMMA_N - 1 ) / WMMA_N, (M + WMMA_M - 1 ) / WMMA_M);
184175
185176 tensor_core_sgemm_kernel_fp16<<<gridDim , blockDim , 0 , stream>>> (A, B, C, M, K, N);
@@ -255,9 +246,17 @@ inline void launch_tensor_core_sgemm_with_fallback(const float *A, const float *
255246 DeviceMemory<half> d_A_fp16 (num_A);
256247 DeviceMemory<half> d_B_fp16 (num_B);
257248
258- int blockSize = 256 ;
259- int gridSizeA = static_cast <int >((num_A + blockSize - 1 ) / blockSize);
260- int gridSizeB = static_cast <int >((num_B + blockSize - 1 ) / blockSize);
249+ int blockSize = kDefaultBlockSize ;
250+ // 安全计算 gridSize,检查溢出
251+ auto safeGridSize = [](size_t num, int blk) -> int {
252+ size_t grid = (num + blk - 1 ) / blk;
253+ if (grid > static_cast <size_t >(INT_MAX)) {
254+ throw CudaError (" Grid size overflow: matrix too large for kernel launch" );
255+ }
256+ return static_cast <int >(grid);
257+ };
258+ int gridSizeA = safeGridSize (num_A, blockSize);
259+ int gridSizeB = safeGridSize (num_B, blockSize);
261260
262261 float_to_half_kernel<<<gridSizeA, blockSize, 0 , stream>>> (A, d_A_fp16.get (),
263262 static_cast <int >(num_A));
0 commit comments