Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
[submodule "third_party/googletest"]
path = third_party/googletest
url = git@github.com:google/googletest.git
[submodule "third_party/InfiniOps"]
path = third_party/InfiniOps
url = git@github.com:InfiniTensor/InfiniOps.git
34 changes: 34 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ option(USE_CUDA "Support NVIDIA CUDA" OFF)
option(PROFILE_MODE "ENABLE PROFILE MODE" OFF)
option(USE_OMP "Use OpenMP as backend for Eigen" ON)
option(USE_NCCL "Build project for distributed running" ON)
option(USE_INFINIOPS "Use InfiniOps as an optional kernel provider" OFF)
option(BUILD_TEST "Build InfiniTrain tests" OFF)

project(infini_train VERSION 0.5.0 LANGUAGES CXX)
Expand Down Expand Up @@ -51,6 +52,32 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen)

include_directories(${PROJECT_SOURCE_DIR})

if(USE_INFINIOPS)
add_compile_definitions(USE_INFINIOPS=1)

set(INFINIOPS_SOURCE_DIR "${PROJECT_SOURCE_DIR}/third_party/InfiniOps")
if(NOT EXISTS "${INFINIOPS_SOURCE_DIR}/CMakeLists.txt")
message(FATAL_ERROR
"USE_INFINIOPS=ON requires InfiniOps under third_party/InfiniOps. "
"Run: git submodule update --init third_party/InfiniOps")
endif()

set(INFINIOPS_WITH_CPU OFF)
if(NOT USE_CUDA)
set(INFINIOPS_WITH_CPU ON)
endif()

set(WITH_CPU ${INFINIOPS_WITH_CPU} CACHE BOOL "Enable InfiniOps CPU backend" FORCE)
set(WITH_NVIDIA ${USE_CUDA} CACHE BOOL "Enable InfiniOps NVIDIA backend" FORCE)
add_subdirectory(${INFINIOPS_SOURCE_DIR} ${CMAKE_BINARY_DIR}/third_party/InfiniOps EXCLUDE_FROM_ALL)
if(NOT TARGET infiniops)
message(FATAL_ERROR "InfiniOps third-party project did not define target `infiniops`")
endif()
if(NOT TARGET InfiniOps::infiniops)
add_library(InfiniOps::infiniops ALIAS infiniops)
endif()
endif()

if(PROFILE_MODE)
add_compile_definitions(PROFILE_MODE=1)
endif()
Expand All @@ -62,9 +89,13 @@ endif()
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
if(NOT USE_INFINIOPS)
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/.*\.cc$")
endif()
if(NOT USE_CUDA)
list(FILTER SRC EXCLUDE REGEX ".*runtime/cuda/.*")
list(FILTER SRC EXCLUDE REGEX ".*ccl/cuda/.*")
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/kernel_provider/infiniops/cuda/.*")
endif()
if(NOT USE_NCCL)
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
Expand Down Expand Up @@ -126,6 +157,9 @@ endif()
# ------------------------------------------------------------------------------

add_library(infini_train STATIC ${SRC})
if(USE_INFINIOPS)
target_link_libraries(infini_train PUBLIC InfiniOps::infiniops)
endif()
target_link_libraries(infini_train
PUBLIC
glog
Expand Down
45 changes: 45 additions & 0 deletions infini_train/include/core/kernel_provider/infiniops/adapter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#pragma once

#include <cstdint>
#include <memory>
#include <mutex>
#include <vector>

#include <handle.h>

#include "data_type.h"
#include "tensor.h"

#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"

namespace infini_train {
class Tensor;
} // namespace infini_train

namespace infini_train::core {
class Stream;
} // namespace infini_train::core

namespace infini_train::kernel_provider::infiniops {

infini::ops::DataType ToOpsDataType(DataType dtype);

infini::ops::Device ToOpsDevice(const Device &device);

std::mutex &InfiniOpsCallMutex();

using HandleFactory = infini::ops::Handle (*)(const Device &device, core::Stream *stream);

void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory);

infini::ops::Handle GetHandle(const Device &device);

infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor);

infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device);

infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device,
const std::vector<int64_t> &strides);

} // namespace infini_train::kernel_provider::infiniops
50 changes: 50 additions & 0 deletions infini_train/include/core/kernel_provider/infiniops_registry.h
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是出于什么原因要单独写一套 registry,而不能直接复用 InfiniTrain 原有的注册表呢?

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once

#include <map>
#include <string>
#include <utility>

#include "glog/logging.h"

#include "infini_train/include/device.h"
#include "infini_train/include/dispatcher.h"

namespace infini_train::kernel_provider {

using KeyT = std::pair<Device::DeviceType, std::string>;

class InfiniOpsRegistry {
public:
static InfiniOpsRegistry &Instance() {
static InfiniOpsRegistry instance;
return instance;
}

const KernelFunction *Lookup(const std::string &kernel_name) const {
auto it = name_to_kernel_map_.find(kernel_name);
return it == name_to_kernel_map_.end() ? nullptr : &it->second;
}

template <typename FuncT> void Register(const std::string &kernel_name, FuncT &&kernel) {
CHECK(!name_to_kernel_map_.contains(kernel_name)) << "InfiniOps kernel already registered: " << kernel_name;
name_to_kernel_map_.emplace(kernel_name, kernel);
}

private:
std::map<std::string, KernelFunction> name_to_kernel_map_;
};

// Bridge functions used by Dispatcher::GetKernel. Implemented in
// infiniops_registry.cc; declared here for users that already include
// the full registry header (e.g. unit tests).
bool InfiniOpsEnabled();
bool InfiniOpsEnabled(const KeyT &key);
const KernelFunction *LookupInfiniOpsKernel(const KeyT &key);

} // namespace infini_train::kernel_provider

#define REGISTER_INFINIOPS_KERNEL(kernel_name, kernel_func) \
static const bool _##kernel_name##_infiniops_registered##__COUNTER__ = []() { \
infini_train::kernel_provider::InfiniOpsRegistry::Instance().Register(#kernel_name, kernel_func); \
return true; \
}();
17 changes: 17 additions & 0 deletions infini_train/include/dispatcher.h
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不应该给 infinops 开额外分支,之前接沐曦 kernel 这块是不需要动的。

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <map>
#include <string>
#include <type_traits>
#include <utility>

Expand Down Expand Up @@ -47,6 +48,11 @@ class KernelFunction {
void *func_ptr_ = nullptr;
};

namespace kernel_provider {
bool InfiniOpsEnabled(const std::pair<Device::DeviceType, std::string> &key);
const KernelFunction *LookupInfiniOpsKernel(const std::pair<Device::DeviceType, std::string> &key);
} // namespace kernel_provider

class Dispatcher {
public:
using KeyT = std::pair<Device::DeviceType, std::string>;
Expand All @@ -57,6 +63,17 @@ class Dispatcher {
}

const KernelFunction &GetKernel(KeyT key) const {
if (kernel_provider::InfiniOpsEnabled(key)) {
if (const auto *kernel = kernel_provider::LookupInfiniOpsKernel(key)) {
#ifdef PROFILE_MODE
SetProfileContext(key.second, key.first);
#endif
return *kernel;
}
LOG(WARNING) << "InfiniOps kernel enabled but not registered: " << key.second
<< " on device: " << static_cast<int>(key.first) << "; falling back to default kernel";
}

CHECK(key_to_kernel_map_.contains(key))
<< "Kernel not found: " << key.second << " on device: " << static_cast<int>(key.first);
#ifdef PROFILE_MODE
Expand Down
48 changes: 48 additions & 0 deletions infini_train/include/kernels/common/gemm.h
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件内容没什么问题,但不适合放到 include 里作为公共头文件暴露,先放 infini_train/src/kernels/common 里吧

Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once

#include "infini_train/include/datatype.h"
#include "infini_train/include/device.h"

namespace infini_train::kernels {

enum class GemmTranspose : int {
kNoTranspose = 0,
kTranspose = 1,
};

/**
* Parameter bundle for a single GEMM call:
* C = alpha * op(A) * op(B) + beta * C
*
* batch_count == 1 describes a non-batched GEMM. batch_count > 1 describes a
* strided-batched GEMM. When batch_count == 1, stride_a/b/c are unused and must
* be left at 0.
*/
struct GemmParams {
GemmTranspose trans_a = GemmTranspose::kNoTranspose;
GemmTranspose trans_b = GemmTranspose::kNoTranspose;

int m = 0; // rows of op(A) and C
int n = 0; // cols of op(B) and C
int k = 0; // cols of op(A) == rows of op(B)

const void *A = nullptr;
int lda = 0;
const void *B = nullptr;
int ldb = 0;
void *C = nullptr;
int ldc = 0;

float alpha = 1.0f;
float beta = 0.0f;

int batch_count = 1;
long long stride_a = 0;
long long stride_b = 0;
long long stride_c = 0;

DataType input_dtype;
DataType output_dtype;
};

} // namespace infini_train::kernels
2 changes: 1 addition & 1 deletion infini_train/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
std::shared_ptr<Tensor> View(const std::vector<int64_t> &dims);
std::shared_ptr<Tensor> Contiguous();
// FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor
// class before this can be implemented correctly. The guard in elementwise.cu ensures
// class before this can be implemented correctly. The elementwise broadcast guard ensures
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用改吧

// non-contiguous tensors fall back to the broadcast path until this is resolved.
bool IsContiguous() const;
std::shared_ptr<Tensor> Flatten(int64_t start = 0, int64_t end = -1);
Expand Down
120 changes: 120 additions & 0 deletions infini_train/src/core/kernel_provider/infiniops/adapter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include "infini_train/include/core/kernel_provider/infiniops/adapter.h"

#include <map>
#include <unordered_map>

#include "glog/logging.h"

#include "infini_train/include/core/runtime/device_guard.h"

namespace infini_train::kernel_provider::infiniops {

namespace {

inline const std::unordered_map<DataType, infini::ops::DataType> kOpsDataTypeMap = {
{DataType::kFLOAT16, infini::ops::DataType::kFloat16}, {DataType::kBFLOAT16, infini::ops::DataType::kBFloat16},
{DataType::kFLOAT32, infini::ops::DataType::kFloat32}, {DataType::kFLOAT64, infini::ops::DataType::kFloat64},
{DataType::kINT8, infini::ops::DataType::kInt8}, {DataType::kINT16, infini::ops::DataType::kInt16},
{DataType::kINT32, infini::ops::DataType::kInt32}, {DataType::kINT64, infini::ops::DataType::kInt64},
{DataType::kUINT8, infini::ops::DataType::kUInt8}, {DataType::kUINT16, infini::ops::DataType::kUInt16},
{DataType::kUINT32, infini::ops::DataType::kUInt32}, {DataType::kUINT64, infini::ops::DataType::kUInt64},
};

inline const std::unordered_map<Device::DeviceType, infini::ops::Device::Type> kOpsDeviceTypeMap = {
{Device::DeviceType::kCUDA, infini::ops::Device::Type::kNvidia},
{Device::DeviceType::kCPU, infini::ops::Device::Type::kCpu},
};

std::map<Device::DeviceType, HandleFactory> &HandleFactories() {
static std::map<Device::DeviceType, HandleFactory> factories;
return factories;
}

} // namespace

void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory) {
CHECK(factory != nullptr);
auto &factories = HandleFactories();
CHECK(!factories.contains(type)) << "InfiniOps handle factory already registered for device type "
<< static_cast<int>(type);
factories.emplace(type, factory);
}

infini::ops::Handle GetHandle(const Device &device) {
auto &factories = HandleFactories();
auto it = factories.find(device.type());
CHECK(it != factories.end()) << "InfiniOps handle factory is not registered for device type "
<< static_cast<int>(device.type());

auto *stream = core::GetDeviceGuardImpl(device.type())->GetStream(device);
return it->second(device, stream);
}

infini::ops::DataType ToOpsDataType(DataType dtype) {
auto it = kOpsDataTypeMap.find(dtype);
if (it == kOpsDataTypeMap.end()) {
LOG(FATAL) << "Unsupported DataType for InfiniOps: " << static_cast<int>(dtype);
__builtin_unreachable();
}
return it->second;
}

infini::ops::Device ToOpsDevice(const Device &device) {
auto it = kOpsDeviceTypeMap.find(device.type());
if (it == kOpsDeviceTypeMap.end()) {
LOG(FATAL) << "Unsupported DeviceType for InfiniOps: " << static_cast<int>(device.type());
__builtin_unreachable();
}
return {it->second, device.index()};
}

std::mutex &InfiniOpsCallMutex() {
static std::mutex mutex;
return mutex;
}

namespace {
infini::ops::Tensor::Strides ComputeContiguousStrides(const std::vector<int64_t> &dims) {
infini::ops::Tensor::Strides strides(dims.size());
if (dims.empty()) {
return strides;
}
strides.back() = 1;
for (int i = static_cast<int>(dims.size()) - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * static_cast<infini::ops::Tensor::Stride>(dims[i + 1]);
}
return strides;
}

infini::ops::Tensor::Shape ToShape(const std::vector<int64_t> &dims) {
infini::ops::Tensor::Shape shape(dims.size());
for (size_t i = 0; i < dims.size(); ++i) { shape[i] = static_cast<infini::ops::Tensor::Size>(dims[i]); }
return shape;
}

infini::ops::Tensor::Strides ToStrides(const std::vector<int64_t> &strides) {
infini::ops::Tensor::Strides ops_strides(strides.size());
for (size_t i = 0; i < strides.size(); ++i) {
ops_strides[i] = static_cast<infini::ops::Tensor::Stride>(strides[i]);
}
return ops_strides;
}
} // namespace

infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor) {
const auto &dims = tensor->Dims();
return {tensor->DataPtr(), ToShape(dims), ToOpsDataType(tensor->Dtype()), ToOpsDevice(tensor->GetDevice()),
ComputeContiguousStrides(dims)};
}

infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device) {
return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ComputeContiguousStrides(dims)};
}

infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device,
const std::vector<int64_t> &strides) {
CHECK_EQ(dims.size(), strides.size());
return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ToStrides(strides)};
}

} // namespace infini_train::kernel_provider::infiniops
Loading
Loading