|
| 1 | +#include "infini_train/include/core/kernel_provider/infiniops/adapter.h" |
| 2 | + |
| 3 | +#include <map> |
| 4 | +#include <unordered_map> |
| 5 | + |
| 6 | +#include "glog/logging.h" |
| 7 | + |
| 8 | +#include "infini_train/include/core/runtime/device_guard.h" |
| 9 | + |
| 10 | +namespace infini_train::kernel_provider::infiniops { |
| 11 | + |
| 12 | +namespace { |
| 13 | + |
| 14 | +inline const std::unordered_map<DataType, infini::ops::DataType> kOpsDataTypeMap = { |
| 15 | + {DataType::kFLOAT16, infini::ops::DataType::kFloat16}, {DataType::kBFLOAT16, infini::ops::DataType::kBFloat16}, |
| 16 | + {DataType::kFLOAT32, infini::ops::DataType::kFloat32}, {DataType::kFLOAT64, infini::ops::DataType::kFloat64}, |
| 17 | + {DataType::kINT8, infini::ops::DataType::kInt8}, {DataType::kINT16, infini::ops::DataType::kInt16}, |
| 18 | + {DataType::kINT32, infini::ops::DataType::kInt32}, {DataType::kINT64, infini::ops::DataType::kInt64}, |
| 19 | + {DataType::kUINT8, infini::ops::DataType::kUInt8}, {DataType::kUINT16, infini::ops::DataType::kUInt16}, |
| 20 | + {DataType::kUINT32, infini::ops::DataType::kUInt32}, {DataType::kUINT64, infini::ops::DataType::kUInt64}, |
| 21 | +}; |
| 22 | + |
| 23 | +inline const std::unordered_map<Device::DeviceType, infini::ops::Device::Type> kOpsDeviceTypeMap = { |
| 24 | + {Device::DeviceType::kCUDA, infini::ops::Device::Type::kNvidia}, |
| 25 | + {Device::DeviceType::kCPU, infini::ops::Device::Type::kCpu}, |
| 26 | +}; |
| 27 | + |
| 28 | +std::map<Device::DeviceType, HandleFactory> &HandleFactories() { |
| 29 | + static std::map<Device::DeviceType, HandleFactory> factories; |
| 30 | + return factories; |
| 31 | +} |
| 32 | + |
| 33 | +} // namespace |
| 34 | + |
| 35 | +void RegisterHandleFactory(Device::DeviceType type, HandleFactory factory) { |
| 36 | + CHECK(factory != nullptr); |
| 37 | + auto &factories = HandleFactories(); |
| 38 | + CHECK(!factories.contains(type)) << "InfiniOps handle factory already registered for device type " |
| 39 | + << static_cast<int>(type); |
| 40 | + factories.emplace(type, factory); |
| 41 | +} |
| 42 | + |
| 43 | +infini::ops::Handle GetHandle(const Device &device) { |
| 44 | + auto &factories = HandleFactories(); |
| 45 | + auto it = factories.find(device.type()); |
| 46 | + CHECK(it != factories.end()) << "InfiniOps handle factory is not registered for device type " |
| 47 | + << static_cast<int>(device.type()); |
| 48 | + |
| 49 | + auto *stream = core::GetDeviceGuardImpl(device.type())->GetStream(device); |
| 50 | + return it->second(device, stream); |
| 51 | +} |
| 52 | + |
| 53 | +infini::ops::DataType ToOpsDataType(DataType dtype) { |
| 54 | + auto it = kOpsDataTypeMap.find(dtype); |
| 55 | + if (it == kOpsDataTypeMap.end()) { |
| 56 | + LOG(FATAL) << "Unsupported DataType for InfiniOps: " << static_cast<int>(dtype); |
| 57 | + __builtin_unreachable(); |
| 58 | + } |
| 59 | + return it->second; |
| 60 | +} |
| 61 | + |
| 62 | +infini::ops::Device ToOpsDevice(const Device &device) { |
| 63 | + auto it = kOpsDeviceTypeMap.find(device.type()); |
| 64 | + if (it == kOpsDeviceTypeMap.end()) { |
| 65 | + LOG(FATAL) << "Unsupported DeviceType for InfiniOps: " << static_cast<int>(device.type()); |
| 66 | + __builtin_unreachable(); |
| 67 | + } |
| 68 | + return {it->second, device.index()}; |
| 69 | +} |
| 70 | + |
| 71 | +std::mutex &InfiniOpsCallMutex() { |
| 72 | + static std::mutex mutex; |
| 73 | + return mutex; |
| 74 | +} |
| 75 | + |
| 76 | +namespace { |
| 77 | +infini::ops::Tensor::Strides ComputeContiguousStrides(const std::vector<int64_t> &dims) { |
| 78 | + infini::ops::Tensor::Strides strides(dims.size()); |
| 79 | + if (dims.empty()) { |
| 80 | + return strides; |
| 81 | + } |
| 82 | + strides.back() = 1; |
| 83 | + for (int i = static_cast<int>(dims.size()) - 2; i >= 0; --i) { |
| 84 | + strides[i] = strides[i + 1] * static_cast<infini::ops::Tensor::Stride>(dims[i + 1]); |
| 85 | + } |
| 86 | + return strides; |
| 87 | +} |
| 88 | + |
| 89 | +infini::ops::Tensor::Shape ToShape(const std::vector<int64_t> &dims) { |
| 90 | + infini::ops::Tensor::Shape shape(dims.size()); |
| 91 | + for (size_t i = 0; i < dims.size(); ++i) { shape[i] = static_cast<infini::ops::Tensor::Size>(dims[i]); } |
| 92 | + return shape; |
| 93 | +} |
| 94 | + |
| 95 | +infini::ops::Tensor::Strides ToStrides(const std::vector<int64_t> &strides) { |
| 96 | + infini::ops::Tensor::Strides ops_strides(strides.size()); |
| 97 | + for (size_t i = 0; i < strides.size(); ++i) { |
| 98 | + ops_strides[i] = static_cast<infini::ops::Tensor::Stride>(strides[i]); |
| 99 | + } |
| 100 | + return ops_strides; |
| 101 | +} |
| 102 | +} // namespace |
| 103 | + |
| 104 | +infini::ops::Tensor ToOpsTensor(const std::shared_ptr<Tensor> &tensor) { |
| 105 | + const auto &dims = tensor->Dims(); |
| 106 | + return {tensor->DataPtr(), ToShape(dims), ToOpsDataType(tensor->Dtype()), ToOpsDevice(tensor->GetDevice()), |
| 107 | + ComputeContiguousStrides(dims)}; |
| 108 | +} |
| 109 | + |
| 110 | +infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device) { |
| 111 | + return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ComputeContiguousStrides(dims)}; |
| 112 | +} |
| 113 | + |
| 114 | +infini::ops::Tensor ToOpsTensor(void *data, const std::vector<int64_t> &dims, DataType dtype, const Device &device, |
| 115 | + const std::vector<int64_t> &strides) { |
| 116 | + CHECK_EQ(dims.size(), strides.size()); |
| 117 | + return {data, ToShape(dims), ToOpsDataType(dtype), ToOpsDevice(device), ToStrides(strides)}; |
| 118 | +} |
| 119 | + |
| 120 | +} // namespace infini_train::kernel_provider::infiniops |
0 commit comments