Skip to content

Commit f7b3fcb

Browse files
Merge branch 'InfiniTensor:master' into lr_scheduler
2 parents 3a7abb4 + b1e4b03 commit f7b3fcb

70 files changed

Lines changed: 2171 additions & 1177 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ endif()
4848
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
4949
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
5050
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
51+
if(NOT USE_NCCL)
52+
list(FILTER SRC EXCLUDE REGEX ".*infini_train/src/core/ccl/cuda/.*")
53+
endif()
5154

5255
# CPU kernels (*.cc)
5356
file(GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc)

example/gpt2/main.cc

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "glog/logging.h"
1111

1212
#include "infini_train/include/autocast.h"
13-
#include "infini_train/include/core/device_guard.h"
13+
#include "infini_train/include/core/runtime/device_guard.h"
1414
#include "infini_train/include/dataloader.h"
1515
#include "infini_train/include/device.h"
1616
#include "infini_train/include/lr_scheduler.h"
@@ -152,24 +152,25 @@ void Train(const nn::parallel::Rank &rank) {
152152

153153
if (rank.IsParallel()) {
154154
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
155+
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
155156

156157
if (ddp_world_size > 1) {
157-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
158-
GetDataParallelGroupRanks(rank.GlobalRank()));
158+
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
159+
GetDataParallelGroupRanks(rank.GlobalRank()));
159160
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
160161
}
161162

162163
if (tp_world_size > 1) {
163-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
164-
GetTensorParallelGroupRanks(rank.GlobalRank()));
164+
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
165+
GetTensorParallelGroupRanks(rank.GlobalRank()));
165166
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
166167
// NOTE(zbl): Reserved for VocabParallelEmbedding
167168
nn::parallel::tp_rank = tp_rank;
168169
}
169170

170171
if (pp_world_size > 1) {
171-
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
172-
GetPipelineParallelGroupRanks(rank.GlobalRank()));
172+
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
173+
GetPipelineParallelGroupRanks(rank.GlobalRank()));
173174
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());
174175

175176
nn::parallel::pp_rank = pp_rank;
@@ -231,8 +232,8 @@ void Train(const nn::parallel::Rank &rank) {
231232
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
232233
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
233234
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
234-
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
235-
rank.thread_rank(), ddp_config);
235+
(*mutable_chunks)[chunk_id]
236+
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank, ddp_config);
236237
}
237238
}
238239
} else if (ddp_world_size > 1) {
@@ -241,7 +242,7 @@ void Train(const nn::parallel::Rank &rank) {
241242
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
242243
// are created during the conversion.
243244
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
244-
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
245+
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
245246
}
246247

247248
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),

example/gpt2/net.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ GPT2FirstStage::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>>
198198
auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled();
199199
int tp_rank = 0;
200200
if (tp_world_size > 1) {
201-
auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get(
202-
nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
201+
auto tp_group = nn::parallel::ProcessGroupFactory::Instance(device.type())
202+
->Get(nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
203203
tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank());
204204
}
205205
int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1];

example/llama3/main.cc

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "glog/logging.h"
99

1010
#include "infini_train/include/autocast.h"
11-
#include "infini_train/include/core/device_guard.h"
11+
#include "infini_train/include/core/runtime/device_guard.h"
1212
#include "infini_train/include/dataloader.h"
1313
#include "infini_train/include/device.h"
1414
#include "infini_train/include/lr_scheduler.h"
@@ -133,24 +133,25 @@ void Train(const nn::parallel::Rank &rank) {
133133

134134
if (rank.IsParallel()) {
135135
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
136+
auto *pg_factory = ProcessGroupFactory::Instance(device.type());
136137

137138
if (ddp_world_size > 1) {
138-
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
139-
GetDataParallelGroupRanks(rank.GlobalRank()));
139+
ddp_pg = pg_factory->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
140+
GetDataParallelGroupRanks(rank.GlobalRank()));
140141
ddp_rank = ddp_pg->GetGroupRank(rank.GlobalRank());
141142
}
142143

143144
if (tp_world_size > 1) {
144-
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
145-
GetTensorParallelGroupRanks(rank.GlobalRank()));
145+
tp_pg = pg_factory->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
146+
GetTensorParallelGroupRanks(rank.GlobalRank()));
146147
tp_rank = tp_pg->GetGroupRank(rank.GlobalRank());
147148
// NOTE(zbl): Reserved for VocabParallelEmbedding
148149
nn::parallel::tp_rank = tp_rank;
149150
}
150151

151152
if (pp_world_size > 1) {
152-
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
153-
GetPipelineParallelGroupRanks(rank.GlobalRank()));
153+
pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
154+
GetPipelineParallelGroupRanks(rank.GlobalRank()));
154155
pp_rank = pp_pg->GetGroupRank(rank.GlobalRank());
155156

156157
nn::parallel::pp_rank = pp_rank;
@@ -210,8 +211,8 @@ void Train(const nn::parallel::Rank &rank) {
210211
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
211212
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
212213
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
213-
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
214-
rank.thread_rank(), ddp_config);
214+
(*mutable_chunks)[chunk_id]
215+
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank, ddp_config);
215216
}
216217
}
217218
} else if (ddp_world_size > 1) {
@@ -221,7 +222,7 @@ void Train(const nn::parallel::Rank &rank) {
221222
// are created during the conversion.
222223

223224
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
224-
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
225+
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
225226
}
226227

227228
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),

infini_train/include/core/blas_handle.h

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
#include <cstdint>
5+
#include <memory>
6+
#include <string>
7+
#include <unordered_map>
8+
9+
#include "infini_train/include/core/ccl/ccl_common.h"
10+
#include "infini_train/include/datatype.h"
11+
#include "infini_train/include/device.h"
12+
#include "infini_train/include/nn/parallel/reduce_op_type.h"
13+
14+
namespace infini_train::core {
15+
16+
class Stream;
17+
18+
class CclImpl {
19+
public:
20+
CclImpl() {}
21+
virtual ~CclImpl() = default;
22+
23+
virtual Device::DeviceType Type() const = 0;
24+
25+
virtual void GroupStart() const;
26+
27+
virtual void GroupEnd() const;
28+
29+
virtual void GetAsyncError(const CclComm *comm, CclStatus *async_error) const;
30+
31+
virtual void GetUniqueId(CclUniqueId **unique_id) const;
32+
33+
virtual void CommInitAll(CclComm **comms, int ndev, const int *devlist) const;
34+
35+
virtual void CommInitRank(CclComm **comm, int nranks, const CclUniqueId &unique_id, int rank) const;
36+
37+
virtual void CommDestroy(CclComm *comm) const;
38+
39+
virtual void AllReduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype,
40+
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm, Stream *stream) const;
41+
42+
virtual void Broadcast(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, int root,
43+
const CclComm *comm, Stream *stream) const;
44+
45+
virtual void Reduce(const void *sendbuff, void *recvbuff, size_t count, DataType dtype,
46+
nn::parallel::function::ReduceOpType reduce_op, int root, const CclComm *comm,
47+
Stream *stream) const;
48+
49+
virtual void AllGather(const void *sendbuff, void *recvbuff, size_t count, DataType dtype, const CclComm *comm,
50+
Stream *stream) const;
51+
52+
virtual void ReduceScatter(const void *sendbuff, void *recvbuff, size_t recv_count, DataType dtype,
53+
nn::parallel::function::ReduceOpType reduce_op, const CclComm *comm,
54+
Stream *stream) const;
55+
56+
virtual void Send(const void *buff, size_t count, DataType dtype, int peer, const CclComm *comm,
57+
Stream *stream) const;
58+
59+
virtual void Recv(void *buff, size_t count, DataType dtype, int peer, const CclComm *comm, Stream *stream) const;
60+
};
61+
62+
class CclGroupGuard {
63+
public:
64+
explicit CclGroupGuard(Device::DeviceType type);
65+
~CclGroupGuard();
66+
67+
CclGroupGuard(const CclGroupGuard &) = delete;
68+
CclGroupGuard &operator=(const CclGroupGuard &) = delete;
69+
CclGroupGuard(CclGroupGuard &&) = delete;
70+
CclGroupGuard &operator=(CclGroupGuard &&) = delete;
71+
72+
private:
73+
CclImpl *impl_ = nullptr;
74+
};
75+
76+
class CclImplRegistry {
77+
public:
78+
static CclImplRegistry &Instance();
79+
80+
void Register(Device::DeviceType type, std::unique_ptr<CclImpl> impl);
81+
82+
CclImpl *Get(Device::DeviceType type) const;
83+
84+
private:
85+
CclImplRegistry() = default;
86+
CclImplRegistry(const CclImplRegistry &) = delete;
87+
CclImplRegistry &operator=(const CclImplRegistry &) = delete;
88+
89+
std::unordered_map<Device::DeviceType, std::unique_ptr<CclImpl>> impls_;
90+
};
91+
92+
CclImpl *GetCclImpl(Device::DeviceType type);
93+
94+
} // namespace infini_train::core
95+
96+
#define INFINI_TRAIN_REGISTER_CCL_IMPL(device_type, class_impl) \
97+
static const bool __infini_train_ccl_registered##__COUNTER__ = []() { \
98+
infini_train::core::CclImplRegistry::Instance().Register(device_type, std::make_unique<class_impl>()); \
99+
return true; \
100+
}();
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
#include <cstdint>
5+
6+
#include "glog/logging.h"
7+
8+
namespace infini_train::core {
9+
10+
#define INFINI_TRAIN_CCL_STATUS_LIST(X) \
11+
X(kSuccess, 0) \
12+
X(kInProgress, 1) \
13+
X(kTimeout, 2) \
14+
X(kError, -1) \
15+
X(kInvalidArgument, -2) \
16+
X(kUnavailable, -3) \
17+
X(kNotSupported, -4) \
18+
X(kInternal, -5) \
19+
X(kUnknown, -127)
20+
21+
enum class CclStatus : int32_t {
22+
#define INFINI_TRAIN_CCL_STATUS_ENUM_ITEM(name, value) name = value,
23+
INFINI_TRAIN_CCL_STATUS_LIST(INFINI_TRAIN_CCL_STATUS_ENUM_ITEM)
24+
#undef INFINI_TRAIN_CCL_STATUS_ENUM_ITEM
25+
};
26+
27+
inline const char *CclStatusToString(CclStatus status) {
28+
switch (status) {
29+
#define INFINI_TRAIN_CCL_STATUS_CASE(name, value) \
30+
case CclStatus::name: \
31+
return #name;
32+
INFINI_TRAIN_CCL_STATUS_LIST(INFINI_TRAIN_CCL_STATUS_CASE)
33+
#undef INFINI_TRAIN_CCL_STATUS_CASE
34+
default:
35+
LOG(FATAL) << "Unsupported RuntimeStatus type: " << static_cast<int>(status);
36+
return "";
37+
}
38+
}
39+
40+
#undef INFINI_TRAIN_CCL_STATUS_LIST
41+
42+
class CclComm {
43+
public:
44+
CclComm() = default;
45+
virtual ~CclComm() = default;
46+
};
47+
48+
class CclUniqueId {
49+
public:
50+
CclUniqueId() = default;
51+
virtual ~CclUniqueId() = default;
52+
53+
virtual size_t Size() const = 0;
54+
virtual const void *Data() const = 0;
55+
virtual void Load(const void *src, size_t size) = 0;
56+
};
57+
58+
} // namespace infini_train::core
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <string>
4+
5+
#include "infini_train/include/core/ccl/ccl_common.h"
6+
7+
namespace infini_train::core {
8+
9+
void WriteUniqueIdFile(const CclUniqueId &unique_id, const std::string &pg_name);
10+
11+
void ReadUniqueIdFile(CclUniqueId *unique_id, const std::string &pg_name);
12+
13+
void CleanupUniqueIdFile(const std::string &pg_name);
14+
15+
} // namespace infini_train::core

infini_train/include/core/device_guard.h renamed to infini_train/include/core/runtime/device_guard.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <memory>
55
#include <unordered_map>
66

7+
#include "infini_train/include/core/runtime/runtime_common.h"
78
#include "infini_train/include/device.h"
89

910
namespace infini_train::core {
@@ -55,8 +56,6 @@ inline const char *MemcpyKindToString(MemcpyKind k) {
5556
// DeviceGuard (the public RAII wrapper) forwards calls to the DeviceGuardImpl
5657
// instance registered for the device type.
5758
//
58-
// TODO(dcj): add event management
59-
//
6059
class DeviceGuardImpl {
6160
public:
6261
DeviceGuardImpl() {}
@@ -81,6 +80,34 @@ class DeviceGuardImpl {
8180

8281
virtual Stream *GetStream(Device) const;
8382

83+
virtual Stream *CreateStream(Device) const;
84+
85+
virtual Stream *CreateStreamWithPriority(Device, int priority) const;
86+
87+
virtual void DestroyStream(Stream *) const;
88+
89+
virtual void GetStreamPriorityRange(int *low, int *high) const;
90+
91+
// ----------------------------------------------------------------------
92+
// Event management
93+
// ----------------------------------------------------------------------
94+
95+
virtual void EventCreate(Event **event) const;
96+
97+
virtual void EventCreateWithFlags(Event **event, EventFlag flags) const;
98+
99+
virtual void EventDestroy(Event *event) const;
100+
101+
virtual void EventRecord(Event *event, Stream *stream) const;
102+
103+
virtual void StreamWaitEvent(Stream *stream, Event *event, uint32_t flags) const;
104+
105+
virtual RuntimeStatus EventSynchronize(Event *event) const;
106+
107+
virtual RuntimeStatus EventQuery(Event *event) const;
108+
109+
virtual float EventElapsedTime(Event *start_event, Event *stop_event) const;
110+
84111
// ----------------------------------------------------------------------
85112
// Synchronization
86113
// ----------------------------------------------------------------------

0 commit comments

Comments
 (0)