Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
# Only build CUDA shims when CUDA language/toolchain is available.
if(CMAKE_CUDA_COMPILER)
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
runtime/shims/sort.cu runtime/shims/rand.cu
runtime/shims/sort.cu runtime/shims/rand.cu runtime/shims/mm.cu
)
endif()

Expand Down Expand Up @@ -153,7 +153,7 @@ endif()
if(_cuda_is_msvc_toolchain)
target_link_libraries(
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
${CMAKE_DL_LIBS}
CUDA::cublas ${CMAKE_DL_LIBS}
)
# Link object library directly so symbols are pulled exactly once while
# avoiding duplicate static/object inclusion and interface leakage.
Expand All @@ -162,8 +162,13 @@ else()
target_link_libraries(
aoti_cuda_shims
PRIVATE cuda_platform
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
PUBLIC -Wl,--whole-archive
aoti_common_shims_slim
-Wl,--no-whole-archive
CUDA::cudart
CUDA::curand
CUDA::cublas
${CMAKE_DL_LIBS}
)
endif()

Expand Down
175 changes: 175 additions & 0 deletions backends/cuda/runtime/shims/mm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cublas_v2.h>
#include <cuda_runtime.h>

#include <executorch/backends/aoti/slim/c10/core/ScalarType.h>
#include <executorch/backends/aoti/slim/cuda/guard.h>
#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/shims/mm.h>

#include <mutex>
#include <vector>

namespace executorch::backends::cuda {

namespace c10_slim = executorch::backends::aoti::slim::c10;

namespace {

constexpr int kMaxDevices = 16;

struct CuBLASHandles {
std::mutex mutex;
cublasHandle_t handles[kMaxDevices] = {};
bool initialized[kMaxDevices] = {};

cublasHandle_t get(int device) {
std::lock_guard<std::mutex> lock(mutex);
if (!initialized[device]) {
cudaSetDevice(device);
cublasCreate(&handles[device]);
cublasSetMathMode(handles[device], CUBLAS_DEFAULT_MATH);
initialized[device] = true;
}
return handles[device];
}
};

CuBLASHandles& cublas_handles() {
static CuBLASHandles instance;
return instance;
}

} // namespace

#ifdef __cplusplus
extern "C" {
#endif

AOTITorchError
aoti_torch_cuda_mm_out(Tensor* out, Tensor* self, Tensor* mat2) {
ET_CHECK_OR_RETURN_ERROR(
out != nullptr, InvalidArgument, "mm_out: out is null");
ET_CHECK_OR_RETURN_ERROR(
self != nullptr, InvalidArgument, "mm_out: self is null");
ET_CHECK_OR_RETURN_ERROR(
mat2 != nullptr, InvalidArgument, "mm_out: mat2 is null");
ET_CHECK_OR_RETURN_ERROR(
self->dim() == 2 && mat2->dim() == 2 && out->dim() == 2,
InvalidArgument,
"mm_out: all tensors must be 2D");
ET_CHECK_OR_RETURN_ERROR(
self->is_contiguous() && mat2->is_contiguous() && out->is_contiguous(),
InvalidArgument,
"mm_out: all tensors must be contiguous");

int64_t M = self->size(0);
int64_t K = self->size(1);
int64_t N = mat2->size(1);

ET_CHECK_OR_RETURN_ERROR(
mat2->size(0) == K,
InvalidArgument,
"mm_out: self [%ld,%ld] x mat2 [%ld,%ld] inner dims mismatch",
M,
K,
mat2->size(0),
N);
ET_CHECK_OR_RETURN_ERROR(
out->size(0) == M && out->size(1) == N,
InvalidArgument,
"mm_out: out shape mismatch");

auto dtype = self->dtype();
ET_CHECK_OR_RETURN_ERROR(
mat2->dtype() == dtype && out->dtype() == dtype,
InvalidArgument,
"mm_out: dtype mismatch");

cudaDataType_t cuda_dtype;
cublasComputeType_t compute_type;
if (dtype == c10_slim::ScalarType::BFloat16) {
cuda_dtype = CUDA_R_16BF;
compute_type = CUBLAS_COMPUTE_32F;
} else if (dtype == c10_slim::ScalarType::Half) {
cuda_dtype = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_32F;
} else if (dtype == c10_slim::ScalarType::Float) {
cuda_dtype = CUDA_R_32F;
compute_type = CUBLAS_COMPUTE_32F;
} else {
ET_CHECK_OR_RETURN_ERROR(
false, InvalidArgument, "mm_out: unsupported dtype");
}

int device = self->device_index();
ET_CHECK_OR_RETURN_ERROR(
device >= 0 && device < kMaxDevices,
InvalidArgument,
"mm_out: device index %d out of range",
device);

auto stream_result = getCurrentCUDAStream(device);
ET_CHECK_OR_RETURN_ERROR(
stream_result.ok(), Internal, "mm_out: failed to get CUDA stream");

// Per-device handle; mutex in get() ensures thread-safe initialization.
// cublasSetStream + cublasGemmEx are serialized under the same mutex to
// prevent races when multiple threads share a device.
auto& handles = cublas_handles();
std::lock_guard<std::mutex> lock(handles.mutex);
cublasHandle_t handle = handles.get(device);
cublasSetStream(handle, stream_result.get());

// cuBLAS is column-major. For row-major C = A @ B:
// C^T = B^T @ A^T
// With column-major interpretation of row-major data:
// A_row[M,K] looks like A^T_col[K,M] with lda=K
// B_row[K,N] looks like B^T_col[N,K] with ldb=N
// C_row[M,N] looks like C^T_col[N,M] with ldc=N
// So: C^T = B^T @ A^T → gemm(N, N, N, M, K, B, N, A, K, C, N)
float alpha = 1.0f;
float beta = 0.0f;

auto status = cublasGemmEx(
handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
N, // m (columns of C^T)
M, // n (rows of C^T)
K, // k
&alpha,
mat2->data_ptr(), // B^T in col-major = B in row-major
cuda_dtype,
N, // ldb (row-major stride of mat2)
self->data_ptr(), // A^T in col-major = A in row-major
cuda_dtype,
K, // lda (row-major stride of self)
&beta,
out->data_ptr(),
cuda_dtype,
N, // ldc (row-major stride of out)
compute_type,
CUBLAS_GEMM_DEFAULT);

ET_CHECK_OR_RETURN_ERROR(
status == CUBLAS_STATUS_SUCCESS,
Internal,
"mm_out: cublasGemmEx failed with status %d",
(int)status);

return Error::Ok;
}

#ifdef __cplusplus
}
#endif

} // namespace executorch::backends::cuda
41 changes: 41 additions & 0 deletions backends/cuda/runtime/shims/mm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cuda_runtime.h>
#include <executorch/backends/aoti/common_shims_slim.h>
#include <executorch/backends/aoti/export.h>

namespace executorch::backends::cuda {

using executorch::backends::aoti::AOTITorchError;
using executorch::backends::aoti::Tensor;

#ifdef __cplusplus
extern "C" {
#endif

/**
* Matrix multiplication via cuBLAS: out = self @ mat2.
*
* Replaces libtorch's aoti_torch_cuda_mm_out so the AOTI CUDA backend
* can run without libtorch_cuda.so. Calls cublasGemmEx directly.
*
* @param out Pre-allocated output [M, N], same dtype as inputs.
* @param self Input matrix [M, K]. Must be bf16 or fp16, 2D, contiguous.
* @param mat2 Input matrix [K, N]. Must be bf16 or fp16, 2D, contiguous.
*/
AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_cuda_mm_out(Tensor* out, Tensor* self, Tensor* mat2);

#ifdef __cplusplus
}
#endif

} // namespace executorch::backends::cuda
18 changes: 18 additions & 0 deletions backends/cuda/runtime/shims/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,21 @@ foreach(test_name ${CUDA_SHIM_TESTS})

add_test(NAME ${test_name} COMMAND ${test_name})
endforeach()

# mm_out test — cuBLAS is already linked into aoti_cuda_shims
add_executable(test_aoti_torch_cuda_mm_out test_aoti_torch_cuda_mm_out.cpp)

target_include_directories(
test_aoti_torch_cuda_mm_out PRIVATE ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}
${CUDAToolkit_INCLUDE_DIRS}
)

target_compile_definitions(test_aoti_torch_cuda_mm_out PRIVATE CUDA_AVAILABLE=1)

target_link_libraries(
test_aoti_torch_cuda_mm_out
PRIVATE GTest::gtest GTest::gtest_main aoti_cuda_shims executorch_core
CUDA::cudart
)

add_test(NAME test_aoti_torch_cuda_mm_out COMMAND test_aoti_torch_cuda_mm_out)
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_cpp_unittest("aoti_torch_item_bool")
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")
cuda_shim_cpp_unittest("aoti_torch_cuda_mm_out")
Loading
Loading