@@ -25,29 +25,16 @@ constexpr int PBT_ITERATIONS = 100;
2525
2626std::vector<std::tuple<int , int , int >> getStandardDimensions () {
2727 return {
28- {1 , 1 , 1 },
29- {16 , 16 , 16 },
30- {32 , 32 , 32 },
31- {64 , 64 , 64 },
32- {128 , 128 , 128 },
33- {256 , 256 , 256 },
34- {512 , 512 , 512 },
35- {64 , 128 , 256 },
36- {256 , 64 , 128 },
37- {128 , 256 , 64 },
38- {511 , 513 , 1025 },
28+ {1 , 1 , 1 }, {16 , 16 , 16 }, {32 , 32 , 32 }, {64 , 64 , 64 },
29+ {128 , 128 , 128 }, {256 , 256 , 256 }, {512 , 512 , 512 }, {64 , 128 , 256 },
30+ {256 , 64 , 128 }, {128 , 256 , 64 }, {511 , 513 , 1025 },
3931 };
4032}
4133
4234std::vector<std::tuple<int , int , int >> getTensorCoreFastPathDimensions () {
4335 return {
44- {16 , 16 , 16 },
45- {32 , 32 , 32 },
46- {64 , 64 , 64 },
47- {128 , 128 , 128 },
48- {256 , 256 , 256 },
49- {64 , 128 , 256 },
50- {256 , 64 , 128 },
36+ {16 , 16 , 16 }, {32 , 32 , 32 }, {64 , 64 , 64 }, {128 , 128 , 128 },
37+ {256 , 256 , 256 }, {64 , 128 , 256 }, {256 , 64 , 128 },
5138 };
5239}
5340
@@ -96,9 +83,8 @@ TEST_F(ErrorDetectionTest, StandardKernelErrorDetection) {
9683 h_test_[i] = h_ref_[i] + error_magnitude * (dist (gen) > 0 ? 1 : -1 );
9784 }
9885
99- VerifyResult result =
100- compareMatrices (h_test_.data (), h_ref_.data (), 64 , 64 ,
101- kStandardVerifyTolerance );
86+ VerifyResult result = compareMatrices (h_test_.data (), h_ref_.data (), 64 , 64 ,
87+ kStandardVerifyTolerance );
10288
10389 EXPECT_TRUE (SGEMMVerifier::shouldFlagAsIncorrect (result))
10490 << " Iteration " << iter << " : error above tolerance should be flagged" ;
@@ -119,9 +105,8 @@ TEST_F(ErrorDetectionTest, StandardKernelPassesWithinTolerance) {
119105 h_test_[i] = h_ref_[i] + error_magnitude * dist (gen);
120106 }
121107
122- VerifyResult result =
123- compareMatrices (h_test_.data (), h_ref_.data (), 64 , 64 ,
124- kStandardVerifyTolerance );
108+ VerifyResult result = compareMatrices (h_test_.data (), h_ref_.data (), 64 , 64 ,
109+ kStandardVerifyTolerance );
125110
126111 EXPECT_TRUE (result.passed )
127112 << " Iteration " << iter << " : error within tolerance should pass" ;
@@ -142,9 +127,8 @@ TEST_F(ErrorDetectionTest, TensorCoreErrorDetection) {
142127 h_test_[i] = h_ref_[i] + error_magnitude * (dist (gen) > 0 ? 1 : -1 );
143128 }
144129
145- VerifyResult result =
146- compareMatrices (h_test_.data (), h_ref_.data (), 64 , 64 ,
147- kTensorCoreVerifyTolerance );
130+ VerifyResult result = compareMatrices (h_test_.data (), h_ref_.data (), 64 , 64 ,
131+ kTensorCoreVerifyTolerance );
148132
149133 EXPECT_TRUE (SGEMMVerifier::shouldFlagAsIncorrect (result))
150134 << " Iteration " << iter
@@ -193,9 +177,9 @@ protected:
193177 }
194178
195179 template <typename LaunchFn>
196- VerifyResult runKernelAndCompare (LaunchFn launch_fn,
197- VerifyTolerance tolerance =
198- kStandardVerifyTolerance ) {
180+ VerifyResult
181+ runKernelAndCompare (LaunchFn launch_fn,
182+ VerifyTolerance tolerance = kStandardVerifyTolerance ) {
199183 CUDA_CHECK (cudaMemset (d_C_, 0 , M_ * N_ * sizeof (float )));
200184 launch_fn ();
201185 CUDA_CHECK (cudaDeviceSynchronize ());
@@ -244,8 +228,9 @@ INSTANTIATE_TEST_SUITE_P(StandardDimensions, TiledSGEMMTest,
244228class BankConflictFreeSGEMMTest : public SGEMMKernelTest {};
245229
246230TEST_P (BankConflictFreeSGEMMTest, CorrectnessProperty) {
247- VerifyResult result = runKernelAndCompare (
248- [&] { launch_bank_conflict_free_sgemm<32 >(d_A_, d_B_, d_C_, M_ , K_ , N_ ); });
231+ VerifyResult result = runKernelAndCompare ([&] {
232+ launch_bank_conflict_free_sgemm<32 >(d_A_, d_B_, d_C_, M_ , K_ , N_ );
233+ });
249234
250235 EXPECT_TRUE (result.passed )
251236 << " BankConflictFree SGEMM failed for dimensions " << M_ << " x" << K_
@@ -287,8 +272,9 @@ TEST_P(TensorCoreSGEMMTest, FastPathCorrectnessProperty) {
287272 << " x" << N_ << " (max_rel_error: " << result.max_rel_error << " )" ;
288273}
289274
290- INSTANTIATE_TEST_SUITE_P (TensorCoreFastPathDimensions, TensorCoreSGEMMTest,
291- ::testing::ValuesIn (getTensorCoreFastPathDimensions()));
275+ INSTANTIATE_TEST_SUITE_P (
276+ TensorCoreFastPathDimensions, TensorCoreSGEMMTest,
277+ ::testing::ValuesIn (getTensorCoreFastPathDimensions()));
292278
293279class TensorCoreFallbackTest : public SGEMMKernelTest {};
294280
@@ -302,13 +288,17 @@ TEST_P(TensorCoreFallbackTest, NonAlignedInputsFallbackSafely) {
302288 << N_ << " (max_rel_error: " << result.max_rel_error << " )" ;
303289}
304290
305- INSTANTIATE_TEST_SUITE_P (TensorCoreFallbackDimensions, TensorCoreFallbackTest,
306- ::testing::ValuesIn (getTensorCoreFallbackDimensions()));
291+ INSTANTIATE_TEST_SUITE_P (
292+ TensorCoreFallbackDimensions, TensorCoreFallbackTest,
293+ ::testing::ValuesIn (getTensorCoreFallbackDimensions()));
307294
308295TEST (TensorCoreWrapperTest, ZeroSizeInputsReturnSafely) {
309- EXPECT_NO_THROW (launch_tensor_core_sgemm (nullptr , nullptr , nullptr , 0 , 16 , 16 ));
310- EXPECT_NO_THROW (launch_tensor_core_sgemm (nullptr , nullptr , nullptr , 16 , 0 , 16 ));
311- EXPECT_NO_THROW (launch_tensor_core_sgemm (nullptr , nullptr , nullptr , 16 , 16 , 0 ));
296+ EXPECT_NO_THROW (
297+ launch_tensor_core_sgemm (nullptr , nullptr , nullptr , 0 , 16 , 16 ));
298+ EXPECT_NO_THROW (
299+ launch_tensor_core_sgemm (nullptr , nullptr , nullptr , 16 , 0 , 16 ));
300+ EXPECT_NO_THROW (
301+ launch_tensor_core_sgemm (nullptr , nullptr , nullptr , 16 , 16 , 0 ));
312302}
313303
314304class DimensionInvarianceTest : public ::testing::Test {
@@ -358,9 +348,8 @@ TEST_F(DimensionInvarianceTest, AllStandardKernelsWorkWithVariousDimensions) {
358348 CUDA_CHECK (cudaMemcpy (h_C.data (), d_C, M * N * sizeof (float ),
359349 cudaMemcpyDeviceToHost));
360350
361- VerifyResult result =
362- compareMatrices (h_C.data (), h_ref.data (), M, N,
363- kStandardVerifyTolerance );
351+ VerifyResult result = compareMatrices (h_C.data (), h_ref.data (), M, N,
352+ kStandardVerifyTolerance );
364353 EXPECT_TRUE (result.passed )
365354 << name << " failed at iteration " << iter << " with dimensions " << M
366355 << " x" << K << " x" << N;
0 commit comments