Skip to content

Commit 865b51c

Browse files
committed
feat: add InfiniOps as optional kernel provider
Wire InfiniOps in as a pluggable kernel provider keyed at the GEMM level: Dispatcher consults a per-key whitelist hook and routes registered ops to InfiniOps, falling back to the default CUDA kernel otherwise. linear, matmul and outer now invoke Gemm via Dispatcher rather than calling the cuBLAS wrapper directly, so InfiniOps Gemm transparently covers all three.
1 parent 49cb34b commit 865b51c

19 files changed

Lines changed: 767 additions & 243 deletions

File tree

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@
1010
[submodule "third_party/googletest"]
1111
path = third_party/googletest
1212
url = git@github.com:google/googletest.git
13+
[submodule "third_party/InfiniOps"]
14+
path = third_party/InfiniOps
15+
url = git@github.com:InfiniTensor/InfiniOps.git

CMakeLists.txt

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ option(USE_CUDA "Support NVIDIA CUDA" OFF)
44
option(PROFILE_MODE "ENABLE PROFILE MODE" OFF)
55
option(USE_OMP "Use OpenMP as backend for Eigen" ON)
66
option(USE_NCCL "Build project for distributed running" ON)
7+
option(USE_INFINIOPS "Use InfiniOps as an optional kernel provider" OFF)
78
option(BUILD_TEST "Build InfiniTrain tests" OFF)
89

910
project(infini_train VERSION 0.5.0 LANGUAGES CXX)
@@ -51,6 +52,32 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen)
5152

5253
include_directories(${PROJECT_SOURCE_DIR})
5354

55+
if(USE_INFINIOPS)
56+
add_compile_definitions(USE_INFINIOPS=1)
57+
58+
set(INFINIOPS_SOURCE_DIR "${PROJECT_SOURCE_DIR}/third_party/InfiniOps")
59+
if(NOT EXISTS "${INFINIOPS_SOURCE_DIR}/CMakeLists.txt")
60+
message(FATAL_ERROR
61+
"USE_INFINIOPS=ON requires InfiniOps under third_party/InfiniOps. "
62+
"Run: git submodule update --init third_party/InfiniOps")
63+
endif()
64+
65+
set(INFINIOPS_WITH_CPU OFF)
66+
if(NOT USE_CUDA)
67+
set(INFINIOPS_WITH_CPU ON)
68+
endif()
69+
70+
set(WITH_CPU ${INFINIOPS_WITH_CPU} CACHE BOOL "Enable InfiniOps CPU backend" FORCE)
71+
set(WITH_NVIDIA ${USE_CUDA} CACHE BOOL "Enable InfiniOps NVIDIA backend" FORCE)
72+
add_subdirectory(${INFINIOPS_SOURCE_DIR} ${CMAKE_BINARY_DIR}/third_party/InfiniOps EXCLUDE_FROM_ALL)
73+
if(NOT TARGET infiniops)
74+
message(FATAL_ERROR "InfiniOps third-party project did not define target `infiniops`")
75+
endif()
76+
if(NOT TARGET InfiniOps::infiniops)
77+
add_library(InfiniOps::infiniops ALIAS infiniops)
78+
endif()
79+
endif()
80+
5481
if(PROFILE_MODE)
5582
add_compile_definitions(PROFILE_MODE=1)
5683
endif()
@@ -62,9 +89,13 @@ endif()
6289
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
6390
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
6491
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
92+
if(NOT USE_INFINIOPS)
93+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/.*\.cc$")
94+
endif()
6595
if(NOT USE_CUDA)
6696
list(FILTER SRC EXCLUDE REGEX ".*runtime/cuda/.*")
6797
list(FILTER SRC EXCLUDE REGEX ".*ccl/cuda/.*")
98+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/cuda/.*")
6899
endif()
69100
if(NOT USE_NCCL)
70101
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
@@ -126,6 +157,9 @@ endif()
126157
# ------------------------------------------------------------------------------
127158

128159
add_library(infini_train STATIC ${SRC})
160+
if(USE_INFINIOPS)
161+
target_link_libraries(infini_train PUBLIC InfiniOps::infiniops)
162+
endif()
129163
target_link_libraries(infini_train
130164
PUBLIC
131165
glog
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <memory>
5+
#include <mutex>
6+
#include <vector>
7+
8+
#include <handle.h>
9+
10+
#include "data_type.h"
11+
#include "tensor.h"
12+
13+
#include "infini_train/include/datatype.h"
14+
#include "infini_train/include/device.h"
15+
16+
namespace infini_train {
17+
class Tensor;
18+
} // namespace infini_train
19+
20+
namespace infini_train::core {
21+
class Stream;
22+
} // namespace infini_train::core
23+
24+
namespace infini_train::kernel_provider::infiniops {
25+
26+
infini::ops::DataType ToOpsDataType(DataType dtype);
27+
28+
infini::ops::Device ToOpsDevice(const Device &device);
29+
30+
std::mutex &InfiniOpsCallMutex();
31+
32+
using HandleFactory = infini::ops::Handle (*)(const Device &device, core::Stream *stream);
33+
34+
void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory);
35+
36+
infini::ops::Handle GetHandle(const Device &device);
37+
38+
infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor);
39+
40+
infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device);
41+
42+
infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device,
43+
const std::vector<int64_t> &strides);
44+
45+
} // namespace infini_train::kernel_provider::infiniops
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#include <map>
4+
#include <string>
5+
#include <utility>
6+
7+
#include "glog/logging.h"
8+
9+
#include "infini_train/include/device.h"
10+
#include "infini_train/include/dispatcher.h"
11+
12+
namespace infini_train::kernel_provider {
13+
14+
using KeyT = std::pair<Device::DeviceType, std::string>;
15+
16+
class InfiniOpsRegistry {
17+
public:
18+
static InfiniOpsRegistry &Instance() {
19+
static InfiniOpsRegistry instance;
20+
return instance;
21+
}
22+
23+
const KernelFunction *Lookup(const std::string &kernel_name) const {
24+
auto it = name_to_kernel_map_.find(kernel_name);
25+
return it == name_to_kernel_map_.end() ? nullptr : &it->second;
26+
}
27+
28+
template <typename FuncT> void Register(const std::string &kernel_name, FuncT &&kernel) {
29+
CHECK(!name_to_kernel_map_.contains(kernel_name)) << "InfiniOps kernel already registered: " << kernel_name;
30+
name_to_kernel_map_.emplace(kernel_name, kernel);
31+
}
32+
33+
private:
34+
std::map<std::string, KernelFunction> name_to_kernel_map_;
35+
};
36+
37+
// Bridge functions used by Dispatcher::GetKernel. Implemented in
38+
// infiniops_registry.cc; declared here for users that already include
39+
// the full registry header (e.g. unit tests).
40+
bool InfiniOpsEnabled();
41+
bool InfiniOpsEnabled(const KeyT &key);
42+
const KernelFunction *LookupInfiniOpsKernel(const KeyT &key);
43+
44+
} // namespace infini_train::kernel_provider
45+
46+
#define REGISTER_INFINIOPS_KERNEL(kernel_name, kernel_func) \
47+
static const bool _##kernel_name##_infiniops_registered##__COUNTER__ = []() { \
48+
infini_train::kernel_provider::InfiniOpsRegistry::Instance().Register(#kernel_name, kernel_func); \
49+
return true; \
50+
}();

infini_train/include/dispatcher.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <map>
4+
#include <string>
45
#include <type_traits>
56
#include <utility>
67

@@ -47,6 +48,11 @@ class KernelFunction {
4748
void *func_ptr_ = nullptr;
4849
};
4950

51+
namespace kernel_provider {
52+
bool InfiniOpsEnabled(const std::pair<Device::DeviceType, std::string> &key);
53+
const KernelFunction *LookupInfiniOpsKernel(const std::pair<Device::DeviceType, std::string> &key);
54+
} // namespace kernel_provider
55+
5056
class Dispatcher {
5157
public:
5258
using KeyT = std::pair<Device::DeviceType, std::string>;
@@ -57,6 +63,17 @@ class Dispatcher {
5763
}
5864

5965
const KernelFunction &GetKernel(KeyT key) const {
66+
if (kernel_provider::InfiniOpsEnabled(key)) {
67+
if (const auto *kernel = kernel_provider::LookupInfiniOpsKernel(key)) {
68+
#ifdef PROFILE_MODE
69+
SetProfileContext(key.second, key.first);
70+
#endif
71+
return *kernel;
72+
}
73+
LOG(WARNING) << "InfiniOps kernel enabled but not registered: " << key.second
74+
<< " on device: " << static_cast<int>(key.first) << "; falling back to default kernel";
75+
}
76+
6077
CHECK(key_to_kernel_map_.contains(key))
6178
<< "Kernel not found: " << key.second << " on device: " << static_cast<int>(key.first);
6279
#ifdef PROFILE_MODE
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#pragma once
2+
3+
#include "infini_train/include/datatype.h"
4+
#include "infini_train/include/device.h"
5+
6+
namespace infini_train::kernels {
7+
8+
enum class GemmTranspose : int {
9+
kNoTranspose = 0,
10+
kTranspose = 1,
11+
};
12+
13+
/**
14+
* Parameter bundle for a single GEMM call:
15+
* C = alpha * op(A) * op(B) + beta * C
16+
*
17+
* batch_count == 1 describes a non-batched GEMM. batch_count > 1 describes a
18+
* strided-batched GEMM. When batch_count == 1, stride_a/b/c are unused and must
19+
* be left at 0.
20+
*/
21+
struct GemmParams {
22+
GemmTranspose trans_a = GemmTranspose::kNoTranspose;
23+
GemmTranspose trans_b = GemmTranspose::kNoTranspose;
24+
25+
int m = 0; // rows of op(A) and C
26+
int n = 0; // cols of op(B) and C
27+
int k = 0; // cols of op(A) == rows of op(B)
28+
29+
const void *A = nullptr;
30+
int lda = 0;
31+
const void *B = nullptr;
32+
int ldb = 0;
33+
void *C = nullptr;
34+
int ldc = 0;
35+
36+
float alpha = 1.0f;
37+
float beta = 0.0f;
38+
39+
int batch_count = 1;
40+
long long stride_a = 0;
41+
long long stride_b = 0;
42+
long long stride_c = 0;
43+
44+
DataType input_dtype;
45+
DataType output_dtype;
46+
};
47+
48+
} // namespace infini_train::kernels

infini_train/include/tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
139139
std::shared_ptr<Tensor> View(const std::vector<int64_t> &dims);
140140
std::shared_ptr<Tensor> Contiguous();
141141
// FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor
142-
// class before this can be implemented correctly. The guard in elementwise.cu ensures
142+
// class before this can be implemented correctly. The elementwise broadcast guard ensures
143143
// non-contiguous tensors fall back to the broadcast path until this is resolved.
144144
bool IsContiguous() const;
145145
std::shared_ptr<Tensor> Flatten(int64_t start = 0, int64_t end = -1);
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include "infini_train/include/core/kernel_provider/infiniops/adapter.h"
2+
3+
#include <map>
4+
#include <unordered_map>
5+
6+
#include "glog/logging.h"
7+
8+
#include "infini_train/include/core/runtime/device_guard.h"
9+
10+
namespace infini_train::kernel_provider::infiniops {
11+
12+
namespace {
13+
14+
inline const std::unordered_map<DataType, infini::ops::DataType> kOpsDataTypeMap = {
15+
{DataType::kFLOAT16, infini::ops::DataType::kFloat16}, {DataType::kBFLOAT16, infini::ops::DataType::kBFloat16},
16+
{DataType::kFLOAT32, infini::ops::DataType::kFloat32}, {DataType::kFLOAT64, infini::ops::DataType::kFloat64},
17+
{DataType::kINT8, infini::ops::DataType::kInt8}, {DataType::kINT16, infini::ops::DataType::kInt16},
18+
{DataType::kINT32, infini::ops::DataType::kInt32}, {DataType::kINT64, infini::ops::DataType::kInt64},
19+
{DataType::kUINT8, infini::ops::DataType::kUInt8}, {DataType::kUINT16, infini::ops::DataType::kUInt16},
20+
{DataType::kUINT32, infini::ops::DataType::kUInt32}, {DataType::kUINT64, infini::ops::DataType::kUInt64},
21+
};
22+
23+
inline const std::unordered_map<Device::DeviceType, infini::ops::Device::Type> kOpsDeviceTypeMap = {
24+
{Device::DeviceType::kCUDA, infini::ops::Device::Type::kNvidia},
25+
{Device::DeviceType::kCPU, infini::ops::Device::Type::kCpu},
26+
};
27+
28+
std::map<Device::DeviceType, HandleFactory> &HandleFactories() {
29+
static std::map<Device::DeviceType, HandleFactory> factories;
30+
return factories;
31+
}
32+
33+
} // namespace
34+
35+
void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory) {
36+
CHECK(factory != nullptr);
37+
auto &factories = HandleFactories();
38+
CHECK(!factories.contains(type)) << "InfiniOps handle factory already registered for device type "
39+
<< static_cast<int>(type);
40+
factories.emplace(type, factory);
41+
}
42+
43+
infini::ops::Handle GetHandle(const Device &device) {
44+
auto &factories = HandleFactories();
45+
auto it = factories.find(device.type());
46+
CHECK(it != factories.end()) << "InfiniOps handle factory is not registered for device type "
47+
<< static_cast<int>(device.type());
48+
49+
auto *stream = core::GetDeviceGuardImpl(device.type())->GetStream(device);
50+
return it->second(device, stream);
51+
}
52+
53+
infini::ops::DataType ToOpsDataType(DataType dtype) {
54+
auto it = kOpsDataTypeMap.find(dtype);
55+
if (it == kOpsDataTypeMap.end()) {
56+
LOG(FATAL) << "Unsupported DataType for InfiniOps: " << static_cast<int>(dtype);
57+
__builtin_unreachable();
58+
}
59+
return it->second;
60+
}
61+
62+
infini::ops::Device ToOpsDevice(const Device &device) {
63+
auto it = kOpsDeviceTypeMap.find(device.type());
64+
if (it == kOpsDeviceTypeMap.end()) {
65+
LOG(FATAL) << "Unsupported DeviceType for InfiniOps: " << static_cast<int>(device.type());
66+
__builtin_unreachable();
67+
}
68+
return {it->second, device.index()};
69+
}
70+
71+
std::mutex &InfiniOpsCallMutex() {
72+
static std::mutex mutex;
73+
return mutex;
74+
}
75+
76+
namespace {
77+
infini::ops::Tensor::Strides ComputeContiguousStrides(const std::vector<int64_t> &dims) {
78+
infini::ops::Tensor::Strides strides(dims.size());
79+
if (dims.empty()) {
80+
return strides;
81+
}
82+
strides.back() = 1;
83+
for (int i = static_cast<int>(dims.size()) - 2; i >= 0; --i) {
84+
strides[i] = strides[i + 1] * static_cast<infini::ops::Tensor::Stride>(dims[i + 1]);
85+
}
86+
return strides;
87+
}
88+
89+
infini::ops::Tensor::Shape ToShape(const std::vector<int64_t> &dims) {
90+
infini::ops::Tensor::Shape shape(dims.size());
91+
for (size_t i = 0; i < dims.size(); ++i) { shape[i] = static_cast<infini::ops::Tensor::Size>(dims[i]); }
92+
return shape;
93+
}
94+
95+
infini::ops::Tensor::Strides ToStrides(const std::vector<int64_t> &strides) {
96+
infini::ops::Tensor::Strides ops_strides(strides.size());
97+
for (size_t i = 0; i < strides.size(); ++i) {
98+
ops_strides[i] = static_cast<infini::ops::Tensor::Stride>(strides[i]);
99+
}
100+
return ops_strides;
101+
}
102+
} // namespace
103+
104+
infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor) {
105+
const auto &dims = tensor->Dims();
106+
return {tensor->DataPtr(), ToShape(dims), ToOpsDataType(tensor->Dtype()), ToOpsDevice(tensor->GetDevice()),
107+
ComputeContiguousStrides(dims)};
108+
}
109+
110+
infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device) {
111+
return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ComputeContiguousStrides(dims)};
112+
}
113+
114+
infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device,
115+
const std::vector<int64_t> &strides) {
116+
CHECK_EQ(dims.size(), strides.size());
117+
return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ToStrides(strides)};
118+
}
119+
120+
} // namespace infini_train::kernel_provider::infiniops

0 commit comments

Comments
 (0)