From 865b51c6b10a1c83bd593ef0f28286fcd103de25 Mon Sep 17 00:00:00 2001 From: cx Date: Thu, 28 May 2026 08:59:14 +0000 Subject: [PATCH] feat: add InfiniOps as optional kernel provider Wire InfiniOps in as a pluggable kernel provider keyed at the GEMM level: Dispatcher consults a per-key whitelist hook and routes registered ops to InfiniOps, falling back to the default CUDA kernel otherwise. linear, matmul and outer now invoke Gemm via Dispatcher rather than calling the cuBLAS wrapper directly, so InfiniOps Gemm transparently covers all three. --- .gitmodules | 3 + CMakeLists.txt | 34 +++++ .../core/kernel_provider/infiniops/adapter.h | 45 ++++++ .../core/kernel_provider/infiniops_registry.h | 50 ++++++ infini_train/include/dispatcher.h | 17 +++ infini_train/include/kernels/common/gemm.h | 48 ++++++ infini_train/include/tensor.h | 2 +- .../core/kernel_provider/infiniops/adapter.cc | 120 +++++++++++++++ .../kernel_provider/infiniops/cuda/handle.cc | 25 +++ .../kernel_provider/infiniops/elementwise.cc | 71 +++++++++ .../core/kernel_provider/infiniops/gemm.cc | 73 +++++++++ .../kernel_provider/infiniops_registry.cc | 51 +++++++ infini_train/src/kernels/cuda/common/gemm.cu | 27 +++- infini_train/src/kernels/cuda/common/gemm.cuh | 53 +------ infini_train/src/kernels/cuda/linear.cu | 142 +++++++++--------- infini_train/src/kernels/cuda/matmul.cu | 129 ++++++++-------- infini_train/src/kernels/cuda/outer.cu | 117 ++++++++------- infini_train/src/tensor.cc | 2 +- third_party/InfiniOps | 1 + 19 files changed, 767 insertions(+), 243 deletions(-) create mode 100644 infini_train/include/core/kernel_provider/infiniops/adapter.h create mode 100644 infini_train/include/core/kernel_provider/infiniops_registry.h create mode 100644 infini_train/include/kernels/common/gemm.h create mode 100644 infini_train/src/core/kernel_provider/infiniops/adapter.cc create mode 100644 infini_train/src/core/kernel_provider/infiniops/cuda/handle.cc create mode 100644 infini_train/src/core/kernel_provider/infiniops/elementwise.cc create mode 100644 infini_train/src/core/kernel_provider/infiniops/gemm.cc create mode 100644 infini_train/src/core/kernel_provider/infiniops_registry.cc create mode 160000 third_party/InfiniOps diff --git a/.gitmodules b/.gitmodules index d33d0a64..7b02a741 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "third_party/googletest"] path = third_party/googletest url = git@github.com:google/googletest.git +[submodule "third_party/InfiniOps"] + path = third_party/InfiniOps + url = git@github.com:InfiniTensor/InfiniOps.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c6da822..83d289d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ option(USE_CUDA "Support NVIDIA CUDA" OFF) option(PROFILE_MODE "ENABLE PROFILE MODE" OFF) option(USE_OMP "Use OpenMP as backend for Eigen" ON) option(USE_NCCL "Build project for distributed running" ON) +option(USE_INFINIOPS "Use InfiniOps as an optional kernel provider" OFF) option(BUILD_TEST "Build InfiniTrain tests" OFF) project(infini_train VERSION 0.5.0 LANGUAGES CXX) @@ -51,6 +52,32 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen) include_directories(${PROJECT_SOURCE_DIR}) +if(USE_INFINIOPS) + add_compile_definitions(USE_INFINIOPS=1) + + set(INFINIOPS_SOURCE_DIR "${PROJECT_SOURCE_DIR}/third_party/InfiniOps") + if(NOT EXISTS "${INFINIOPS_SOURCE_DIR}/CMakeLists.txt") + message(FATAL_ERROR + "USE_INFINIOPS=ON requires InfiniOps under third_party/InfiniOps. " + "Run: git submodule update --init third_party/InfiniOps") + endif() + + set(INFINIOPS_WITH_CPU OFF) + if(NOT USE_CUDA) + set(INFINIOPS_WITH_CPU ON) + endif() + + set(WITH_CPU ${INFINIOPS_WITH_CPU} CACHE BOOL "Enable InfiniOps CPU backend" FORCE) + set(WITH_NVIDIA ${USE_CUDA} CACHE BOOL "Enable InfiniOps NVIDIA backend" FORCE) + add_subdirectory(${INFINIOPS_SOURCE_DIR} ${CMAKE_BINARY_DIR}/third_party/InfiniOps EXCLUDE_FROM_ALL) + if(NOT TARGET infiniops) + message(FATAL_ERROR "InfiniOps third-party project did not define target `infiniops`") + endif() + if(NOT TARGET InfiniOps::infiniops) + add_library(InfiniOps::infiniops ALIAS infiniops) + endif() +endif() + if(PROFILE_MODE) add_compile_definitions(PROFILE_MODE=1) endif() @@ -62,9 +89,13 @@ endif() # Framework core sources (*.cc), excluding cpu kernels (they are built separately) file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc) list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*") +if(NOT USE_INFINIOPS) + list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/.*\.cc$") +endif() if(NOT USE_CUDA) list(FILTER SRC EXCLUDE REGEX ".*runtime/cuda/.*") list(FILTER SRC EXCLUDE REGEX ".*ccl/cuda/.*") + list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/cuda/.*") endif() if(NOT USE_NCCL) list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*") @@ -126,6 +157,9 @@ endif() # ------------------------------------------------------------------------------ add_library(infini_train STATIC ${SRC}) +if(USE_INFINIOPS) + target_link_libraries(infini_train PUBLIC InfiniOps::infiniops) +endif() target_link_libraries(infini_train PUBLIC glog diff --git a/infini_train/include/core/kernel_provider/infiniops/adapter.h b/infini_train/include/core/kernel_provider/infiniops/adapter.h new file mode 100644 index 00000000..5cd2079b --- /dev/null +++ b/infini_train/include/core/kernel_provider/infiniops/adapter.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "data_type.h" +#include "tensor.h" + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" + +namespace infini_train { +class Tensor; +} // namespace infini_train + +namespace infini_train::core { +class Stream; +} // namespace infini_train::core + +namespace infini_train::kernel_provider::infiniops { + +infini::ops::DataType ToOpsDataType(DataType dtype); + +infini::ops::Device ToOpsDevice(const Device &device); + +std::mutex &InfiniOpsCallMutex(); + +using HandleFactory = infini::ops::Handle (*)(const Device &device, core::Stream *stream); + +void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory); + +infini::ops::Handle GetHandle(const Device &device); + +infini::ops::Tensor ToOpsTensor(const std::shared_ptr &tensor); + +infini::ops::Tensor ToOpsTensor(void *data, const std::vector &dims, DataType dtype, const Device &device); + +infini::ops::Tensor ToOpsTensor(void *data, const std::vector &dims, DataType dtype, const Device &device, + const std::vector &strides); + +} // namespace infini_train::kernel_provider::infiniops diff --git a/infini_train/include/core/kernel_provider/infiniops_registry.h b/infini_train/include/core/kernel_provider/infiniops_registry.h new file mode 100644 index 00000000..4dda9578 --- /dev/null +++ b/infini_train/include/core/kernel_provider/infiniops_registry.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/device.h" +#include "infini_train/include/dispatcher.h" + +namespace infini_train::kernel_provider { + +using KeyT = std::pair; + +class InfiniOpsRegistry { +public: + static InfiniOpsRegistry &Instance() { + static InfiniOpsRegistry instance; + return instance; + } + + const KernelFunction *Lookup(const std::string &kernel_name) const { + auto it = name_to_kernel_map_.find(kernel_name); + return it == name_to_kernel_map_.end() ? nullptr : &it->second; + } + + template void Register(const std::string &kernel_name, FuncT &&kernel) { + CHECK(!name_to_kernel_map_.contains(kernel_name)) << "InfiniOps kernel already registered: " << kernel_name; + name_to_kernel_map_.emplace(kernel_name, kernel); + } + +private: + std::map name_to_kernel_map_; +}; + +// Bridge functions used by Dispatcher::GetKernel. Implemented in +// infiniops_registry.cc; declared here for users that already include +// the full registry header (e.g. unit tests). +bool InfiniOpsEnabled(); +bool InfiniOpsEnabled(const KeyT &key); +const KernelFunction *LookupInfiniOpsKernel(const KeyT &key); + +} // namespace infini_train::kernel_provider + +#define REGISTER_INFINIOPS_KERNEL(kernel_name, kernel_func) \ + static const bool _##kernel_name##_infiniops_registered##__COUNTER__ = []() { \ + infini_train::kernel_provider::InfiniOpsRegistry::Instance().Register(#kernel_name, kernel_func); \ + return true; \ + }(); diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 638df76a..05c6a678 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -47,6 +48,11 @@ class KernelFunction { void *func_ptr_ = nullptr; }; +namespace kernel_provider { +bool InfiniOpsEnabled(const std::pair &key); +const KernelFunction *LookupInfiniOpsKernel(const std::pair &key); +} // namespace kernel_provider + class Dispatcher { public: using KeyT = std::pair; @@ -57,6 +63,17 @@ class Dispatcher { } const KernelFunction &GetKernel(KeyT key) const { + if (kernel_provider::InfiniOpsEnabled(key)) { + if (const auto *kernel = kernel_provider::LookupInfiniOpsKernel(key)) { +#ifdef PROFILE_MODE + SetProfileContext(key.second, key.first); +#endif + return *kernel; + } + LOG(WARNING) << "InfiniOps kernel enabled but not registered: " << key.second + << " on device: " << static_cast(key.first) << "; falling back to default kernel"; + } + CHECK(key_to_kernel_map_.contains(key)) << "Kernel not found: " << key.second << " on device: " << static_cast(key.first); #ifdef PROFILE_MODE diff --git a/infini_train/include/kernels/common/gemm.h b/infini_train/include/kernels/common/gemm.h new file mode 100644 index 00000000..4af6a511 --- /dev/null +++ b/infini_train/include/kernels/common/gemm.h @@ -0,0 +1,48 @@ +#pragma once + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" + +namespace infini_train::kernels { + +enum class GemmTranspose : int { + kNoTranspose = 0, + kTranspose = 1, +}; + +/** + * Parameter bundle for a single GEMM call: + * C = alpha * op(A) * op(B) + beta * C + * + * batch_count == 1 describes a non-batched GEMM. batch_count > 1 describes a + * strided-batched GEMM. When batch_count == 1, stride_a/b/c are unused and must + * be left at 0. + */ +struct GemmParams { + GemmTranspose trans_a = GemmTranspose::kNoTranspose; + GemmTranspose trans_b = GemmTranspose::kNoTranspose; + + int m = 0; // rows of op(A) and C + int n = 0; // cols of op(B) and C + int k = 0; // cols of op(A) == rows of op(B) + + const void *A = nullptr; + int lda = 0; + const void *B = nullptr; + int ldb = 0; + void *C = nullptr; + int ldc = 0; + + float alpha = 1.0f; + float beta = 0.0f; + + int batch_count = 1; + long long stride_a = 0; + long long stride_b = 0; + long long stride_c = 0; + + DataType input_dtype; + DataType output_dtype; +}; + +} // namespace infini_train::kernels diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 12f45f57..734ed018 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -139,7 +139,7 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr View(const std::vector &dims); std::shared_ptr Contiguous(); // FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor - // class before this can be implemented correctly. The guard in elementwise.cu ensures + // class before this can be implemented correctly. The elementwise broadcast guard ensures // non-contiguous tensors fall back to the broadcast path until this is resolved. bool IsContiguous() const; std::shared_ptr Flatten(int64_t start = 0, int64_t end = -1); diff --git a/infini_train/src/core/kernel_provider/infiniops/adapter.cc b/infini_train/src/core/kernel_provider/infiniops/adapter.cc new file mode 100644 index 00000000..cbe62c83 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops/adapter.cc @@ -0,0 +1,120 @@ +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/runtime/device_guard.h" + +namespace infini_train::kernel_provider::infiniops { + +namespace { + +inline const std::unordered_map kOpsDataTypeMap = { + {DataType::kFLOAT16, infini::ops::DataType::kFloat16}, {DataType::kBFLOAT16, infini::ops::DataType::kBFloat16}, + {DataType::kFLOAT32, infini::ops::DataType::kFloat32}, {DataType::kFLOAT64, infini::ops::DataType::kFloat64}, + {DataType::kINT8, infini::ops::DataType::kInt8}, {DataType::kINT16, infini::ops::DataType::kInt16}, + {DataType::kINT32, infini::ops::DataType::kInt32}, {DataType::kINT64, infini::ops::DataType::kInt64}, + {DataType::kUINT8, infini::ops::DataType::kUInt8}, {DataType::kUINT16, infini::ops::DataType::kUInt16}, + {DataType::kUINT32, infini::ops::DataType::kUInt32}, {DataType::kUINT64, infini::ops::DataType::kUInt64}, +}; + +inline const std::unordered_map kOpsDeviceTypeMap = { + {Device::DeviceType::kCUDA, infini::ops::Device::Type::kNvidia}, + {Device::DeviceType::kCPU, infini::ops::Device::Type::kCpu}, +}; + +std::map &HandleFactories() { + static std::map factories; + return factories; +} + +} // namespace + +void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory) { + CHECK(factory != nullptr); + auto &factories = HandleFactories(); + CHECK(!factories.contains(type)) << "InfiniOps handle factory already registered for device type " + << static_cast(type); + factories.emplace(type, factory); +} + +infini::ops::Handle GetHandle(const Device &device) { + auto &factories = HandleFactories(); + auto it = factories.find(device.type()); + CHECK(it != factories.end()) << "InfiniOps handle factory is not registered for device type " + << static_cast(device.type()); + + auto *stream = core::GetDeviceGuardImpl(device.type())->GetStream(device); + return it->second(device, stream); +} + +infini::ops::DataType ToOpsDataType(DataType dtype) { + auto it = kOpsDataTypeMap.find(dtype); + if (it == kOpsDataTypeMap.end()) { + LOG(FATAL) << "Unsupported DataType for InfiniOps: " << static_cast(dtype); + __builtin_unreachable(); + } + return it->second; +} + +infini::ops::Device ToOpsDevice(const Device &device) { + auto it = kOpsDeviceTypeMap.find(device.type()); + if (it == kOpsDeviceTypeMap.end()) { + LOG(FATAL) << "Unsupported DeviceType for InfiniOps: " << static_cast(device.type()); + __builtin_unreachable(); + } + return {it->second, device.index()}; +} + +std::mutex &InfiniOpsCallMutex() { + static std::mutex mutex; + return mutex; +} + +namespace { +infini::ops::Tensor::Strides ComputeContiguousStrides(const std::vector &dims) { + infini::ops::Tensor::Strides strides(dims.size()); + if (dims.empty()) { + return strides; + } + strides.back() = 1; + for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * static_cast(dims[i + 1]); + } + return strides; +} + +infini::ops::Tensor::Shape ToShape(const std::vector &dims) { + infini::ops::Tensor::Shape shape(dims.size()); + for (size_t i = 0; i < dims.size(); ++i) { shape[i] = static_cast(dims[i]); } + return shape; +} + +infini::ops::Tensor::Strides ToStrides(const std::vector &strides) { + infini::ops::Tensor::Strides ops_strides(strides.size()); + for (size_t i = 0; i < strides.size(); ++i) { + ops_strides[i] = static_cast(strides[i]); + } + return ops_strides; +} +} // namespace + +infini::ops::Tensor ToOpsTensor(const std::shared_ptr &tensor) { + const auto &dims = tensor->Dims(); + return {tensor->DataPtr(), ToShape(dims), ToOpsDataType(tensor->Dtype()), ToOpsDevice(tensor->GetDevice()), + ComputeContiguousStrides(dims)}; +} + +infini::ops::Tensor ToOpsTensor(void *data, const std::vector &dims, DataType dtype, const Device &device) { + return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ComputeContiguousStrides(dims)}; +} + +infini::ops::Tensor ToOpsTensor(void *data, const std::vector &dims, DataType dtype, const Device &device, + const std::vector &strides) { + CHECK_EQ(dims.size(), strides.size()); + return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ToStrides(strides)}; +} + +} // namespace infini_train::kernel_provider::infiniops diff --git a/infini_train/src/core/kernel_provider/infiniops/cuda/handle.cc b/infini_train/src/core/kernel_provider/infiniops/cuda/handle.cc new file mode 100644 index 00000000..70a44073 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops/cuda/handle.cc @@ -0,0 +1,25 @@ +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" + +#include "glog/logging.h" + +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernel_provider::infiniops { +namespace { + +infini::ops::Handle MakeCudaHandle(const Device &, core::Stream *stream) { + auto *cuda_stream = dynamic_cast(stream); + CHECK_NOTNULL(cuda_stream); + + infini::ops::Handle handle; + handle.set_stream(static_cast(cuda_stream->cuda_stream())); + return handle; +} + +const bool kCudaHandleFactoryRegistered = []() { + RegisterHandleFactory(Device::DeviceType::kCUDA, MakeCudaHandle); + return true; +}(); + +} // namespace +} // namespace infini_train::kernel_provider::infiniops diff --git a/infini_train/src/core/kernel_provider/infiniops/elementwise.cc b/infini_train/src/core/kernel_provider/infiniops/elementwise.cc new file mode 100644 index 00000000..77bed935 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops/elementwise.cc @@ -0,0 +1,71 @@ +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" +#include "infini_train/include/core/kernel_provider/infiniops_registry.h" +#include "infini_train/include/tensor.h" + +#include + +namespace infini_train::kernel_provider::infiniops { +namespace { + +std::vector ComputeBroadcastStrides(const std::vector &dims, const std::vector &out_dims) { + CHECK_LE(dims.size(), out_dims.size()); + + std::vector strides(dims.size()); + if (!dims.empty()) { + strides.back() = 1; + for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { strides[i] = strides[i + 1] * dims[i + 1]; } + } + + const size_t pad = out_dims.size() - dims.size(); + std::vector out_strides(out_dims.size(), 0); + for (size_t i = 0; i < dims.size(); ++i) { + const int64_t dim = dims[i]; + const int64_t out_dim = out_dims[pad + i]; + CHECK(dim == out_dim || dim == 1) << "InfiniOps Add broadcast shape mismatch"; + out_strides[pad + i] = dim == 1 ? 0 : strides[i]; + } + return out_strides; +} + +infini::ops::Tensor ToBroadcastOpsTensor(const std::shared_ptr &tensor, const std::vector &out_dims, + DataType dtype) { + const auto strides = ComputeBroadcastStrides(tensor->Dims(), out_dims); + return ToOpsTensor(tensor->DataPtr(), out_dims, dtype, tensor->GetDevice(), strides); +} + +} // namespace + +std::shared_ptr AddForward(const std::shared_ptr &a, const std::shared_ptr &b) { + CHECK_GE(a->NumElements(), b->NumElements()); + CHECK_EQ(a->NumElements() % b->NumElements(), 0); + + auto a_dtype = a->Dtype(); + auto b_dtype = b->Dtype(); + DataType promoted_type = PromoteDataTypes(a_dtype, b_dtype); + + auto a_promoted = a_dtype == promoted_type ? a : std::make_shared(a->To(promoted_type)); + auto b_promoted = b_dtype == promoted_type ? b : std::make_shared(b->To(promoted_type)); + + auto output = std::make_shared(a->Dims(), promoted_type, a->GetDevice()); + + auto handle = GetHandle(a->GetDevice()); + auto a_ops = ToBroadcastOpsTensor(a_promoted, output->Dims(), promoted_type); + auto b_ops = ToBroadcastOpsTensor(b_promoted, output->Dims(), promoted_type); + auto c_ops = ToOpsTensor(output); + + { + std::lock_guard lock(InfiniOpsCallMutex()); + infini::ops::functional::Add(handle, {}, a_ops, b_ops, c_ops); + } + return output; +} + +} // namespace infini_train::kernel_provider::infiniops + +REGISTER_INFINIOPS_KERNEL(AddForward, infini_train::kernel_provider::infiniops::AddForward) diff --git a/infini_train/src/core/kernel_provider/infiniops/gemm.cc b/infini_train/src/core/kernel_provider/infiniops/gemm.cc new file mode 100644 index 00000000..2b04f549 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops/gemm.cc @@ -0,0 +1,73 @@ +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" +#include "infini_train/include/core/kernel_provider/infiniops_registry.h" +#include "infini_train/include/kernels/common/gemm.h" + +#include + +namespace infini_train::kernel_provider::infiniops { +namespace { + +int ToInfiniOpsTrans(kernels::GemmTranspose op) { + switch (op) { + case kernels::GemmTranspose::kNoTranspose: + return 0; + case kernels::GemmTranspose::kTranspose: + return 1; + } + LOG(FATAL) << "InfiniOps Gemm: unsupported transpose flag " << static_cast(op); + return 0; // unreachable +} + +std::vector MatrixShape(int batch_count, int rows, int cols) { + if (batch_count > 1) { + return {batch_count, rows, cols}; + } + return {rows, cols}; +} + +std::vector RowMajorStrides(int batch_count, int ld, long long batch_stride) { + if (batch_count > 1) { + return {batch_stride, ld, 1}; + } + return {ld, 1}; +} + +infini::ops::Tensor MakeRowMajorTransposeView(const void *data, int batch_count, int column_major_rows, + int column_major_cols, int ld, long long batch_stride, DataType dtype, + const Device &device) { + return ToOpsTensor(const_cast(data), MatrixShape(batch_count, column_major_cols, column_major_rows), dtype, + device, RowMajorStrides(batch_count, ld, batch_stride)); +} + +} // namespace + +void Gemm(Device device, kernels::GemmParams p) { + CHECK_GE(p.batch_count, 1); + + const bool trans_a = p.trans_a == kernels::GemmTranspose::kTranspose; + const bool trans_b = p.trans_b == kernels::GemmTranspose::kTranspose; + + const int a_rows = trans_a ? p.k : p.m; + const int a_cols = trans_a ? p.m : p.k; + const int b_rows = trans_b ? p.n : p.k; + const int b_cols = trans_b ? p.k : p.n; + + auto handle = GetHandle(device); + auto a = MakeRowMajorTransposeView(p.A, p.batch_count, a_rows, a_cols, p.lda, p.stride_a, p.input_dtype, device); + auto b = MakeRowMajorTransposeView(p.B, p.batch_count, b_rows, b_cols, p.ldb, p.stride_b, p.input_dtype, device); + auto c = MakeRowMajorTransposeView(p.C, p.batch_count, p.m, p.n, p.ldc, p.stride_c, p.output_dtype, device); + + std::lock_guard lock(InfiniOpsCallMutex()); + infini::ops::functional::Gemm(handle, {}, b, a, std::optional{p.alpha}, std::optional{p.beta}, + std::optional{ToInfiniOpsTrans(p.trans_b)}, + std::optional{ToInfiniOpsTrans(p.trans_a)}, c); +} + +} // namespace infini_train::kernel_provider::infiniops + +REGISTER_INFINIOPS_KERNEL(Gemm, infini_train::kernel_provider::infiniops::Gemm) diff --git a/infini_train/src/core/kernel_provider/infiniops_registry.cc b/infini_train/src/core/kernel_provider/infiniops_registry.cc new file mode 100644 index 00000000..b5862d87 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops_registry.cc @@ -0,0 +1,51 @@ +#include "infini_train/include/core/kernel_provider/infiniops_registry.h" + +#include +#include +#include +#include + +namespace infini_train::kernel_provider { + +namespace { + +const std::set kEnabledKernelWhitelist = { + "Gemm", + "AddForward", +}; + +} // namespace + +bool InfiniOpsEnabled() { +#ifdef USE_INFINIOPS + return true; +#else + return false; +#endif +} + +bool InfiniOpsEnabled(const KeyT &key) { + return InfiniOpsEnabled() && key.first == Device::DeviceType::kCUDA && kEnabledKernelWhitelist.contains(key.second); +} + +const KernelFunction *LookupInfiniOpsKernel(const KeyT &key) { + const auto *kernel = InfiniOpsRegistry::Instance().Lookup(key.second); + + static std::mutex log_mutex; + static std::set logged_use; + static std::set logged_fallback; + + const auto log_key = std::to_string(static_cast(key.first)) + ":" + key.second; + std::lock_guard lock(log_mutex); + if (kernel != nullptr) { + if (logged_use.insert(log_key).second) { + std::cout << "[InfiniOps] use " << key.second << " on device " << static_cast(key.first) << std::endl; + } + } else if (logged_fallback.insert(log_key).second) { + std::cout << "[InfiniOps] fallback " << key.second << " on device " << static_cast(key.first) << std::endl; + } + + return kernel; +} + +} // namespace infini_train::kernel_provider diff --git a/infini_train/src/kernels/cuda/common/gemm.cu b/infini_train/src/kernels/cuda/common/gemm.cu index b30a0545..80f84970 100644 --- a/infini_train/src/kernels/cuda/common/gemm.cu +++ b/infini_train/src/kernels/cuda/common/gemm.cu @@ -7,6 +7,7 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/datatype.h" +#include "infini_train/include/dispatcher.h" #include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" @@ -14,6 +15,17 @@ namespace infini_train::kernels::cuda { namespace { +cublasOperation_t ToCublasOperation(GemmTranspose op) { + switch (op) { + case GemmTranspose::kNoTranspose: + return CUBLAS_OP_N; + case GemmTranspose::kTranspose: + return CUBLAS_OP_T; + } + LOG(FATAL) << "Gemm: unsupported transpose flag " << static_cast(op); + return CUBLAS_OP_N; // unreachable +} + cudaDataType_t ToCudaDataType(DataType dt) { switch (dt) { case DataType::kFLOAT32: @@ -23,14 +35,14 @@ cudaDataType_t ToCudaDataType(DataType dt) { case DataType::kFLOAT16: return CUDA_R_16F; default: - LOG(FATAL) << "GemmCuda: unsupported DataType " << static_cast(dt); + LOG(FATAL) << "Gemm: unsupported DataType " << static_cast(dt); return CUDA_R_32F; // unreachable } } } // namespace -void GemmCuda(const Device &device, const GemmParams &p) { +void Gemm(Device device, GemmParams p) { const cublasHandle_t blas_handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); @@ -43,6 +55,8 @@ void GemmCuda(const Device &device, const GemmParams &p) { DCHECK_EQ(p.stride_c, 0LL); } + const cublasOperation_t trans_a = ToCublasOperation(p.trans_a); + const cublasOperation_t trans_b = ToCublasOperation(p.trans_b); const cudaDataType_t type_a = ToCudaDataType(p.input_dtype); const cudaDataType_t type_b = ToCudaDataType(p.input_dtype); const cudaDataType_t type_c = ToCudaDataType(p.output_dtype); @@ -51,10 +65,10 @@ void GemmCuda(const Device &device, const GemmParams &p) { const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; if (p.batch_count == 1) { - CUBLAS_CHECK(cublasGemmEx(blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.B, + CUBLAS_CHECK(cublasGemmEx(blas_handle, trans_a, trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.B, type_b, p.ldb, &p.beta, p.C, type_c, p.ldc, compute_type, CUBLAS_GEMM_DEFAULT)); } else { - CUBLAS_CHECK(cublasGemmStridedBatchedEx(blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, + CUBLAS_CHECK(cublasGemmStridedBatchedEx(blas_handle, trans_a, trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.stride_a, p.B, type_b, p.ldb, p.stride_b, &p.beta, p.C, type_c, p.ldc, p.stride_c, p.batch_count, compute_type, CUBLAS_GEMM_DEFAULT)); } @@ -64,7 +78,10 @@ void SgemvCuda(const Device &device, const SgemvParams &p) { const cublasHandle_t blas_handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); - CUBLAS_CHECK(cublasSgemv(blas_handle, p.trans, p.m, p.n, &p.alpha, p.A, p.lda, p.x, p.incx, &p.beta, p.y, p.incy)); + CUBLAS_CHECK(cublasSgemv(blas_handle, ToCublasOperation(p.trans), p.m, p.n, &p.alpha, p.A, p.lda, p.x, p.incx, + &p.beta, p.y, p.incy)); } } // namespace infini_train::kernels::cuda + +REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, Gemm, infini_train::kernels::cuda::Gemm) diff --git a/infini_train/src/kernels/cuda/common/gemm.cuh b/infini_train/src/kernels/cuda/common/gemm.cuh index 485b633d..1980a441 100644 --- a/infini_train/src/kernels/cuda/common/gemm.cuh +++ b/infini_train/src/kernels/cuda/common/gemm.cuh @@ -1,58 +1,15 @@ #pragma once -#include - -#include "infini_train/include/datatype.h" -#include "infini_train/include/device.h" +#include "infini_train/include/kernels/common/gemm.h" namespace infini_train::kernels::cuda { /** - * Parameter bundle for a single GEMM call: - * C = alpha * op(A) * op(B) + beta * C - * - * batch_count == 1 → non-batched path (cublasGemmEx) - * batch_count > 1 → strided-batched (cublasGemmStridedBatchedEx) + * Execute the GEMM described by `p` via the CUDA backend. * - * When batch_count == 1, stride_a/b/c are unused and must be left at 0. - */ -struct GemmParams { - cublasOperation_t trans_a = CUBLAS_OP_N; - cublasOperation_t trans_b = CUBLAS_OP_N; - - int m = 0; // rows of op(A) and C - int n = 0; // cols of op(B) and C - int k = 0; // cols of op(A) == rows of op(B) - - const void *A = nullptr; - int lda = 0; - const void *B = nullptr; - int ldb = 0; - void *C = nullptr; - int ldc = 0; - - float alpha = 1.0f; - float beta = 0.0f; - - // batch_count=1: non-batched (Linear path); stride_a/b/c must be 0 - // batch_count>1: strided-batched (Matmul path) - int batch_count = 1; - long long stride_a = 0; - long long stride_b = 0; - long long stride_c = 0; - - DataType input_dtype; // dtype of A and B - DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out) -}; - -/** - * Execute the GEMM described by `p` via cuBLAS. - * Dispatches to cublasGemmEx (batch_count==1) or - * cublasGemmStridedBatchedEx (batch_count>1). - * Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision. - * Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)). + * Arguments are passed by value to match Dispatcher function erasure reliably. */ -void GemmCuda(const Device &device, const GemmParams &p); +void Gemm(Device device, GemmParams p); /** * Parameter bundle for a single SGEMV call (fp32 only): @@ -62,7 +19,7 @@ void GemmCuda(const Device &device, const GemmParams &p); * where m_phys and n_phys are the physical (pre-transpose) row/col counts of A. */ struct SgemvParams { - cublasOperation_t trans = CUBLAS_OP_N; + GemmTranspose trans = GemmTranspose::kNoTranspose; int m = 0; int n = 0; const float *A = nullptr; diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index b66c82d0..2725053e 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -83,10 +83,10 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons } // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector). - // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda. + // cublasSgemv does not support bf16, so bf16 falls through to Gemm. if (bs == 1 && dtype == DataType::kFLOAT32) { SgemvCuda(device, SgemvParams{ - .trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N, + .trans = transpose ? GemmTranspose::kTranspose : GemmTranspose::kNoTranspose, .m = static_cast(transpose ? in_features : out_features), .n = static_cast(transpose ? out_features : in_features), .A = static_cast(weight->DataPtr()), @@ -110,24 +110,26 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons // C = output.T[out_features, bs] // A = weight.T[out_features, in_features] // B = input.T[in_features, bs] - GemmCuda(device, GemmParams{ - .trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N, - .trans_b = CUBLAS_OP_N, - .m = static_cast(out_features), - .n = static_cast(bs), - .k = static_cast(in_features), - .A = weight->DataPtr(), - .lda = static_cast(transpose ? in_features : out_features), - .B = input->DataPtr(), - .ldb = static_cast(in_features), - .C = output->DataPtr(), - .ldc = static_cast(out_features), - .alpha = 1.0f, - .beta = 1.0f, // bias already written into output; beta=1 accumulates - .batch_count = 1, - .input_dtype = dtype, - .output_dtype = dtype, - }); + Dispatcher::Instance().Call( + {device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = transpose ? GemmTranspose::kTranspose : GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(out_features), + .n = static_cast(bs), + .k = static_cast(in_features), + .A = weight->DataPtr(), + .lda = static_cast(transpose ? in_features : out_features), + .B = input->DataPtr(), + .ldb = static_cast(in_features), + .C = output->DataPtr(), + .ldc = static_cast(out_features), + .alpha = 1.0f, + .beta = 1.0f, // bias already written into output; beta=1 accumulates + .batch_count = 1, + .input_dtype = dtype, + .output_dtype = dtype, + }); } return output; @@ -172,19 +174,20 @@ std::shared_ptr LinearBackwardInput(const std::shared_ptr &weigh auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector). - // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda. + // cublasSgemv does not support bf16, so bf16 falls through to Gemm. if (bs == 1 && compute_dtype == DataType::kFLOAT32) { - SgemvCuda(grad_output->GetDevice(), SgemvParams{ - .trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T, - .m = static_cast(transpose ? in_features : out_features), - .n = static_cast(transpose ? out_features : in_features), - .A = static_cast(weight->DataPtr()), - .lda = static_cast(transpose ? in_features : out_features), - .x = static_cast(grad_output_promoted->DataPtr()), - .y = static_cast(grad_input->DataPtr()), - .alpha = 1.0f, - .beta = 0.0f, - }); + SgemvCuda(grad_output->GetDevice(), + SgemvParams{ + .trans = transpose ? GemmTranspose::kNoTranspose : GemmTranspose::kTranspose, + .m = static_cast(transpose ? in_features : out_features), + .n = static_cast(transpose ? out_features : in_features), + .A = static_cast(weight->DataPtr()), + .lda = static_cast(transpose ? in_features : out_features), + .x = static_cast(grad_output_promoted->DataPtr()), + .y = static_cast(grad_input->DataPtr()), + .alpha = 1.0f, + .beta = 0.0f, + }); } else { // - if transpose: // weight is [out_features, in_features] here @@ -199,24 +202,26 @@ std::shared_ptr LinearBackwardInput(const std::shared_ptr &weigh // C = d_input.T[in_features, bs] // A = weight.T[out_features, in_features] // B = d_output.T[out_features, bs] - GemmCuda(grad_output->GetDevice(), GemmParams{ - .trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T, - .trans_b = CUBLAS_OP_N, - .m = static_cast(in_features), - .n = static_cast(bs), - .k = static_cast(out_features), - .A = weight->DataPtr(), - .lda = static_cast(transpose ? in_features : out_features), - .B = grad_output_promoted->DataPtr(), - .ldb = static_cast(out_features), - .C = grad_input->DataPtr(), - .ldc = static_cast(in_features), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = compute_dtype, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call( + {grad_output->GetDevice().type(), "Gemm"}, grad_output->GetDevice(), + GemmParams{ + .trans_a = transpose ? GemmTranspose::kNoTranspose : GemmTranspose::kTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(in_features), + .n = static_cast(bs), + .k = static_cast(out_features), + .A = weight->DataPtr(), + .lda = static_cast(transpose ? in_features : out_features), + .B = grad_output_promoted->DataPtr(), + .ldb = static_cast(out_features), + .C = grad_input->DataPtr(), + .ldc = static_cast(in_features), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = compute_dtype, + .output_dtype = output_dtype, + }); } return grad_input; @@ -258,24 +263,25 @@ std::shared_ptr LinearBackwardWeight(const std::shared_ptr &inpu const int lda = static_cast(transpose ? in_features : out_features); const int ldb = static_cast(transpose ? out_features : in_features); - GemmCuda(grad_output->GetDevice(), GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_T, - .m = static_cast(transpose ? in_features : out_features), - .n = static_cast(transpose ? out_features : in_features), - .k = static_cast(bs), - .A = a, - .lda = lda, - .B = b, - .ldb = ldb, - .C = grad_weight->DataPtr(), - .ldc = static_cast(transpose ? in_features : out_features), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = compute_dtype, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({grad_output->GetDevice().type(), "Gemm"}, grad_output->GetDevice(), + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kTranspose, + .m = static_cast(transpose ? in_features : out_features), + .n = static_cast(transpose ? out_features : in_features), + .k = static_cast(bs), + .A = a, + .lda = lda, + .B = b, + .ldb = ldb, + .C = grad_weight->DataPtr(), + .ldc = static_cast(transpose ? in_features : out_features), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = compute_dtype, + .output_dtype = output_dtype, + }); return grad_weight; } diff --git a/infini_train/src/kernels/cuda/matmul.cu b/infini_train/src/kernels/cuda/matmul.cu index 7e301039..b1c6e381 100644 --- a/infini_train/src/kernels/cuda/matmul.cu +++ b/infini_train/src/kernels/cuda/matmul.cu @@ -48,27 +48,28 @@ std::shared_ptr MatmulForward(const std::shared_ptr &input, cons // A = other.T[*, n, k] // B = input.T[*, k, m] // NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere) - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_N, - .m = static_cast(n), - .n = static_cast(m), - .k = static_cast(k), - .A = other->DataPtr(), - .lda = static_cast(n), - .B = input->DataPtr(), - .ldb = static_cast(k), - .C = output->DataPtr(), - .ldc = static_cast(n), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = static_cast(bs), - .stride_a = bs > 1 ? n * k : 0, - .stride_b = bs > 1 ? k * m : 0, - .stride_c = bs > 1 ? m * n : 0, - .input_dtype = dtype, - .output_dtype = dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(n), + .n = static_cast(m), + .k = static_cast(k), + .A = other->DataPtr(), + .lda = static_cast(n), + .B = input->DataPtr(), + .ldb = static_cast(k), + .C = output->DataPtr(), + .ldc = static_cast(n), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = static_cast(bs), + .stride_a = bs > 1 ? n * k : 0, + .stride_b = bs > 1 ? k * m : 0, + .stride_c = bs > 1 ? m * n : 0, + .input_dtype = dtype, + .output_dtype = dtype, + }); return output; } @@ -118,27 +119,28 @@ std::shared_ptr MatmulBackwardInput(const std::shared_ptr &other // C = grad_input.T[*, k, m] // A = other.T[*, n, k] // B = grad_output.T[*, n, m] - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_T, - .trans_b = CUBLAS_OP_N, - .m = static_cast(k), - .n = static_cast(m), - .k = static_cast(n), - .A = other->DataPtr(), - .lda = static_cast(n), - .B = grad_output_promoted->DataPtr(), - .ldb = static_cast(n), - .C = grad_input->DataPtr(), - .ldc = static_cast(k), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = static_cast(bs), - .stride_a = bs > 1 ? k * n : 0, - .stride_b = bs > 1 ? n * m : 0, - .stride_c = bs > 1 ? m * k : 0, - .input_dtype = compute_dtype, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(k), + .n = static_cast(m), + .k = static_cast(n), + .A = other->DataPtr(), + .lda = static_cast(n), + .B = grad_output_promoted->DataPtr(), + .ldb = static_cast(n), + .C = grad_input->DataPtr(), + .ldc = static_cast(k), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = static_cast(bs), + .stride_a = bs > 1 ? k * n : 0, + .stride_b = bs > 1 ? n * m : 0, + .stride_c = bs > 1 ? m * k : 0, + .input_dtype = compute_dtype, + .output_dtype = output_dtype, + }); return grad_input; } @@ -187,27 +189,28 @@ std::shared_ptr MatmulBackwardOther(const std::shared_ptr &input // C = grad_other.T[*, n, k] // A = grad_output.T[*, n, m] // B = input.T[*, k, m] - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_T, - .m = static_cast(n), - .n = static_cast(k), - .k = static_cast(m), - .A = grad_output_promoted->DataPtr(), - .lda = static_cast(n), - .B = input1->DataPtr(), - .ldb = static_cast(k), - .C = grad_other->DataPtr(), - .ldc = static_cast(n), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = static_cast(bs), - .stride_a = bs > 1 ? n * m : 0, - .stride_b = bs > 1 ? k * m : 0, - .stride_c = bs > 1 ? n * k : 0, - .input_dtype = compute_dtype, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kTranspose, + .m = static_cast(n), + .n = static_cast(k), + .k = static_cast(m), + .A = grad_output_promoted->DataPtr(), + .lda = static_cast(n), + .B = input1->DataPtr(), + .ldb = static_cast(k), + .C = grad_other->DataPtr(), + .ldc = static_cast(n), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = static_cast(bs), + .stride_a = bs > 1 ? n * m : 0, + .stride_b = bs > 1 ? k * m : 0, + .stride_c = bs > 1 ? n * k : 0, + .input_dtype = compute_dtype, + .output_dtype = output_dtype, + }); return grad_other; } diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index 4a6c8ab3..0b8252b7 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -39,24 +39,25 @@ std::shared_ptr OuterForward(const std::shared_ptr &input, const // output[M, N] = input[M, 1] * other.T[1, N] // output.T[N, M] = other[N, 1] * input.T[1, M] // This is a GEMM with k=1: C[N,M] = A[N,1] * B[1,M] - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_N, - .m = static_cast(N), - .n = static_cast(M), - .k = 1, - .A = other->DataPtr(), - .lda = static_cast(N), - .B = input->DataPtr(), - .ldb = 1, - .C = output->DataPtr(), - .ldc = static_cast(N), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = dtype, - .output_dtype = dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(N), + .n = static_cast(M), + .k = 1, + .A = other->DataPtr(), + .lda = static_cast(N), + .B = input->DataPtr(), + .ldb = 1, + .C = output->DataPtr(), + .ldc = static_cast(N), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = dtype, + .output_dtype = dtype, + }); return output; } @@ -98,7 +99,7 @@ std::tuple, std::shared_ptr> OuterBackward(const case DataType::kFLOAT32: { // grad_input[M] = grad_output[M, N] × other[N] SgemvCuda(device, SgemvParams{ - .trans = CUBLAS_OP_T, + .trans = GemmTranspose::kTranspose, .m = static_cast(N), .n = static_cast(M), .A = static_cast(grad_output_promoted->DataPtr()), @@ -109,7 +110,7 @@ std::tuple, std::shared_ptr> OuterBackward(const // grad_other[N] = grad_output.T[N, M] × input[M] SgemvCuda(device, SgemvParams{ - .trans = CUBLAS_OP_N, + .trans = GemmTranspose::kNoTranspose, .m = static_cast(N), .n = static_cast(M), .A = static_cast(grad_output_promoted->DataPtr()), @@ -120,51 +121,53 @@ std::tuple, std::shared_ptr> OuterBackward(const break; } case DataType::kBFLOAT16: { - // bf16: cublasSgemv does not support bf16; use GemmCuda (GEMM with k=M or k=N). + // bf16: cublasSgemv does not support bf16; use Gemm (GEMM with k=M or k=N). // grad_input[M] = grad_output[M, N] × other[N] // grad_input.T[1, M] = other.T[1, N] × grad_output.T[N, M] // C[1,M] = A[1,N] * B[N,M] - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_N, - .m = 1, - .n = static_cast(M), - .k = static_cast(N), - .A = other_promoted->DataPtr(), - .lda = 1, - .B = grad_output_promoted->DataPtr(), - .ldb = static_cast(N), - .C = grad_input->DataPtr(), - .ldc = 1, - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = promoted_type, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = 1, + .n = static_cast(M), + .k = static_cast(N), + .A = other_promoted->DataPtr(), + .lda = 1, + .B = grad_output_promoted->DataPtr(), + .ldb = static_cast(N), + .C = grad_input->DataPtr(), + .ldc = 1, + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = promoted_type, + .output_dtype = output_dtype, + }); // grad_other[N] = grad_output.T[N, M] × input[M] // grad_other.T[1, N] = input.T[1, M] × grad_output[M, N] // C[1,N] = A[1,M] * B[M,N] (B stored as grad_output.T[N,M], so ldb=N, trans_b=T) - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_T, - .m = 1, - .n = static_cast(N), - .k = static_cast(M), - .A = input_promoted->DataPtr(), - .lda = 1, - .B = grad_output_promoted->DataPtr(), - .ldb = static_cast(N), - .C = grad_other->DataPtr(), - .ldc = 1, - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = promoted_type, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kTranspose, + .m = 1, + .n = static_cast(N), + .k = static_cast(M), + .A = input_promoted->DataPtr(), + .lda = 1, + .B = grad_output_promoted->DataPtr(), + .ldb = static_cast(N), + .C = grad_other->DataPtr(), + .ldc = 1, + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = promoted_type, + .output_dtype = output_dtype, + }); break; } default: diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index f7947030..b7605a24 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -376,7 +376,7 @@ std::shared_ptr Tensor::Contiguous() { // FIXME: Requires stride tracking in the Tensor class before this can be implemented // correctly. Currently always returns true as a placeholder. The contiguous guard in -// elementwise.cu ensures non-contiguous tensors fall back to the broadcast path. +// the elementwise provider ensures non-contiguous tensors fall back to the broadcast path. bool Tensor::IsContiguous() const { return true; } std::shared_ptr Tensor::Flatten(int64_t start, int64_t end) { diff --git a/third_party/InfiniOps b/third_party/InfiniOps new file mode 160000 index 00000000..c4141be5 --- /dev/null +++ b/third_party/InfiniOps @@ -0,0 +1 @@ +Subproject commit c4141be5ff779edb6af94cad720bc55018c334e4