77#include < c10/cuda/CUDAFunctions.h>
88#include < c10/macros/Export.h>
99#include < c10/util/irange.h>
10+ #include < torch/version.h>
1011
1112// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
1213// added bf16 support
@@ -226,7 +227,9 @@ cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
226227template <>
227228void bgemm<double >(CUDABLAS_BGEMM_ARGTYPES(double )) {
228229 // See Note [Writing Nondeterministic Operations]
230+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
229231 globalContext ().alertCuBLASConfigNotDeterministic ();
232+ #endif
230233 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
231234 cublasOperation_t opa = _cublasOpFromChar (transa);
232235 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -239,7 +242,9 @@ void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
239242template <>
240243void bgemm<float >(CUDABLAS_BGEMM_ARGTYPES(float )) {
241244 // See Note [Writing Nondeterministic Operations]
245+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
242246 globalContext ().alertCuBLASConfigNotDeterministic ();
247+ #endif
243248 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
244249 cublasOperation_t opa = _cublasOpFromChar (transa);
245250 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -252,7 +257,9 @@ void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
252257template <>
253258void bgemm<c10::complex <double >>(CUDABLAS_BGEMM_ARGTYPES(c10::complex <double >)) {
254259 // See Note [Writing Nondeterministic Operations]
260+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
255261 globalContext ().alertCuBLASConfigNotDeterministic ();
262+ #endif
256263 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
257264 cublasOperation_t opa = _cublasOpFromChar (transa);
258265 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -267,7 +274,9 @@ void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>))
267274template <>
268275void bgemm<c10::complex <float >>(CUDABLAS_BGEMM_ARGTYPES(c10::complex <float >)) {
269276 // See Note [Writing Nondeterministic Operations]
277+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
270278 globalContext ().alertCuBLASConfigNotDeterministic ();
279+ #endif
271280 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
272281 cublasOperation_t opa = _cublasOpFromChar (transa);
273282 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -282,7 +291,9 @@ void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
282291template <>
283292void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
284293 // See Note [Writing Nondeterministic Operations]
294+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
285295 globalContext ().alertCuBLASConfigNotDeterministic ();
296+ #endif
286297 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
287298 cublasOperation_t opa = _cublasOpFromChar (transa);
288299 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -311,7 +322,11 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
311322
312323 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties ();
313324 if (prop->major >= 5 ){
325+ #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
326+ if (at::globalContext ().allowFP16ReductionCuBLAS () == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
327+ #else
314328 if (at::globalContext ().allowFP16ReductionCuBLAS ()) {
329+ #endif
315330 at::Half falpha = alpha;
316331 at::Half fbeta = beta;
317332 TORCH_CUDABLAS_CHECK (cublasGemmStridedBatchedExFix (
@@ -350,7 +365,9 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
350365template <>
351366void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES (at::BFloat16)) {
352367 // See Note [Writing Nondeterministic Operations]
368+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
353369 globalContext ().alertCuBLASConfigNotDeterministic ();
370+ #endif
354371 BGEMM_CHECK_ARGVALUES (at::BFloat16);
355372 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
356373 cublasOperation_t opa = _cublasOpFromChar (transa);
@@ -383,7 +400,9 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
383400template <>
384401void gemm<double >(CUDABLAS_GEMM_ARGTYPES (double )) {
385402 // See Note [Writing Nondeterministic Operations]
403+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
386404 globalContext ().alertCuBLASConfigNotDeterministic ();
405+ #endif
387406 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
388407 cublasOperation_t opa = _cublasOpFromChar (transa);
389408 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -396,7 +415,9 @@ void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
396415template <>
397416void gemm<float >(CUDABLAS_GEMM_ARGTYPES (float )) {
398417 // See Note [Writing Nondeterministic Operations]
418+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
399419 globalContext ().alertCuBLASConfigNotDeterministic ();
420+ #endif
400421 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
401422 cublasOperation_t opa = _cublasOpFromChar (transa);
402423 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -410,7 +431,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
410431 template <>
411432 void gemm<c10::complex <double >>(CUDABLAS_GEMM_ARGTYPES (c10::complex <double >)) {
412433 // See Note [Writing Nondeterministic Operations]
434+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
413435 globalContext ().alertCuBLASConfigNotDeterministic ();
436+ #endif
414437 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
415438 cublasOperation_t opa = _cublasOpFromChar (transa);
416439 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -427,7 +450,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
427450 template <>
428451 void gemm<c10::complex <float >>(CUDABLAS_GEMM_ARGTYPES (c10::complex <float >)) {
429452 // See Note [Writing Nondeterministic Operations]
453+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
430454 globalContext ().alertCuBLASConfigNotDeterministic ();
455+ #endif
431456 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
432457 cublasOperation_t opa = _cublasOpFromChar (transa);
433458 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -443,7 +468,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
443468template <>
444469void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES (at::Half)) {
445470 // See Note [Writing Nondeterministic Operations]
471+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
446472 globalContext ().alertCuBLASConfigNotDeterministic ();
473+ #endif
447474 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
448475 cublasOperation_t opa = _cublasOpFromChar (transa);
449476 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -490,12 +517,20 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
490517 TORCH_CUDABLAS_CHECK (cublasSetMathMode (handle, CUBLAS_TENSOR_OP_MATH));
491518#else
492519 cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
520+ #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
521+ if (at::globalContext ().allowFP16ReductionCuBLAS () != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
522+ #else
493523 if (!at::globalContext ().allowFP16ReductionCuBLAS ()) {
524+ #endif
494525 cublas_flags = static_cast <cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
495526 }
496527 TORCH_CUDABLAS_CHECK (cublasSetMathMode (handle, cublas_flags));
497528#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
529+ #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
530+ if (at::globalContext ().allowFP16ReductionCuBLAS () == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
531+ #else
498532 if (at::globalContext ().allowFP16ReductionCuBLAS ()) {
533+ #endif
499534 at::Half falpha = alpha;
500535 at::Half fbeta = beta;
501536 TORCH_CUDABLAS_CHECK (cublasGemmEx_ (
@@ -606,7 +641,9 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
606641#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
607642template <>
608643void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES (at::BFloat16)) {
644+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
609645 globalContext ().alertCuBLASConfigNotDeterministic ();
646+ #endif
610647 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
611648 cublasOperation_t opa = _cublasOpFromChar (transa);
612649 cublasOperation_t opb = _cublasOpFromChar (transb);
@@ -617,7 +654,11 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
617654#if TORCH_VERSION_MAJOR > 2 || \
618655 (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 2 )
619656 cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
657+ #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
658+ if (at::globalContext ().allowBF16ReductionCuBLAS () != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
659+ #else
620660 if (!at::globalContext ().allowBF16ReductionCuBLAS ()) {
661+ #endif
621662 cublas_flags = static_cast <cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
622663 }
623664 TORCH_CUDABLAS_CHECK (cublasSetMathMode (handle, cublas_flags));
@@ -1126,7 +1167,9 @@ void trsmBatched<c10::complex<double>>(
11261167 template <>
11271168 void gemv<c10::complex <double >>(CUDABLAS_GEMV_ARGTYPES (c10::complex <double >)) {
11281169 // See Note [Writing Nondeterministic Operations]
1170+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
11291171 globalContext ().alertCuBLASConfigNotDeterministic ();
1172+ #endif
11301173 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
11311174 cublasOperation_t op = _cublasOpFromChar (trans);
11321175 _cublasAdjustLdLevel2 (m, n, &lda);
@@ -1145,7 +1188,9 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
11451188 // loss still happens on TF32. So we disable it here.
11461189 NoTF32Guard disable_tf32;
11471190 // See Note [Writing Nondeterministic Operations]
1191+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
11481192 globalContext ().alertCuBLASConfigNotDeterministic ();
1193+ #endif
11491194 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
11501195 cublasOperation_t op = _cublasOpFromChar (trans);
11511196 _cublasAdjustLdLevel2 (m, n, &lda);
@@ -1160,7 +1205,9 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
11601205template <>
11611206void gemv<double >(CUDABLAS_GEMV_ARGTYPES (double )) {
11621207 // See Note [Writing Nondeterministic Operations]
1208+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
11631209 globalContext ().alertCuBLASConfigNotDeterministic ();
1210+ #endif
11641211 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
11651212 cublasOperation_t op = _cublasOpFromChar (trans);
11661213 _cublasAdjustLdLevel2 (m, n, &lda);
@@ -1175,7 +1222,9 @@ void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
11751222 // loss still happens on TF32. So we disable it here.
11761223 NoTF32Guard disable_tf32;
11771224 // See Note [Writing Nondeterministic Operations]
1225+ #if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
11781226 globalContext ().alertCuBLASConfigNotDeterministic ();
1227+ #endif
11791228 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
11801229 cublasOperation_t op = _cublasOpFromChar (trans);
11811230 _cublasAdjustLdLevel2 (m, n, &lda);
0 commit comments