-
Notifications
You must be signed in to change notification settings - Fork 45
[WIP] feat: add InfiniOps as optional kernel provider #161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| #pragma once | ||
|
|
||
| #include <cstdint> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <vector> | ||
|
|
||
| #include <handle.h> | ||
|
|
||
| #include "data_type.h" | ||
| #include "tensor.h" | ||
|
|
||
| #include "infini_train/include/datatype.h" | ||
| #include "infini_train/include/device.h" | ||
|
|
||
| namespace infini_train { | ||
| class Tensor; | ||
| } // namespace infini_train | ||
|
|
||
| namespace infini_train::core { | ||
| class Stream; | ||
| } // namespace infini_train::core | ||
|
|
||
| namespace infini_train::kernel_provider::infiniops { | ||
|
|
||
| infini::ops::DataType ToOpsDataType(DataType dtype); | ||
|
|
||
| infini::ops::Device ToOpsDevice(const Device &device); | ||
|
|
||
| std::mutex &InfiniOpsCallMutex(); | ||
|
|
||
| using HandleFactory = infini::ops::Handle (*)(const Device &device, core::Stream *stream); | ||
|
|
||
| void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory); | ||
|
|
||
| infini::ops::Handle GetHandle(const Device &device); | ||
|
|
||
| infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor); | ||
|
|
||
| infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device); | ||
|
|
||
| infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device, | ||
| const std::vector<int64_t> &strides); | ||
|
|
||
| } // namespace infini_train::kernel_provider::infiniops |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| #pragma once | ||
|
|
||
| #include <map> | ||
| #include <string> | ||
| #include <utility> | ||
|
|
||
| #include "glog/logging.h" | ||
|
|
||
| #include "infini_train/include/device.h" | ||
| #include "infini_train/include/dispatcher.h" | ||
|
|
||
| namespace infini_train::kernel_provider { | ||
|
|
||
| using KeyT = std::pair<Device::DeviceType, std::string>; | ||
|
|
||
| class InfiniOpsRegistry { | ||
| public: | ||
| static InfiniOpsRegistry &Instance() { | ||
| static InfiniOpsRegistry instance; | ||
| return instance; | ||
| } | ||
|
|
||
| const KernelFunction *Lookup(const std::string &kernel_name) const { | ||
| auto it = name_to_kernel_map_.find(kernel_name); | ||
| return it == name_to_kernel_map_.end() ? nullptr : &it->second; | ||
| } | ||
|
|
||
| template <typename FuncT> void Register(const std::string &kernel_name, FuncT &&kernel) { | ||
| CHECK(!name_to_kernel_map_.contains(kernel_name)) << "InfiniOps kernel already registered: " << kernel_name; | ||
| name_to_kernel_map_.emplace(kernel_name, kernel); | ||
| } | ||
|
|
||
| private: | ||
| std::map<std::string, KernelFunction> name_to_kernel_map_; | ||
| }; | ||
|
|
||
| // Bridge functions used by Dispatcher::GetKernel. Implemented in | ||
| // infiniops_registry.cc; declared here for users that already include | ||
| // the full registry header (e.g. unit tests). | ||
| bool InfiniOpsEnabled(); | ||
| bool InfiniOpsEnabled(const KeyT &key); | ||
| const KernelFunction *LookupInfiniOpsKernel(const KeyT &key); | ||
|
|
||
| } // namespace infini_train::kernel_provider | ||
|
|
||
| #define REGISTER_INFINIOPS_KERNEL(kernel_name, kernel_func) \ | ||
| static const bool _##kernel_name##_infiniops_registered##__COUNTER__ = []() { \ | ||
| infini_train::kernel_provider::InfiniOpsRegistry::Instance().Register(#kernel_name, kernel_func); \ | ||
| return true; \ | ||
| }(); |
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不应该给 infinops 开额外分支,之前接沐曦 kernel 这块是不需要动的。 |
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个头文件内容没什么问题,但不适合放到 include 里作为公共头文件暴露,先放 infini_train/src/kernels/common 里吧 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| #pragma once | ||
|
|
||
| #include "infini_train/include/datatype.h" | ||
| #include "infini_train/include/device.h" | ||
|
|
||
| namespace infini_train::kernels { | ||
|
|
||
| enum class GemmTranspose : int { | ||
| kNoTranspose = 0, | ||
| kTranspose = 1, | ||
| }; | ||
|
|
||
| /** | ||
| * Parameter bundle for a single GEMM call: | ||
| * C = alpha * op(A) * op(B) + beta * C | ||
| * | ||
| * batch_count == 1 describes a non-batched GEMM. batch_count > 1 describes a | ||
| * strided-batched GEMM. When batch_count == 1, stride_a/b/c are unused and must | ||
| * be left at 0. | ||
| */ | ||
| struct GemmParams { | ||
| GemmTranspose trans_a = GemmTranspose::kNoTranspose; | ||
| GemmTranspose trans_b = GemmTranspose::kNoTranspose; | ||
|
|
||
| int m = 0; // rows of op(A) and C | ||
| int n = 0; // cols of op(B) and C | ||
| int k = 0; // cols of op(A) == rows of op(B) | ||
|
|
||
| const void *A = nullptr; | ||
| int lda = 0; | ||
| const void *B = nullptr; | ||
| int ldb = 0; | ||
| void *C = nullptr; | ||
| int ldc = 0; | ||
|
|
||
| float alpha = 1.0f; | ||
| float beta = 0.0f; | ||
|
|
||
| int batch_count = 1; | ||
| long long stride_a = 0; | ||
| long long stride_b = 0; | ||
| long long stride_c = 0; | ||
|
|
||
| DataType input_dtype; | ||
| DataType output_dtype; | ||
| }; | ||
|
|
||
| } // namespace infini_train::kernels |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -139,7 +139,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> { | |
| std::shared_ptr<Tensor> View(const std::vector<int64_t> &dims); | ||
| std::shared_ptr<Tensor> Contiguous(); | ||
| // FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor | ||
| // class before this can be implemented correctly. The guard in elementwise.cu ensures | ||
| // class before this can be implemented correctly. The elementwise broadcast guard ensures | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不用改吧 |
||
| // non-contiguous tensors fall back to the broadcast path until this is resolved. | ||
| bool IsContiguous() const; | ||
| std::shared_ptr<Tensor> Flatten(int64_t start = 0, int64_t end = -1); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| #include "infini_train/include/core/kernel_provider/infiniops/adapter.h" | ||
|
|
||
| #include <map> | ||
| #include <unordered_map> | ||
|
|
||
| #include "glog/logging.h" | ||
|
|
||
| #include "infini_train/include/core/runtime/device_guard.h" | ||
|
|
||
| namespace infini_train::kernel_provider::infiniops { | ||
|
|
||
| namespace { | ||
|
|
||
| inline const std::unordered_map<DataType, infini::ops::DataType> kOpsDataTypeMap = { | ||
| {DataType::kFLOAT16, infini::ops::DataType::kFloat16}, {DataType::kBFLOAT16, infini::ops::DataType::kBFloat16}, | ||
| {DataType::kFLOAT32, infini::ops::DataType::kFloat32}, {DataType::kFLOAT64, infini::ops::DataType::kFloat64}, | ||
| {DataType::kINT8, infini::ops::DataType::kInt8}, {DataType::kINT16, infini::ops::DataType::kInt16}, | ||
| {DataType::kINT32, infini::ops::DataType::kInt32}, {DataType::kINT64, infini::ops::DataType::kInt64}, | ||
| {DataType::kUINT8, infini::ops::DataType::kUInt8}, {DataType::kUINT16, infini::ops::DataType::kUInt16}, | ||
| {DataType::kUINT32, infini::ops::DataType::kUInt32}, {DataType::kUINT64, infini::ops::DataType::kUInt64}, | ||
| }; | ||
|
|
||
| inline const std::unordered_map<Device::DeviceType, infini::ops::Device::Type> kOpsDeviceTypeMap = { | ||
| {Device::DeviceType::kCUDA, infini::ops::Device::Type::kNvidia}, | ||
| {Device::DeviceType::kCPU, infini::ops::Device::Type::kCpu}, | ||
| }; | ||
|
|
||
| std::map<Device::DeviceType, HandleFactory> &HandleFactories() { | ||
| static std::map<Device::DeviceType, HandleFactory> factories; | ||
| return factories; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory) { | ||
| CHECK(factory != nullptr); | ||
| auto &factories = HandleFactories(); | ||
| CHECK(!factories.contains(type)) << "InfiniOps handle factory already registered for device type " | ||
| << static_cast<int>(type); | ||
| factories.emplace(type, factory); | ||
| } | ||
|
|
||
| infini::ops::Handle GetHandle(const Device &device) { | ||
| auto &factories = HandleFactories(); | ||
| auto it = factories.find(device.type()); | ||
| CHECK(it != factories.end()) << "InfiniOps handle factory is not registered for device type " | ||
| << static_cast<int>(device.type()); | ||
|
|
||
| auto *stream = core::GetDeviceGuardImpl(device.type())->GetStream(device); | ||
| return it->second(device, stream); | ||
| } | ||
|
|
||
| infini::ops::DataType ToOpsDataType(DataType dtype) { | ||
| auto it = kOpsDataTypeMap.find(dtype); | ||
| if (it == kOpsDataTypeMap.end()) { | ||
| LOG(FATAL) << "Unsupported DataType for InfiniOps: " << static_cast<int>(dtype); | ||
| __builtin_unreachable(); | ||
| } | ||
| return it->second; | ||
| } | ||
|
|
||
| infini::ops::Device ToOpsDevice(const Device &device) { | ||
| auto it = kOpsDeviceTypeMap.find(device.type()); | ||
| if (it == kOpsDeviceTypeMap.end()) { | ||
| LOG(FATAL) << "Unsupported DeviceType for InfiniOps: " << static_cast<int>(device.type()); | ||
| __builtin_unreachable(); | ||
| } | ||
| return {it->second, device.index()}; | ||
| } | ||
|
|
||
| std::mutex &InfiniOpsCallMutex() { | ||
| static std::mutex mutex; | ||
| return mutex; | ||
| } | ||
|
|
||
| namespace { | ||
| infini::ops::Tensor::Strides ComputeContiguousStrides(const std::vector<int64_t> &dims) { | ||
| infini::ops::Tensor::Strides strides(dims.size()); | ||
| if (dims.empty()) { | ||
| return strides; | ||
| } | ||
| strides.back() = 1; | ||
| for (int i = static_cast<int>(dims.size()) - 2; i >= 0; --i) { | ||
| strides[i] = strides[i + 1] * static_cast<infini::ops::Tensor::Stride>(dims[i + 1]); | ||
| } | ||
| return strides; | ||
| } | ||
|
|
||
| infini::ops::Tensor::Shape ToShape(const std::vector<int64_t> &dims) { | ||
| infini::ops::Tensor::Shape shape(dims.size()); | ||
| for (size_t i = 0; i < dims.size(); ++i) { shape[i] = static_cast<infini::ops::Tensor::Size>(dims[i]); } | ||
| return shape; | ||
| } | ||
|
|
||
| infini::ops::Tensor::Strides ToStrides(const std::vector<int64_t> &strides) { | ||
| infini::ops::Tensor::Strides ops_strides(strides.size()); | ||
| for (size_t i = 0; i < strides.size(); ++i) { | ||
| ops_strides[i] = static_cast<infini::ops::Tensor::Stride>(strides[i]); | ||
| } | ||
| return ops_strides; | ||
| } | ||
| } // namespace | ||
|
|
||
| infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor) { | ||
| const auto &dims = tensor->Dims(); | ||
| return {tensor->DataPtr(), ToShape(dims), ToOpsDataType(tensor->Dtype()), ToOpsDevice(tensor->GetDevice()), | ||
| ComputeContiguousStrides(dims)}; | ||
| } | ||
|
|
||
| infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device) { | ||
| return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ComputeContiguousStrides(dims)}; | ||
| } | ||
|
|
||
| infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device, | ||
| const std::vector<int64_t> &strides) { | ||
| CHECK_EQ(dims.size(), strides.size()); | ||
| return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ToStrides(strides)}; | ||
| } | ||
|
|
||
| } // namespace infini_train::kernel_provider::infiniops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是出于什么原因要单独写一套 registry,而不能直接复用 InfiniTrain 原有的注册表呢?