|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <cublas_v2.h> |
| 10 | +#include <cuda_runtime.h> |
| 11 | + |
| 12 | +#include <executorch/backends/aoti/slim/c10/core/ScalarType.h> |
| 13 | +#include <executorch/backends/aoti/slim/cuda/guard.h> |
| 14 | +#include <executorch/backends/aoti/utils.h> |
| 15 | +#include <executorch/backends/cuda/runtime/shims/mm.h> |
| 16 | + |
| 17 | +#include <mutex> |
| 18 | +#include <vector> |
| 19 | + |
| 20 | +namespace executorch::backends::cuda { |
| 21 | + |
| 22 | +namespace c10_slim = executorch::backends::aoti::slim::c10; |
| 23 | + |
| 24 | +namespace { |
| 25 | + |
| 26 | +constexpr int kMaxDevices = 16; |
| 27 | + |
| 28 | +struct CuBLASHandles { |
| 29 | + std::mutex mutex; |
| 30 | + cublasHandle_t handles[kMaxDevices] = {}; |
| 31 | + bool initialized[kMaxDevices] = {}; |
| 32 | + |
| 33 | + cublasHandle_t get(int device) { |
| 34 | + std::lock_guard<std::mutex> lock(mutex); |
| 35 | + if (!initialized[device]) { |
| 36 | + cudaSetDevice(device); |
| 37 | + cublasCreate(&handles[device]); |
| 38 | + cublasSetMathMode(handles[device], CUBLAS_DEFAULT_MATH); |
| 39 | + initialized[device] = true; |
| 40 | + } |
| 41 | + return handles[device]; |
| 42 | + } |
| 43 | +}; |
| 44 | + |
| 45 | +CuBLASHandles& cublas_handles() { |
| 46 | + static CuBLASHandles instance; |
| 47 | + return instance; |
| 48 | +} |
| 49 | + |
| 50 | +} // namespace |
| 51 | + |
| 52 | +#ifdef __cplusplus |
| 53 | +extern "C" { |
| 54 | +#endif |
| 55 | + |
| 56 | +AOTITorchError |
| 57 | +aoti_torch_cuda_mm_out(Tensor* out, Tensor* self, Tensor* mat2) { |
| 58 | + ET_CHECK_OR_RETURN_ERROR( |
| 59 | + out != nullptr, InvalidArgument, "mm_out: out is null"); |
| 60 | + ET_CHECK_OR_RETURN_ERROR( |
| 61 | + self != nullptr, InvalidArgument, "mm_out: self is null"); |
| 62 | + ET_CHECK_OR_RETURN_ERROR( |
| 63 | + mat2 != nullptr, InvalidArgument, "mm_out: mat2 is null"); |
| 64 | + ET_CHECK_OR_RETURN_ERROR( |
| 65 | + self->dim() == 2 && mat2->dim() == 2 && out->dim() == 2, |
| 66 | + InvalidArgument, |
| 67 | + "mm_out: all tensors must be 2D"); |
| 68 | + ET_CHECK_OR_RETURN_ERROR( |
| 69 | + self->is_contiguous() && mat2->is_contiguous() && out->is_contiguous(), |
| 70 | + InvalidArgument, |
| 71 | + "mm_out: all tensors must be contiguous"); |
| 72 | + |
| 73 | + int64_t M = self->size(0); |
| 74 | + int64_t K = self->size(1); |
| 75 | + int64_t N = mat2->size(1); |
| 76 | + |
| 77 | + ET_CHECK_OR_RETURN_ERROR( |
| 78 | + mat2->size(0) == K, |
| 79 | + InvalidArgument, |
| 80 | + "mm_out: self [%ld,%ld] x mat2 [%ld,%ld] inner dims mismatch", |
| 81 | + M, |
| 82 | + K, |
| 83 | + mat2->size(0), |
| 84 | + N); |
| 85 | + ET_CHECK_OR_RETURN_ERROR( |
| 86 | + out->size(0) == M && out->size(1) == N, |
| 87 | + InvalidArgument, |
| 88 | + "mm_out: out shape mismatch"); |
| 89 | + |
| 90 | + auto dtype = self->dtype(); |
| 91 | + ET_CHECK_OR_RETURN_ERROR( |
| 92 | + mat2->dtype() == dtype && out->dtype() == dtype, |
| 93 | + InvalidArgument, |
| 94 | + "mm_out: dtype mismatch"); |
| 95 | + |
| 96 | + cudaDataType_t cuda_dtype; |
| 97 | + cublasComputeType_t compute_type; |
| 98 | + if (dtype == c10_slim::ScalarType::BFloat16) { |
| 99 | + cuda_dtype = CUDA_R_16BF; |
| 100 | + compute_type = CUBLAS_COMPUTE_32F; |
| 101 | + } else if (dtype == c10_slim::ScalarType::Half) { |
| 102 | + cuda_dtype = CUDA_R_16F; |
| 103 | + compute_type = CUBLAS_COMPUTE_32F; |
| 104 | + } else if (dtype == c10_slim::ScalarType::Float) { |
| 105 | + cuda_dtype = CUDA_R_32F; |
| 106 | + compute_type = CUBLAS_COMPUTE_32F; |
| 107 | + } else { |
| 108 | + ET_CHECK_OR_RETURN_ERROR( |
| 109 | + false, InvalidArgument, "mm_out: unsupported dtype"); |
| 110 | + } |
| 111 | + |
| 112 | + int device = self->device_index(); |
| 113 | + ET_CHECK_OR_RETURN_ERROR( |
| 114 | + device >= 0 && device < kMaxDevices, |
| 115 | + InvalidArgument, |
| 116 | + "mm_out: device index %d out of range", |
| 117 | + device); |
| 118 | + |
| 119 | + auto stream_result = getCurrentCUDAStream(device); |
| 120 | + ET_CHECK_OR_RETURN_ERROR( |
| 121 | + stream_result.ok(), Internal, "mm_out: failed to get CUDA stream"); |
| 122 | + |
| 123 | + // Per-device handle; mutex in get() ensures thread-safe initialization. |
| 124 | + // cublasSetStream + cublasGemmEx are serialized under the same mutex to |
| 125 | + // prevent races when multiple threads share a device. |
| 126 | + auto& handles = cublas_handles(); |
| 127 | + std::lock_guard<std::mutex> lock(handles.mutex); |
| 128 | + cublasHandle_t handle = handles.get(device); |
| 129 | + cublasSetStream(handle, stream_result.get()); |
| 130 | + |
| 131 | + // cuBLAS is column-major. For row-major C = A @ B: |
| 132 | + // C^T = B^T @ A^T |
| 133 | + // With column-major interpretation of row-major data: |
| 134 | + // A_row[M,K] looks like A^T_col[K,M] with lda=K |
| 135 | + // B_row[K,N] looks like B^T_col[N,K] with ldb=N |
| 136 | + // C_row[M,N] looks like C^T_col[N,M] with ldc=N |
| 137 | + // So: C^T = B^T @ A^T → gemm(N, N, N, M, K, B, N, A, K, C, N) |
| 138 | + float alpha = 1.0f; |
| 139 | + float beta = 0.0f; |
| 140 | + |
| 141 | + auto status = cublasGemmEx( |
| 142 | + handle, |
| 143 | + CUBLAS_OP_N, |
| 144 | + CUBLAS_OP_N, |
| 145 | + N, // m (columns of C^T) |
| 146 | + M, // n (rows of C^T) |
| 147 | + K, // k |
| 148 | + &alpha, |
| 149 | + mat2->data_ptr(), // B^T in col-major = B in row-major |
| 150 | + cuda_dtype, |
| 151 | + N, // ldb (row-major stride of mat2) |
| 152 | + self->data_ptr(), // A^T in col-major = A in row-major |
| 153 | + cuda_dtype, |
| 154 | + K, // lda (row-major stride of self) |
| 155 | + &beta, |
| 156 | + out->data_ptr(), |
| 157 | + cuda_dtype, |
| 158 | + N, // ldc (row-major stride of out) |
| 159 | + compute_type, |
| 160 | + CUBLAS_GEMM_DEFAULT); |
| 161 | + |
| 162 | + ET_CHECK_OR_RETURN_ERROR( |
| 163 | + status == CUBLAS_STATUS_SUCCESS, |
| 164 | + Internal, |
| 165 | + "mm_out: cublasGemmEx failed with status %d", |
| 166 | + (int)status); |
| 167 | + |
| 168 | + return Error::Ok; |
| 169 | +} |
| 170 | + |
| 171 | +#ifdef __cplusplus |
| 172 | +} |
| 173 | +#endif |
| 174 | + |
| 175 | +} // namespace executorch::backends::cuda |
0 commit comments