From 4534154a05f8f6743e5ef38c5f5951f2a322be91 Mon Sep 17 00:00:00 2001 From: chen Date: Fri, 10 Apr 2026 07:26:54 +0000 Subject: [PATCH 1/3] Refactor(linear): split LinearBackward kernel into 3 independent kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move grad_flags logic from kernel to autograd layer. The monolithic LinearBackward kernel is replaced by LinearBackwardInput, LinearBackwardWeight, and LinearBackwardBias — each a pure compute operation with no autograd-related parameters. --- infini_train/include/autograd/linear.h | 6 - infini_train/src/autograd/linear.cc | 30 ++- infini_train/src/kernels/cpu/linear.cc | 84 ++++--- infini_train/src/kernels/cuda/linear.cu | 297 ++++++++++++------------ 4 files changed, 203 insertions(+), 214 deletions(-) diff --git a/infini_train/include/autograd/linear.h b/infini_train/include/autograd/linear.h index 21d107b9..cebed3b2 100644 --- a/infini_train/include/autograd/linear.h +++ b/infini_train/include/autograd/linear.h @@ -12,12 +12,6 @@ class Tensor; namespace infini_train::autograd { -struct LinearGradFlags { - bool input = false; - bool weight = false; - bool bias = false; -}; - class Linear : public Function { public: static constexpr char kType[] = "LinearFunction"; diff --git a/infini_train/src/autograd/linear.cc b/infini_train/src/autograd/linear.cc index c9ed1dbb..ff0283ce 100644 --- a/infini_train/src/autograd/linear.cc +++ b/infini_train/src/autograd/linear.cc @@ -56,17 +56,29 @@ std::vector> Linear::Backward(const std::vector 1 && needs_input_grad_[1], - .bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]}; + bool need_grad_input = needs_input_grad_[0]; + bool need_grad_weight = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + bool need_grad_bias = bias_ && needs_input_grad_.size() > 2 && needs_input_grad_[2]; auto device = grad_output->GetDevice().type(); - // TODO: skip autograd graph construction entirely when no input requires grad - auto [grad_input, grad_weight, grad_bias] - = Dispatcher::Instance() - .Call, std::shared_ptr, std::shared_ptr>>( - {device, "LinearBackward"}, input, weight, transpose_, in_features_, out_features_, input_dims_, - grad_output, bias_, grad_flags); + + std::shared_ptr grad_input = nullptr; + std::shared_ptr grad_weight = nullptr; + std::shared_ptr grad_bias = nullptr; + + if (need_grad_input) { + grad_input = Dispatcher::Instance().Call>( + {device, "LinearBackwardInput"}, weight, grad_output, transpose_, in_features_, out_features_, input_dims_); + } + if (need_grad_weight) { + grad_weight = Dispatcher::Instance().Call>( + {device, "LinearBackwardWeight"}, input, grad_output, transpose_, in_features_, out_features_); + } + if (need_grad_bias) { + grad_bias = Dispatcher::Instance().Call>({device, "LinearBackwardBias"}, grad_output, + out_features_); + } + if (bias_) { return {grad_input, grad_weight, grad_bias}; } else { diff --git a/infini_train/src/kernels/cpu/linear.cc b/infini_train/src/kernels/cpu/linear.cc index 2b209417..f238135c 100644 --- a/infini_train/src/kernels/cpu/linear.cc +++ b/infini_train/src/kernels/cpu/linear.cc @@ -5,7 +5,6 @@ #include "glog/logging.h" -#include "infini_train/include/autograd/linear.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -146,62 +145,55 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons return output; } -// TODO(dcj): support linear without bias later -std::tuple, std::shared_ptr, std::shared_ptr> -LinearBackward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, - int64_t in_features, int64_t out_features, const std::vector &input_dims, - const std::shared_ptr &grad_output, bool bias, - infini_train::autograd::LinearGradFlags grad_flags) { +std::shared_ptr LinearBackwardInput(const std::shared_ptr &weight, + const std::shared_ptr &grad_output, bool transpose, + int64_t in_features, int64_t out_features, + const std::vector &input_dims) { /* transpose: grad_input = grad_output * weight grad_input[*, in_features] = grad_output[*, out_features] * weight[out_features, in_features] - grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features] - grad_bias[out_features] = grad_output[*, out_features].sum(axis=0) !transpose: grad_input = grad_output * weight^T grad_input[*, in_features] = grad_output[_, out_features] * weight[in_features, out_features]^T - grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features] - grad_bias[out_features] = grad_output[*, out_features].sum(axis=0) */ - const auto compute_grad_input = grad_flags.input; - const auto compute_grad_weight = grad_flags.weight; - const auto compute_grad_bias = grad_flags.bias; - CHECK_GE(input_dims.size(), 2); - - std::vector weight_dims - = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; - - std::shared_ptr grad_input = nullptr; - std::shared_ptr grad_weight = nullptr; - std::shared_ptr grad_bias = nullptr; - - if (compute_grad_input) { - CHECK(weight != nullptr) << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; - grad_input = std::make_shared(input_dims, DataType::kFLOAT32); - if (transpose) { - grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix(); - } else { - grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose(); - } + auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); + if (transpose) { + grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix(); + } else { + grad_input->EigenMatrix() = grad_output->EigenMatrix() * weight->EigenMatrix().transpose(); } + return grad_input; +} - if (compute_grad_weight) { - CHECK(input != nullptr) << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; - grad_weight = std::make_shared(weight_dims, DataType::kFLOAT32); - if (transpose) { - grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix(); - } else { - grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix(); - } - } +std::shared_ptr LinearBackwardWeight(const std::shared_ptr &input, + const std::shared_ptr &grad_output, bool transpose, + int64_t in_features, int64_t out_features) { + /* + transpose: + grad_weight[out_features, in_features] = grad_output[*, out_features]^T * input[*, in_features] - if (compute_grad_bias && bias) { - grad_bias = std::make_shared(std::vector{out_features}, DataType::kFLOAT32); - grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum(); + !transpose: + grad_weight[in_features, out_features] = input[*, in_features]^T * grad_output[*, out_features] + */ + std::vector weight_dims + = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; + auto grad_weight = std::make_shared(weight_dims, DataType::kFLOAT32); + if (transpose) { + grad_weight->EigenMatrix() = grad_output->EigenMatrix().transpose() * input->EigenMatrix(); + } else { + grad_weight->EigenMatrix() = input->EigenMatrix().transpose() * grad_output->EigenMatrix(); } + return grad_weight; +} - return {grad_input, grad_weight, grad_bias}; +std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_output, int64_t out_features) { + /* + grad_bias[out_features] = grad_output[*, out_features].sum(axis=0) + */ + auto grad_bias = std::make_shared(std::vector{out_features}, DataType::kFLOAT32); + grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum(); + return grad_bias; } } // namespace infini_train::kernels::cpu @@ -211,6 +203,8 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr #include -#include "infini_train/include/autograd/linear.h" #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/common/cuda/kernel_helper.cuh" #include "infini_train/include/core/runtime/device_guard.h" @@ -317,183 +316,171 @@ __global__ void ReduceColumnsKernel(const TIn *__restrict__ input, TOut *__restr } } -std::tuple, std::shared_ptr, std::shared_ptr> -LinearBackward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, - int64_t in_features, int64_t out_features, const std::vector &input_dims, - const std::shared_ptr &grad_output, bool bias, - infini_train::autograd::LinearGradFlags grad_flags) { - const auto compute_grad_input = grad_flags.input; - const auto compute_grad_weight = grad_flags.weight; - const auto compute_grad_bias = grad_flags.bias; - +std::shared_ptr LinearBackwardInput(const std::shared_ptr &weight, + const std::shared_ptr &grad_output, bool transpose, + int64_t in_features, int64_t out_features, + const std::vector &input_dims) { CHECK_GE(input_dims.size(), 2); const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies{}); - const std::vector weight_dims - = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; + auto compute_dtype = weight->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); + auto grad_output_promoted + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - auto dtype = grad_output->Dtype(); + // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. + auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); - // For type promotion, use available tensors - DataType input_dtype = input ? input->Dtype() : (weight ? weight->Dtype() : dtype); - DataType weight_dtype = weight ? weight->Dtype() : (input ? input->Dtype() : dtype); - // Compute dtype determined by saved tensors (forward compute dtype), not grad_output - DataType compute_dtype = DispatchFunc, DataTypeList>( - {input_dtype, weight_dtype}, [=]() { return DataTypeMap_v>; }, - "CUDA LinearBackward"); + auto device = grad_output->GetDevice(); + float alpha = 1.0f; + float beta = 0.0f; + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); + + // TODO(zbl): use cublasSgemv if possible + // - if transpose: + // weight is [out_features, in_features] here + // d_input = d_output * weight --> d_input.T = weight.T * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[in_features, out_features] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // weight is [in_features, out_features] here + // d_input = d_output * weight.T --> d_input.T = weight * d_output.T + // C = d_input.T[in_features, bs] + // A = weight.T[out_features, in_features] + // B = d_output.T[out_features, bs] + auto trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; + auto lda = transpose ? in_features : out_features; + switch (compute_dtype) { + DISPATCH_CASE(WRAP({ + CUBLAS_CHECK(cublasSgemm(handle, trans_a, CUBLAS_OP_N, in_features, bs, out_features, &alpha, + static_cast(weight->DataPtr()), lda, + static_cast(grad_output_promoted->DataPtr()), + out_features, &beta, static_cast(grad_input->DataPtr()), + in_features)); + }), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP({ + CUBLAS_CHECK(cublasGemmEx( + handle, trans_a, CUBLAS_OP_N, in_features, bs, out_features, &alpha, weight->DataPtr(), + CUDA_R_16BF, lda, grad_output_promoted->DataPtr(), CUDA_R_16BF, out_features, &beta, + grad_input->DataPtr(), CUDA_R_32F, in_features, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + }), + DataType::kBFLOAT16) + } + + return grad_input; +} + +std::shared_ptr LinearBackwardWeight(const std::shared_ptr &input, + const std::shared_ptr &grad_output, bool transpose, + int64_t in_features, int64_t out_features) { + const auto &grad_output_dims = grad_output->Dims(); + CHECK_GE(grad_output_dims.size(), 2); + const int64_t bs + = std::accumulate(grad_output_dims.rbegin() + 1, grad_output_dims.rend(), 1, std::multiplies{}); + + auto compute_dtype = input->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); auto grad_output_promoted - = dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + const std::vector weight_dims + = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; + // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. + auto grad_weight = std::make_shared(weight_dims, output_dtype, grad_output->GetDevice()); - // Allocate only needed gradient tensors (selective save: input/weight may be nullptr). - std::shared_ptr grad_input = nullptr; - std::shared_ptr grad_weight = nullptr; - std::shared_ptr grad_bias = nullptr; + auto device = grad_output->GetDevice(); + float alpha = 1.0f; + float beta = 0.0f; + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); - if (compute_grad_input) { - grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); - } - if (compute_grad_weight) { - grad_weight = std::make_shared(weight_dims, output_dtype, grad_output->GetDevice()); - } - // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output, and ReduceColumnsKernel assigns directly. - if (compute_grad_bias && bias) { - grad_bias - = std::make_shared(std::vector{out_features}, output_dtype, grad_output->GetDevice()); + // - if transpose: + // d_weight = d_output.T * input --> d_weight.T = input.T * d_output + // C = d_weight.T[in_features, out_features] + // A = input.T[in_features, bs] + // B = d_output.T[out_features, bs] + // + // - if not transpose: + // d_weight = input.T * d_output --> d_weight.T = d_output.T * input + // C = d_weight.T[out_features, in_features] + // A = d_output.T[out_features, bs] + // B = input.T[in_features, bs] + int m = transpose ? in_features : out_features; + int n = transpose ? out_features : in_features; + auto ldc = transpose ? in_features : out_features; + + switch (compute_dtype) { + DISPATCH_CASE(WRAP({ + const void *a = transpose ? input->DataPtr() : grad_output_promoted->DataPtr(); + const void *b = transpose ? grad_output_promoted->DataPtr() : input->DataPtr(); + auto lda = transpose ? in_features : out_features; + auto ldb = transpose ? out_features : in_features; + CUBLAS_CHECK(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, m, n, bs, &alpha, + static_cast(a), lda, static_cast(b), + ldb, &beta, static_cast(grad_weight->DataPtr()), ldc)); + }), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP({ + const void *a = transpose ? input->DataPtr() : grad_output_promoted->DataPtr(); + const void *b = transpose ? grad_output_promoted->DataPtr() : input->DataPtr(); + auto lda = transpose ? in_features : out_features; + auto ldb = transpose ? out_features : in_features; + CUBLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, m, n, bs, &alpha, a, CUDA_R_16BF, + lda, b, CUDA_R_16BF, ldb, &beta, grad_weight->DataPtr(), CUDA_R_32F, + ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); + }), + DataType::kBFLOAT16) } + return grad_weight; +} + +std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_output, int64_t out_features) { + const auto &dims = grad_output->Dims(); + CHECK_GE(dims.size(), 2); + const int64_t bs = std::accumulate(dims.rbegin() + 1, dims.rend(), 1, std::multiplies{}); + + auto compute_dtype = grad_output->Dtype(); + // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + auto grad_bias + = std::make_shared(std::vector{out_features}, output_dtype, grad_output->GetDevice()); + auto device = grad_output->GetDevice(); const auto &cuda_stream = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - float alpha = 1.0f; - float beta = 0.0f; - - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - + // d_bias = \sum_i(i=0, bs-1) d_output[i] + // TODO(dcj): use thrust::fill or reduce kernel do this + constexpr int BLOCK_SIZE = 256; switch (compute_dtype) { - // TODO(zbl): use cublasSgemv if possible - DISPATCH_CASE( - WRAP({ - if (compute_grad_input) { - // - if transpose: - // weight is [out_features, in_features] here - // d_input = d_output * weight --> d_input.T = weight.T * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[in_features, out_features] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // weight is [in_features, out_features] here - // d_input = d_output * weight.T --> d_input.T = weight * d_output.T - // C = d_input.T[in_features, bs] - // A = weight.T[out_features, in_features] - // B = d_output.T[out_features, bs] - CHECK(weight != nullptr) - << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; - auto weight_promoted - = weight_dtype == compute_dtype ? weight : std::make_shared(weight->To(compute_dtype)); - auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; - auto lda1 = transpose ? in_features : out_features; - CUBLAS_CHECK(cublasSgemm(handle, trans_a1, CUBLAS_OP_N, in_features, bs, out_features, &alpha, - static_cast(weight_promoted->DataPtr()), lda1, - static_cast(grad_output_promoted->DataPtr()), out_features, - &beta, static_cast(grad_input->DataPtr()), in_features)); - } - if (compute_grad_weight) { - // - if transpose: - // d_weight = d_output.T * input --> d_weight.T = input.T * d_output - // C = d_weight.T[in_features, out_features] - // A = input.T[in_features, bs] - // B = d_output.T[out_features, bs] - // - // - if not transpose: - // d_weight = input.T * d_output --> d_weight.T = d_output.T * input - // C = d_weight.T[out_features, in_features] - // A = d_output.T[out_features, bs] - // B = input.T[in_features, bs] - CHECK(input != nullptr) - << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; - auto input_promoted - = input_dtype == compute_dtype ? input : std::make_shared(input->To(compute_dtype)); - auto trans_a2 = CUBLAS_OP_N; - auto trans_b2 = CUBLAS_OP_T; - int m2 = transpose ? in_features : out_features; - int n2 = transpose ? out_features : in_features; - const void *a2 = transpose ? input_promoted->DataPtr() : grad_output_promoted->DataPtr(); - const void *b2 = transpose ? grad_output_promoted->DataPtr() : input_promoted->DataPtr(); - auto lda2 = transpose ? in_features : out_features; - auto ldb2 = transpose ? out_features : in_features; - auto ldc2 = transpose ? in_features : out_features; - CUBLAS_CHECK(cublasSgemm(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, - static_cast(a2), lda2, static_cast(b2), ldb2, - &beta, static_cast(grad_weight->DataPtr()), ldc2)); - } - // d_bias = \sum_i(i=0, bs-1) d_output[i] - // TODO(dcj): use thrust::fill or reduce kernel do this - if (compute_grad_bias && bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); - } - }), - DataType::kFLOAT32) DISPATCH_CASE(WRAP({ - if (compute_grad_input) { - CHECK(weight != nullptr) - << "compute_grad_input=true but weight is nullptr (selective save mismatch)"; - auto weight_promoted = weight_dtype == compute_dtype - ? weight - : std::make_shared(weight->To(compute_dtype)); - auto trans_a1 = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; - auto lda1 = transpose ? in_features : out_features; - CUBLAS_CHECK(cublasGemmEx(handle, trans_a1, CUBLAS_OP_N, in_features, bs, out_features, - &alpha, weight_promoted->DataPtr(), CUDA_R_16BF, lda1, - grad_output_promoted->DataPtr(), CUDA_R_16BF, out_features, - &beta, grad_input->DataPtr(), CUDA_R_32F, in_features, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - } - if (compute_grad_weight) { - CHECK(input != nullptr) - << "compute_grad_weight=true but input is nullptr (selective save mismatch)"; - auto input_promoted = input_dtype == compute_dtype - ? input - : std::make_shared(input->To(compute_dtype)); - auto trans_a2 = CUBLAS_OP_N; - auto trans_b2 = CUBLAS_OP_T; - int m2 = transpose ? in_features : out_features; - int n2 = transpose ? out_features : in_features; - const void *a2 = transpose ? input_promoted->DataPtr() : grad_output_promoted->DataPtr(); - const void *b2 = transpose ? grad_output_promoted->DataPtr() : input_promoted->DataPtr(); - auto lda2 = transpose ? in_features : out_features; - auto ldb2 = transpose ? out_features : in_features; - auto ldc2 = transpose ? in_features : out_features; - CUBLAS_CHECK(cublasGemmEx(handle, trans_a2, trans_b2, m2, n2, bs, &alpha, a2, CUDA_R_16BF, - lda2, b2, CUDA_R_16BF, ldb2, &beta, grad_weight->DataPtr(), - CUDA_R_32F, ldc2, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - } - if (compute_grad_bias && bias) { - constexpr int BLOCK_SIZE = 256; - int threads_per_block = BLOCK_SIZE; - int num_blocks = out_features; - ReduceColumnsKernel<<>>( - static_cast(grad_output_promoted->DataPtr()), - static_cast(grad_bias->DataPtr()), out_features, bs); - } + ReduceColumnsKernel<<>>( + static_cast(grad_output->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); + }), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP({ + ReduceColumnsKernel<<>>( + static_cast(grad_output->DataPtr()), + static_cast(grad_bias->DataPtr()), out_features, bs); }), DataType::kBFLOAT16) } - return {grad_input, grad_weight, grad_bias}; + return grad_bias; } } // namespace infini_train::kernels::cuda @@ -503,6 +490,8 @@ LinearBackward(const std::shared_ptr &input, const std::shared_ptr Date: Fri, 10 Apr 2026 08:15:48 +0000 Subject: [PATCH 2/3] refactor(matmul): split MatmulBackward kernel into 2 independent kernels Move needs_input_grad logic from kernel to autograd layer. The monolithic MatmulBackward kernel is replaced by MatmulBackwardInput1 and MatmulBackwardInput2. --- infini_train/src/autograd/matmul.cc | 33 ++++- infini_train/src/kernels/cpu/linear.cc | 71 ++++++--- infini_train/src/kernels/cuda/linear.cu | 182 ++++++++++++++---------- infini_train/src/kernels/cuda/outer.cu | 2 +- 4 files changed, 185 insertions(+), 103 deletions(-) diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 49f593bf..259cb4a4 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -31,10 +31,17 @@ void Matmul::SetupContext(const std::vector> &input_tens // FIXME: compute_dtype is not necessarily the dtype of output_tensor; it should be // determined by autocast, not derived from output->Dtype(). auto compute_dtype = output->Dtype(); - saved_tensors_ = { - input1->Dtype() == compute_dtype ? input1 : std::make_shared(input1->To(compute_dtype)), - input2->Dtype() == compute_dtype ? input2 : std::make_shared(input2->To(compute_dtype)), + + // grad_input1 = grad_output @ input2^T, so input2 is needed + // grad_input2 = grad_output^T @ input1, so input1 is needed + bool need_grad_input1 = needs_input_grad_.size() > 0 && needs_input_grad_[0]; + bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + + auto cast = [&](const std::shared_ptr &t) { + return t->Dtype() == compute_dtype ? t : std::make_shared(t->To(compute_dtype)); }; + + saved_tensors_ = {need_grad_input2 ? cast(input1) : nullptr, need_grad_input1 ? cast(input2) : nullptr}; out_features_ = output->Dims()[0]; } @@ -45,10 +52,24 @@ std::vector> Matmul::Backward(const std::vector 0 && needs_input_grad_[0]; + bool need_grad_input2 = needs_input_grad_.size() > 1 && needs_input_grad_[1]; + auto device = input1->GetDevice().type(); - auto [grad_input1, grad_input2] - = Dispatcher::Instance().Call, std::shared_ptr>>( - {device, "MatmulBackward"}, input1, input2, grad_output); + + std::shared_ptr grad_input1 = nullptr; + std::shared_ptr grad_input2 = nullptr; + + if (need_grad_input1) { + grad_input1 = Dispatcher::Instance().Call>({device, "MatmulBackwardInput1"}, input2, + grad_output, input1->Dims()); + } + if (need_grad_input2) { + grad_input2 = Dispatcher::Instance().Call>({device, "MatmulBackwardInput2"}, input1, + grad_output, input2->Dims()); + } + return {grad_input1, grad_input2}; } } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/linear.cc b/infini_train/src/kernels/cpu/linear.cc index f238135c..361c56f8 100644 --- a/infini_train/src/kernels/cpu/linear.cc +++ b/infini_train/src/kernels/cpu/linear.cc @@ -50,38 +50,71 @@ std::shared_ptr MatmulForward(const std::shared_ptr &input, cons return {output}; } -std::tuple, std::shared_ptr> -MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, - const std::shared_ptr &grad_output) { +std::shared_ptr MatmulBackwardInput1(const std::shared_ptr &other, + const std::shared_ptr &grad_output, + const std::vector &input_dims) { /* grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T - grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] */ - const auto &input_dims = input->Dims(); const auto &other_dims = other->Dims(); const auto &grad_output_dims = grad_output->Dims(); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(other_dims.size(), grad_output_dims.size()); + + const int64_t m = grad_output_dims[grad_output_dims.size() - 2]; + const int64_t k = other_dims[other_dims.size() - 2]; + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; + + const int64_t bs + = std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) { + CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); + grad_input->Fill(0.0f); + + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + const float grad = static_cast(grad_output->DataPtr())[b * m * n + i * n + j]; + for (int64_t p = 0; p < k; ++p) { + const auto other_idx = b * k * n + p * n + j; + static_cast(grad_input->DataPtr())[b * m * k + i * k + p] + += grad * static_cast(other->DataPtr())[other_idx]; + } + } + } + } + return grad_input; +} + +std::shared_ptr MatmulBackwardInput2(const std::shared_ptr &input1, + const std::shared_ptr &grad_output, + const std::vector &other_dims) { + /* + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + const auto &input_dims = input1->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + CHECK_GE(input_dims.size(), 2); - CHECK_EQ(input_dims.size(), other_dims.size()); CHECK_EQ(input_dims.size(), grad_output_dims.size()); const int64_t m = input_dims[input_dims.size() - 2]; const int64_t k = input_dims[input_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); - const int64_t n = other_dims[other_dims.size() - 1]; - + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); - CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + CHECK_EQ(k, other_dims[other_dims.size() - 2]); const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; } - auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); auto grad_other = std::make_shared(other_dims, DataType::kFLOAT32); - grad_input->Fill(0.0f); grad_other->Fill(0.0f); for (int64_t b = 0; b < bs; ++b) { @@ -90,16 +123,13 @@ MatmulBackward(const std::shared_ptr &input, const std::shared_ptr(grad_output->DataPtr())[b * m * n + i * n + j]; for (int64_t p = 0; p < k; ++p) { const auto input_idx = b * m * k + i * k + p; - const auto other_idx = b * k * n + p * n + j; - static_cast(grad_input->DataPtr())[input_idx] - += grad * static_cast(other->DataPtr())[other_idx]; - static_cast(grad_other->DataPtr())[other_idx] - += grad * static_cast(input->DataPtr())[input_idx]; + static_cast(grad_other->DataPtr())[b * k * n + p * n + j] + += grad * static_cast(input1->DataPtr())[input_idx]; } } } } - return {grad_input, grad_other}; + return grad_other; } std::shared_ptr LinearForward(const std::shared_ptr &input, const std::shared_ptr &weight, @@ -201,7 +231,8 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_LINEAR_KERNEL(MatmulForward) -REGISTER_CPU_LINEAR_KERNEL(MatmulBackward) +REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput1) +REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput2) REGISTER_CPU_LINEAR_KERNEL(LinearForward) REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput) REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight) diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index ec079b5d..cbc74c5e 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -80,112 +80,141 @@ std::shared_ptr MatmulForward(const std::shared_ptr &input, cons return output; } -std::tuple, std::shared_ptr> -MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, - const std::shared_ptr &grad_output) { +std::shared_ptr MatmulBackwardInput1(const std::shared_ptr &other, + const std::shared_ptr &grad_output, + const std::vector &input_dims) { /* grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T - grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] */ - auto input_dtype = input->Dtype(); - auto other_dtype = other->Dtype(); - auto grad_output_dtype = grad_output->Dtype(); - // Compute dtype determined by saved tensors (forward compute dtype), not grad_output - DataType compute_dtype = DispatchFunc, DataTypeList>( - {input_dtype, other_dtype}, [=]() { return DataTypeMap_v>; }, - "CUDA MatmulBackward"); + const auto &other_dims = other->Dims(); + const auto &grad_output_dims = grad_output->Dims(); - auto input_promoted = input_dtype == compute_dtype ? input : std::make_shared(input->To(compute_dtype)); - auto other_promoted = other_dtype == compute_dtype ? other : std::make_shared(other->To(compute_dtype)); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(other_dims.size(), grad_output_dims.size()); + + const int64_t m = grad_output_dims[grad_output_dims.size() - 2]; + const int64_t k = other_dims[other_dims.size() - 2]; + const int64_t n = other_dims[other_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + + const int64_t bs + = std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) { + CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto compute_dtype = other->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); auto grad_output_promoted = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - const auto &input_dims = input->Dims(); - const auto &other_dims = other->Dims(); - const auto &grad_output_dims = grad_output->Dims(); + // For bf16 compute, output in fp32 to preserve accumulation precision. + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); + + // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. + + auto device = grad_output->GetDevice(); + const float alpha = 1.0f, beta = 0.0f; + cublasHandle_t handle = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); + // cuBLAS is colmun-major + // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T + // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] + // C = grad_input.T[*, k, m] + // A = other.T[*, n, k] + // B = grad_output.T[*, n, m] + const int lda = n, ldb = n, ldc = k; + const int64_t stride_a = k * n; + const int64_t stride_b = n * m; + const int64_t stride_c = m * k; + switch (compute_dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_32F, lda, + stride_a, grad_output_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, + grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_16BF, lda, + stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, + grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) + } + + return grad_input; +} + +std::shared_ptr MatmulBackwardInput2(const std::shared_ptr &input1, + const std::shared_ptr &grad_output, + const std::vector &other_dims) { + /* + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + + const auto &input_dims = input1->Dims(); + const auto &grad_output_dims = grad_output->Dims(); CHECK_GE(input_dims.size(), 2); - CHECK_EQ(input_dims.size(), other_dims.size()); CHECK_EQ(input_dims.size(), grad_output_dims.size()); const int64_t m = input_dims[input_dims.size() - 2]; const int64_t k = input_dims[input_dims.size() - 1]; - const int64_t n = other_dims[other_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + CHECK_EQ(input_dims[input_dims.size() - 1], other_dims[other_dims.size() - 2]); const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; } - // For bf16 compute, output in fp32 to preserve accumulation precision (matches PyTorch behavior) + auto compute_dtype = input1->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); + auto grad_output_promoted + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); + + // For bf16 compute, output in fp32 to preserve accumulation precision. auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; - auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); auto grad_other = std::make_shared(other_dims, output_dtype, grad_output->GetDevice()); // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. - auto device = input_promoted->GetDevice(); + auto device = grad_output->GetDevice(); const float alpha = 1.0f, beta = 0.0f; cublasHandle_t handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); - { - // cuBLAS is colmun-major - // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T - // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] - // C = grad_input.T[*, k, m] - // A = other.T[*, n, k] - // B = grad_output.T[*, n, m] - const int lda = n, ldb = n, ldc = k; - const int64_t stride_a = k * n; - const int64_t stride_b = n * m; - const int64_t stride_c = m * k; - switch (compute_dtype) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_32F, - lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, - grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other_promoted->DataPtr(), CUDA_R_16BF, - lda, stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, - grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - } - } - - { - // cuBLAS is colmun-major - // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input - // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] - // C = grad_other.T[*, n, k] - // A = grad_output.T[*, n, m] - // B = input.T[*, k, m] - const int lda = n, ldb = k, ldc = n; - const int64_t stride_a = n * m; - const int64_t stride_b = k * m; - const int64_t stride_c = n * k; - switch (compute_dtype) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), - CUDA_R_32F, lda, stride_a, input_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, - grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), - CUDA_R_16BF, lda, stride_a, input_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, - grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - } + // cuBLAS is colmun-major + // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input + // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] + // C = grad_other.T[*, n, k] + // A = grad_output.T[*, n, m] + // B = input.T[*, k, m] + const int lda = n, ldb = k, ldc = n; + const int64_t stride_a = n * m; + const int64_t stride_b = k * m; + const int64_t stride_c = n * k; + switch (compute_dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), + CUDA_R_32F, lda, stride_a, input1->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, + grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), + CUDA_R_16BF, lda, stride_a, input1->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, + grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) } - return {grad_input, grad_other}; + return grad_other; } template __global__ void BiasCopyKernel(T *output, const T *bias, int bs, int out_features) { @@ -328,7 +357,7 @@ std::shared_ptr LinearBackwardInput(const std::shared_ptr &weigh auto grad_output_promoted = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + // For bf16 compute, accumulate in fp32 to preserve precision. auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); @@ -391,7 +420,7 @@ std::shared_ptr LinearBackwardWeight(const std::shared_ptr &inpu auto grad_output_promoted = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + // For bf16 compute, accumulate in fp32 to preserve precision. auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; const std::vector weight_dims = transpose ? std::vector{out_features, in_features} : std::vector{in_features, out_features}; @@ -452,7 +481,7 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o const int64_t bs = std::accumulate(dims.rbegin() + 1, dims.rend(), 1, std::multiplies{}); auto compute_dtype = grad_output->Dtype(); - // For bf16 compute, accumulate in fp32 to preserve precision (matches PyTorch behavior). + // For bf16 compute, accumulate in fp32 to preserve precision. auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; auto grad_bias = std::make_shared(std::vector{out_features}, output_dtype, grad_output->GetDevice()); @@ -488,7 +517,8 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) REGISTER_CUDA_LINEAR_KERNEL(MatmulForward) -REGISTER_CUDA_LINEAR_KERNEL(MatmulBackward) +REGISTER_CUDA_LINEAR_KERNEL(MatmulBackwardInput1) +REGISTER_CUDA_LINEAR_KERNEL(MatmulBackwardInput2) REGISTER_CUDA_LINEAR_KERNEL(LinearForward) REGISTER_CUDA_LINEAR_KERNEL(LinearBackwardInput) REGISTER_CUDA_LINEAR_KERNEL(LinearBackwardWeight) diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index ae7c9f7b..f3140ca5 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -90,7 +90,7 @@ std::tuple, std::shared_ptr> OuterBackward(const auto grad_output_promoted = grad_output_dtype == promoted_type ? grad_output : std::make_shared(grad_output->To(promoted_type)); - // For bf16 compute, output in fp32 to preserve accumulation precision (matches PyTorch behavior) + // For bf16 compute, output in fp32 to preserve accumulation precision. auto output_dtype = (promoted_type == DataType::kBFLOAT16) ? DataType::kFLOAT32 : promoted_type; auto grad_input = std::make_shared(std::vector{M}, output_dtype, grad_output->GetDevice()); auto grad_other = std::make_shared(std::vector{N}, output_dtype, grad_output->GetDevice()); From 23d301bae71d88415d5acd42b35729a0c9299427 Mon Sep 17 00:00:00 2001 From: chen Date: Wed, 15 Apr 2026 01:55:08 +0000 Subject: [PATCH 3/3] refactor(gemm): extract shared GemmCuda primitive; split matmul kernels; rename MatmulBackwardInput1/2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add gemm.cuh / gemm.cu: GemmParams struct + GemmCuda() dispatch (cublasGemmEx or cublasGemmStridedBatchedEx based on batch_count), GetCublasHandle(), GetCudaStream() shared across all GEMM-using kernels - Split matmul kernels (CPU + CUDA) out of linear.cc / linear.cu into dedicated matmul.cc / matmul.cu; linear.* now only contains the four Linear kernels - Rename MatmulBackwardInput1 → MatmulBackwardInput, MatmulBackwardInput2 → MatmulBackwardOther for semantic clarity matching MatmulForward(input, other) parameter names - Rewrite outer.cu to use GemmCuda() (OuterForward + bf16 backward paths); keep cublasSgemv for the fp32 backward path (more efficient, bf16 unsupported) --- infini_train/include/common/cuda/gemm.cuh | 72 +++++ infini_train/src/autograd/matmul.cc | 14 +- infini_train/src/kernels/cpu/linear.cc | 126 +------- infini_train/src/kernels/cpu/matmul.cc | 145 +++++++++ infini_train/src/kernels/cuda/gemm.cu | 73 +++++ infini_train/src/kernels/cuda/linear.cu | 366 ++++------------------ infini_train/src/kernels/cuda/matmul.cu | 229 ++++++++++++++ infini_train/src/kernels/cuda/outer.cu | 179 ++++++----- 8 files changed, 694 insertions(+), 510 deletions(-) create mode 100644 infini_train/include/common/cuda/gemm.cuh create mode 100644 infini_train/src/kernels/cpu/matmul.cc create mode 100644 infini_train/src/kernels/cuda/gemm.cu create mode 100644 infini_train/src/kernels/cuda/matmul.cu diff --git a/infini_train/include/common/cuda/gemm.cuh b/infini_train/include/common/cuda/gemm.cuh new file mode 100644 index 00000000..a258d3e7 --- /dev/null +++ b/infini_train/include/common/cuda/gemm.cuh @@ -0,0 +1,72 @@ +#pragma once + +#include +#include + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" + +namespace infini_train::kernels::cuda { + +/** + * Return the cuBLAS handle associated with the given device. + * Shared by linear.cu, matmul.cu, and any future GEMM-using kernels. + */ +cublasHandle_t GetCublasHandle(const Device &device); + +/** + * Return the CUDA stream associated with the given device. + * Shared by kernels that need to launch device-side code directly. + */ +cudaStream_t GetCudaStream(const Device &device); + +/** + * Parameter bundle for a single GEMM call: + * C = alpha * op(A) * op(B) + beta * C + * + * batch_count == 1 → non-batched path (cublasGemmEx) + * batch_count > 1 → strided-batched (cublasGemmStridedBatchedEx) + * + * When batch_count == 1, stride_a/b/c are unused and must be left at 0. + */ +struct GemmParams { + cublasOperation_t trans_a = CUBLAS_OP_N; + cublasOperation_t trans_b = CUBLAS_OP_N; + + int m = 0; // rows of op(A) and C + int n = 0; // cols of op(B) and C + int k = 0; // cols of op(A) == rows of op(B) + + const void *A = nullptr; + int lda = 0; + const void *B = nullptr; + int ldb = 0; + void *C = nullptr; + int ldc = 0; + + float alpha = 1.0f; + float beta = 0.0f; + + // batch_count=1: non-batched (Linear path); stride_a/b/c must be 0 + // batch_count>1: strided-batched (Matmul path) + int batch_count = 1; + long long stride_a = 0; + long long stride_b = 0; + long long stride_c = 0; + + DataType input_dtype; // dtype of A and B + DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out) + + cublasHandle_t blas_handle = nullptr; +}; + +/** + * Execute the GEMM described by `p` via cuBLAS. + * Dispatches to cublasGemmEx (batch_count==1) or + * cublasGemmStridedBatchedEx (batch_count>1). + * Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision. + * Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)). + */ +void GemmCuda(const GemmParams &p); + +} // namespace infini_train::kernels::cuda diff --git a/infini_train/src/autograd/matmul.cc b/infini_train/src/autograd/matmul.cc index 259cb4a4..47151662 100644 --- a/infini_train/src/autograd/matmul.cc +++ b/infini_train/src/autograd/matmul.cc @@ -58,18 +58,18 @@ std::vector> Matmul::Backward(const std::vectorGetDevice().type(); - std::shared_ptr grad_input1 = nullptr; - std::shared_ptr grad_input2 = nullptr; + std::shared_ptr grad_input = nullptr; + std::shared_ptr grad_other = nullptr; if (need_grad_input1) { - grad_input1 = Dispatcher::Instance().Call>({device, "MatmulBackwardInput1"}, input2, - grad_output, input1->Dims()); + grad_input = Dispatcher::Instance().Call>({device, "MatmulBackwardInput"}, input2, + grad_output, input1->Dims()); } if (need_grad_input2) { - grad_input2 = Dispatcher::Instance().Call>({device, "MatmulBackwardInput2"}, input1, - grad_output, input2->Dims()); + grad_other = Dispatcher::Instance().Call>({device, "MatmulBackwardOther"}, input1, + grad_output, input2->Dims()); } - return {grad_input1, grad_input2}; + return {grad_input, grad_other}; } } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/linear.cc b/infini_train/src/kernels/cpu/linear.cc index 361c56f8..9d28a92a 100644 --- a/infini_train/src/kernels/cpu/linear.cc +++ b/infini_train/src/kernels/cpu/linear.cc @@ -9,128 +9,6 @@ #include "infini_train/include/tensor.h" namespace infini_train::kernels::cpu { -std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { - /* - output[*, m, n] = input[*, m, k] * other[*, k, n] - */ - // TODO(dcj): support broadcast later - const auto &input_dims = input->Dims(); - const auto &other_dims = other->Dims(); - - CHECK_GE(input_dims.size(), 2); - CHECK_GE(other_dims.size(), 2); - CHECK_EQ(input_dims.size(), other_dims.size()); - - const int64_t m = input_dims[input_dims.size() - 2]; - const int64_t k = input_dims[input_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); - const int64_t n = other_dims[other_dims.size() - 1]; - - const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); - for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; - } - - std::vector output_dims = input_dims; - output_dims[output_dims.size() - 1] = n; - auto output = std::make_shared(output_dims, DataType::kFLOAT32); - - for (int64_t b = 0; b < bs; ++b) { - for (int64_t i = 0; i < m; ++i) { - for (int64_t j = 0; j < n; ++j) { - float acc = 0.0f; - for (int64_t p = 0; p < k; ++p) { - acc += static_cast(input->DataPtr())[b * m * k + i * k + p] - * static_cast(other->DataPtr())[b * k * n + p * n + j]; - } - static_cast(output->DataPtr())[b * m * n + i * n + j] = acc; - } - } - } - return {output}; -} - -std::shared_ptr MatmulBackwardInput1(const std::shared_ptr &other, - const std::shared_ptr &grad_output, - const std::vector &input_dims) { - /* - grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T - */ - const auto &other_dims = other->Dims(); - const auto &grad_output_dims = grad_output->Dims(); - - CHECK_GE(other_dims.size(), 2); - CHECK_EQ(other_dims.size(), grad_output_dims.size()); - - const int64_t m = grad_output_dims[grad_output_dims.size() - 2]; - const int64_t k = other_dims[other_dims.size() - 2]; - const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; - - const int64_t bs - = std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies{}); - for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) { - CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match"; - } - - auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); - grad_input->Fill(0.0f); - - for (int64_t b = 0; b < bs; ++b) { - for (int64_t i = 0; i < m; ++i) { - for (int64_t j = 0; j < n; ++j) { - const float grad = static_cast(grad_output->DataPtr())[b * m * n + i * n + j]; - for (int64_t p = 0; p < k; ++p) { - const auto other_idx = b * k * n + p * n + j; - static_cast(grad_input->DataPtr())[b * m * k + i * k + p] - += grad * static_cast(other->DataPtr())[other_idx]; - } - } - } - } - return grad_input; -} - -std::shared_ptr MatmulBackwardInput2(const std::shared_ptr &input1, - const std::shared_ptr &grad_output, - const std::vector &other_dims) { - /* - grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] - */ - const auto &input_dims = input1->Dims(); - const auto &grad_output_dims = grad_output->Dims(); - - CHECK_GE(input_dims.size(), 2); - CHECK_EQ(input_dims.size(), grad_output_dims.size()); - - const int64_t m = input_dims[input_dims.size() - 2]; - const int64_t k = input_dims[input_dims.size() - 1]; - const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; - CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); - CHECK_EQ(k, other_dims[other_dims.size() - 2]); - - const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); - for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; - } - - auto grad_other = std::make_shared(other_dims, DataType::kFLOAT32); - grad_other->Fill(0.0f); - - for (int64_t b = 0; b < bs; ++b) { - for (int64_t i = 0; i < m; ++i) { - for (int64_t j = 0; j < n; ++j) { - const float grad = static_cast(grad_output->DataPtr())[b * m * n + i * n + j]; - for (int64_t p = 0; p < k; ++p) { - const auto input_idx = b * m * k + i * k + p; - static_cast(grad_other->DataPtr())[b * k * n + p * n + j] - += grad * static_cast(input1->DataPtr())[input_idx]; - } - } - } - } - return grad_other; -} std::shared_ptr LinearForward(const std::shared_ptr &input, const std::shared_ptr &weight, bool transpose, const std::shared_ptr &bias) { @@ -225,14 +103,12 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o grad_bias->EigenVector() = grad_output->EigenMatrix().colwise().sum(); return grad_bias; } + } // namespace infini_train::kernels::cpu #define REGISTER_CPU_LINEAR_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) -REGISTER_CPU_LINEAR_KERNEL(MatmulForward) -REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput1) -REGISTER_CPU_LINEAR_KERNEL(MatmulBackwardInput2) REGISTER_CPU_LINEAR_KERNEL(LinearForward) REGISTER_CPU_LINEAR_KERNEL(LinearBackwardInput) REGISTER_CPU_LINEAR_KERNEL(LinearBackwardWeight) diff --git a/infini_train/src/kernels/cpu/matmul.cc b/infini_train/src/kernels/cpu/matmul.cc new file mode 100644 index 00000000..8228434f --- /dev/null +++ b/infini_train/src/kernels/cpu/matmul.cc @@ -0,0 +1,145 @@ +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { + /* + output[*, m, n] = input[*, m, k] * other[*, k, n] + */ + // TODO(dcj): support broadcast later + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = other_dims[other_dims.size() - 1]; + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < static_cast(input_dims.size()) - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + } + + std::vector output_dims = input_dims; + output_dims[output_dims.size() - 1] = n; + auto output = std::make_shared(output_dims, DataType::kFLOAT32); + + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + float acc = 0.0f; + for (int64_t p = 0; p < k; ++p) { + acc += static_cast(input->DataPtr())[b * m * k + i * k + p] + * static_cast(other->DataPtr())[b * k * n + p * n + j]; + } + static_cast(output->DataPtr())[b * m * n + i * n + j] = acc; + } + } + } + return {output}; +} + +std::shared_ptr MatmulBackwardInput(const std::shared_ptr &other, + const std::shared_ptr &grad_output, + const std::vector &input_dims) { + /* + grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T + */ + const auto &other_dims = other->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(other_dims.size(), grad_output_dims.size()); + + const int64_t m = grad_output_dims[grad_output_dims.size() - 2]; + const int64_t k = other_dims[other_dims.size() - 2]; + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; + + const int64_t bs + = std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < static_cast(grad_output_dims.size()) - 2; ++i) { + CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); + grad_input->Fill(0.0f); + + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + const float grad = static_cast(grad_output->DataPtr())[b * m * n + i * n + j]; + for (int64_t p = 0; p < k; ++p) { + const auto other_idx = b * k * n + p * n + j; + static_cast(grad_input->DataPtr())[b * m * k + i * k + p] + += grad * static_cast(other->DataPtr())[other_idx]; + } + } + } + } + return grad_input; +} + +std::shared_ptr MatmulBackwardOther(const std::shared_ptr &input1, + const std::shared_ptr &grad_output, + const std::vector &other_dims) { + /* + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + const auto &input_dims = input1->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_EQ(input_dims.size(), grad_output_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < static_cast(input_dims.size()) - 2; ++i) { + CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto grad_other = std::make_shared(other_dims, DataType::kFLOAT32); + grad_other->Fill(0.0f); + + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + const float grad = static_cast(grad_output->DataPtr())[b * m * n + i * n + j]; + for (int64_t p = 0; p < k; ++p) { + const auto input_idx = b * m * k + i * k + p; + static_cast(grad_other->DataPtr())[b * k * n + p * n + j] + += grad * static_cast(input1->DataPtr())[input_idx]; + } + } + } + } + return grad_other; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_MATMUL_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_MATMUL_KERNEL(MatmulForward) +REGISTER_CPU_MATMUL_KERNEL(MatmulBackwardInput) +REGISTER_CPU_MATMUL_KERNEL(MatmulBackwardOther) + +#undef REGISTER_CPU_MATMUL_KERNEL diff --git a/infini_train/src/kernels/cuda/gemm.cu b/infini_train/src/kernels/cuda/gemm.cu new file mode 100644 index 00000000..3b1efa38 --- /dev/null +++ b/infini_train/src/kernels/cuda/gemm.cu @@ -0,0 +1,73 @@ +#include "infini_train/include/common/cuda/gemm.cuh" + +#include + +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/datatype.h" + +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +cublasHandle_t GetCublasHandle(const Device &device) { + return dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) + ->cublas_handle(); +} + +cudaStream_t GetCudaStream(const Device &device) { + return dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); +} + +namespace { + +cudaDataType_t ToCudaDataType(DataType dt) { + switch (dt) { + case DataType::kFLOAT32: + return CUDA_R_32F; + case DataType::kBFLOAT16: + return CUDA_R_16BF; + case DataType::kFLOAT16: + return CUDA_R_16F; + default: + LOG(FATAL) << "GemmCuda: unsupported DataType " << static_cast(dt); + return CUDA_R_32F; // unreachable + } +} + +} // namespace + +void GemmCuda(const GemmParams &p) { + DCHECK(p.blas_handle != nullptr); + + if (p.batch_count == 1) { + // strides are unused in the non-batched path; assert they are left at 0 + // to catch accidental misuse early. + DCHECK_EQ(p.stride_a, 0LL); + DCHECK_EQ(p.stride_b, 0LL); + DCHECK_EQ(p.stride_c, 0LL); + } + + const cudaDataType_t type_a = ToCudaDataType(p.input_dtype); + const cudaDataType_t type_b = ToCudaDataType(p.input_dtype); + const cudaDataType_t type_c = ToCudaDataType(p.output_dtype); + // Always use CUBLAS_COMPUTE_32F: required for bf16/fp16 correctness, + // and fine for fp32 (same compute path). + const cublasComputeType_t ctype = CUBLAS_COMPUTE_32F; + + if (p.batch_count == 1) { + CUBLAS_CHECK(cublasGemmEx(p.blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.B, + type_b, p.ldb, &p.beta, p.C, type_c, p.ldc, ctype, CUBLAS_GEMM_DEFAULT)); + } else { + CUBLAS_CHECK(cublasGemmStridedBatchedEx(p.blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, + type_a, p.lda, p.stride_a, p.B, type_b, p.ldb, p.stride_b, &p.beta, p.C, + type_c, p.ldc, p.stride_c, p.batch_count, ctype, CUBLAS_GEMM_DEFAULT)); + } +} + +} // namespace infini_train::kernels::cuda diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index cbc74c5e..9efb3688 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -7,216 +7,13 @@ #include #include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/common/cuda/gemm.cuh" #include "infini_train/include/common/cuda/kernel_helper.cuh" -#include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" -#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" - namespace infini_train::kernels::cuda { -std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { - /* - output[*, m, n] = input[*, m, k] * other[*, k, n] - */ - const auto &input_dims = input->Dims(); - const auto &other_dims = other->Dims(); - - CHECK_GE(input_dims.size(), 2); - CHECK_GE(other_dims.size(), 2); - CHECK_EQ(input_dims.size(), other_dims.size()); - - const int64_t m = input_dims[input_dims.size() - 2]; - const int64_t k = input_dims[input_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); - const int64_t n = other_dims[other_dims.size() - 1]; - - const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); - for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; - } - - auto dtype = input->Dtype(); - std::vector output_dims = input_dims; - output_dims[output_dims.size() - 1] = n; - auto output = std::make_shared(output_dims, dtype, input->GetDevice()); - - auto device = input->GetDevice(); - const float alpha = 1.0f, beta = 0.0f; - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - - // cuBLAS is colmun-major - // output = input * other --> output.T = other.T * input.T - // C = A * B ==> output.T[*, n, m] = other.T[*, n, k] * input.T[*, k, m] - // C = output.T[*, n, m] - // A = other.T[*, n, k] - // B = input.T[*, k, m] - int lda = n; - int ldb = k; - int ldc = n; - int64_t stride_a = n * k; - int64_t stride_b = k * m; - int64_t stride_c = m * n; - // NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere) - - switch (dtype) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_32F, lda, - stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_32F, - ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_16BF, lda, - stride_a, input->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_16BF, - ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - default: - LOG_UNSUPPORTED_DTYPE(dtype, "CUDA MatmulForward"); - } - - return output; -} - -std::shared_ptr MatmulBackwardInput1(const std::shared_ptr &other, - const std::shared_ptr &grad_output, - const std::vector &input_dims) { - /* - grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T - */ - - const auto &other_dims = other->Dims(); - const auto &grad_output_dims = grad_output->Dims(); - - CHECK_GE(other_dims.size(), 2); - CHECK_EQ(other_dims.size(), grad_output_dims.size()); - - const int64_t m = grad_output_dims[grad_output_dims.size() - 2]; - const int64_t k = other_dims[other_dims.size() - 2]; - const int64_t n = other_dims[other_dims.size() - 1]; - CHECK_EQ(k, other_dims[other_dims.size() - 2]); - CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); - CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); - - const int64_t bs - = std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies{}); - for (int64_t i = 0; i < grad_output_dims.size() - 2; ++i) { - CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match"; - } - - auto compute_dtype = other->Dtype(); - auto grad_output_dtype = grad_output->Dtype(); - auto grad_output_promoted - = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - - // For bf16 compute, output in fp32 to preserve accumulation precision. - auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; - auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); - - // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. - - auto device = grad_output->GetDevice(); - const float alpha = 1.0f, beta = 0.0f; - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - - // cuBLAS is colmun-major - // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T - // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] - // C = grad_input.T[*, k, m] - // A = other.T[*, n, k] - // B = grad_output.T[*, n, m] - const int lda = n, ldb = n, ldc = k; - const int64_t stride_a = k * n; - const int64_t stride_b = n * m; - const int64_t stride_c = m * k; - switch (compute_dtype) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_32F, lda, - stride_a, grad_output_promoted->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, - grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_16BF, lda, - stride_a, grad_output_promoted->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, - grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - } - - return grad_input; -} - -std::shared_ptr MatmulBackwardInput2(const std::shared_ptr &input1, - const std::shared_ptr &grad_output, - const std::vector &other_dims) { - /* - grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] - */ - - const auto &input_dims = input1->Dims(); - const auto &grad_output_dims = grad_output->Dims(); - CHECK_GE(input_dims.size(), 2); - CHECK_EQ(input_dims.size(), grad_output_dims.size()); - - const int64_t m = input_dims[input_dims.size() - 2]; - const int64_t k = input_dims[input_dims.size() - 1]; - const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; - CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); - CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); - CHECK_EQ(input_dims[input_dims.size() - 1], other_dims[other_dims.size() - 2]); - - const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); - for (int64_t i = 0; i < input_dims.size() - 2; ++i) { - CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; - CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; - } - - auto compute_dtype = input1->Dtype(); - auto grad_output_dtype = grad_output->Dtype(); - auto grad_output_promoted - = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); - - // For bf16 compute, output in fp32 to preserve accumulation precision. - auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; - auto grad_other = std::make_shared(other_dims, output_dtype, grad_output->GetDevice()); - - // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. - - auto device = grad_output->GetDevice(); - const float alpha = 1.0f, beta = 0.0f; - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - - // cuBLAS is colmun-major - // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input - // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] - // C = grad_other.T[*, n, k] - // A = grad_output.T[*, n, m] - // B = input.T[*, k, m] - const int lda = n, ldb = k, ldc = n; - const int64_t stride_a = n * m; - const int64_t stride_b = k * m; - const int64_t stride_c = n * k; - switch (compute_dtype) { - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), - CUDA_R_32F, lda, stride_a, input1->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, - grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( - handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output_promoted->DataPtr(), - CUDA_R_16BF, lda, stride_a, input1->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, - grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), - DataType::kBFLOAT16) - } - - return grad_other; -} - template __global__ void BiasCopyKernel(T *output, const T *bias, int bs, int out_features) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= bs * out_features) { @@ -260,9 +57,7 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons auto output = std::make_shared(output_dims, dtype, input->GetDevice()); auto device = input->GetDevice(); - const auto &cuda_stream = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) - ->cuda_stream(); + const auto cuda_stream = GetCudaStream(device); if (bias) { CHECK_EQ(bias->Dims().size(), 1); @@ -282,15 +77,6 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons input->Dtype(), [=]() { output->Fill(0); }, "CUDA LinearForward"); } - const float alpha = 1.0f; - const float beta = 1.0f; - auto trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - auto trans_b = CUBLAS_OP_N; - auto lda = transpose ? in_features : out_features; - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - // TODO(zbl): use cublasSgemv if possible for convenience and simplicity // // - if a is transposed: @@ -305,22 +91,26 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons // C = output.T[out_features, bs] // A = weight.T[out_features, in_features] // B = input.T[in_features, bs] - switch (input->Dtype()) { - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasSgemm(handle, trans_a, trans_b, out_features, bs, in_features, &alpha, - static_cast(weight->DataPtr()), lda, - static_cast(input->DataPtr()), in_features, &beta, - static_cast(output->DataPtr()), out_features)); - }), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasGemmEx(handle, trans_a, trans_b, out_features, bs, in_features, &alpha, - weight->DataPtr(), CUDA_R_16BF, lda, input->DataPtr(), CUDA_R_16BF, - in_features, &beta, output->DataPtr(), CUDA_R_16BF, out_features, - CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - }), - DataType::kBFLOAT16) - } + GemmParams p; + p.trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + p.trans_b = CUBLAS_OP_N; + p.m = static_cast(out_features); + p.n = static_cast(bs); + p.k = static_cast(in_features); + p.A = weight->DataPtr(); + p.lda = static_cast(transpose ? in_features : out_features); + p.B = input->DataPtr(); + p.ldb = static_cast(in_features); + p.C = output->DataPtr(); + p.ldc = static_cast(out_features); + p.alpha = 1.0f; + p.beta = 1.0f; // bias already written into output; beta=1 accumulates + p.batch_count = 1; + p.input_dtype = dtype; + p.output_dtype = dtype; + p.blas_handle = GetCublasHandle(device); + + GemmCuda(p); return output; } @@ -362,13 +152,6 @@ std::shared_ptr LinearBackwardInput(const std::shared_ptr &weigh // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); - auto device = grad_output->GetDevice(); - float alpha = 1.0f; - float beta = 0.0f; - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - // TODO(zbl): use cublasSgemv if possible // - if transpose: // weight is [out_features, in_features] here @@ -383,26 +166,26 @@ std::shared_ptr LinearBackwardInput(const std::shared_ptr &weigh // C = d_input.T[in_features, bs] // A = weight.T[out_features, in_features] // B = d_output.T[out_features, bs] - auto trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; - auto lda = transpose ? in_features : out_features; - - switch (compute_dtype) { - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasSgemm(handle, trans_a, CUBLAS_OP_N, in_features, bs, out_features, &alpha, - static_cast(weight->DataPtr()), lda, - static_cast(grad_output_promoted->DataPtr()), - out_features, &beta, static_cast(grad_input->DataPtr()), - in_features)); - }), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasGemmEx( - handle, trans_a, CUBLAS_OP_N, in_features, bs, out_features, &alpha, weight->DataPtr(), - CUDA_R_16BF, lda, grad_output_promoted->DataPtr(), CUDA_R_16BF, out_features, &beta, - grad_input->DataPtr(), CUDA_R_32F, in_features, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - }), - DataType::kBFLOAT16) - } + GemmParams p; + p.trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T; + p.trans_b = CUBLAS_OP_N; + p.m = static_cast(in_features); + p.n = static_cast(bs); + p.k = static_cast(out_features); + p.A = weight->DataPtr(); + p.lda = static_cast(transpose ? in_features : out_features); + p.B = grad_output_promoted->DataPtr(); + p.ldb = static_cast(out_features); + p.C = grad_input->DataPtr(); + p.ldc = static_cast(in_features); + p.alpha = 1.0f; + p.beta = 0.0f; + p.batch_count = 1; + p.input_dtype = compute_dtype; + p.output_dtype = output_dtype; + p.blas_handle = GetCublasHandle(grad_output->GetDevice()); + + GemmCuda(p); return grad_input; } @@ -427,13 +210,6 @@ std::shared_ptr LinearBackwardWeight(const std::shared_ptr &inpu // No Fill(0) needed: cuBLAS beta=0.0f fully overwrites output. auto grad_weight = std::make_shared(weight_dims, output_dtype, grad_output->GetDevice()); - auto device = grad_output->GetDevice(); - float alpha = 1.0f; - float beta = 0.0f; - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - // - if transpose: // d_weight = d_output.T * input --> d_weight.T = input.T * d_output // C = d_weight.T[in_features, out_features] @@ -445,32 +221,31 @@ std::shared_ptr LinearBackwardWeight(const std::shared_ptr &inpu // C = d_weight.T[out_features, in_features] // A = d_output.T[out_features, bs] // B = input.T[in_features, bs] - int m = transpose ? in_features : out_features; - int n = transpose ? out_features : in_features; - auto ldc = transpose ? in_features : out_features; - - switch (compute_dtype) { - DISPATCH_CASE(WRAP({ - const void *a = transpose ? input->DataPtr() : grad_output_promoted->DataPtr(); - const void *b = transpose ? grad_output_promoted->DataPtr() : input->DataPtr(); - auto lda = transpose ? in_features : out_features; - auto ldb = transpose ? out_features : in_features; - CUBLAS_CHECK(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, m, n, bs, &alpha, - static_cast(a), lda, static_cast(b), - ldb, &beta, static_cast(grad_weight->DataPtr()), ldc)); - }), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP({ - const void *a = transpose ? input->DataPtr() : grad_output_promoted->DataPtr(); - const void *b = transpose ? grad_output_promoted->DataPtr() : input->DataPtr(); - auto lda = transpose ? in_features : out_features; - auto ldb = transpose ? out_features : in_features; - CUBLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, m, n, bs, &alpha, a, CUDA_R_16BF, - lda, b, CUDA_R_16BF, ldb, &beta, grad_weight->DataPtr(), CUDA_R_32F, - ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - }), - DataType::kBFLOAT16) - } + const void *a = transpose ? input->DataPtr() : grad_output_promoted->DataPtr(); + const void *b = transpose ? grad_output_promoted->DataPtr() : input->DataPtr(); + const int lda = static_cast(transpose ? in_features : out_features); + const int ldb = static_cast(transpose ? out_features : in_features); + + GemmParams p; + p.trans_a = CUBLAS_OP_N; + p.trans_b = CUBLAS_OP_T; + p.m = static_cast(transpose ? in_features : out_features); + p.n = static_cast(transpose ? out_features : in_features); + p.k = static_cast(bs); + p.A = a; + p.lda = lda; + p.B = b; + p.ldb = ldb; + p.C = grad_weight->DataPtr(); + p.ldc = static_cast(transpose ? in_features : out_features); + p.alpha = 1.0f; + p.beta = 0.0f; + p.batch_count = 1; + p.input_dtype = compute_dtype; + p.output_dtype = output_dtype; + p.blas_handle = GetCublasHandle(grad_output->GetDevice()); + + GemmCuda(p); return grad_weight; } @@ -487,9 +262,7 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o = std::make_shared(std::vector{out_features}, output_dtype, grad_output->GetDevice()); auto device = grad_output->GetDevice(); - const auto &cuda_stream = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) - ->cuda_stream(); + const auto cuda_stream = GetCudaStream(device); // d_bias = \sum_i(i=0, bs-1) d_output[i] // TODO(dcj): use thrust::fill or reduce kernel do this @@ -516,9 +289,6 @@ std::shared_ptr LinearBackwardBias(const std::shared_ptr &grad_o #define REGISTER_CUDA_LINEAR_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) -REGISTER_CUDA_LINEAR_KERNEL(MatmulForward) -REGISTER_CUDA_LINEAR_KERNEL(MatmulBackwardInput1) -REGISTER_CUDA_LINEAR_KERNEL(MatmulBackwardInput2) REGISTER_CUDA_LINEAR_KERNEL(LinearForward) REGISTER_CUDA_LINEAR_KERNEL(LinearBackwardInput) REGISTER_CUDA_LINEAR_KERNEL(LinearBackwardWeight) diff --git a/infini_train/src/kernels/cuda/matmul.cu b/infini_train/src/kernels/cuda/matmul.cu new file mode 100644 index 00000000..7e3fa0c0 --- /dev/null +++ b/infini_train/src/kernels/cuda/matmul.cu @@ -0,0 +1,229 @@ +#include +#include +#include +#include + +#include + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/common/cuda/gemm.cuh" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cuda { + +std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { + /* + output[*, m, n] = input[*, m, k] * other[*, k, n] + */ + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = other_dims[other_dims.size() - 1]; + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < static_cast(input_dims.size()) - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto dtype = input->Dtype(); + std::vector output_dims = input_dims; + output_dims[output_dims.size() - 1] = n; + auto output = std::make_shared(output_dims, dtype, input->GetDevice()); + + auto device = input->GetDevice(); + + // cuBLAS is colmun-major + // output = input * other --> output.T = other.T * input.T + // C = A * B ==> output.T[*, n, m] = other.T[*, n, k] * input.T[*, k, m] + // C = output.T[*, n, m] + // A = other.T[*, n, k] + // B = input.T[*, k, m] + // NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere) + GemmParams p; + p.trans_a = CUBLAS_OP_N; + p.trans_b = CUBLAS_OP_N; + p.m = static_cast(n); + p.n = static_cast(m); + p.k = static_cast(k); + p.A = other->DataPtr(); + p.lda = static_cast(n); + p.stride_a = n * k; + p.B = input->DataPtr(); + p.ldb = static_cast(k); + p.stride_b = k * m; + p.C = output->DataPtr(); + p.ldc = static_cast(n); + p.stride_c = m * n; + p.alpha = 1.0f; + p.beta = 0.0f; + p.batch_count = static_cast(bs); + p.input_dtype = dtype; + p.output_dtype = dtype; + p.blas_handle = GetCublasHandle(device); + + GemmCuda(p); + + return output; +} + +std::shared_ptr MatmulBackwardInput(const std::shared_ptr &other, + const std::shared_ptr &grad_output, + const std::vector &input_dims) { + /* + grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T + */ + + const auto &other_dims = other->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(other_dims.size(), grad_output_dims.size()); + + const int64_t m = grad_output_dims[grad_output_dims.size() - 2]; + const int64_t k = other_dims[other_dims.size() - 2]; + const int64_t n = other_dims[other_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + + const int64_t bs + = std::accumulate(grad_output_dims.rbegin() + 2, grad_output_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < static_cast(grad_output_dims.size()) - 2; ++i) { + CHECK_EQ(grad_output_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto compute_dtype = other->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); + auto grad_output_promoted + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); + + // For bf16 compute, output in fp32 to preserve accumulation precision. + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); + + // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. + + auto device = grad_output->GetDevice(); + + // cuBLAS is colmun-major + // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T + // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] + // C = grad_input.T[*, k, m] + // A = other.T[*, n, k] + // B = grad_output.T[*, n, m] + GemmParams p; + p.trans_a = CUBLAS_OP_T; + p.trans_b = CUBLAS_OP_N; + p.m = static_cast(k); + p.n = static_cast(m); + p.k = static_cast(n); + p.A = other->DataPtr(); + p.lda = static_cast(n); + p.stride_a = k * n; + p.B = grad_output_promoted->DataPtr(); + p.ldb = static_cast(n); + p.stride_b = n * m; + p.C = grad_input->DataPtr(); + p.ldc = static_cast(k); + p.stride_c = m * k; + p.alpha = 1.0f; + p.beta = 0.0f; + p.batch_count = static_cast(bs); + p.input_dtype = compute_dtype; + p.output_dtype = output_dtype; + p.blas_handle = GetCublasHandle(device); + + GemmCuda(p); + + return grad_input; +} + +std::shared_ptr MatmulBackwardOther(const std::shared_ptr &input1, + const std::shared_ptr &grad_output, + const std::vector &other_dims) { + /* + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + + const auto &input_dims = input1->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + CHECK_GE(input_dims.size(), 2); + CHECK_EQ(input_dims.size(), grad_output_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + const int64_t n = grad_output_dims[grad_output_dims.size() - 1]; + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + CHECK_EQ(input_dims[input_dims.size() - 1], other_dims[other_dims.size() - 2]); + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < static_cast(input_dims.size()) - 2; ++i) { + CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto compute_dtype = input1->Dtype(); + auto grad_output_dtype = grad_output->Dtype(); + auto grad_output_promoted + = grad_output_dtype == compute_dtype ? grad_output : std::make_shared(grad_output->To(compute_dtype)); + + // For bf16 compute, output in fp32 to preserve accumulation precision. + auto output_dtype = (compute_dtype == DataType::kBFLOAT16) ? DataType::kFLOAT32 : compute_dtype; + auto grad_other = std::make_shared(other_dims, output_dtype, grad_output->GetDevice()); + + // No Fill(0) needed: cuBLAS beta=0.0f means C is fully overwritten, never read. + + auto device = grad_output->GetDevice(); + + // cuBLAS is colmun-major + // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input + // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] + // C = grad_other.T[*, n, k] + // A = grad_output.T[*, n, m] + // B = input.T[*, k, m] + GemmParams p; + p.trans_a = CUBLAS_OP_N; + p.trans_b = CUBLAS_OP_T; + p.m = static_cast(n); + p.n = static_cast(k); + p.k = static_cast(m); + p.A = grad_output_promoted->DataPtr(); + p.lda = static_cast(n); + p.stride_a = n * m; + p.B = input1->DataPtr(); + p.ldb = static_cast(k); + p.stride_b = k * m; + p.C = grad_other->DataPtr(); + p.ldc = static_cast(n); + p.stride_c = n * k; + p.alpha = 1.0f; + p.beta = 0.0f; + p.batch_count = static_cast(bs); + p.input_dtype = compute_dtype; + p.output_dtype = output_dtype; + p.blas_handle = GetCublasHandle(device); + + GemmCuda(p); + + return grad_other; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_MATMUL_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_MATMUL_KERNEL(MatmulForward) +REGISTER_CUDA_MATMUL_KERNEL(MatmulBackwardInput) +REGISTER_CUDA_MATMUL_KERNEL(MatmulBackwardOther) + +#undef REGISTER_CUDA_MATMUL_KERNEL diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index f3140ca5..5ac854c6 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -7,13 +7,12 @@ #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" -#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/common/cuda/gemm.cuh" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" -#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" - namespace infini_train::kernels::cuda { + std::shared_ptr OuterForward(const std::shared_ptr &input, const std::shared_ptr &other) { /* Computes outer product: output[i, j] = input[i] * other[j] @@ -29,35 +28,36 @@ std::shared_ptr OuterForward(const std::shared_ptr &input, const const int64_t M = in_dims[0]; const int64_t N = ot_dims[0]; - auto output = std::make_shared(std::vector{M, N}, input->Dtype(), input->GetDevice()); + auto dtype = input->Dtype(); + auto output = std::make_shared(std::vector{M, N}, dtype, input->GetDevice()); auto device = input->GetDevice(); + // reinterpret input: [M] as column vector [M, 1] // reinterpret other: [N] as row vector [1, N] // output[M, N] = input[M, 1] * other.T[1, N] // output.T[N, M] = other[N, 1] * input.T[1, M] - float alpha = 1.0f; - float beta = 0.0f; - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); - - switch (input->Dtype()) { - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, 1, &alpha, - static_cast(other->DataPtr()), N, - static_cast(input->DataPtr()), 1, &beta, - static_cast(output->DataPtr()), N)); - }), - DataType::kFLOAT32) - DISPATCH_CASE(WRAP({ - CUBLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, 1, &alpha, other->DataPtr(), - CUDA_R_16BF, N, input->DataPtr(), CUDA_R_16BF, 1, &beta, - output->DataPtr(), CUDA_R_16BF, N, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT)); - }), - DataType::kBFLOAT16) - } + // This is a GEMM with k=1: C[N,M] = A[N,1] * B[1,M] + GemmParams p; + p.trans_a = CUBLAS_OP_N; + p.trans_b = CUBLAS_OP_N; + p.m = static_cast(N); + p.n = static_cast(M); + p.k = 1; + p.A = other->DataPtr(); + p.lda = static_cast(N); + p.B = input->DataPtr(); + p.ldb = 1; + p.C = output->DataPtr(); + p.ldc = static_cast(N); + p.alpha = 1.0f; + p.beta = 0.0f; + p.batch_count = 1; + p.input_dtype = dtype; + p.output_dtype = dtype; + p.blas_handle = GetCublasHandle(device); + + GemmCuda(p); return output; } @@ -95,63 +95,82 @@ std::tuple, std::shared_ptr> OuterBackward(const auto grad_input = std::make_shared(std::vector{M}, output_dtype, grad_output->GetDevice()); auto grad_other = std::make_shared(std::vector{N}, output_dtype, grad_output->GetDevice()); - DispatchFunc( - promoted_type, - [=]() { - grad_input->Fill(0); - grad_other->Fill(0); - }, - "CUDA OuterBackward"); - auto device = input->GetDevice(); - float alpha = 1.0f; - float beta = 0.0f; - cublasHandle_t handle = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) - ->cublas_handle(); switch (promoted_type) { - DISPATCH_CASE(WRAP({ - // grad_input[M, 1] = grad_output[M, N] × other[N, 1] - // y = grad_input[M] - // A = grad_output.T[N, M] - // x = other[N] - CUBLAS_CHECK(cublasSgemv(handle, CUBLAS_OP_T, N, M, &alpha, - static_cast(grad_output_promoted->DataPtr()), N, - static_cast(other_promoted->DataPtr()), 1, &beta, - static_cast(grad_input->DataPtr()), 1)); - - // grad_other[N, 1] = grad_output.T[N, M] × input[M, 1] - // y = grad_other[N] - // A = grad_output.T[N, M] - // x = input[M] - CUBLAS_CHECK(cublasSgemv(handle, CUBLAS_OP_N, N, M, &alpha, - static_cast(grad_output_promoted->DataPtr()), N, - static_cast(input_promoted->DataPtr()), 1, &beta, - static_cast(grad_other->DataPtr()), 1)); - }), - DataType::kFLOAT32) - DISPATCH_CASE( - // cublasgemv does not support bf16, use cublasGemmEx to workaround - WRAP({ - // grad_input[M, 1] = grad_output[M, N] × other[N, 1] - // grad_input.T[1, M] = other.T[1, N] × grad_output.T[N, M] - // C = grad_input.T[1, M] - // A = other.T[1, N] - // B = grad_output.T[N, M] - CUBLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, 1, M, N, &alpha, other_promoted->DataPtr(), - CUDA_R_16BF, 1, grad_output_promoted->DataPtr(), CUDA_R_16BF, N, &beta, - grad_input->DataPtr(), CUDA_R_32F, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - // grad_other[N, 1] = grad_output.T[N, M] × input[M, 1] - // grad_other.T[1, N] = input.T[1, M] × grad_output[M, N] - // C = grad_other.T[1, N] - // A = input.T[1, M] - // B = grad_output.T[N, M] - CUBLAS_CHECK(cublasGemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, 1, N, M, &alpha, input_promoted->DataPtr(), - CUDA_R_16BF, 1, grad_output_promoted->DataPtr(), CUDA_R_16BF, N, &beta, - grad_other->DataPtr(), CUDA_R_32F, 1, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); - }), - DataType::kBFLOAT16) + case DataType::kFLOAT32: { + // fp32: use cublasSgemv (specialized matrix-vector kernel, more efficient than GEMM for this shape) + // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda below. + float alpha = 1.0f, beta = 0.0f; + cublasHandle_t handle = GetCublasHandle(device); + + // grad_input[M] = grad_output[M, N] × other[N] + // y = grad_input[M], A = grad_output.T[N, M], x = other[N] + CUBLAS_CHECK(cublasSgemv(handle, CUBLAS_OP_T, N, M, &alpha, + static_cast(grad_output_promoted->DataPtr()), N, + static_cast(other_promoted->DataPtr()), 1, &beta, + static_cast(grad_input->DataPtr()), 1)); + + // grad_other[N] = grad_output.T[N, M] × input[M] + // y = grad_other[N], A = grad_output.T[N, M], x = input[M] + CUBLAS_CHECK(cublasSgemv(handle, CUBLAS_OP_N, N, M, &alpha, + static_cast(grad_output_promoted->DataPtr()), N, + static_cast(input_promoted->DataPtr()), 1, &beta, + static_cast(grad_other->DataPtr()), 1)); + break; + } + case DataType::kBFLOAT16: { + // bf16: cublasSgemv does not support bf16; use GemmCuda (GEMM with k=M or k=N). + + // grad_input[M] = grad_output[M, N] × other[N] + // grad_input.T[1, M] = other.T[1, N] × grad_output.T[N, M] + // C[1,M] = A[1,N] * B[N,M] + GemmParams p_input; + p_input.trans_a = CUBLAS_OP_N; + p_input.trans_b = CUBLAS_OP_N; + p_input.m = 1; + p_input.n = static_cast(M); + p_input.k = static_cast(N); + p_input.A = other_promoted->DataPtr(); + p_input.lda = 1; + p_input.B = grad_output_promoted->DataPtr(); + p_input.ldb = static_cast(N); + p_input.C = grad_input->DataPtr(); + p_input.ldc = 1; + p_input.alpha = 1.0f; + p_input.beta = 0.0f; + p_input.batch_count = 1; + p_input.input_dtype = promoted_type; + p_input.output_dtype = output_dtype; + p_input.blas_handle = GetCublasHandle(device); + GemmCuda(p_input); + + // grad_other[N] = grad_output.T[N, M] × input[M] + // grad_other.T[1, N] = input.T[1, M] × grad_output[M, N] + // C[1,N] = A[1,M] * B[M,N] (B stored as grad_output.T[N,M], so ldb=N, trans_b=T) + GemmParams p_other; + p_other.trans_a = CUBLAS_OP_N; + p_other.trans_b = CUBLAS_OP_T; + p_other.m = 1; + p_other.n = static_cast(N); + p_other.k = static_cast(M); + p_other.A = input_promoted->DataPtr(); + p_other.lda = 1; + p_other.B = grad_output_promoted->DataPtr(); + p_other.ldb = static_cast(N); + p_other.C = grad_other->DataPtr(); + p_other.ldc = 1; + p_other.alpha = 1.0f; + p_other.beta = 0.0f; + p_other.batch_count = 1; + p_other.input_dtype = promoted_type; + p_other.output_dtype = output_dtype; + p_other.blas_handle = GetCublasHandle(device); + GemmCuda(p_other); + break; + } + default: + LOG(FATAL) << "CUDA OuterBackward: unsupported dtype"; } return {grad_input, grad_other};