Skip to content

Commit bc88a0c

Browse files
authored
fix: update cublas api to torch 2.10+
2 parents 55140f8 + 291228a commit bc88a0c

4 files changed

Lines changed: 87 additions & 22 deletions

File tree

.github/workflows/wheels.yml

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@ jobs:
1111
os:
1212
- ubuntu-22.04
1313
# - windows-2019
14-
python: ['3.10', '3.11', '3.12']
15-
torch_version: ['2.10.0']
14+
python: ['3.10', '3.11', '3.12', '3.13', '3.14']
15+
# torch_version: ['2.7.0', '2.8.0', '2.9.0', '2.10.0', '2.11.0']
16+
torch_version: ['2.11.0']
1617
cuda_short_version: ['126']
18+
exclude:
19+
- torch_version: '2.7.0'
20+
python: '3.14'
21+
- torch_version: '2.8.0'
22+
python: '3.14'
1723

1824
uses: ./.github/workflows/wheels_build.yml
1925
with:
@@ -22,21 +28,21 @@ jobs:
2228
torch_version: ${{ matrix.torch_version }}
2329
cuda_short_version: ${{ matrix.cuda_short_version }}
2430

25-
build-pypi:
26-
# Single canonical build intended for PyPI: no local CUDA/torch suffix
27-
strategy:
28-
fail-fast: false
29-
matrix:
30-
os: ['ubuntu-22.04']
31-
python: ['3.10', '3.11', '3.12']
31+
# build-pypi:
32+
# # Single canonical build intended for PyPI: no local CUDA/torch suffix
33+
# strategy:
34+
# fail-fast: false
35+
# matrix:
36+
# os: ['ubuntu-22.04']
37+
# python: ['3.10', '3.11', '3.12', '3.13', '3.14']
3238

33-
uses: ./.github/workflows/wheels_build.yml
34-
with:
35-
os: ${{ matrix.os }}
36-
python: ${{ matrix.python }}
37-
torch_version: '2.10.0'
38-
cuda_short_version: '128'
39-
append_local_version: '0' # 0 to disable local version suffix
39+
# uses: ./.github/workflows/wheels_build.yml
40+
# with:
41+
# os: ${{ matrix.os }}
42+
# python: ${{ matrix.python }}
43+
# torch_version: '2.9.0'
44+
# cuda_short_version: '128'
45+
# append_local_version: '0' # 0 to disable local version suffix
4046

4147
# publish to GitHub Release
4248
# gh_release:
@@ -79,11 +85,12 @@ jobs:
7985

8086

8187
consolidate-wheels:
82-
needs: [build-local, build-pypi]
88+
# needs: [build-local, build-pypi]
89+
needs: [build-local]
8390
runs-on: ubuntu-latest
8491
steps:
8592
- name: Download all wheel artifacts
86-
uses: actions/download-artifact@v4
93+
uses: actions/download-artifact@v7
8794
with:
8895
path: dist
8996

@@ -94,7 +101,7 @@ jobs:
94101
ls -l consolidated_wheels
95102
96103
- name: Upload consolidated wheels
97-
uses: actions/upload-artifact@v4
104+
uses: actions/upload-artifact@v6
98105
with:
99106
name: built-wheels
100107
path: consolidated_wheels

.github/workflows/wheels_build.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ jobs:
172172
sudo apt autoremove -y
173173
174174
- name: Recursive checkout
175-
uses: actions/checkout@v3
175+
uses: actions/checkout@v5
176176
with:
177177
submodules: recursive
178178
path: "."
@@ -236,14 +236,14 @@ jobs:
236236

237237
- name: Upload artifact (local build)
238238
if: ${{ inputs.append_local_version != '0' }}
239-
uses: actions/upload-artifact@v4
239+
uses: actions/upload-artifact@v6
240240
with:
241241
name: ${{ inputs.os }}-py${{ inputs.python }}-torch${{ inputs.torch_version }}+cu${{ inputs.cuda_short_version }}
242242
path: dist/*.whl
243243

244244
- name: Upload artifact (pypi build)
245245
if: ${{ inputs.append_local_version == '0' }}
246-
uses: actions/upload-artifact@v4
246+
uses: actions/upload-artifact@v6
247247
with:
248248
name: ${{ inputs.os }}-py${{ inputs.python }}
249249
path: dist/*.whl

src/sfast/csrc/operators/cublas/CUDABlas.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <c10/cuda/CUDAFunctions.h>
88
#include <c10/macros/Export.h>
99
#include <c10/util/irange.h>
10+
#include <torch/version.h>
1011

1112
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
1213
// added bf16 support
@@ -226,7 +227,9 @@ cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
226227
template <>
227228
void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
228229
// See Note [Writing Nondeterministic Operations]
230+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
229231
globalContext().alertCuBLASConfigNotDeterministic();
232+
#endif
230233
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
231234
cublasOperation_t opa = _cublasOpFromChar(transa);
232235
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -239,7 +242,9 @@ void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
239242
template <>
240243
void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
241244
// See Note [Writing Nondeterministic Operations]
245+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
242246
globalContext().alertCuBLASConfigNotDeterministic();
247+
#endif
243248
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
244249
cublasOperation_t opa = _cublasOpFromChar(transa);
245250
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -252,7 +257,9 @@ void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
252257
template <>
253258
void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
254259
// See Note [Writing Nondeterministic Operations]
260+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
255261
globalContext().alertCuBLASConfigNotDeterministic();
262+
#endif
256263
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
257264
cublasOperation_t opa = _cublasOpFromChar(transa);
258265
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -267,7 +274,9 @@ void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>))
267274
template <>
268275
void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
269276
// See Note [Writing Nondeterministic Operations]
277+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
270278
globalContext().alertCuBLASConfigNotDeterministic();
279+
#endif
271280
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
272281
cublasOperation_t opa = _cublasOpFromChar(transa);
273282
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -282,7 +291,9 @@ void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
282291
template <>
283292
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
284293
// See Note [Writing Nondeterministic Operations]
294+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
285295
globalContext().alertCuBLASConfigNotDeterministic();
296+
#endif
286297
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
287298
cublasOperation_t opa = _cublasOpFromChar(transa);
288299
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -311,7 +322,11 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
311322

312323
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
313324
if (prop->major >= 5){
325+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
326+
if (at::globalContext().allowFP16ReductionCuBLAS() == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
327+
#else
314328
if (at::globalContext().allowFP16ReductionCuBLAS()) {
329+
#endif
315330
at::Half falpha = alpha;
316331
at::Half fbeta = beta;
317332
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(
@@ -350,7 +365,9 @@ void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
350365
template <>
351366
void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
352367
// See Note [Writing Nondeterministic Operations]
368+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
353369
globalContext().alertCuBLASConfigNotDeterministic();
370+
#endif
354371
BGEMM_CHECK_ARGVALUES(at::BFloat16);
355372
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
356373
cublasOperation_t opa = _cublasOpFromChar(transa);
@@ -383,7 +400,9 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
383400
template <>
384401
void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
385402
// See Note [Writing Nondeterministic Operations]
403+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
386404
globalContext().alertCuBLASConfigNotDeterministic();
405+
#endif
387406
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
388407
cublasOperation_t opa = _cublasOpFromChar(transa);
389408
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -396,7 +415,9 @@ void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
396415
template <>
397416
void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
398417
// See Note [Writing Nondeterministic Operations]
418+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
399419
globalContext().alertCuBLASConfigNotDeterministic();
420+
#endif
400421
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
401422
cublasOperation_t opa = _cublasOpFromChar(transa);
402423
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -410,7 +431,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
410431
template <>
411432
void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
412433
// See Note [Writing Nondeterministic Operations]
434+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
413435
globalContext().alertCuBLASConfigNotDeterministic();
436+
#endif
414437
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
415438
cublasOperation_t opa = _cublasOpFromChar(transa);
416439
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -427,7 +450,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
427450
template <>
428451
void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
429452
// See Note [Writing Nondeterministic Operations]
453+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
430454
globalContext().alertCuBLASConfigNotDeterministic();
455+
#endif
431456
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
432457
cublasOperation_t opa = _cublasOpFromChar(transa);
433458
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -443,7 +468,9 @@ void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
443468
template <>
444469
void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
445470
// See Note [Writing Nondeterministic Operations]
471+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
446472
globalContext().alertCuBLASConfigNotDeterministic();
473+
#endif
447474
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
448475
cublasOperation_t opa = _cublasOpFromChar(transa);
449476
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -490,12 +517,20 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
490517
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
491518
#else
492519
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
520+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
521+
if (at::globalContext().allowFP16ReductionCuBLAS() != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
522+
#else
493523
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
524+
#endif
494525
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
495526
}
496527
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
497528
#endif // defined(CUDA_VERSION) && CUDA_VERSION < 11000
529+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
530+
if (at::globalContext().allowFP16ReductionCuBLAS() == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
531+
#else
498532
if (at::globalContext().allowFP16ReductionCuBLAS()) {
533+
#endif
499534
at::Half falpha = alpha;
500535
at::Half fbeta = beta;
501536
TORCH_CUDABLAS_CHECK(cublasGemmEx_(
@@ -606,7 +641,9 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
606641
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
607642
template <>
608643
void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
644+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
609645
globalContext().alertCuBLASConfigNotDeterministic();
646+
#endif
610647
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
611648
cublasOperation_t opa = _cublasOpFromChar(transa);
612649
cublasOperation_t opb = _cublasOpFromChar(transb);
@@ -617,7 +654,11 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
617654
#if TORCH_VERSION_MAJOR > 2 || \
618655
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 2)
619656
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
657+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
658+
if (at::globalContext().allowBF16ReductionCuBLAS() != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
659+
#else
620660
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
661+
#endif
621662
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
622663
}
623664
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
@@ -1126,7 +1167,9 @@ void trsmBatched<c10::complex<double>>(
11261167
template <>
11271168
void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
11281169
// See Note [Writing Nondeterministic Operations]
1170+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
11291171
globalContext().alertCuBLASConfigNotDeterministic();
1172+
#endif
11301173
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
11311174
cublasOperation_t op = _cublasOpFromChar(trans);
11321175
_cublasAdjustLdLevel2(m, n, &lda);
@@ -1145,7 +1188,9 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
11451188
// loss still happens on TF32. So we disable it here.
11461189
NoTF32Guard disable_tf32;
11471190
// See Note [Writing Nondeterministic Operations]
1191+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
11481192
globalContext().alertCuBLASConfigNotDeterministic();
1193+
#endif
11491194
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
11501195
cublasOperation_t op = _cublasOpFromChar(trans);
11511196
_cublasAdjustLdLevel2(m, n, &lda);
@@ -1160,7 +1205,9 @@ void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
11601205
template <>
11611206
void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
11621207
// See Note [Writing Nondeterministic Operations]
1208+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
11631209
globalContext().alertCuBLASConfigNotDeterministic();
1210+
#endif
11641211
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
11651212
cublasOperation_t op = _cublasOpFromChar(trans);
11661213
_cublasAdjustLdLevel2(m, n, &lda);
@@ -1175,7 +1222,9 @@ void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
11751222
// loss still happens on TF32. So we disable it here.
11761223
NoTF32Guard disable_tf32;
11771224
// See Note [Writing Nondeterministic Operations]
1225+
#if !(TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10))
11781226
globalContext().alertCuBLASConfigNotDeterministic();
1227+
#endif
11791228
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
11801229
cublasOperation_t op = _cublasOpFromChar(trans);
11811230
_cublasAdjustLdLevel2(m, n, &lda);

src/sfast/csrc/operators/cutlass/cutlass_dual_linear_kernel.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/extension.h>
2+
#include <torch/version.h>
23

34
#include <c10/cuda/CUDAMathCompat.h>
45
#include <c10/cuda/CUDAStream.h>
@@ -486,7 +487,11 @@ torch::Tensor cutlass_linear_geglu(const torch::Tensor &input,
486487
auto dispatch_bf16 = [&] {
487488
#if TORCH_VERSION_MAJOR > 2 || \
488489
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 2)
490+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
491+
if (at::globalContext().allowBF16ReductionCuBLAS() == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
492+
#else
489493
if (at::globalContext().allowBF16ReductionCuBLAS()) {
494+
#endif
490495
output =
491496
CutlassDualGemmLauncher<at::BFloat16, GemmGEGLUWrapper,
492497
cutlass::epilogue::thread::GELU_taylor_fast,
@@ -506,7 +511,11 @@ torch::Tensor cutlass_linear_geglu(const torch::Tensor &input,
506511
AT_DISPATCH_CASE(
507512
at::kHalf,
508513
[&] {
514+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 10)
515+
if (at::globalContext().allowFP16ReductionCuBLAS() == at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
516+
#else
509517
if (at::globalContext().allowFP16ReductionCuBLAS()) {
518+
#endif
510519
output = CutlassDualGemmLauncher<
511520
at::Half, GemmGEGLUWrapper,
512521
cutlass::epilogue::thread::GELU_taylor_fast,

0 commit comments

Comments
 (0)