|
| 1 | +#include "infini_train/src/core/ccl/maca/mccl_impl.h" |
| 2 | + |
| 3 | +#include <mccl.h> |
| 4 | +#include <vector> |
| 5 | + |
| 6 | +#include "glog/logging.h" |
| 7 | + |
| 8 | +#include "infini_train/include/common/maca/common_maca.h" |
| 9 | +#include "infini_train/include/core/runtime/runtime_common.h" |
| 10 | +#include "infini_train/include/device.h" |
| 11 | + |
| 12 | +#include "infini_train/src/core/ccl/maca/mccl_common.h" |
| 13 | +#include "infini_train/src/core/runtime/maca/maca_runtime_common.h" |
| 14 | + |
| 15 | +namespace infini_train::core::maca { |
| 16 | +namespace { |
| 17 | + |
| 18 | +inline const std::unordered_map<DataType, mcclDataType_t> kMcclDtypeMap = { |
| 19 | + {DataType::kUINT8, mcclUint8}, {DataType::kINT8, mcclInt8}, {DataType::kUINT32, mcclUint32}, |
| 20 | + {DataType::kINT32, mcclInt32}, {DataType::kUINT64, mcclUint64}, {DataType::kINT64, mcclInt64}, |
| 21 | + {DataType::kBFLOAT16, mcclBfloat16}, {DataType::kFLOAT16, mcclHalf}, {DataType::kFLOAT32, mcclFloat32}, |
| 22 | + {DataType::kFLOAT64, mcclFloat64}, |
| 23 | +}; |
| 24 | + |
| 25 | +inline const std::unordered_map<nn::parallel::function::ReduceOpType, mcclRedOp_t> kMcclReduceOpMap = { |
| 26 | + {nn::parallel::function::ReduceOpType::kSum, mcclSum}, {nn::parallel::function::ReduceOpType::kProd, mcclProd}, |
| 27 | + {nn::parallel::function::ReduceOpType::kMin, mcclMin}, {nn::parallel::function::ReduceOpType::kMax, mcclMax}, |
| 28 | + {nn::parallel::function::ReduceOpType::kAvg, mcclAvg}, |
| 29 | +}; |
| 30 | + |
| 31 | +inline mcclComm_t GetMcclComm(const CclComm *comm) { |
| 32 | + auto *mccl_comm = dynamic_cast<const McclComm *>(comm); |
| 33 | + CHECK_NOTNULL(mccl_comm); |
| 34 | + return mccl_comm->mccl_comm(); |
| 35 | +} |
| 36 | + |
| 37 | +inline void SetMcclComm(CclComm *comm, mcclComm_t mccl_comm) { |
| 38 | + auto *typed_comm = dynamic_cast<McclComm *>(comm); |
| 39 | + CHECK_NOTNULL(typed_comm); |
| 40 | + typed_comm->set_mccl_comm(mccl_comm); |
| 41 | +} |
| 42 | + |
| 43 | +inline const mcclUniqueId &GetMcclUniqueId(const CclUniqueId &unique_id) { |
| 44 | + auto *mccl_unique_id = dynamic_cast<const McclUniqueId *>(&unique_id); |
| 45 | + CHECK_NOTNULL(mccl_unique_id); |
| 46 | + return *mccl_unique_id->mccl_unique_id(); |
| 47 | +} |
| 48 | + |
| 49 | +inline mcStream_t GetMacaStream(Stream *stream) { |
| 50 | + auto *maca_stream = dynamic_cast<MacaStream *>(stream); |
| 51 | + CHECK_NOTNULL(maca_stream); |
| 52 | + return maca_stream->maca_stream(); |
| 53 | +} |
| 54 | + |
| 55 | +} // namespace |
| 56 | + |
| 57 | +Device::DeviceType McclImpl::Type() const { return Device::DeviceType::kMACA; } |
| 58 | + |
| 59 | +void McclImpl::GroupStart() const { MCCL_CHECK(mcclGroupStart()); } |
| 60 | + |
| 61 | +void McclImpl::GroupEnd() const { MCCL_CHECK(mcclGroupEnd()); } |
| 62 | + |
| 63 | +void McclImpl::GetAsyncError(const CclComm *comm, CclStatus *async_error) const { |
| 64 | + mcclResult_t mccl_async_error = mcclSuccess; |
| 65 | + MCCL_CHECK(mcclCommGetAsyncError(GetMcclComm(comm), &mccl_async_error)); |
| 66 | + if (async_error != nullptr) { |
| 67 | + *async_error = (mccl_async_error == mcclSuccess) ? CclStatus::kSuccess : CclStatus::kError; |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +void McclImpl::GetUniqueId(CclUniqueId **unique_id) const { |
| 72 | + CHECK_NOTNULL(unique_id); |
| 73 | + if (*unique_id == nullptr) { |
| 74 | + *unique_id = new McclUniqueId(); |
| 75 | + } |
| 76 | + auto *mccl_unique_id = dynamic_cast<McclUniqueId *>(*unique_id); |
| 77 | + CHECK_NOTNULL(mccl_unique_id); |
| 78 | + MCCL_CHECK(mcclGetUniqueId(mccl_unique_id->mccl_unique_id())); |
| 79 | +} |
| 80 | + |
| 81 | +void McclImpl::CommInitAll(CclComm **comms, int ndev, const int *devlist) const { |
| 82 | + CHECK_NOTNULL(comms); |
| 83 | + CHECK_GT(ndev, 0); |
| 84 | + CHECK_NOTNULL(devlist); |
| 85 | + |
| 86 | + std::vector<mcclComm_t> mccl_comms(static_cast<size_t>(ndev), nullptr); |
| 87 | + MCCL_CHECK(mcclCommInitAll(mccl_comms.data(), ndev, devlist)); |
| 88 | + for (int i = 0; i < ndev; ++i) { |
| 89 | + if (comms[i] == nullptr) { |
| 90 | + comms[i] = new McclComm(); |
| 91 | + } |
| 92 | + SetMcclComm(comms[i], mccl_comms[static_cast<size_t>(i)]); |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +void McclImpl::CommInitRank(CclComm **comm, int nranks, const CclUniqueId &unique_id, int rank) const { |
| 97 | + CHECK_NOTNULL(comm); |
| 98 | + CHECK_GT(nranks, 0); |
| 99 | + |
| 100 | + if (*comm == nullptr) { |
| 101 | + *comm = new McclComm(); |
| 102 | + } |
| 103 | + |
| 104 | + mcclComm_t mccl_comm = nullptr; |
| 105 | + MCCL_CHECK(mcclCommInitRank(&mccl_comm, nranks, GetMcclUniqueId(unique_id), rank)); |
| 106 | + SetMcclComm(*comm, mccl_comm); |
| 107 | +} |
| 108 | + |
| 109 | +void McclImpl::CommDestroy(CclComm *comm) const { |
| 110 | + if (comm == nullptr) { |
| 111 | + return; |
| 112 | + } |
| 113 | + MCCL_CHECK(mcclCommDestroy(GetMcclComm(comm))); |
| 114 | + SetMcclComm(comm, nullptr); |
| 115 | +} |
| 116 | + |
| 117 | +void McclImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, |
| 118 | + nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const { |
| 119 | + MCCL_CHECK(mcclAllReduce(sendbuff, recvbuff, count, kMcclDtypeMap.at(dtype), kMcclReduceOpMap.at(reduce_op), |
| 120 | + GetMcclComm(comm), GetMacaStream(stream))); |
| 121 | +} |
| 122 | + |
| 123 | +void McclImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, int root, |
| 124 | + const CclComm *comm, Stream *stream) const { |
| 125 | + MCCL_CHECK(mcclBroadcast(sendbuff, recvbuff, count, kMcclDtypeMap.at(dtype), root, GetMcclComm(comm), |
| 126 | + GetMacaStream(stream))); |
| 127 | +} |
| 128 | + |
| 129 | +void McclImpl::Reduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, |
| 130 | + nn::parallel::function::ReduceOpType reduce_op, int root, const CclComm *comm, |
| 131 | + Stream *stream) const { |
| 132 | + MCCL_CHECK(mcclReduce(sendbuff, recvbuff, count, kMcclDtypeMap.at(dtype), kMcclReduceOpMap.at(reduce_op), root, |
| 133 | + GetMcclComm(comm), GetMacaStream(stream))); |
| 134 | +} |
| 135 | + |
| 136 | +void McclImpl::AllGather(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm, |
| 137 | + Stream *stream) const { |
| 138 | + MCCL_CHECK( |
| 139 | + mcclAllGather(sendbuff, recvbuff, count, kMcclDtypeMap.at(dtype), GetMcclComm(comm), GetMacaStream(stream))); |
| 140 | +} |
| 141 | + |
| 142 | +void McclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_count, DataType dtype, |
| 143 | + nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, |
| 144 | + Stream *stream) const { |
| 145 | + MCCL_CHECK(mcclReduceScatter(sendbuff, recvbuff, recv_count, kMcclDtypeMap.at(dtype), |
| 146 | + kMcclReduceOpMap.at(reduce_op), GetMcclComm(comm), GetMacaStream(stream))); |
| 147 | +} |
| 148 | + |
| 149 | +void McclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, |
| 150 | + Stream *stream) const { |
| 151 | + MCCL_CHECK(mcclSend(buff, count, kMcclDtypeMap.at(dtype), peer, GetMcclComm(comm), GetMacaStream(stream))); |
| 152 | +} |
| 153 | + |
| 154 | +void McclImpl::Recv(void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const { |
| 155 | + MCCL_CHECK(mcclRecv(buff, count, kMcclDtypeMap.at(dtype), peer, GetMcclComm(comm), GetMacaStream(stream))); |
| 156 | +} |
| 157 | + |
| 158 | +INFINI_TRAIN_REGISTER_CCL_IMPL(Device::DeviceType::kMACA, McclImpl) |
| 159 | + |
| 160 | +} // namespace infini_train::core::maca |
0 commit comments