Skip to content

Commit eacbaba

Browse files
committed
feat(maca): add MCCL collective backend and remaining kernels
Complete the MACA backend by adding the MCCL-based collective implementation and the rest of the kernel library, enabling multi-card training (DDP) and larger models such as gpt2. - core/ccl/maca: McclComm / McclUniqueId wrappers around mcclComm_t / mcclUniqueId, with Size/Data/Load tied to sizeof(mcclUniqueId) so that the existing backend-agnostic WriteUniqueIdFile / ReadUniqueIdFile unique-id exchange path works unchanged. McclImpl mirrors NcclImpl with kMcclDtypeMap / kMcclReduceOpMap and routes every collective through mcStream_t via dynamic_cast<MacaStream *>. Registered via INFINI_TRAIN_REGISTER_CCL_IMPL(kMACA, McclImpl), so ProcessGroup backed by Device::DeviceType::kMACA transparently picks up MCCL without any ProcessGroupMCCL subclass. - kernels/maca: mechanically port the remaining 15 kernels (cast, comm, concat, cross_entropy, embedding, gather, layernorm, outer, reduction, slice, softmax, split, stack, transform, vocab_parallel_cross_entropy) from their .cu counterparts, including the cub_compat path for cross_entropy/softmax/reduction, mcblas GEMM / GemmEx calls in outer, and __maca_bfloat16 / __half typing throughout.
1 parent 1757509 commit eacbaba

21 files changed

+3364
-3
lines changed

infini_train/include/common/maca/common_maca.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

3+
#include <mcblas/mcblas.h>
34
#include <mcr/mc_runtime.h>
45
#include <mcr/mc_runtime_api.h>
5-
#include <mcblas/mcblas.h>
66

77
#ifdef USE_MCCL
88
#include <mccl.h>

infini_train/include/common/maca/kernel_helper.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,11 @@ template <typename DST, typename SRC> __host__ __device__ DST Cast(SRC &&x) {
6565
// Fallback for all other conversions
6666
if constexpr (std::is_same_v<DST_base, __maca_bfloat16> || std::is_same_v<DST_base, __half>
6767
|| std::is_same_v<SRC_base, __maca_bfloat16> || std::is_same_v<SRC_base, __half>) {
68-
return (DST)(static_cast<float>(std::forward<SRC>(x)));;
68+
return (DST)(static_cast<float>(std::forward<SRC>(x)));
69+
;
6970
} else {
70-
return static_cast<DST>(std::forward<SRC>(x));;
71+
return static_cast<DST>(std::forward<SRC>(x));
72+
;
7173
}
7274
}
7375

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include "infini_train/src/core/ccl/maca/mccl_common.h"
2+
3+
#include <cstring>
4+
5+
#include "glog/logging.h"
6+
7+
namespace infini_train::core {
8+
9+
McclComm::McclComm() = default;
10+
11+
McclComm::McclComm(mcclComm_t comm) : mccl_comm_(comm) {}
12+
13+
mcclComm_t McclComm::mccl_comm() const { return mccl_comm_; }
14+
15+
void McclComm::set_mccl_comm(mcclComm_t comm) { mccl_comm_ = comm; }
16+
17+
McclUniqueId::McclUniqueId() = default;
18+
19+
McclUniqueId::McclUniqueId(const mcclUniqueId &id) : id_(id) {}
20+
21+
size_t McclUniqueId::Size() const { return sizeof(id_); }
22+
23+
const void *McclUniqueId::Data() const { return &id_; }
24+
25+
void McclUniqueId::Load(const void *src, size_t size) {
26+
CHECK_NOTNULL(src);
27+
CHECK_EQ(size, sizeof(id_));
28+
std::memcpy(&id_, src, sizeof(id_));
29+
}
30+
31+
mcclUniqueId *McclUniqueId::mccl_unique_id() { return &id_; }
32+
33+
const mcclUniqueId *McclUniqueId::mccl_unique_id() const { return &id_; }
34+
35+
} // namespace infini_train::core
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include <mccl.h>
4+
5+
#include "infini_train/include/core/ccl/ccl_common.h"
6+
7+
namespace infini_train::core {
8+
9+
class McclComm final : public CclComm {
10+
public:
11+
McclComm();
12+
explicit McclComm(mcclComm_t comm);
13+
14+
mcclComm_t mccl_comm() const;
15+
void set_mccl_comm(mcclComm_t comm);
16+
17+
private:
18+
mcclComm_t mccl_comm_ = nullptr;
19+
};
20+
21+
class McclUniqueId final : public CclUniqueId {
22+
public:
23+
McclUniqueId();
24+
explicit McclUniqueId(const mcclUniqueId &id);
25+
26+
size_t Size() const override;
27+
const void *Data() const override;
28+
void Load(const void *src, size_t size) override;
29+
30+
mcclUniqueId *mccl_unique_id();
31+
const mcclUniqueId *mccl_unique_id() const;
32+
33+
private:
34+
mcclUniqueId id_;
35+
};
36+
37+
} // namespace infini_train::core
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#include "infini_train/src/core/ccl/maca/mccl_impl.h"
2+
3+
#include <mccl.h>
4+
#include <vector>
5+
6+
#include "glog/logging.h"
7+
8+
#include "infini_train/include/common/maca/common_maca.h"
9+
#include "infini_train/include/core/runtime/runtime_common.h"
10+
#include "infini_train/include/device.h"
11+
12+
#include "infini_train/src/core/ccl/maca/mccl_common.h"
13+
#include "infini_train/src/core/runtime/maca/maca_runtime_common.h"
14+
15+
namespace infini_train::core::maca {
16+
namespace {
17+
18+
inline const std::unordered_map<DataType, mcclDataType_t> kMcclDtypeMap = {
19+
{DataType::kUINT8, mcclUint8}, {DataType::kINT8, mcclInt8}, {DataType::kUINT32, mcclUint32},
20+
{DataType::kINT32, mcclInt32}, {DataType::kUINT64, mcclUint64}, {DataType::kINT64, mcclInt64},
21+
{DataType::kBFLOAT16, mcclBfloat16}, {DataType::kFLOAT16, mcclHalf}, {DataType::kFLOAT32, mcclFloat32},
22+
{DataType::kFLOAT64, mcclFloat64},
23+
};
24+
25+
inline const std::unordered_map<nn::parallel::function::ReduceOpType, mcclRedOp_t> kMcclReduceOpMap = {
26+
{nn::parallel::function::ReduceOpType::kSum, mcclSum}, {nn::parallel::function::ReduceOpType::kProd, mcclProd},
27+
{nn::parallel::function::ReduceOpType::kMin, mcclMin}, {nn::parallel::function::ReduceOpType::kMax, mcclMax},
28+
{nn::parallel::function::ReduceOpType::kAvg, mcclAvg},
29+
};
30+
31+
inline mcclComm_t GetMcclComm(const CclComm *comm) {
32+
auto *mccl_comm = dynamic_cast<const McclComm *>(comm);
33+
CHECK_NOTNULL(mccl_comm);
34+
return mccl_comm->mccl_comm();
35+
}
36+
37+
inline void SetMcclComm(CclComm *comm, mcclComm_t mccl_comm) {
38+
auto *typed_comm = dynamic_cast<McclComm *>(comm);
39+
CHECK_NOTNULL(typed_comm);
40+
typed_comm->set_mccl_comm(mccl_comm);
41+
}
42+
43+
inline const mcclUniqueId &GetMcclUniqueId(const CclUniqueId &unique_id) {
44+
auto *mccl_unique_id = dynamic_cast<const McclUniqueId *>(&unique_id);
45+
CHECK_NOTNULL(mccl_unique_id);
46+
return *mccl_unique_id->mccl_unique_id();
47+
}
48+
49+
inline mcStream_t GetMacaStream(Stream *stream) {
50+
auto *maca_stream = dynamic_cast<MacaStream *>(stream);
51+
CHECK_NOTNULL(maca_stream);
52+
return maca_stream->maca_stream();
53+
}
54+
55+
} // namespace
56+
57+
Device::DeviceType McclImpl::Type() const { return Device::DeviceType::kMACA; }
58+
59+
void McclImpl::GroupStart() const { MCCL_CHECK(mcclGroupStart()); }
60+
61+
void McclImpl::GroupEnd() const { MCCL_CHECK(mcclGroupEnd()); }
62+
63+
void McclImpl::GetAsyncError(const CclComm *comm, CclStatus *async_error) const {
64+
mcclResult_t mccl_async_error = mcclSuccess;
65+
MCCL_CHECK(mcclCommGetAsyncError(GetMcclComm(comm), &mccl_async_error));
66+
if (async_error != nullptr) {
67+
*async_error = (mccl_async_error == mcclSuccess) ? CclStatus::kSuccess : CclStatus::kError;
68+
}
69+
}
70+
71+
void McclImpl::GetUniqueId(CclUniqueId **unique_id) const {
72+
CHECK_NOTNULL(unique_id);
73+
if (*unique_id == nullptr) {
74+
*unique_id = new McclUniqueId();
75+
}
76+
auto *mccl_unique_id = dynamic_cast<McclUniqueId *>(*unique_id);
77+
CHECK_NOTNULL(mccl_unique_id);
78+
MCCL_CHECK(mcclGetUniqueId(mccl_unique_id->mccl_unique_id()));
79+
}
80+
81+
void McclImpl::CommInitAll(CclComm **comms, int ndev, const int *devlist) const {
82+
CHECK_NOTNULL(comms);
83+
CHECK_GT(ndev, 0);
84+
CHECK_NOTNULL(devlist);
85+
86+
std::vector<mcclComm_t> mccl_comms(static_cast<size_t>(ndev), nullptr);
87+
MCCL_CHECK(mcclCommInitAll(mccl_comms.data(), ndev, devlist));
88+
for (int i = 0; i < ndev; ++i) {
89+
if (comms[i] == nullptr) {
90+
comms[i] = new McclComm();
91+
}
92+
SetMcclComm(comms[i], mccl_comms[static_cast<size_t>(i)]);
93+
}
94+
}
95+
96+
void McclImpl::CommInitRank(CclComm **comm, int nranks, const CclUniqueId &unique_id, int rank) const {
97+
CHECK_NOTNULL(comm);
98+
CHECK_GT(nranks, 0);
99+
100+
if (*comm == nullptr) {
101+
*comm = new McclComm();
102+
}
103+
104+
mcclComm_t mccl_comm = nullptr;
105+
MCCL_CHECK(mcclCommInitRank(&mccl_comm, nranks, GetMcclUniqueId(unique_id), rank));
106+
SetMcclComm(*comm, mccl_comm);
107+
}
108+
109+
void McclImpl::CommDestroy(CclComm *comm) const {
110+
if (comm == nullptr) {
111+
return;
112+
}
113+
MCCL_CHECK(mcclCommDestroy(GetMcclComm(comm)));
114+
SetMcclComm(comm, nullptr);
115+
}
116+
117+
void McclImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype,
118+
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const {
119+
MCCL_CHECK(mcclAllReduce(sendbuff, recvbuff, count, kMcclDtypeMap.at(dtype), kMcclReduceOpMap.at(reduce_op),
120+
GetMcclComm(comm), GetMacaStream(stream)));
121+
}
122+
123+
void McclImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, int root,
124+
const CclComm *comm, Stream *stream) const {
125+
MCCL_CHECK(mcclBroadcast(sendbuff, recvbuff, count, kMcclDtypeMap.at(dtype), root, GetMcclComm(comm),
126+
GetMacaStream(stream)));
127+
}
128+
129+
void McclImpl::Reduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype,
130+
nn::parallel::function::ReduceOpType reduce_op, int root, const CclComm *comm,
131+
Stream *stream) const {
132+
MCCL_CHECK(mcclReduce(sendbuff, recvbuff, count, kMcclDtypeMap.at(dtype), kMcclReduceOpMap.at(reduce_op), root,
133+
GetMcclComm(comm), GetMacaStream(stream)));
134+
}
135+
136+
void McclImpl::AllGather(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm,
137+
Stream *stream) const {
138+
MCCL_CHECK(
139+
mcclAllGather(sendbuff, recvbuff, count, kMcclDtypeMap.at(dtype), GetMcclComm(comm), GetMacaStream(stream)));
140+
}
141+
142+
void McclImpl::ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_count, DataType dtype,
143+
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm,
144+
Stream *stream) const {
145+
MCCL_CHECK(mcclReduceScatter(sendbuff, recvbuff, recv_count, kMcclDtypeMap.at(dtype),
146+
kMcclReduceOpMap.at(reduce_op), GetMcclComm(comm), GetMacaStream(stream)));
147+
}
148+
149+
void McclImpl::Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm,
150+
Stream *stream) const {
151+
MCCL_CHECK(mcclSend(buff, count, kMcclDtypeMap.at(dtype), peer, GetMcclComm(comm), GetMacaStream(stream)));
152+
}
153+
154+
void McclImpl::Recv(void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const {
155+
MCCL_CHECK(mcclRecv(buff, count, kMcclDtypeMap.at(dtype), peer, GetMcclComm(comm), GetMacaStream(stream)));
156+
}
157+
158+
INFINI_TRAIN_REGISTER_CCL_IMPL(Device::DeviceType::kMACA, McclImpl)
159+
160+
} // namespace infini_train::core::maca
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#pragma once
2+
3+
#include <string>
4+
#include <unordered_map>
5+
6+
#include "infini_train/include/core/ccl/ccl.h"
7+
8+
namespace infini_train::core::maca {
9+
10+
class McclImpl final : public CclImpl {
11+
public:
12+
Device::DeviceType Type() const override;
13+
14+
void GroupStart() const override;
15+
16+
void GroupEnd() const override;
17+
18+
void GetAsyncError(const CclComm *comm, CclStatus *async_error) const override;
19+
20+
void GetUniqueId(CclUniqueId **unique_id) const override;
21+
22+
void CommInitAll(CclComm **comms, int ndev, const int *devlist) const override;
23+
24+
void CommInitRank(CclComm **comm, int nranks, const CclUniqueId &unique_id, int rank) const override;
25+
26+
void CommDestroy(CclComm *comm) const override;
27+
28+
void AllReduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype,
29+
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const override;
30+
31+
void Broadcast(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, int root, const CclComm *comm,
32+
Stream *stream) const override;
33+
34+
void Reduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype,
35+
nn::parallel::function::ReduceOpType reduce_op, int root, const CclComm *comm,
36+
Stream *stream) const override;
37+
38+
void AllGather(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm,
39+
Stream *stream) const override;
40+
41+
void ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_count, DataType dtype,
42+
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm,
43+
Stream *stream) const override;
44+
45+
void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm,
46+
Stream *stream) const override;
47+
48+
void Recv(void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const override;
49+
};
50+
51+
} // namespace infini_train::core::maca
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include <memory>
2+
3+
#include "infini_train/include/common/common.h"
4+
#include "infini_train/include/common/maca/kernel_helper.cuh"
5+
#include "infini_train/include/core/runtime/device_guard.h"
6+
#include "infini_train/include/datatype.h"
7+
#include "infini_train/include/device.h"
8+
#include "infini_train/include/dispatcher.h"
9+
#include "infini_train/include/tensor.h"
10+
11+
#include "infini_train/src/core/runtime/maca/maca_runtime_common.h"
12+
13+
namespace infini_train::kernels::maca {
14+
15+
template <typename Tdst, typename Tsrc>
16+
__global__ void CastKernel(Tdst *dst, const Tsrc *src, size_t num_elements, size_t offset) {
17+
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
18+
19+
if (idx < num_elements) {
20+
dst[idx] = common::maca::Cast<Tdst>(src[idx]);
21+
}
22+
}
23+
24+
std::shared_ptr<Tensor> Cast(std::shared_ptr<Tensor> input, DataType dtype) {
25+
auto dst_tensor = std::make_shared<Tensor>(input->Dims(), dtype, input->GetDevice());
26+
auto device = input->GetDevice();
27+
const auto &maca_stream = dynamic_cast<infini_train::core::maca::MacaStream *>(
28+
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
29+
->maca_stream();
30+
31+
const size_t num_elements = input->NumElements();
32+
dim3 block_dims(256);
33+
dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x));
34+
const size_t step = grid_dims.x * block_dims.x;
35+
36+
DispatchFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
37+
{dtype, input->Dtype()},
38+
[=]<typename Tdst, typename Tsrc>() {
39+
auto dst = static_cast<Tdst *>(dst_tensor->DataPtr());
40+
auto src = static_cast<const Tsrc *>(input->DataPtr());
41+
for (size_t offset = 0; offset < num_elements; offset += step) {
42+
CastKernel<<<grid_dims, block_dims, 0, maca_stream>>>(dst, src, num_elements, offset);
43+
}
44+
},
45+
"MACA Cast");
46+
47+
return {dst_tensor};
48+
}
49+
} // namespace infini_train::kernels::maca
50+
51+
#define REGISTER_MACA_CAST_KERNEL(kernel_name) \
52+
REGISTER_KERNEL(infini_train::Device::DeviceType::kMACA, kernel_name, infini_train::kernels::maca::kernel_name)
53+
54+
REGISTER_MACA_CAST_KERNEL(Cast)
55+
56+
#undef REGISTER_MACA_CAST_KERNEL

0 commit comments

Comments
 (0)