Skip to content

Commit 1ac5630

Browse files
committed
fix: harden CUDA kernel validation and error handling
Propagate CUDA failures as exceptions and reject unsupported or invalid kernel configurations up front so incorrect launches fail fast instead of silently producing bad results.
1 parent cc97d93 commit 1ac5630

9 files changed

Lines changed: 129 additions & 42 deletions

File tree

src/03_gemm/gemm.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
11
#include "gemm.cuh"
22
#include "../common/cuda_check.cuh"
3+
#include <stdexcept>
34

45
namespace hpc::gemm {
56

7+
namespace {
8+
9+
void validate_gemm_args(const void* A, const void* B, const void* C, int M, int N, int K) {
10+
if (A == nullptr || B == nullptr || C == nullptr) {
11+
throw std::invalid_argument("gemm expects non-null A, B, and C pointers");
12+
}
13+
if (M <= 0 || N <= 0 || K <= 0) {
14+
throw std::invalid_argument("gemm expects positive M, N, and K");
15+
}
16+
}
17+
18+
} // namespace
19+
620
constexpr int TILE_SIZE = 32;
721

822
// Naive GEMM: each thread computes one element
@@ -75,6 +89,7 @@ template <>
7589
void gemm<float, GemmOpt::Naive>(const float* A, const float* B, float* C,
7690
int M, int N, int K,
7791
float alpha, float beta, cudaStream_t stream) {
92+
validate_gemm_args(A, B, C, M, N, K);
7893
dim3 block(16, 16);
7994
dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
8095
gemm_naive_kernel<float><<<grid, block, 0, stream>>>(A, B, C, M, N, K, alpha, beta);
@@ -85,6 +100,7 @@ template <>
85100
void gemm<float, GemmOpt::SharedMemTiling>(const float* A, const float* B, float* C,
86101
int M, int N, int K,
87102
float alpha, float beta, cudaStream_t stream) {
103+
validate_gemm_args(A, B, C, M, N, K);
88104
dim3 block(TILE_SIZE, TILE_SIZE);
89105
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
90106
gemm_shared_kernel<float><<<grid, block, 0, stream>>>(A, B, C, M, N, K, alpha, beta);
@@ -177,6 +193,7 @@ template <>
177193
void gemm<float, GemmOpt::DoubleBuffer>(const float* A, const float* B, float* C,
178194
int M, int N, int K,
179195
float alpha, float beta, cudaStream_t stream) {
196+
validate_gemm_args(A, B, C, M, N, K);
180197
dim3 block(TILE_SIZE, TILE_SIZE);
181198
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
182199
gemm_double_buffer_kernel<float><<<grid, block, 0, stream>>>(A, B, C, M, N, K, alpha, beta);
@@ -301,6 +318,7 @@ template <>
301318
void gemm<float, GemmOpt::RegisterTiling>(const float* A, const float* B, float* C,
302319
int M, int N, int K,
303320
float alpha, float beta, cudaStream_t stream) {
321+
validate_gemm_args(A, B, C, M, N, K);
304322
constexpr int THREADS_PER_BLOCK = (BLK_M / REG_TILE_M) * (BLK_N / REG_TILE_N);
305323
dim3 block(THREADS_PER_BLOCK);
306324
dim3 grid((N + BLK_N - 1) / BLK_N, (M + BLK_M - 1) / BLK_M);
@@ -435,6 +453,10 @@ template <>
435453
void gemm<__half, GemmOpt::TensorCoreWMMA>(const __half* A, const __half* B, __half* C,
436454
int M, int N, int K,
437455
float alpha, float beta, cudaStream_t stream) {
456+
validate_gemm_args(A, B, C, M, N, K);
457+
if ((M % 16) != 0 || (N % 16) != 0 || (K % 16) != 0) {
458+
throw std::invalid_argument("TensorCoreWMMA requires M, N, and K to be multiples of 16");
459+
}
438460
// Each block has multiple warps
439461
constexpr int WARPS_PER_BLOCK = (WMMA_BLK_M / WMMA_M) * (WMMA_BLK_N / WMMA_N);
440462
constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * 32;
@@ -662,6 +684,7 @@ template <>
662684
void gemm<float, GemmOpt::SoftwarePipeline>(const float* A, const float* B, float* C,
663685
int M, int N, int K,
664686
float alpha, float beta, cudaStream_t stream) {
687+
validate_gemm_args(A, B, C, M, N, K);
665688
constexpr int THREADS_PER_BLOCK = 256;
666689
dim3 block(THREADS_PER_BLOCK);
667690
dim3 grid((N + PIPE_TILE_N - 1) / PIPE_TILE_N, (M + PIPE_TILE_M - 1) / PIPE_TILE_M);
@@ -733,6 +756,7 @@ template <>
733756
void gemm<int8_t, GemmOpt::SharedMemTiling>(const int8_t* A, const int8_t* B, int8_t* C,
734757
int M, int N, int K,
735758
float alpha, float beta, cudaStream_t stream) {
759+
validate_gemm_args(A, B, C, M, N, K);
736760
dim3 block(TILE_SIZE, TILE_SIZE);
737761
dim3 grid((N + TILE_SIZE - 1) / TILE_SIZE, (M + TILE_SIZE - 1) / TILE_SIZE);
738762
gemm_shared_kernel<int8_t><<<grid, block, 0, stream>>>(A, B, C, M, N, K, alpha, beta);

src/04_convolution/conv_winograd.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
#include "conv_winograd.cuh"
22
#include "conv_implicit_gemm.cuh"
33
#include "../common/cuda_check.cuh"
4+
#include <stdexcept>
45

56
namespace hpc::convolution {
67

7-
// Winograd F(2x2, 3x3) transformation matrices
8-
// TODO: Implement full Winograd convolution
9-
8+
// Experimental wrapper: until Winograd transforms are implemented, this path
9+
// intentionally falls back to the validated implicit GEMM implementation.
1010
template <>
1111
void conv2d_winograd<float>(const float* input, const float* weight, float* output,
1212
int batch, int in_channels, int out_channels,
1313
int height, int width, cudaStream_t stream) {
14-
// Placeholder - full implementation requires Winograd transforms
15-
// Fall back to implicit GEMM with 3x3 kernel, stride=1, pad=1
14+
if (input == nullptr || weight == nullptr || output == nullptr) {
15+
throw std::invalid_argument("conv2d_winograd expects non-null input, weight, and output pointers");
16+
}
17+
if (batch <= 0 || in_channels <= 0 || out_channels <= 0 || height <= 0 || width <= 0) {
18+
throw std::invalid_argument("conv2d_winograd expects positive batch/channel/spatial dimensions");
19+
}
20+
1621
ConvParams params{batch, in_channels, out_channels, height, width,
1722
3, 3, 1, 1, 1, 1, 1, 1};
1823
conv2d_implicit_gemm<float>(input, weight, output, params, stream);

src/05_attention/flash_attention.cu

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "flash_attention.cuh"
22
#include "../common/cuda_check.cuh"
33
#include <cfloat>
4+
#include <cmath>
5+
#include <stdexcept>
46

57
namespace hpc::attention {
68

@@ -19,7 +21,6 @@ __global__ void flash_attention_kernel(const T* __restrict__ Q,
1921
float* q_tile = smem;
2022
float* k_tile = q_tile + BLOCK_SIZE * HEAD_DIM;
2123
float* v_tile = k_tile + BLOCK_SIZE * HEAD_DIM;
22-
float* scores = v_tile + BLOCK_SIZE * HEAD_DIM;
2324

2425
int batch_head = blockIdx.x;
2526
int b = batch_head / num_heads;
@@ -118,20 +119,27 @@ void flash_attention_forward<float>(const float* Q, const float* K, const float*
118119
cudaStream_t stream) {
119120
constexpr int BLOCK_SIZE = 64;
120121
constexpr int HEAD_DIM = 64;
121-
122+
123+
if (Q == nullptr || K == nullptr || V == nullptr || O == nullptr) {
124+
throw std::invalid_argument("flash_attention_forward expects non-null Q, K, V, and O pointers");
125+
}
126+
if (config.batch_size <= 0 || config.num_heads <= 0 || config.seq_len <= 0 ||
127+
config.head_dim <= 0) {
128+
throw std::invalid_argument("flash_attention_forward expects positive batch_size, num_heads, seq_len, and head_dim");
129+
}
130+
if (!std::isfinite(config.scale) || config.scale <= 0.0f) {
131+
throw std::invalid_argument("flash_attention_forward expects a finite positive scale");
132+
}
133+
if (config.head_dim != HEAD_DIM) {
134+
throw std::invalid_argument("flash_attention_forward currently supports head_dim == 64 only");
135+
}
136+
122137
dim3 grid(config.batch_size * config.num_heads,
123138
(config.seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE);
124139
dim3 block(BLOCK_SIZE);
125-
126-
size_t smem_size = 3 * BLOCK_SIZE * HEAD_DIM * sizeof(float) +
140+
141+
size_t smem_size = 3 * BLOCK_SIZE * HEAD_DIM * sizeof(float) +
127142
BLOCK_SIZE * BLOCK_SIZE * sizeof(float);
128-
129-
// HEAD_DIM is a compile-time constant; assert config matches
130-
if (config.head_dim != HEAD_DIM) {
131-
fprintf(stderr, "flash_attention: config.head_dim=%d but compiled HEAD_DIM=%d\n",
132-
config.head_dim, HEAD_DIM);
133-
return;
134-
}
135143

136144
flash_attention_kernel<float, BLOCK_SIZE, HEAD_DIM><<<grid, block, smem_size, stream>>>(
137145
Q, K, V, O,

src/06_quantization/int8_quant.cu

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "../common/cuda_check.cuh"
33
#include "../common/reduce.cuh"
44
#include <cfloat>
5+
#include <stdexcept>
56

67
namespace hpc::quantization {
78

@@ -22,7 +23,7 @@ __global__ void compute_scale_kernel(const float* __restrict__ input,
2223
max_abs = hpc::block_reduce_max(max_abs);
2324

2425
if (threadIdx.x == 0) {
25-
scale[row] = max_abs / 127.0f;
26+
scale[row] = max_abs > 0.0f ? (max_abs / 127.0f) : 1.0f;
2627
}
2728
}
2829

@@ -35,7 +36,12 @@ __global__ void quantize_kernel(const float* __restrict__ input,
3536

3637
for (; idx < total; idx += blockDim.x * gridDim.x) {
3738
int row = idx / cols;
38-
float inv_scale = 1.0f / scale[row];
39+
float row_scale = scale[row];
40+
if (row_scale == 0.0f) {
41+
output[idx] = 0;
42+
continue;
43+
}
44+
float inv_scale = 1.0f / row_scale;
3945
float val = input[idx] * inv_scale;
4046
val = fminf(fmaxf(val, -127.0f), 127.0f);
4147
output[idx] = static_cast<int8_t>(roundf(val));
@@ -44,8 +50,16 @@ __global__ void quantize_kernel(const float* __restrict__ input,
4450

4551
void quantize_int8(const float* input, int8_t* output, float* scale,
4652
int rows, int cols, cudaStream_t stream) {
53+
if (input == nullptr || output == nullptr || scale == nullptr) {
54+
throw std::invalid_argument("quantize_int8 expects non-null input, output, and scale pointers");
55+
}
56+
if (rows <= 0 || cols <= 0) {
57+
throw std::invalid_argument("quantize_int8 expects rows and cols to be positive");
58+
}
59+
4760
compute_scale_kernel<<<rows, 256, 0, stream>>>(input, scale, rows, cols);
48-
61+
CUDA_CHECK_LAST();
62+
4963
int total = rows * cols;
5064
int block_size = 256;
5165
int grid_size = (total + block_size - 1) / block_size;
@@ -68,10 +82,17 @@ __global__ void dequantize_int8_kernel(const int8_t* __restrict__ input,
6882

6983
void dequantize_int8(const int8_t* input, const float* scale,
7084
float* output, int rows, int cols, cudaStream_t stream) {
85+
if (input == nullptr || output == nullptr || scale == nullptr) {
86+
throw std::invalid_argument("dequantize_int8 expects non-null input, output, and scale pointers");
87+
}
88+
if (rows <= 0 || cols <= 0) {
89+
throw std::invalid_argument("dequantize_int8 expects rows and cols to be positive");
90+
}
91+
7192
int total = rows * cols;
7293
int block_size = 256;
7394
int grid_size = (total + block_size - 1) / block_size;
74-
95+
7596
dequantize_int8_kernel<<<grid_size, block_size, 0, stream>>>(
7697
input, scale, output, rows, cols);
7798
CUDA_CHECK_LAST();

src/07_cuda13_features/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
# CUDA 13 features module
1+
# Experimental newer-CUDA feature demos and compatibility fallbacks.
2+
# These targets are built as part of the lab, but they currently do not imply
3+
# production-grade Hopper/Blackwell feature coverage.
24
hpc_add_cuda_library(hpc_cuda13
35
tma.cu
46
cluster.cu

src/07_cuda13_features/cluster.cu

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#include "cluster.cuh"
22
#include "../common/cuda_check.cuh"
3+
#include <stdexcept>
34

45
namespace hpc::cuda13 {
56

6-
// Thread Block Clusters placeholder
7-
// Requires Hopper architecture (SM90+) and CUDA 12+
7+
// Experimental fallback for a future thread-block-cluster implementation.
8+
// Today this uses a portable block reduction and does not rely on SM90-only features.
89

910
template <typename T>
1011
__global__ void cluster_reduce_kernel(const T* __restrict__ input,
@@ -35,13 +36,22 @@ __global__ void cluster_reduce_kernel(const T* __restrict__ input,
3536
template <>
3637
void cluster_reduce<float>(const float* input, float* output, size_t n,
3738
const ClusterConfig& config, cudaStream_t stream) {
39+
if (input == nullptr || output == nullptr) {
40+
throw std::invalid_argument("cluster_reduce expects non-null input and output pointers");
41+
}
42+
if (n == 0) {
43+
throw std::invalid_argument("cluster_reduce expects n > 0");
44+
}
45+
if (config.block_dims.x == 0) {
46+
throw std::invalid_argument("cluster_reduce expects config.block_dims.x > 0");
47+
}
48+
3849
int block_size = config.block_dims.x;
3950
int grid_size = (n + block_size - 1) / block_size;
4051
size_t smem_size = block_size * sizeof(float);
41-
42-
// Initialize output to zero
43-
cudaMemsetAsync(output, 0, sizeof(float), stream);
44-
52+
53+
CUDA_CHECK(cudaMemsetAsync(output, 0, sizeof(float), stream));
54+
4555
cluster_reduce_kernel<float><<<grid_size, block_size, smem_size, stream>>>(
4656
input, output, n);
4757
CUDA_CHECK_LAST();

src/07_cuda13_features/fp8_gemm.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
namespace hpc::cuda13 {
55

6-
// FP8 GEMM placeholder
7-
// Requires Hopper architecture (SM90+) and CUDA 12+
8-
// Uses e4m3 and e5m2 data types
6+
// Experimental FP8-like demo path.
7+
// This currently scales float inputs in a standard kernel; it is not a true Hopper FP8 implementation.
98

109
template <typename T>
1110
__global__ void fp8_gemm_kernel(const T* __restrict__ A,

src/07_cuda13_features/tma.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#include "tma.cuh"
22
#include "../common/cuda_check.cuh"
3+
#include <stdexcept>
34

45
namespace hpc::cuda13 {
56

6-
// TMA (Tensor Memory Accelerator) placeholder
7-
// Requires Hopper architecture (SM90+) and CUDA 12+
7+
// Experimental fallback for the future TMA path.
8+
// This currently performs a regular kernel copy so behavior is portable and testable.
89

910
template <typename T>
1011
__global__ void async_copy_kernel(const T* __restrict__ src,
@@ -22,6 +23,13 @@ __global__ void async_copy_kernel(const T* __restrict__ src,
2223
template <>
2324
void tma_copy_2d<float>(const float* src, float* dst,
2425
int rows, int cols, cudaStream_t stream) {
26+
if (src == nullptr || dst == nullptr) {
27+
throw std::invalid_argument("tma_copy_2d expects non-null src and dst pointers");
28+
}
29+
if (rows <= 0 || cols <= 0) {
30+
throw std::invalid_argument("tma_copy_2d expects positive rows and cols");
31+
}
32+
2533
dim3 block(256);
2634
dim3 grid((cols + block.x - 1) / block.x, rows);
2735
async_copy_kernel<float><<<grid, block, 0, stream>>>(src, dst, rows, cols);

src/common/cuda_check.cuh

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
#pragma once
22

33
#include <cuda_runtime.h>
4-
#include <cstdio>
5-
#include <cstdlib>
4+
#include <sstream>
5+
#include <stdexcept>
6+
#include <string>
7+
8+
namespace hpc::detail {
9+
10+
[[noreturn]] inline void throw_cuda_error(cudaError_t err, const char* file, int line) {
11+
std::ostringstream message;
12+
message << "CUDA error at " << file << ':' << line << ": "
13+
<< cudaGetErrorString(err);
14+
throw std::runtime_error(message.str());
15+
}
16+
17+
} // namespace hpc::detail
618

719
// Macros are not scoped by namespaces; define them at file scope.
8-
#define CUDA_CHECK(call) \
9-
do { \
10-
cudaError_t err = call; \
11-
if (err != cudaSuccess) { \
12-
fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \
13-
cudaGetErrorString(err)); \
14-
exit(EXIT_FAILURE); \
15-
} \
20+
#define CUDA_CHECK(call) \
21+
do { \
22+
cudaError_t err = (call); \
23+
if (err != cudaSuccess) { \
24+
::hpc::detail::throw_cuda_error(err, __FILE__, __LINE__); \
25+
} \
1626
} while (0)
1727

1828
#define CUDA_CHECK_LAST() CUDA_CHECK(cudaGetLastError())

0 commit comments

Comments
 (0)