Skip to content

Commit f531c1f

Browse files
committed
fix: 修复整数溢出风险
修复了多个文件中潜在的整数溢出问题: 1. src/kernels/tensor_core_benchmark.cuh - 添加矩阵元素数量溢出检查 - 使用 size_t 进行乘法运算后再转换为 int 2. src/kernels/tensor_core_sgemm.cuh - 添加矩阵元素数量溢出检查 - 确保传递给 kernel 的 size 参数安全 3. src/utils/benchmark.cuh - 使用 size_t 进行矩阵大小计算 - 添加溢出检查防止内存分配错误 4. tests/test_sgemm.cu - 删除多余的命名空间闭合标记 这些问题在大矩阵场景下(如 M,K,N > 46340)可能导致: - 数据损坏 - 内存访问越界 - 程序崩溃
1 parent 0e455d5 commit f531c1f

4 files changed

Lines changed: 58 additions & 18 deletions

File tree

src/kernels/tensor_core_benchmark.cuh

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,19 @@ runTensorCoreComputeOnlyBenchmark(cublasHandle_t cublas_handle, int M, int K, in
7979
int gridSizeA = safeGridSize(static_cast<size_t>(M) * K, blockSize);
8080
int gridSizeB = safeGridSize(static_cast<size_t>(K) * N, blockSize);
8181

82-
float_to_half_kernel<<<gridSizeA, blockSize>>>(d_A.get(), d_A_fp16.get(), M * K);
83-
float_to_half_kernel<<<gridSizeB, blockSize>>>(d_B.get(), d_B_fp16.get(), K * N);
82+
size_t num_A = static_cast<size_t>(M) * K;
83+
size_t num_B = static_cast<size_t>(K) * N;
84+
85+
// 检查矩阵元素数量是否超过 int 最大值
86+
if (num_A > static_cast<size_t>(INT_MAX)) {
87+
throw CudaError("Matrix A size overflow: too many elements for int parameter");
88+
}
89+
if (num_B > static_cast<size_t>(INT_MAX)) {
90+
throw CudaError("Matrix B size overflow: too many elements for int parameter");
91+
}
92+
93+
float_to_half_kernel<<<gridSizeA, blockSize>>>(d_A.get(), d_A_fp16.get(), static_cast<int>(num_A));
94+
float_to_half_kernel<<<gridSizeB, blockSize>>>(d_B.get(), d_B_fp16.get(), static_cast<int>(num_B));
8495
CUDA_CHECK(cudaGetLastError());
8596
CUDA_CHECK(cudaDeviceSynchronize());
8697

src/kernels/tensor_core_sgemm.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,14 @@ inline void launch_tensor_core_sgemm_with_fallback(const float *A, const float *
258258
int gridSizeA = safeGridSize(num_A, blockSize);
259259
int gridSizeB = safeGridSize(num_B, blockSize);
260260

261+
// 检查矩阵元素数量是否超过 int 最大值
262+
if (num_A > static_cast<size_t>(INT_MAX)) {
263+
throw CudaError("Matrix A size overflow: too many elements for int parameter");
264+
}
265+
if (num_B > static_cast<size_t>(INT_MAX)) {
266+
throw CudaError("Matrix B size overflow: too many elements for int parameter");
267+
}
268+
261269
float_to_half_kernel<<<gridSizeA, blockSize, 0, stream>>>(A, d_A_fp16.get(),
262270
static_cast<int>(num_A));
263271
float_to_half_kernel<<<gridSizeB, blockSize, 0, stream>>>(B, d_B_fp16.get(),

src/utils/benchmark.cuh

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,28 @@ class SGEMMBenchmark {
6565
result.N = N;
6666

6767
// 初始化数据
68-
std::vector<float> h_A(M * K), h_B(K * N), h_C(M * N), h_C_ref(M * N);
69-
DeviceMemory<float> d_A(M * K);
70-
DeviceMemory<float> d_B(K * N);
71-
DeviceMemory<float> d_C(M * N);
72-
DeviceMemory<float> d_C_ref(M * N);
68+
// 安全计算矩阵大小,避免整数溢出
69+
size_t size_A = static_cast<size_t>(M) * K;
70+
size_t size_B = static_cast<size_t>(K) * N;
71+
size_t size_C = static_cast<size_t>(M) * N;
72+
73+
// 检查是否超过 size_t 范围(实际上是检查是否合理)
74+
if (size_A > static_cast<size_t>(INT_MAX) || size_B > static_cast<size_t>(INT_MAX) ||
75+
size_C > static_cast<size_t>(INT_MAX)) {
76+
throw CudaError("Matrix dimensions too large for benchmark");
77+
}
78+
79+
std::vector<float> h_A(size_A), h_B(size_B), h_C(size_C), h_C_ref(size_C);
80+
DeviceMemory<float> d_A(size_A);
81+
DeviceMemory<float> d_B(size_B);
82+
DeviceMemory<float> d_C(size_C);
83+
DeviceMemory<float> d_C_ref(size_C);
7384

7485
initRandomMatrix(h_A.data(), M, K, -1.0f, 1.0f, 42);
7586
initRandomMatrix(h_B.data(), K, N, -1.0f, 1.0f, 123);
7687

77-
d_A.copyFromHost(h_A.data(), M * K);
78-
d_B.copyFromHost(h_B.data(), K * N);
88+
d_A.copyFromHost(h_A.data(), size_A);
89+
d_B.copyFromHost(h_B.data(), size_B);
7990

8091
// 计算参考结果
8192
float alpha = 1.0f, beta = 0.0f;
@@ -94,8 +105,8 @@ class SGEMMBenchmark {
94105
result.efficiency = calculateEfficiency(result.gflops, getTheoreticalPeakGflops());
95106

96107
// 验证正确性
97-
d_C.copyToHost(h_C.data(), M * N);
98-
d_C_ref.copyToHost(h_C_ref.data(), M * N);
108+
d_C.copyToHost(h_C.data(), size_C);
109+
d_C_ref.copyToHost(h_C_ref.data(), size_C);
99110

100111
VerifyResult verify_result = compareMatrices(h_C.data(), h_C_ref.data(), M, N, tolerance);
101112
result.correct = verify_result.passed;
@@ -115,16 +126,27 @@ class SGEMMBenchmark {
115126
result.K = K;
116127
result.N = N;
117128

118-
std::vector<float> h_A(M * K), h_B(K * N);
119-
DeviceMemory<float> d_A(M * K);
120-
DeviceMemory<float> d_B(K * N);
121-
DeviceMemory<float> d_C(M * N);
129+
// 安全计算矩阵大小,避免整数溢出
130+
size_t size_A = static_cast<size_t>(M) * K;
131+
size_t size_B = static_cast<size_t>(K) * N;
132+
size_t size_C = static_cast<size_t>(M) * N;
133+
134+
// 检查是否超过 size_t 范围
135+
if (size_A > static_cast<size_t>(INT_MAX) || size_B > static_cast<size_t>(INT_MAX) ||
136+
size_C > static_cast<size_t>(INT_MAX)) {
137+
throw CudaError("Matrix dimensions too large for benchmark");
138+
}
139+
140+
std::vector<float> h_A(size_A), h_B(size_B);
141+
DeviceMemory<float> d_A(size_A);
142+
DeviceMemory<float> d_B(size_B);
143+
DeviceMemory<float> d_C(size_C);
122144

123145
initRandomMatrix(h_A.data(), M, K, -1.0f, 1.0f, 42);
124146
initRandomMatrix(h_B.data(), K, N, -1.0f, 1.0f, 123);
125147

126-
d_A.copyFromHost(h_A.data(), M * K);
127-
d_B.copyFromHost(h_B.data(), K * N);
148+
d_A.copyFromHost(h_A.data(), size_A);
149+
d_B.copyFromHost(h_B.data(), size_B);
128150

129151
float alpha = 1.0f, beta = 0.0f;
130152

tests/test_sgemm.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ void computeReference(cublasHandle_t handle, const float *d_A, const float *d_B,
6060
CUBLAS_CHECK(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, d_B, N, d_A, K,
6161
&beta, d_C, N));
6262
}
63-
} // namespace
6463

6564
class ErrorDetectionTest : public ::testing::Test {
6665
protected:

0 commit comments

Comments
 (0)