Skip to content

Commit b6b4ad7

Browse files
committed
Add cuBLAS mm_out shim to eliminate libtorch runtime dependency
Implements aoti_torch_cuda_mm_out as a thin cuBLAS wrapper in the ExecuTorch AOTI CUDA shims. When Inductor picks cuBLAS over Triton templates for aten::mm (F.linear), the compiled .so requires this symbol at runtime. Without this shim, it resolves from libtorch_cuda.so, pulling in the full libtorch runtime. In practice, Inductor's autotune on A100 picks Triton templates for the Qwen3.5 MoE dense projections (bf16 [M,2048]x[2048,N]), so the shim is not exercised for this model. It serves as a safety net for models or shapes where cuBLAS wins the autotune, ensuring fully libtorch-free AOTI CUDA deployment in all cases. Co-authored-by: Claude <noreplyanthropic.com>
1 parent 8ae05c2 commit b6b4ad7

6 files changed

Lines changed: 559 additions & 4 deletions

File tree

backends/cuda/CMakeLists.txt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
110110
# Only build CUDA shims when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112112
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
113-
runtime/shims/sort.cu runtime/shims/rand.cu
113+
runtime/shims/sort.cu runtime/shims/rand.cu runtime/shims/mm.cu
114114
)
115115
endif()
116116

@@ -153,7 +153,7 @@ endif()
153153
if(_cuda_is_msvc_toolchain)
154154
target_link_libraries(
155155
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
156-
${CMAKE_DL_LIBS}
156+
CUDA::cublas ${CMAKE_DL_LIBS}
157157
)
158158
# Link object library directly so symbols are pulled exactly once while
159159
# avoiding duplicate static/object inclusion and interface leakage.
@@ -162,8 +162,13 @@ else()
162162
target_link_libraries(
163163
aoti_cuda_shims
164164
PRIVATE cuda_platform
165-
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
166-
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
165+
PUBLIC -Wl,--whole-archive
166+
aoti_common_shims_slim
167+
-Wl,--no-whole-archive
168+
CUDA::cudart
169+
CUDA::curand
170+
CUDA::cublas
171+
${CMAKE_DL_LIBS}
167172
)
168173
endif()
169174

backends/cuda/runtime/shims/mm.cu

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cublas_v2.h>
10+
#include <cuda_runtime.h>
11+
12+
#include <executorch/backends/aoti/slim/c10/core/ScalarType.h>
13+
#include <executorch/backends/aoti/slim/cuda/guard.h>
14+
#include <executorch/backends/aoti/utils.h>
15+
#include <executorch/backends/cuda/runtime/shims/mm.h>
16+
17+
#include <mutex>
18+
#include <vector>
19+
20+
namespace executorch::backends::cuda {
21+
22+
namespace c10_slim = executorch::backends::aoti::slim::c10;
23+
24+
namespace {
25+
26+
constexpr int kMaxDevices = 16;
27+
28+
struct CuBLASHandles {
29+
std::mutex mutex;
30+
cublasHandle_t handles[kMaxDevices] = {};
31+
bool initialized[kMaxDevices] = {};
32+
33+
cublasHandle_t get(int device) {
34+
std::lock_guard<std::mutex> lock(mutex);
35+
if (!initialized[device]) {
36+
cudaSetDevice(device);
37+
cublasCreate(&handles[device]);
38+
cublasSetMathMode(handles[device], CUBLAS_DEFAULT_MATH);
39+
initialized[device] = true;
40+
}
41+
return handles[device];
42+
}
43+
};
44+
45+
CuBLASHandles& cublas_handles() {
46+
static CuBLASHandles instance;
47+
return instance;
48+
}
49+
50+
} // namespace
51+
52+
#ifdef __cplusplus
53+
extern "C" {
54+
#endif
55+
56+
AOTITorchError
57+
aoti_torch_cuda_mm_out(Tensor* out, Tensor* self, Tensor* mat2) {
58+
ET_CHECK_OR_RETURN_ERROR(
59+
out != nullptr, InvalidArgument, "mm_out: out is null");
60+
ET_CHECK_OR_RETURN_ERROR(
61+
self != nullptr, InvalidArgument, "mm_out: self is null");
62+
ET_CHECK_OR_RETURN_ERROR(
63+
mat2 != nullptr, InvalidArgument, "mm_out: mat2 is null");
64+
ET_CHECK_OR_RETURN_ERROR(
65+
self->dim() == 2 && mat2->dim() == 2 && out->dim() == 2,
66+
InvalidArgument,
67+
"mm_out: all tensors must be 2D");
68+
ET_CHECK_OR_RETURN_ERROR(
69+
self->is_contiguous() && mat2->is_contiguous() && out->is_contiguous(),
70+
InvalidArgument,
71+
"mm_out: all tensors must be contiguous");
72+
73+
int64_t M = self->size(0);
74+
int64_t K = self->size(1);
75+
int64_t N = mat2->size(1);
76+
77+
ET_CHECK_OR_RETURN_ERROR(
78+
mat2->size(0) == K,
79+
InvalidArgument,
80+
"mm_out: self [%ld,%ld] x mat2 [%ld,%ld] inner dims mismatch",
81+
M,
82+
K,
83+
mat2->size(0),
84+
N);
85+
ET_CHECK_OR_RETURN_ERROR(
86+
out->size(0) == M && out->size(1) == N,
87+
InvalidArgument,
88+
"mm_out: out shape mismatch");
89+
90+
auto dtype = self->dtype();
91+
ET_CHECK_OR_RETURN_ERROR(
92+
mat2->dtype() == dtype && out->dtype() == dtype,
93+
InvalidArgument,
94+
"mm_out: dtype mismatch");
95+
96+
cudaDataType_t cuda_dtype;
97+
cublasComputeType_t compute_type;
98+
if (dtype == c10_slim::ScalarType::BFloat16) {
99+
cuda_dtype = CUDA_R_16BF;
100+
compute_type = CUBLAS_COMPUTE_32F;
101+
} else if (dtype == c10_slim::ScalarType::Half) {
102+
cuda_dtype = CUDA_R_16F;
103+
compute_type = CUBLAS_COMPUTE_32F;
104+
} else if (dtype == c10_slim::ScalarType::Float) {
105+
cuda_dtype = CUDA_R_32F;
106+
compute_type = CUBLAS_COMPUTE_32F;
107+
} else {
108+
ET_CHECK_OR_RETURN_ERROR(
109+
false, InvalidArgument, "mm_out: unsupported dtype");
110+
}
111+
112+
int device = self->device_index();
113+
ET_CHECK_OR_RETURN_ERROR(
114+
device >= 0 && device < kMaxDevices,
115+
InvalidArgument,
116+
"mm_out: device index %d out of range",
117+
device);
118+
119+
auto stream_result = getCurrentCUDAStream(device);
120+
ET_CHECK_OR_RETURN_ERROR(
121+
stream_result.ok(), Internal, "mm_out: failed to get CUDA stream");
122+
123+
// Per-device handle; mutex in get() ensures thread-safe initialization.
124+
// cublasSetStream + cublasGemmEx are serialized under the same mutex to
125+
// prevent races when multiple threads share a device.
126+
auto& handles = cublas_handles();
127+
std::lock_guard<std::mutex> lock(handles.mutex);
128+
cublasHandle_t handle = handles.get(device);
129+
cublasSetStream(handle, stream_result.get());
130+
131+
// cuBLAS is column-major. For row-major C = A @ B:
132+
// C^T = B^T @ A^T
133+
// With column-major interpretation of row-major data:
134+
// A_row[M,K] looks like A^T_col[K,M] with lda=K
135+
// B_row[K,N] looks like B^T_col[N,K] with ldb=N
136+
// C_row[M,N] looks like C^T_col[N,M] with ldc=N
137+
// So: C^T = B^T @ A^T → gemm(N, N, N, M, K, B, N, A, K, C, N)
138+
float alpha = 1.0f;
139+
float beta = 0.0f;
140+
141+
auto status = cublasGemmEx(
142+
handle,
143+
CUBLAS_OP_N,
144+
CUBLAS_OP_N,
145+
N, // m (columns of C^T)
146+
M, // n (rows of C^T)
147+
K, // k
148+
&alpha,
149+
mat2->data_ptr(), // B^T in col-major = B in row-major
150+
cuda_dtype,
151+
N, // ldb (row-major stride of mat2)
152+
self->data_ptr(), // A^T in col-major = A in row-major
153+
cuda_dtype,
154+
K, // lda (row-major stride of self)
155+
&beta,
156+
out->data_ptr(),
157+
cuda_dtype,
158+
N, // ldc (row-major stride of out)
159+
compute_type,
160+
CUBLAS_GEMM_DEFAULT);
161+
162+
ET_CHECK_OR_RETURN_ERROR(
163+
status == CUBLAS_STATUS_SUCCESS,
164+
Internal,
165+
"mm_out: cublasGemmEx failed with status %d",
166+
(int)status);
167+
168+
return Error::Ok;
169+
}
170+
171+
#ifdef __cplusplus
172+
}
173+
#endif
174+
175+
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/mm.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cuda_runtime.h>
12+
#include <executorch/backends/aoti/common_shims_slim.h>
13+
#include <executorch/backends/aoti/export.h>
14+
15+
namespace executorch::backends::cuda {
16+
17+
using executorch::backends::aoti::AOTITorchError;
18+
using executorch::backends::aoti::Tensor;
19+
20+
#ifdef __cplusplus
21+
extern "C" {
22+
#endif
23+
24+
/**
25+
* Matrix multiplication via cuBLAS: out = self @ mat2.
26+
*
27+
* Replaces libtorch's aoti_torch_cuda_mm_out so the AOTI CUDA backend
28+
* can run without libtorch_cuda.so. Calls cublasGemmEx directly.
29+
*
30+
* @param out Pre-allocated output [M, N], same dtype as inputs.
31+
* @param self Input matrix [M, K]. Must be bf16 or fp16, 2D, contiguous.
32+
* @param mat2 Input matrix [K, N]. Must be bf16 or fp16, 2D, contiguous.
33+
*/
34+
AOTI_SHIM_EXPORT AOTITorchError
35+
aoti_torch_cuda_mm_out(Tensor* out, Tensor* self, Tensor* mat2);
36+
37+
#ifdef __cplusplus
38+
}
39+
#endif
40+
41+
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/tests/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,21 @@ foreach(test_name ${CUDA_SHIM_TESTS})
6767

6868
add_test(NAME ${test_name} COMMAND ${test_name})
6969
endforeach()
70+
71+
# mm_out test — cuBLAS is already linked into aoti_cuda_shims
72+
add_executable(test_aoti_torch_cuda_mm_out test_aoti_torch_cuda_mm_out.cpp)
73+
74+
target_include_directories(
75+
test_aoti_torch_cuda_mm_out PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}
76+
${CUDAToolkit_INCLUDE_DIRS}
77+
)
78+
79+
target_compile_definitions(test_aoti_torch_cuda_mm_out PRIVATE CUDA_AVAILABLE=1)
80+
81+
target_link_libraries(
82+
test_aoti_torch_cuda_mm_out
83+
PRIVATE GTest::gtest GTest::gtest_main aoti_cuda_shims executorch_core
84+
CUDA::cudart
85+
)
86+
87+
add_test(NAME test_aoti_torch_cuda_mm_out COMMAND test_aoti_torch_cuda_mm_out)

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ def define_common_targets():
4242
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
4343
cuda_shim_cpp_unittest("aoti_torch_item_bool")
4444
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")
45+
cuda_shim_cpp_unittest("aoti_torch_cuda_mm_out")

0 commit comments

Comments
 (0)