Skip to content

Commit a6a4477

Browse files
committed
refactor: deepen modules and consolidate utilities
- Delete unused CublasSgemm class (duplicate of SGEMMVerifier) - Move BenchmarkResult to benchmark_core.cuh for kernel layer access - Move kTensorCoreVerifyTolerance to verify.cuh (centralize tolerances) - Refactor SGEMMBenchmark to use measureGpuTime (unified timing) - Refactor tensor_core_benchmark to use measureGpuTime and calculateSgemmMetrics - Replace int return codes with ParseResult enum (self-documenting API) Architecture improvements: - Reduce shallow modules by eliminating duplicate timing logic - Kernel layer can now use BenchmarkResult without depending on high-level orchestrator - All verification tolerances now in verify.cuh (improved locality)
1 parent 5cdbbfa commit a6a4477

9 files changed

Lines changed: 118 additions & 234 deletions

File tree

CONTEXT.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,6 @@ Tensor Core 特有的 benchmark 功能,提供:
5757
- `getTheoreticalPeakGflops()` / `getTheoreticalPeakBandwidth()` - 理论峰值查询
5858
- `calculateEfficiency()` / `calculateBandwidthUtilization()` - 效率计算
5959

60-
### Benchmark cuBLAS
61-
**位置**: `src/utils/benchmark_cublas.cuh`
62-
63-
cuBLAS 参考实现:
64-
- `CublasSgemm` - cuBLAS SGEMM 参考调用器
65-
- `SgemmReferenceCalculator` - 完整参考计算流程
66-
6760
### 高级接口
6861
**位置**: `src/utils/benchmark.cuh`
6962

src/cli_parser.cuh

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,21 @@ const std::vector<std::tuple<int, int, int>> BenchmarkConfig::DEFAULT_CASES = {
3535
{511, 513, 1025},
3636
};
3737

38+
// ============================================================================
39+
// CLI 解析结果
40+
// ============================================================================
41+
42+
/**
43+
* CLI 解析结果枚举
44+
*
45+
* 自文档化的返回类型,替代 int 返回码。
46+
*/
47+
enum class ParseResult {
48+
Success, // 解析成功,继续执行
49+
Error, // 解析错误,返回错误码
50+
HelpShown // 显示帮助后正常退出
51+
};
52+
3853
// ============================================================================
3954
// CLI 解析器
4055
// ============================================================================
@@ -86,28 +101,26 @@ class CliParser {
86101
* 解析命令行参数
87102
*
88103
* @param config 输出配置对象
89-
* @return 0 成功,1 错误,2 显示帮助后退出
104+
* @return ParseResult 枚举指示解析结果
90105
*/
91-
int parse(BenchmarkConfig &config) {
106+
ParseResult parse(BenchmarkConfig &config) {
92107
for (int i = 1; i < argc_; ++i) {
93108
std::string arg = argv_[i];
94109

95110
if (arg == "-h" || arg == "--help") {
96111
printUsage(argv_[0]);
97-
return 2;
112+
return ParseResult::HelpShown;
98113
}
99114

100115
if (arg == "-s" || arg == "--size") {
101-
int result = parseSizeArg(i, config);
102-
if (result != 0)
103-
return result;
116+
if (!parseSizeArg(i, config))
117+
return ParseResult::Error;
104118
continue;
105119
}
106120

107121
if (arg == "--dims") {
108-
int result = parseDimsArg(i, config);
109-
if (result != 0)
110-
return result;
122+
if (!parseDimsArg(i, config))
123+
return ParseResult::Error;
111124
continue;
112125
}
113126

@@ -117,30 +130,28 @@ class CliParser {
117130
}
118131

119132
if (arg == "--warmup") {
120-
int result = parseWarmupArg(i, config);
121-
if (result != 0)
122-
return result;
133+
if (!parseWarmupArg(i, config))
134+
return ParseResult::Error;
123135
continue;
124136
}
125137

126138
if (arg == "--benchmark") {
127-
int result = parseBenchmarkArg(i, config);
128-
if (result != 0)
129-
return result;
139+
if (!parseBenchmarkArg(i, config))
140+
return ParseResult::Error;
130141
continue;
131142
}
132143

133144
fprintf(stderr, "Unknown argument: %s\n", arg.c_str());
134145
printUsage(argv_[0]);
135-
return 1;
146+
return ParseResult::Error;
136147
}
137148

138149
// 默认添加 1024x1024x1024
139150
if (config.empty()) {
140151
config.addCase(1024, 1024, 1024);
141152
}
142153

143-
return 0;
154+
return ParseResult::Success;
144155
}
145156

146157
void printUsage(const char *program) const {
@@ -163,82 +174,82 @@ class CliParser {
163174
}
164175

165176
private:
166-
int parseSizeArg(int &i, BenchmarkConfig &config) {
177+
bool parseSizeArg(int &i, BenchmarkConfig &config) {
167178
if (i + 1 >= argc_) {
168179
fprintf(stderr, "Error: -s requires a size argument\n");
169-
return 1;
180+
return false;
170181
}
171182

172183
int size;
173184
if (!detail::safeStrToInt(argv_[++i], &size, "size")) {
174-
return 1;
185+
return false;
175186
}
176187
if (size <= 0) {
177188
fprintf(stderr, "Error: Size must be positive\n");
178-
return 1;
189+
return false;
179190
}
180191

181192
config.addCase(size, size, size);
182-
return 0;
193+
return true;
183194
}
184195

185-
int parseDimsArg(int &i, BenchmarkConfig &config) {
196+
bool parseDimsArg(int &i, BenchmarkConfig &config) {
186197
if (i + 3 >= argc_) {
187198
fprintf(stderr, "Error: --dims requires M K N arguments\n");
188-
return 1;
199+
return false;
189200
}
190201

191202
int M, K, N;
192203
if (!detail::safeStrToInt(argv_[++i], &M, "M dimension") ||
193204
!detail::safeStrToInt(argv_[++i], &K, "K dimension") ||
194205
!detail::safeStrToInt(argv_[++i], &N, "N dimension")) {
195-
return 1;
206+
return false;
196207
}
197208
if (M <= 0 || K <= 0 || N <= 0) {
198209
fprintf(stderr, "Error: Dimensions must be positive\n");
199-
return 1;
210+
return false;
200211
}
201212

202213
config.addCase(M, K, N);
203-
return 0;
214+
return true;
204215
}
205216

206-
int parseWarmupArg(int &i, BenchmarkConfig &config) {
217+
bool parseWarmupArg(int &i, BenchmarkConfig &config) {
207218
if (i + 1 >= argc_) {
208219
fprintf(stderr, "Error: --warmup requires a number argument\n");
209-
return 1;
220+
return false;
210221
}
211222

212223
int warmup;
213224
if (!detail::safeStrToInt(argv_[++i], &warmup, "warmup")) {
214-
return 1;
225+
return false;
215226
}
216227
if (warmup < 0) {
217228
fprintf(stderr, "Error: Warmup runs must be non-negative\n");
218-
return 1;
229+
return false;
219230
}
220231

221232
config.warmup_runs = warmup;
222-
return 0;
233+
return true;
223234
}
224235

225-
int parseBenchmarkArg(int &i, BenchmarkConfig &config) {
236+
bool parseBenchmarkArg(int &i, BenchmarkConfig &config) {
226237
if (i + 1 >= argc_) {
227238
fprintf(stderr, "Error: --benchmark requires a number argument\n");
228-
return 1;
239+
return false;
229240
}
230241

231242
int bench;
232243
if (!detail::safeStrToInt(argv_[++i], &bench, "benchmark")) {
233-
return 1;
244+
return false;
234245
}
235246
if (bench <= 0) {
236247
fprintf(stderr, "Error: Benchmark runs must be positive\n");
237-
return 1;
248+
return false;
238249
}
239250

240251
config.benchmark_runs = bench;
241-
return 0;
252+
return true;
242253
}
243254

244255
int argc_;

src/kernels/tensor_core_benchmark.cuh

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#pragma once
22

33
#include "tensor_core_sgemm.cuh"
4+
#include "../utils/benchmark_core.cuh"
5+
#include "../utils/benchmark_metrics.cuh"
6+
#include "../utils/verify.cuh"
7+
48
#include <cublas_v2.h>
59
#include <cuda_fp16.h>
610
#include <cuda_runtime.h>
@@ -71,36 +75,17 @@ runTensorCoreComputeOnlyBenchmark(cublasHandle_t cublas_handle, int M, int K, in
7175
CUDA_CHECK(cudaGetLastError());
7276
CUDA_CHECK(cudaDeviceSynchronize());
7377

74-
for (int i = 0; i < warmup_runs; ++i) {
75-
d_C.zero();
76-
launch_tensor_core_sgemm_fp16(d_A_fp16.get(), d_B_fp16.get(), d_C.get(), M, K, N);
77-
}
78-
CUDA_CHECK(cudaDeviceSynchronize());
79-
80-
cudaEvent_t start, stop;
81-
CUDA_CHECK(cudaEventCreate(&start));
82-
CUDA_CHECK(cudaEventCreate(&stop));
83-
84-
CUDA_CHECK(cudaEventRecord(start));
85-
for (int i = 0; i < benchmark_runs; ++i) {
86-
launch_tensor_core_sgemm_fp16(d_A_fp16.get(), d_B_fp16.get(), d_C.get(), M, K, N);
87-
}
88-
CUDA_CHECK(cudaEventRecord(stop));
89-
CUDA_CHECK(cudaEventSynchronize(stop));
90-
91-
float total_time_ms;
92-
CUDA_CHECK(cudaEventElapsedTime(&total_time_ms, start, stop));
93-
94-
cudaEventDestroy(start);
95-
cudaEventDestroy(stop);
96-
97-
// 填充性能指标
98-
result.time_ms = total_time_ms / benchmark_runs;
99-
double flops = 2.0 * result.M * result.N * result.K;
100-
result.gflops = (flops / (result.time_ms * 1e-3)) / 1e9;
101-
double bytes =
102-
(result.M * result.K + result.K * result.N + result.M * result.N) * sizeof(float);
103-
result.bandwidth_gb_s = (bytes / (result.time_ms * 1e-3)) / 1e9;
78+
// 使用统一的 measureGpuTime 计时
79+
float time_ms =
80+
measureGpuTime([&]() { launch_tensor_core_sgemm_fp16(d_A_fp16.get(), d_B_fp16.get(), d_C.get(), M, K, N); },
81+
warmup_runs, benchmark_runs);
82+
83+
// 计算指标
84+
PerformanceMetrics metrics = calculateSgemmMetrics(M, K, N, time_ms);
85+
result.time_ms = metrics.time_ms;
86+
result.gflops = metrics.gflops;
87+
result.bandwidth_gb_s = metrics.bandwidth_gb_s;
88+
result.efficiency = calculateEfficiency(result.gflops, getTheoreticalPeakGflops());
10489

10590
d_C.copyToHost(h_C.data(), M * N);
10691
d_C_ref.copyToHost(h_C_ref.data(), M * N);

src/kernels/tensor_core_sgemm.cuh

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99
// - WMMA FP16→FP32 计算
1010
// - FP32→FP16 类型转换
1111
// - 统一启动接口(强制显式 fallback)
12-
// - 验证容差常量
1312
//
1413
// 设计原则:
1514
// - 深层模块:小接口,大实现
1615
// - 不提供默认 fallback,强制调用者显式指定
1716
// - 消除与具体内核的循环依赖
17+
//
18+
// 验证容差:
19+
// - 使用 verify.cuh 中定义的 kTensorCoreVerifyTolerance
1820
// ============================================================================
1921

2022
#include "../utils/cuda_utils.cuh"
23+
#include "../utils/verify.cuh"
2124
#include <cuda_fp16.h>
2225
#include <cuda_runtime.h>
2326
#include <functional>
@@ -80,14 +83,6 @@ inline const char *getTensorCoreArchName() {
8083
return "Unknown";
8184
}
8285

83-
// ============================================================================
84-
// Tensor Core Verification Tolerance
85-
// ============================================================================
86-
87-
// Tensor Core 使用 FP16 中间精度,需要更宽松的容差
88-
// 此容差定义在 Tensor Core 模块中,保持精度相关常量与其实现在一起
89-
inline constexpr VerifyTolerance kTensorCoreVerifyTolerance{5e-2f, 1e-2f};
90-
9186
// ============================================================================
9287
// Fallback 策略接口
9388
// ============================================================================

src/main.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ int main(int argc, char **argv) {
1717
BenchmarkConfig config;
1818
CliParser parser(argc, argv);
1919

20-
int parse_result = parser.parse(config);
21-
if (parse_result == 2) {
22-
// 显示帮助后正常退出
20+
ParseResult result = parser.parse(config);
21+
switch (result) {
22+
case ParseResult::HelpShown:
2323
return 0;
24-
}
25-
if (parse_result != 0) {
26-
// 解析错误
24+
case ParseResult::Error:
2725
return 1;
26+
case ParseResult::Success:
27+
break;
2828
}
2929

3030
BenchmarkRunner runner(config);

0 commit comments

Comments
 (0)