diff --git a/.gitmodules b/.gitmodules index d33d0a64..7b02a741 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "third_party/googletest"] path = third_party/googletest url = git@github.com:google/googletest.git +[submodule "third_party/InfiniOps"] + path = third_party/InfiniOps + url = git@github.com:InfiniTensor/InfiniOps.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c6da822..83d289d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ option(USE_CUDA "Support NVIDIA CUDA" OFF) option(PROFILE_MODE "ENABLE PROFILE MODE" OFF) option(USE_OMP "Use OpenMP as backend for Eigen" ON) option(USE_NCCL "Build project for distributed running" ON) +option(USE_INFINIOPS "Use InfiniOps as an optional kernel provider" OFF) option(BUILD_TEST "Build InfiniTrain tests" OFF) project(infini_train VERSION 0.5.0 LANGUAGES CXX) @@ -51,6 +52,32 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen) include_directories(${PROJECT_SOURCE_DIR}) +if(USE_INFINIOPS) + add_compile_definitions(USE_INFINIOPS=1) + + set(INFINIOPS_SOURCE_DIR "${PROJECT_SOURCE_DIR}/third_party/InfiniOps") + if(NOT EXISTS "${INFINIOPS_SOURCE_DIR}/CMakeLists.txt") + message(FATAL_ERROR + "USE_INFINIOPS=ON requires InfiniOps under third_party/InfiniOps. " + "Run: git submodule update --init third_party/InfiniOps") + endif() + + set(INFINIOPS_WITH_CPU OFF) + if(NOT USE_CUDA) + set(INFINIOPS_WITH_CPU ON) + endif() + + set(WITH_CPU ${INFINIOPS_WITH_CPU} CACHE BOOL "Enable InfiniOps CPU backend" FORCE) + set(WITH_NVIDIA ${USE_CUDA} CACHE BOOL "Enable InfiniOps NVIDIA backend" FORCE) + add_subdirectory(${INFINIOPS_SOURCE_DIR} ${CMAKE_BINARY_DIR}/third_party/InfiniOps EXCLUDE_FROM_ALL) + if(NOT TARGET infiniops) + message(FATAL_ERROR "InfiniOps third-party project did not define target `infiniops`") + endif() + if(NOT TARGET InfiniOps::infiniops) + add_library(InfiniOps::infiniops ALIAS infiniops) + endif() +endif() + if(PROFILE_MODE) add_compile_definitions(PROFILE_MODE=1) endif() @@ -62,9 +89,13 @@ endif() # Framework core sources (*.cc), excluding cpu kernels (they are built separately) file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc) list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*") +if(NOT USE_INFINIOPS) + list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/.*\.cc$") +endif() if(NOT USE_CUDA) list(FILTER SRC EXCLUDE REGEX ".*runtime/cuda/.*") list(FILTER SRC EXCLUDE REGEX ".*ccl/cuda/.*") + list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/cuda/.*") endif() if(NOT USE_NCCL) list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*") @@ -126,6 +157,9 @@ endif() # ------------------------------------------------------------------------------ add_library(infini_train STATIC ${SRC}) +if(USE_INFINIOPS) + target_link_libraries(infini_train PUBLIC InfiniOps::infiniops) +endif() target_link_libraries(infini_train PUBLIC glog diff --git a/infini_train/include/core/kernel_provider/infiniops/adapter.h b/infini_train/include/core/kernel_provider/infiniops/adapter.h new file mode 100644 index 00000000..5cd2079b --- /dev/null +++ b/infini_train/include/core/kernel_provider/infiniops/adapter.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "data_type.h" +#include "tensor.h" + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" + +namespace infini_train { +class Tensor; +} // namespace infini_train + +namespace infini_train::core { +class Stream; +} // namespace infini_train::core + +namespace infini_train::kernel_provider::infiniops { + +infini::ops::DataType ToOpsDataType(DataType dtype); + +infini::ops::Device ToOpsDevice(const Device &device); + +std::mutex &InfiniOpsCallMutex(); + +using HandleFactory = infini::ops::Handle (*)(const Device &device, core::Stream *stream); + +void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory); + +infini::ops::Handle GetHandle(const Device &device); + +infini::ops::Tensor ToOpsTensor(const std::shared_ptr &tensor); + +infini::ops::Tensor ToOpsTensor(void *data, const std::vector &dims, DataType dtype, const Device &device); + +infini::ops::Tensor ToOpsTensor(void *data, const std::vector &dims, DataType dtype, const Device &device, + const std::vector &strides); + +} // namespace infini_train::kernel_provider::infiniops diff --git a/infini_train/include/core/kernel_provider/infiniops_registry.h b/infini_train/include/core/kernel_provider/infiniops_registry.h new file mode 100644 index 00000000..4dda9578 --- /dev/null +++ b/infini_train/include/core/kernel_provider/infiniops_registry.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/device.h" +#include "infini_train/include/dispatcher.h" + +namespace infini_train::kernel_provider { + +using KeyT = std::pair; + +class InfiniOpsRegistry { +public: + static InfiniOpsRegistry &Instance() { + static InfiniOpsRegistry instance; + return instance; + } + + const KernelFunction *Lookup(const std::string &kernel_name) const { + auto it = name_to_kernel_map_.find(kernel_name); + return it == name_to_kernel_map_.end() ? nullptr : &it->second; + } + + template void Register(const std::string &kernel_name, FuncT &&kernel) { + CHECK(!name_to_kernel_map_.contains(kernel_name)) << "InfiniOps kernel already registered: " << kernel_name; + name_to_kernel_map_.emplace(kernel_name, kernel); + } + +private: + std::map name_to_kernel_map_; +}; + +// Bridge functions used by Dispatcher::GetKernel. Implemented in +// infiniops_registry.cc; declared here for users that already include +// the full registry header (e.g. unit tests). +bool InfiniOpsEnabled(); +bool InfiniOpsEnabled(const KeyT &key); +const KernelFunction *LookupInfiniOpsKernel(const KeyT &key); + +} // namespace infini_train::kernel_provider + +#define REGISTER_INFINIOPS_KERNEL(kernel_name, kernel_func) \ + static const bool _##kernel_name##_infiniops_registered##__COUNTER__ = []() { \ + infini_train::kernel_provider::InfiniOpsRegistry::Instance().Register(#kernel_name, kernel_func); \ + return true; \ + }(); diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 638df76a..05c6a678 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -47,6 +48,11 @@ class KernelFunction { void *func_ptr_ = nullptr; }; +namespace kernel_provider { +bool InfiniOpsEnabled(const std::pair &key); +const KernelFunction *LookupInfiniOpsKernel(const std::pair &key); +} // namespace kernel_provider + class Dispatcher { public: using KeyT = std::pair; @@ -57,6 +63,17 @@ class Dispatcher { } const KernelFunction &GetKernel(KeyT key) const { + if (kernel_provider::InfiniOpsEnabled(key)) { + if (const auto *kernel = kernel_provider::LookupInfiniOpsKernel(key)) { +#ifdef PROFILE_MODE + SetProfileContext(key.second, key.first); +#endif + return *kernel; + } + LOG(WARNING) << "InfiniOps kernel enabled but not registered: " << key.second + << " on device: " << static_cast(key.first) << "; falling back to default kernel"; + } + CHECK(key_to_kernel_map_.contains(key)) << "Kernel not found: " << key.second << " on device: " << static_cast(key.first); #ifdef PROFILE_MODE diff --git a/infini_train/include/kernels/common/gemm.h b/infini_train/include/kernels/common/gemm.h new file mode 100644 index 00000000..4af6a511 --- /dev/null +++ b/infini_train/include/kernels/common/gemm.h @@ -0,0 +1,48 @@ +#pragma once + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" + +namespace infini_train::kernels { + +enum class GemmTranspose : int { + kNoTranspose = 0, + kTranspose = 1, +}; + +/** + * Parameter bundle for a single GEMM call: + * C = alpha * op(A) * op(B) + beta * C + * + * batch_count == 1 describes a non-batched GEMM. batch_count > 1 describes a + * strided-batched GEMM. When batch_count == 1, stride_a/b/c are unused and must + * be left at 0. + */ +struct GemmParams { + GemmTranspose trans_a = GemmTranspose::kNoTranspose; + GemmTranspose trans_b = GemmTranspose::kNoTranspose; + + int m = 0; // rows of op(A) and C + int n = 0; // cols of op(B) and C + int k = 0; // cols of op(A) == rows of op(B) + + const void *A = nullptr; + int lda = 0; + const void *B = nullptr; + int ldb = 0; + void *C = nullptr; + int ldc = 0; + + float alpha = 1.0f; + float beta = 0.0f; + + int batch_count = 1; + long long stride_a = 0; + long long stride_b = 0; + long long stride_c = 0; + + DataType input_dtype; + DataType output_dtype; +}; + +} // namespace infini_train::kernels diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 12f45f57..734ed018 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -139,7 +139,7 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr View(const std::vector &dims); std::shared_ptr Contiguous(); // FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor - // class before this can be implemented correctly. The guard in elementwise.cu ensures + // class before this can be implemented correctly. The elementwise broadcast guard ensures // non-contiguous tensors fall back to the broadcast path until this is resolved. bool IsContiguous() const; std::shared_ptr Flatten(int64_t start = 0, int64_t end = -1); diff --git a/infini_train/src/core/kernel_provider/infiniops/adapter.cc b/infini_train/src/core/kernel_provider/infiniops/adapter.cc new file mode 100644 index 00000000..cbe62c83 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops/adapter.cc @@ -0,0 +1,120 @@ +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/runtime/device_guard.h" + +namespace infini_train::kernel_provider::infiniops { + +namespace { + +inline const std::unordered_map kOpsDataTypeMap = { + {DataType::kFLOAT16, infini::ops::DataType::kFloat16}, {DataType::kBFLOAT16, infini::ops::DataType::kBFloat16}, + {DataType::kFLOAT32, infini::ops::DataType::kFloat32}, {DataType::kFLOAT64, infini::ops::DataType::kFloat64}, + {DataType::kINT8, infini::ops::DataType::kInt8}, {DataType::kINT16, infini::ops::DataType::kInt16}, + {DataType::kINT32, infini::ops::DataType::kInt32}, {DataType::kINT64, infini::ops::DataType::kInt64}, + {DataType::kUINT8, infini::ops::DataType::kUInt8}, {DataType::kUINT16, infini::ops::DataType::kUInt16}, + {DataType::kUINT32, infini::ops::DataType::kUInt32}, {DataType::kUINT64, infini::ops::DataType::kUInt64}, +}; + +inline const std::unordered_map kOpsDeviceTypeMap = { + {Device::DeviceType::kCUDA, infini::ops::Device::Type::kNvidia}, + {Device::DeviceType::kCPU, infini::ops::Device::Type::kCpu}, +}; + +std::map &HandleFactories() { + static std::map factories; + return factories; +} + +} // namespace + +void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory) { + CHECK(factory != nullptr); + auto &factories = HandleFactories(); + CHECK(!factories.contains(type)) << "InfiniOps handle factory already registered for device type " + << static_cast(type); + factories.emplace(type, factory); +} + +infini::ops::Handle GetHandle(const Device &device) { + auto &factories = HandleFactories(); + auto it = factories.find(device.type()); + CHECK(it != factories.end()) << "InfiniOps handle factory is not registered for device type " + << static_cast(device.type()); + + auto *stream = core::GetDeviceGuardImpl(device.type())->GetStream(device); + return it->second(device, stream); +} + +infini::ops::DataType ToOpsDataType(DataType dtype) { + auto it = kOpsDataTypeMap.find(dtype); + if (it == kOpsDataTypeMap.end()) { + LOG(FATAL) << "Unsupported DataType for InfiniOps: " << static_cast(dtype); + __builtin_unreachable(); + } + return it->second; +} + +infini::ops::Device ToOpsDevice(const Device &device) { + auto it = kOpsDeviceTypeMap.find(device.type()); + if (it == kOpsDeviceTypeMap.end()) { + LOG(FATAL) << "Unsupported DeviceType for InfiniOps: " << static_cast(device.type()); + __builtin_unreachable(); + } + return {it->second, device.index()}; +} + +std::mutex &InfiniOpsCallMutex() { + static std::mutex mutex; + return mutex; +} + +namespace { +infini::ops::Tensor::Strides ComputeContiguousStrides(const std::vector &dims) { + infini::ops::Tensor::Strides strides(dims.size()); + if (dims.empty()) { + return strides; + } + strides.back() = 1; + for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * static_cast(dims[i + 1]); + } + return strides; +} + +infini::ops::Tensor::Shape ToShape(const std::vector &dims) { + infini::ops::Tensor::Shape shape(dims.size()); + for (size_t i = 0; i < dims.size(); ++i) { shape[i] = static_cast(dims[i]); } + return shape; +} + +infini::ops::Tensor::Strides ToStrides(const std::vector &strides) { + infini::ops::Tensor::Strides ops_strides(strides.size()); + for (size_t i = 0; i < strides.size(); ++i) { + ops_strides[i] = static_cast(strides[i]); + } + return ops_strides; +} +} // namespace + +infini::ops::Tensor ToOpsTensor(const std::shared_ptr &tensor) { + const auto &dims = tensor->Dims(); + return {tensor->DataPtr(), ToShape(dims), ToOpsDataType(tensor->Dtype()), ToOpsDevice(tensor->GetDevice()), + ComputeContiguousStrides(dims)}; +} + +infini::ops::Tensor ToOpsTensor(void *data, const std::vector &dims, DataType dtype, const Device &device) { + return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ComputeContiguousStrides(dims)}; +} + +infini::ops::Tensor ToOpsTensor(void *data, const std::vector &dims, DataType dtype, const Device &device, + const std::vector &strides) { + CHECK_EQ(dims.size(), strides.size()); + return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ToStrides(strides)}; +} + +} // namespace infini_train::kernel_provider::infiniops diff --git a/infini_train/src/core/kernel_provider/infiniops/cuda/handle.cc b/infini_train/src/core/kernel_provider/infiniops/cuda/handle.cc new file mode 100644 index 00000000..70a44073 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops/cuda/handle.cc @@ -0,0 +1,25 @@ +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" + +#include "glog/logging.h" + +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernel_provider::infiniops { +namespace { + +infini::ops::Handle MakeCudaHandle(const Device &, core::Stream *stream) { + auto *cuda_stream = dynamic_cast(stream); + CHECK_NOTNULL(cuda_stream); + + infini::ops::Handle handle; + handle.set_stream(static_cast(cuda_stream->cuda_stream())); + return handle; +} + +const bool kCudaHandleFactoryRegistered = []() { + RegisterHandleFactory(Device::DeviceType::kCUDA, MakeCudaHandle); + return true; +}(); + +} // namespace +} // namespace infini_train::kernel_provider::infiniops diff --git a/infini_train/src/core/kernel_provider/infiniops/elementwise.cc b/infini_train/src/core/kernel_provider/infiniops/elementwise.cc new file mode 100644 index 00000000..77bed935 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops/elementwise.cc @@ -0,0 +1,71 @@ +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" +#include "infini_train/include/core/kernel_provider/infiniops_registry.h" +#include "infini_train/include/tensor.h" + +#include + +namespace infini_train::kernel_provider::infiniops { +namespace { + +std::vector ComputeBroadcastStrides(const std::vector &dims, const std::vector &out_dims) { + CHECK_LE(dims.size(), out_dims.size()); + + std::vector strides(dims.size()); + if (!dims.empty()) { + strides.back() = 1; + for (int i = static_cast(dims.size()) - 2; i >= 0; --i) { strides[i] = strides[i + 1] * dims[i + 1]; } + } + + const size_t pad = out_dims.size() - dims.size(); + std::vector out_strides(out_dims.size(), 0); + for (size_t i = 0; i < dims.size(); ++i) { + const int64_t dim = dims[i]; + const int64_t out_dim = out_dims[pad + i]; + CHECK(dim == out_dim || dim == 1) << "InfiniOps Add broadcast shape mismatch"; + out_strides[pad + i] = dim == 1 ? 0 : strides[i]; + } + return out_strides; +} + +infini::ops::Tensor ToBroadcastOpsTensor(const std::shared_ptr &tensor, const std::vector &out_dims, + DataType dtype) { + const auto strides = ComputeBroadcastStrides(tensor->Dims(), out_dims); + return ToOpsTensor(tensor->DataPtr(), out_dims, dtype, tensor->GetDevice(), strides); +} + +} // namespace + +std::shared_ptr AddForward(const std::shared_ptr &a, const std::shared_ptr &b) { + CHECK_GE(a->NumElements(), b->NumElements()); + CHECK_EQ(a->NumElements() % b->NumElements(), 0); + + auto a_dtype = a->Dtype(); + auto b_dtype = b->Dtype(); + DataType promoted_type = PromoteDataTypes(a_dtype, b_dtype); + + auto a_promoted = a_dtype == promoted_type ? a : std::make_shared(a->To(promoted_type)); + auto b_promoted = b_dtype == promoted_type ? b : std::make_shared(b->To(promoted_type)); + + auto output = std::make_shared(a->Dims(), promoted_type, a->GetDevice()); + + auto handle = GetHandle(a->GetDevice()); + auto a_ops = ToBroadcastOpsTensor(a_promoted, output->Dims(), promoted_type); + auto b_ops = ToBroadcastOpsTensor(b_promoted, output->Dims(), promoted_type); + auto c_ops = ToOpsTensor(output); + + { + std::lock_guard lock(InfiniOpsCallMutex()); + infini::ops::functional::Add(handle, {}, a_ops, b_ops, c_ops); + } + return output; +} + +} // namespace infini_train::kernel_provider::infiniops + +REGISTER_INFINIOPS_KERNEL(AddForward, infini_train::kernel_provider::infiniops::AddForward) diff --git a/infini_train/src/core/kernel_provider/infiniops/gemm.cc b/infini_train/src/core/kernel_provider/infiniops/gemm.cc new file mode 100644 index 00000000..2b04f549 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops/gemm.cc @@ -0,0 +1,73 @@ +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" +#include "infini_train/include/core/kernel_provider/infiniops_registry.h" +#include "infini_train/include/kernels/common/gemm.h" + +#include + +namespace infini_train::kernel_provider::infiniops { +namespace { + +int ToInfiniOpsTrans(kernels::GemmTranspose op) { + switch (op) { + case kernels::GemmTranspose::kNoTranspose: + return 0; + case kernels::GemmTranspose::kTranspose: + return 1; + } + LOG(FATAL) << "InfiniOps Gemm: unsupported transpose flag " << static_cast(op); + return 0; // unreachable +} + +std::vector MatrixShape(int batch_count, int rows, int cols) { + if (batch_count > 1) { + return {batch_count, rows, cols}; + } + return {rows, cols}; +} + +std::vector RowMajorStrides(int batch_count, int ld, long long batch_stride) { + if (batch_count > 1) { + return {batch_stride, ld, 1}; + } + return {ld, 1}; +} + +infini::ops::Tensor MakeRowMajorTransposeView(const void *data, int batch_count, int column_major_rows, + int column_major_cols, int ld, long long batch_stride, DataType dtype, + const Device &device) { + return ToOpsTensor(const_cast(data), MatrixShape(batch_count, column_major_cols, column_major_rows), dtype, + device, RowMajorStrides(batch_count, ld, batch_stride)); +} + +} // namespace + +void Gemm(Device device, kernels::GemmParams p) { + CHECK_GE(p.batch_count, 1); + + const bool trans_a = p.trans_a == kernels::GemmTranspose::kTranspose; + const bool trans_b = p.trans_b == kernels::GemmTranspose::kTranspose; + + const int a_rows = trans_a ? p.k : p.m; + const int a_cols = trans_a ? p.m : p.k; + const int b_rows = trans_b ? p.n : p.k; + const int b_cols = trans_b ? p.k : p.n; + + auto handle = GetHandle(device); + auto a = MakeRowMajorTransposeView(p.A, p.batch_count, a_rows, a_cols, p.lda, p.stride_a, p.input_dtype, device); + auto b = MakeRowMajorTransposeView(p.B, p.batch_count, b_rows, b_cols, p.ldb, p.stride_b, p.input_dtype, device); + auto c = MakeRowMajorTransposeView(p.C, p.batch_count, p.m, p.n, p.ldc, p.stride_c, p.output_dtype, device); + + std::lock_guard lock(InfiniOpsCallMutex()); + infini::ops::functional::Gemm(handle, {}, b, a, std::optional{p.alpha}, std::optional{p.beta}, + std::optional{ToInfiniOpsTrans(p.trans_b)}, + std::optional{ToInfiniOpsTrans(p.trans_a)}, c); +} + +} // namespace infini_train::kernel_provider::infiniops + +REGISTER_INFINIOPS_KERNEL(Gemm, infini_train::kernel_provider::infiniops::Gemm) diff --git a/infini_train/src/core/kernel_provider/infiniops_registry.cc b/infini_train/src/core/kernel_provider/infiniops_registry.cc new file mode 100644 index 00000000..b5862d87 --- /dev/null +++ b/infini_train/src/core/kernel_provider/infiniops_registry.cc @@ -0,0 +1,51 @@ +#include "infini_train/include/core/kernel_provider/infiniops_registry.h" + +#include +#include +#include +#include + +namespace infini_train::kernel_provider { + +namespace { + +const std::set kEnabledKernelWhitelist = { + "Gemm", + "AddForward", +}; + +} // namespace + +bool InfiniOpsEnabled() { +#ifdef USE_INFINIOPS + return true; +#else + return false; +#endif +} + +bool InfiniOpsEnabled(const KeyT &key) { + return InfiniOpsEnabled() && key.first == Device::DeviceType::kCUDA && kEnabledKernelWhitelist.contains(key.second); +} + +const KernelFunction *LookupInfiniOpsKernel(const KeyT &key) { + const auto *kernel = InfiniOpsRegistry::Instance().Lookup(key.second); + + static std::mutex log_mutex; + static std::set logged_use; + static std::set logged_fallback; + + const auto log_key = std::to_string(static_cast(key.first)) + ":" + key.second; + std::lock_guard lock(log_mutex); + if (kernel != nullptr) { + if (logged_use.insert(log_key).second) { + std::cout << "[InfiniOps] use " << key.second << " on device " << static_cast(key.first) << std::endl; + } + } else if (logged_fallback.insert(log_key).second) { + std::cout << "[InfiniOps] fallback " << key.second << " on device " << static_cast(key.first) << std::endl; + } + + return kernel; +} + +} // namespace infini_train::kernel_provider diff --git a/infini_train/src/kernels/cuda/common/gemm.cu b/infini_train/src/kernels/cuda/common/gemm.cu index b30a0545..80f84970 100644 --- a/infini_train/src/kernels/cuda/common/gemm.cu +++ b/infini_train/src/kernels/cuda/common/gemm.cu @@ -7,6 +7,7 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/datatype.h" +#include "infini_train/include/dispatcher.h" #include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" @@ -14,6 +15,17 @@ namespace infini_train::kernels::cuda { namespace { +cublasOperation_t ToCublasOperation(GemmTranspose op) { + switch (op) { + case GemmTranspose::kNoTranspose: + return CUBLAS_OP_N; + case GemmTranspose::kTranspose: + return CUBLAS_OP_T; + } + LOG(FATAL) << "Gemm: unsupported transpose flag " << static_cast(op); + return CUBLAS_OP_N; // unreachable +} + cudaDataType_t ToCudaDataType(DataType dt) { switch (dt) { case DataType::kFLOAT32: @@ -23,14 +35,14 @@ cudaDataType_t ToCudaDataType(DataType dt) { case DataType::kFLOAT16: return CUDA_R_16F; default: - LOG(FATAL) << "GemmCuda: unsupported DataType " << static_cast(dt); + LOG(FATAL) << "Gemm: unsupported DataType " << static_cast(dt); return CUDA_R_32F; // unreachable } } } // namespace -void GemmCuda(const Device &device, const GemmParams &p) { +void Gemm(Device device, GemmParams p) { const cublasHandle_t blas_handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); @@ -43,6 +55,8 @@ void GemmCuda(const Device &device, const GemmParams &p) { DCHECK_EQ(p.stride_c, 0LL); } + const cublasOperation_t trans_a = ToCublasOperation(p.trans_a); + const cublasOperation_t trans_b = ToCublasOperation(p.trans_b); const cudaDataType_t type_a = ToCudaDataType(p.input_dtype); const cudaDataType_t type_b = ToCudaDataType(p.input_dtype); const cudaDataType_t type_c = ToCudaDataType(p.output_dtype); @@ -51,10 +65,10 @@ void GemmCuda(const Device &device, const GemmParams &p) { const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; if (p.batch_count == 1) { - CUBLAS_CHECK(cublasGemmEx(blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.B, + CUBLAS_CHECK(cublasGemmEx(blas_handle, trans_a, trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.B, type_b, p.ldb, &p.beta, p.C, type_c, p.ldc, compute_type, CUBLAS_GEMM_DEFAULT)); } else { - CUBLAS_CHECK(cublasGemmStridedBatchedEx(blas_handle, p.trans_a, p.trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, + CUBLAS_CHECK(cublasGemmStridedBatchedEx(blas_handle, trans_a, trans_b, p.m, p.n, p.k, &p.alpha, p.A, type_a, p.lda, p.stride_a, p.B, type_b, p.ldb, p.stride_b, &p.beta, p.C, type_c, p.ldc, p.stride_c, p.batch_count, compute_type, CUBLAS_GEMM_DEFAULT)); } @@ -64,7 +78,10 @@ void SgemvCuda(const Device &device, const SgemvParams &p) { const cublasHandle_t blas_handle = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetBlasHandle(device)) ->cublas_handle(); - CUBLAS_CHECK(cublasSgemv(blas_handle, p.trans, p.m, p.n, &p.alpha, p.A, p.lda, p.x, p.incx, &p.beta, p.y, p.incy)); + CUBLAS_CHECK(cublasSgemv(blas_handle, ToCublasOperation(p.trans), p.m, p.n, &p.alpha, p.A, p.lda, p.x, p.incx, + &p.beta, p.y, p.incy)); } } // namespace infini_train::kernels::cuda + +REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, Gemm, infini_train::kernels::cuda::Gemm) diff --git a/infini_train/src/kernels/cuda/common/gemm.cuh b/infini_train/src/kernels/cuda/common/gemm.cuh index 485b633d..1980a441 100644 --- a/infini_train/src/kernels/cuda/common/gemm.cuh +++ b/infini_train/src/kernels/cuda/common/gemm.cuh @@ -1,58 +1,15 @@ #pragma once -#include - -#include "infini_train/include/datatype.h" -#include "infini_train/include/device.h" +#include "infini_train/include/kernels/common/gemm.h" namespace infini_train::kernels::cuda { /** - * Parameter bundle for a single GEMM call: - * C = alpha * op(A) * op(B) + beta * C - * - * batch_count == 1 → non-batched path (cublasGemmEx) - * batch_count > 1 → strided-batched (cublasGemmStridedBatchedEx) + * Execute the GEMM described by `p` via the CUDA backend. * - * When batch_count == 1, stride_a/b/c are unused and must be left at 0. - */ -struct GemmParams { - cublasOperation_t trans_a = CUBLAS_OP_N; - cublasOperation_t trans_b = CUBLAS_OP_N; - - int m = 0; // rows of op(A) and C - int n = 0; // cols of op(B) and C - int k = 0; // cols of op(A) == rows of op(B) - - const void *A = nullptr; - int lda = 0; - const void *B = nullptr; - int ldb = 0; - void *C = nullptr; - int ldc = 0; - - float alpha = 1.0f; - float beta = 0.0f; - - // batch_count=1: non-batched (Linear path); stride_a/b/c must be 0 - // batch_count>1: strided-batched (Matmul path) - int batch_count = 1; - long long stride_a = 0; - long long stride_b = 0; - long long stride_c = 0; - - DataType input_dtype; // dtype of A and B - DataType output_dtype; // dtype of C (may differ, e.g. bf16 in → fp32 out) -}; - -/** - * Execute the GEMM described by `p` via cuBLAS. - * Dispatches to cublasGemmEx (batch_count==1) or - * cublasGemmStridedBatchedEx (batch_count>1). - * Uses CUBLAS_COMPUTE_32F for all input dtypes to ensure precision. - * Aborts on cuBLAS error (via CUBLAS_CHECK / LOG(FATAL)). + * Arguments are passed by value to match Dispatcher function erasure reliably. */ -void GemmCuda(const Device &device, const GemmParams &p); +void Gemm(Device device, GemmParams p); /** * Parameter bundle for a single SGEMV call (fp32 only): @@ -62,7 +19,7 @@ void GemmCuda(const Device &device, const GemmParams &p); * where m_phys and n_phys are the physical (pre-transpose) row/col counts of A. */ struct SgemvParams { - cublasOperation_t trans = CUBLAS_OP_N; + GemmTranspose trans = GemmTranspose::kNoTranspose; int m = 0; int n = 0; const float *A = nullptr; diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index b66c82d0..2725053e 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -83,10 +83,10 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons } // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector). - // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda. + // cublasSgemv does not support bf16, so bf16 falls through to Gemm. if (bs == 1 && dtype == DataType::kFLOAT32) { SgemvCuda(device, SgemvParams{ - .trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N, + .trans = transpose ? GemmTranspose::kTranspose : GemmTranspose::kNoTranspose, .m = static_cast(transpose ? in_features : out_features), .n = static_cast(transpose ? out_features : in_features), .A = static_cast(weight->DataPtr()), @@ -110,24 +110,26 @@ std::shared_ptr LinearForward(const std::shared_ptr &input, cons // C = output.T[out_features, bs] // A = weight.T[out_features, in_features] // B = input.T[in_features, bs] - GemmCuda(device, GemmParams{ - .trans_a = transpose ? CUBLAS_OP_T : CUBLAS_OP_N, - .trans_b = CUBLAS_OP_N, - .m = static_cast(out_features), - .n = static_cast(bs), - .k = static_cast(in_features), - .A = weight->DataPtr(), - .lda = static_cast(transpose ? in_features : out_features), - .B = input->DataPtr(), - .ldb = static_cast(in_features), - .C = output->DataPtr(), - .ldc = static_cast(out_features), - .alpha = 1.0f, - .beta = 1.0f, // bias already written into output; beta=1 accumulates - .batch_count = 1, - .input_dtype = dtype, - .output_dtype = dtype, - }); + Dispatcher::Instance().Call( + {device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = transpose ? GemmTranspose::kTranspose : GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(out_features), + .n = static_cast(bs), + .k = static_cast(in_features), + .A = weight->DataPtr(), + .lda = static_cast(transpose ? in_features : out_features), + .B = input->DataPtr(), + .ldb = static_cast(in_features), + .C = output->DataPtr(), + .ldc = static_cast(out_features), + .alpha = 1.0f, + .beta = 1.0f, // bias already written into output; beta=1 accumulates + .batch_count = 1, + .input_dtype = dtype, + .output_dtype = dtype, + }); } return output; @@ -172,19 +174,20 @@ std::shared_ptr LinearBackwardInput(const std::shared_ptr &weigh auto grad_input = std::make_shared(input_dims, output_dtype, grad_output->GetDevice()); // When bs==1 and fp32, use cublasSgemv (more efficient than GEMM for matrix-vector). - // cublasSgemv does not support bf16, so bf16 falls through to GemmCuda. + // cublasSgemv does not support bf16, so bf16 falls through to Gemm. if (bs == 1 && compute_dtype == DataType::kFLOAT32) { - SgemvCuda(grad_output->GetDevice(), SgemvParams{ - .trans = transpose ? CUBLAS_OP_N : CUBLAS_OP_T, - .m = static_cast(transpose ? in_features : out_features), - .n = static_cast(transpose ? out_features : in_features), - .A = static_cast(weight->DataPtr()), - .lda = static_cast(transpose ? in_features : out_features), - .x = static_cast(grad_output_promoted->DataPtr()), - .y = static_cast(grad_input->DataPtr()), - .alpha = 1.0f, - .beta = 0.0f, - }); + SgemvCuda(grad_output->GetDevice(), + SgemvParams{ + .trans = transpose ? GemmTranspose::kNoTranspose : GemmTranspose::kTranspose, + .m = static_cast(transpose ? in_features : out_features), + .n = static_cast(transpose ? out_features : in_features), + .A = static_cast(weight->DataPtr()), + .lda = static_cast(transpose ? in_features : out_features), + .x = static_cast(grad_output_promoted->DataPtr()), + .y = static_cast(grad_input->DataPtr()), + .alpha = 1.0f, + .beta = 0.0f, + }); } else { // - if transpose: // weight is [out_features, in_features] here @@ -199,24 +202,26 @@ std::shared_ptr LinearBackwardInput(const std::shared_ptr &weigh // C = d_input.T[in_features, bs] // A = weight.T[out_features, in_features] // B = d_output.T[out_features, bs] - GemmCuda(grad_output->GetDevice(), GemmParams{ - .trans_a = transpose ? CUBLAS_OP_N : CUBLAS_OP_T, - .trans_b = CUBLAS_OP_N, - .m = static_cast(in_features), - .n = static_cast(bs), - .k = static_cast(out_features), - .A = weight->DataPtr(), - .lda = static_cast(transpose ? in_features : out_features), - .B = grad_output_promoted->DataPtr(), - .ldb = static_cast(out_features), - .C = grad_input->DataPtr(), - .ldc = static_cast(in_features), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = compute_dtype, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call( + {grad_output->GetDevice().type(), "Gemm"}, grad_output->GetDevice(), + GemmParams{ + .trans_a = transpose ? GemmTranspose::kNoTranspose : GemmTranspose::kTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(in_features), + .n = static_cast(bs), + .k = static_cast(out_features), + .A = weight->DataPtr(), + .lda = static_cast(transpose ? in_features : out_features), + .B = grad_output_promoted->DataPtr(), + .ldb = static_cast(out_features), + .C = grad_input->DataPtr(), + .ldc = static_cast(in_features), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = compute_dtype, + .output_dtype = output_dtype, + }); } return grad_input; @@ -258,24 +263,25 @@ std::shared_ptr LinearBackwardWeight(const std::shared_ptr &inpu const int lda = static_cast(transpose ? in_features : out_features); const int ldb = static_cast(transpose ? out_features : in_features); - GemmCuda(grad_output->GetDevice(), GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_T, - .m = static_cast(transpose ? in_features : out_features), - .n = static_cast(transpose ? out_features : in_features), - .k = static_cast(bs), - .A = a, - .lda = lda, - .B = b, - .ldb = ldb, - .C = grad_weight->DataPtr(), - .ldc = static_cast(transpose ? in_features : out_features), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = compute_dtype, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({grad_output->GetDevice().type(), "Gemm"}, grad_output->GetDevice(), + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kTranspose, + .m = static_cast(transpose ? in_features : out_features), + .n = static_cast(transpose ? out_features : in_features), + .k = static_cast(bs), + .A = a, + .lda = lda, + .B = b, + .ldb = ldb, + .C = grad_weight->DataPtr(), + .ldc = static_cast(transpose ? in_features : out_features), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = compute_dtype, + .output_dtype = output_dtype, + }); return grad_weight; } diff --git a/infini_train/src/kernels/cuda/matmul.cu b/infini_train/src/kernels/cuda/matmul.cu index 7e301039..b1c6e381 100644 --- a/infini_train/src/kernels/cuda/matmul.cu +++ b/infini_train/src/kernels/cuda/matmul.cu @@ -48,27 +48,28 @@ std::shared_ptr MatmulForward(const std::shared_ptr &input, cons // A = other.T[*, n, k] // B = input.T[*, k, m] // NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere) - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_N, - .m = static_cast(n), - .n = static_cast(m), - .k = static_cast(k), - .A = other->DataPtr(), - .lda = static_cast(n), - .B = input->DataPtr(), - .ldb = static_cast(k), - .C = output->DataPtr(), - .ldc = static_cast(n), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = static_cast(bs), - .stride_a = bs > 1 ? n * k : 0, - .stride_b = bs > 1 ? k * m : 0, - .stride_c = bs > 1 ? m * n : 0, - .input_dtype = dtype, - .output_dtype = dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(n), + .n = static_cast(m), + .k = static_cast(k), + .A = other->DataPtr(), + .lda = static_cast(n), + .B = input->DataPtr(), + .ldb = static_cast(k), + .C = output->DataPtr(), + .ldc = static_cast(n), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = static_cast(bs), + .stride_a = bs > 1 ? n * k : 0, + .stride_b = bs > 1 ? k * m : 0, + .stride_c = bs > 1 ? m * n : 0, + .input_dtype = dtype, + .output_dtype = dtype, + }); return output; } @@ -118,27 +119,28 @@ std::shared_ptr MatmulBackwardInput(const std::shared_ptr &other // C = grad_input.T[*, k, m] // A = other.T[*, n, k] // B = grad_output.T[*, n, m] - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_T, - .trans_b = CUBLAS_OP_N, - .m = static_cast(k), - .n = static_cast(m), - .k = static_cast(n), - .A = other->DataPtr(), - .lda = static_cast(n), - .B = grad_output_promoted->DataPtr(), - .ldb = static_cast(n), - .C = grad_input->DataPtr(), - .ldc = static_cast(k), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = static_cast(bs), - .stride_a = bs > 1 ? k * n : 0, - .stride_b = bs > 1 ? n * m : 0, - .stride_c = bs > 1 ? m * k : 0, - .input_dtype = compute_dtype, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(k), + .n = static_cast(m), + .k = static_cast(n), + .A = other->DataPtr(), + .lda = static_cast(n), + .B = grad_output_promoted->DataPtr(), + .ldb = static_cast(n), + .C = grad_input->DataPtr(), + .ldc = static_cast(k), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = static_cast(bs), + .stride_a = bs > 1 ? k * n : 0, + .stride_b = bs > 1 ? n * m : 0, + .stride_c = bs > 1 ? m * k : 0, + .input_dtype = compute_dtype, + .output_dtype = output_dtype, + }); return grad_input; } @@ -187,27 +189,28 @@ std::shared_ptr MatmulBackwardOther(const std::shared_ptr &input // C = grad_other.T[*, n, k] // A = grad_output.T[*, n, m] // B = input.T[*, k, m] - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_T, - .m = static_cast(n), - .n = static_cast(k), - .k = static_cast(m), - .A = grad_output_promoted->DataPtr(), - .lda = static_cast(n), - .B = input1->DataPtr(), - .ldb = static_cast(k), - .C = grad_other->DataPtr(), - .ldc = static_cast(n), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = static_cast(bs), - .stride_a = bs > 1 ? n * m : 0, - .stride_b = bs > 1 ? k * m : 0, - .stride_c = bs > 1 ? n * k : 0, - .input_dtype = compute_dtype, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kTranspose, + .m = static_cast(n), + .n = static_cast(k), + .k = static_cast(m), + .A = grad_output_promoted->DataPtr(), + .lda = static_cast(n), + .B = input1->DataPtr(), + .ldb = static_cast(k), + .C = grad_other->DataPtr(), + .ldc = static_cast(n), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = static_cast(bs), + .stride_a = bs > 1 ? n * m : 0, + .stride_b = bs > 1 ? k * m : 0, + .stride_c = bs > 1 ? n * k : 0, + .input_dtype = compute_dtype, + .output_dtype = output_dtype, + }); return grad_other; } diff --git a/infini_train/src/kernels/cuda/outer.cu b/infini_train/src/kernels/cuda/outer.cu index 4a6c8ab3..0b8252b7 100644 --- a/infini_train/src/kernels/cuda/outer.cu +++ b/infini_train/src/kernels/cuda/outer.cu @@ -39,24 +39,25 @@ std::shared_ptr OuterForward(const std::shared_ptr &input, const // output[M, N] = input[M, 1] * other.T[1, N] // output.T[N, M] = other[N, 1] * input.T[1, M] // This is a GEMM with k=1: C[N,M] = A[N,1] * B[1,M] - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_N, - .m = static_cast(N), - .n = static_cast(M), - .k = 1, - .A = other->DataPtr(), - .lda = static_cast(N), - .B = input->DataPtr(), - .ldb = 1, - .C = output->DataPtr(), - .ldc = static_cast(N), - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = dtype, - .output_dtype = dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = static_cast(N), + .n = static_cast(M), + .k = 1, + .A = other->DataPtr(), + .lda = static_cast(N), + .B = input->DataPtr(), + .ldb = 1, + .C = output->DataPtr(), + .ldc = static_cast(N), + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = dtype, + .output_dtype = dtype, + }); return output; } @@ -98,7 +99,7 @@ std::tuple, std::shared_ptr> OuterBackward(const case DataType::kFLOAT32: { // grad_input[M] = grad_output[M, N] × other[N] SgemvCuda(device, SgemvParams{ - .trans = CUBLAS_OP_T, + .trans = GemmTranspose::kTranspose, .m = static_cast(N), .n = static_cast(M), .A = static_cast(grad_output_promoted->DataPtr()), @@ -109,7 +110,7 @@ std::tuple, std::shared_ptr> OuterBackward(const // grad_other[N] = grad_output.T[N, M] × input[M] SgemvCuda(device, SgemvParams{ - .trans = CUBLAS_OP_N, + .trans = GemmTranspose::kNoTranspose, .m = static_cast(N), .n = static_cast(M), .A = static_cast(grad_output_promoted->DataPtr()), @@ -120,51 +121,53 @@ std::tuple, std::shared_ptr> OuterBackward(const break; } case DataType::kBFLOAT16: { - // bf16: cublasSgemv does not support bf16; use GemmCuda (GEMM with k=M or k=N). + // bf16: cublasSgemv does not support bf16; use Gemm (GEMM with k=M or k=N). // grad_input[M] = grad_output[M, N] × other[N] // grad_input.T[1, M] = other.T[1, N] × grad_output.T[N, M] // C[1,M] = A[1,N] * B[N,M] - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_N, - .m = 1, - .n = static_cast(M), - .k = static_cast(N), - .A = other_promoted->DataPtr(), - .lda = 1, - .B = grad_output_promoted->DataPtr(), - .ldb = static_cast(N), - .C = grad_input->DataPtr(), - .ldc = 1, - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = promoted_type, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kNoTranspose, + .m = 1, + .n = static_cast(M), + .k = static_cast(N), + .A = other_promoted->DataPtr(), + .lda = 1, + .B = grad_output_promoted->DataPtr(), + .ldb = static_cast(N), + .C = grad_input->DataPtr(), + .ldc = 1, + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = promoted_type, + .output_dtype = output_dtype, + }); // grad_other[N] = grad_output.T[N, M] × input[M] // grad_other.T[1, N] = input.T[1, M] × grad_output[M, N] // C[1,N] = A[1,M] * B[M,N] (B stored as grad_output.T[N,M], so ldb=N, trans_b=T) - GemmCuda(device, GemmParams{ - .trans_a = CUBLAS_OP_N, - .trans_b = CUBLAS_OP_T, - .m = 1, - .n = static_cast(N), - .k = static_cast(M), - .A = input_promoted->DataPtr(), - .lda = 1, - .B = grad_output_promoted->DataPtr(), - .ldb = static_cast(N), - .C = grad_other->DataPtr(), - .ldc = 1, - .alpha = 1.0f, - .beta = 0.0f, - .batch_count = 1, - .input_dtype = promoted_type, - .output_dtype = output_dtype, - }); + Dispatcher::Instance().Call({device.type(), "Gemm"}, device, + GemmParams{ + .trans_a = GemmTranspose::kNoTranspose, + .trans_b = GemmTranspose::kTranspose, + .m = 1, + .n = static_cast(N), + .k = static_cast(M), + .A = input_promoted->DataPtr(), + .lda = 1, + .B = grad_output_promoted->DataPtr(), + .ldb = static_cast(N), + .C = grad_other->DataPtr(), + .ldc = 1, + .alpha = 1.0f, + .beta = 0.0f, + .batch_count = 1, + .input_dtype = promoted_type, + .output_dtype = output_dtype, + }); break; } default: diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index f7947030..b7605a24 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -376,7 +376,7 @@ std::shared_ptr Tensor::Contiguous() { // FIXME: Requires stride tracking in the Tensor class before this can be implemented // correctly. Currently always returns true as a placeholder. The contiguous guard in -// elementwise.cu ensures non-contiguous tensors fall back to the broadcast path. +// the elementwise provider ensures non-contiguous tensors fall back to the broadcast path. bool Tensor::IsContiguous() const { return true; } std::shared_ptr Tensor::Flatten(int64_t start, int64_t end) { diff --git a/third_party/InfiniOps b/third_party/InfiniOps new file mode 160000 index 00000000..c4141be5 --- /dev/null +++ b/third_party/InfiniOps @@ -0,0 +1 @@ +Subproject commit c4141be5ff779edb6af94cad720bc55018c334e4