11#include " gemm.cuh"
22#include " ../common/cuda_check.cuh"
3+ #include < stdexcept>
34
45namespace hpc ::gemm {
56
7+ namespace {
8+
9+ void validate_gemm_args (const void * A, const void * B, const void * C, int M, int N, int K) {
10+ if (A == nullptr || B == nullptr || C == nullptr ) {
11+ throw std::invalid_argument (" gemm expects non-null A, B, and C pointers" );
12+ }
13+ if (M <= 0 || N <= 0 || K <= 0 ) {
14+ throw std::invalid_argument (" gemm expects positive M, N, and K" );
15+ }
16+ }
17+
18+ } // namespace
19+
620constexpr int TILE_SIZE = 32 ;
721
822// Naive GEMM: each thread computes one element
@@ -75,6 +89,7 @@ template <>
7589void gemm<float , GemmOpt::Naive>(const float * A, const float * B, float * C,
7690 int M, int N, int K,
7791 float alpha, float beta, cudaStream_t stream) {
92+ validate_gemm_args (A, B, C, M, N, K);
7893 dim3 block (16 , 16 );
7994 dim3 grid ((N + block.x - 1 ) / block.x , (M + block.y - 1 ) / block.y );
8095 gemm_naive_kernel<float ><<<grid, block, 0 , stream>>> (A, B, C, M, N, K, alpha, beta);
@@ -85,6 +100,7 @@ template <>
85100void gemm<float , GemmOpt::SharedMemTiling>(const float * A, const float * B, float * C,
86101 int M, int N, int K,
87102 float alpha, float beta, cudaStream_t stream) {
103+ validate_gemm_args (A, B, C, M, N, K);
88104 dim3 block (TILE_SIZE, TILE_SIZE);
89105 dim3 grid ((N + TILE_SIZE - 1 ) / TILE_SIZE, (M + TILE_SIZE - 1 ) / TILE_SIZE);
90106 gemm_shared_kernel<float ><<<grid, block, 0 , stream>>> (A, B, C, M, N, K, alpha, beta);
@@ -177,6 +193,7 @@ template <>
177193void gemm<float , GemmOpt::DoubleBuffer>(const float * A, const float * B, float * C,
178194 int M, int N, int K,
179195 float alpha, float beta, cudaStream_t stream) {
196+ validate_gemm_args (A, B, C, M, N, K);
180197 dim3 block (TILE_SIZE, TILE_SIZE);
181198 dim3 grid ((N + TILE_SIZE - 1 ) / TILE_SIZE, (M + TILE_SIZE - 1 ) / TILE_SIZE);
182199 gemm_double_buffer_kernel<float ><<<grid, block, 0 , stream>>> (A, B, C, M, N, K, alpha, beta);
@@ -301,6 +318,7 @@ template <>
301318void gemm<float , GemmOpt::RegisterTiling>(const float * A, const float * B, float * C,
302319 int M, int N, int K,
303320 float alpha, float beta, cudaStream_t stream) {
321+ validate_gemm_args (A, B, C, M, N, K);
304322 constexpr int THREADS_PER_BLOCK = (BLK_M / REG_TILE_M) * (BLK_N / REG_TILE_N);
305323 dim3 block (THREADS_PER_BLOCK);
306324 dim3 grid ((N + BLK_N - 1 ) / BLK_N, (M + BLK_M - 1 ) / BLK_M);
@@ -435,6 +453,10 @@ template <>
435453void gemm<__half, GemmOpt::TensorCoreWMMA>(const __half* A, const __half* B, __half* C,
436454 int M, int N, int K,
437455 float alpha, float beta, cudaStream_t stream) {
456+ validate_gemm_args (A, B, C, M, N, K);
457+ if ((M % 16 ) != 0 || (N % 16 ) != 0 || (K % 16 ) != 0 ) {
458+ throw std::invalid_argument (" TensorCoreWMMA requires M, N, and K to be multiples of 16" );
459+ }
438460 // Each block has multiple warps
439461 constexpr int WARPS_PER_BLOCK = (WMMA_BLK_M / WMMA_M) * (WMMA_BLK_N / WMMA_N);
440462 constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * 32 ;
@@ -662,6 +684,7 @@ template <>
662684void gemm<float , GemmOpt::SoftwarePipeline>(const float * A, const float * B, float * C,
663685 int M, int N, int K,
664686 float alpha, float beta, cudaStream_t stream) {
687+ validate_gemm_args (A, B, C, M, N, K);
665688 constexpr int THREADS_PER_BLOCK = 256 ;
666689 dim3 block (THREADS_PER_BLOCK);
667690 dim3 grid ((N + PIPE_TILE_N - 1 ) / PIPE_TILE_N, (M + PIPE_TILE_M - 1 ) / PIPE_TILE_M);
@@ -733,6 +756,7 @@ template <>
733756void gemm<int8_t , GemmOpt::SharedMemTiling>(const int8_t * A, const int8_t * B, int8_t * C,
734757 int M, int N, int K,
735758 float alpha, float beta, cudaStream_t stream) {
759+ validate_gemm_args (A, B, C, M, N, K);
736760 dim3 block (TILE_SIZE, TILE_SIZE);
737761 dim3 grid ((N + TILE_SIZE - 1 ) / TILE_SIZE, (M + TILE_SIZE - 1 ) / TILE_SIZE);
738762 gemm_shared_kernel<int8_t ><<<grid, block, 0 , stream>>> (A, B, C, M, N, K, alpha, beta);
0 commit comments