Skip to content

Commit 1e37842

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: remove main_grad, modify DistOpt constructor, create ddp folder and other minor fixes
1 parent 8ac4f96 commit 1e37842

16 files changed

Lines changed: 119 additions & 168 deletions

example/gpt2/main.cc

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
#include "infini_train/include/device.h"
1515
#include "infini_train/include/nn/modules/loss.h"
1616
#include "infini_train/include/nn/modules/module.h"
17-
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
18-
#include "infini_train/include/nn/parallel/distributed_optimizer.h"
17+
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
18+
#include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h"
1919
#include "infini_train/include/nn/parallel/global.h"
2020
#include "infini_train/include/nn/parallel/parallel_functional.h"
2121
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
@@ -254,27 +254,11 @@ void Train(const nn::parallel::Rank &rank) {
254254
std::shared_ptr<Optimizer> optimizer = nullptr;
255255

256256
if (FLAGS_use_distributed_optimizer) {
257-
std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers;
258-
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups;
259-
260-
if (pp_world_size > 1 && ddp_world_size > 1) {
261-
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
262-
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
263-
auto buffers
264-
= dynamic_cast<DistributedDataParallel *>(mutable_chunks->at(chunk_id).get())->param_grad_buffers();
265-
auto groups
266-
= dynamic_cast<DistributedDataParallel *>(mutable_chunks->at(chunk_id).get())->bucket_groups();
267-
param_grad_buffers.insert(param_grad_buffers.end(), buffers.begin(), buffers.end());
268-
bucket_groups.insert(bucket_groups.end(), groups.begin(), groups.end());
269-
}
270-
} else if (ddp_world_size > 1) {
271-
param_grad_buffers = dynamic_cast<DistributedDataParallel *>(model.get())->param_grad_buffers();
272-
bucket_groups = dynamic_cast<DistributedDataParallel *>(model.get())->bucket_groups();
273-
}
274-
257+
auto model_chunks = (pp_world_size > 1)
258+
? *(dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks())
259+
: std::vector<std::shared_ptr<nn::Module>>{model};
275260
optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, model->Parameters(),
276-
param_grad_buffers, bucket_groups, ddp_pg,
277-
ddp_world_size, ddp_rank);
261+
model_chunks, ddp_world_size, ddp_rank);
278262
} else {
279263
optimizer = optimizer_creator(model->Parameters());
280264
}

example/llama3/main.cc

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
#include "infini_train/include/device.h"
1313
#include "infini_train/include/nn/modules/loss.h"
1414
#include "infini_train/include/nn/modules/module.h"
15-
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
16-
#include "infini_train/include/nn/parallel/distributed_optimizer.h"
15+
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
16+
#include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h"
1717
#include "infini_train/include/nn/parallel/parallel_functional.h"
1818
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
1919
#include "infini_train/include/nn/parallel/rank.h"
@@ -233,27 +233,11 @@ void Train(const nn::parallel::Rank &rank) {
233233
std::shared_ptr<Optimizer> optimizer = nullptr;
234234

235235
if (FLAGS_use_distributed_optimizer) {
236-
std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers;
237-
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups;
238-
239-
if (pp_world_size > 1 && ddp_world_size > 1) {
240-
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
241-
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
242-
auto buffers
243-
= dynamic_cast<DistributedDataParallel *>(mutable_chunks->at(chunk_id).get())->param_grad_buffers();
244-
auto groups
245-
= dynamic_cast<DistributedDataParallel *>(mutable_chunks->at(chunk_id).get())->bucket_groups();
246-
param_grad_buffers.insert(param_grad_buffers.end(), buffers.begin(), buffers.end());
247-
bucket_groups.insert(bucket_groups.end(), groups.begin(), groups.end());
248-
}
249-
} else if (ddp_world_size > 1) {
250-
param_grad_buffers = dynamic_cast<DistributedDataParallel *>(model.get())->param_grad_buffers();
251-
bucket_groups = dynamic_cast<DistributedDataParallel *>(model.get())->bucket_groups();
252-
}
253-
236+
auto model_chunks = (pp_world_size > 1)
237+
? *(dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks())
238+
: std::vector<std::shared_ptr<nn::Module>>{model};
254239
optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, model->Parameters(),
255-
param_grad_buffers, bucket_groups, ddp_pg,
256-
ddp_world_size, ddp_rank);
240+
model_chunks, ddp_world_size, ddp_rank);
257241
} else {
258242
optimizer = optimizer_creator(model->Parameters());
259243
}

infini_train/include/nn/parallel/distributed_data_parallel.h renamed to infini_train/include/nn/parallel/ddp/distributed_data_parallel.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
#include <memory>
44

55
#include "infini_train/include/nn/modules/module.h"
6-
#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h"
7-
#include "infini_train/include/nn/parallel/param_and_grad_buffer.h"
8-
#include "infini_train/include/nn/parallel/reducer.h"
6+
#include "infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h"
7+
#include "infini_train/include/nn/parallel/ddp/reducer.h"
98

109
namespace infini_train {
1110
class Tensor;
1211
class Device;
12+
namespace nn::parallel {
13+
class DistributedDataParallelConfig;
14+
} // namespace nn::parallel
1315
} // namespace infini_train
1416

1517
namespace infini_train::nn::parallel {
1618

1719
class DistributedDataParallel : public nn::Module {
1820
public:
1921
DistributedDataParallel(std::shared_ptr<nn::Module> module, int thread_rank,
20-
DistributedDataParallelConfig ddp_config = DistributedDataParallelConfig());
22+
DistributedDataParallelConfig ddp_config);
2123

2224
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
2325

infini_train/include/nn/parallel/distributed_data_parallel_config.h renamed to infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,19 @@ class DistributedDataParallelConfig {
4747
bool average_in_collective = true;
4848

4949
// Whether to check NaNs/Infs/unusually large in gradients before collectives.
50+
// TODO(zbl): Unused by now, to be implemented in ParamAndGradBucketGroup::StartGradSync()
5051
bool check_for_nan_in_grad = false;
5152
bool check_for_large_grads = false;
5253

5354
// Number of DistributedOptimizer instances.
5455
// Multiple DistOpt is used for building hierarchical collective groups for param/grad.
56+
// TODO(zbl): Unused by now, to be implemented in ParamAndGradBucketGroup
5557
int num_distributed_optimizer_instances = 1;
5658

5759
// Maximum number of parameters in each ParamAndGradBucket.
58-
// This is distinct from DDP Reducer's MB-based bucket caps.
59-
size_t bucket_size_in_elements = std::numeric_limits<size_t>::max();
60+
// NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps.
61+
// TODO(zbl): To unify the definition of bucket_size argument for users
62+
size_t bucket_size_in_elements = 40000000;
6063

6164
// Whether to pad bucket sizes to improve NCCL bus bandwidth utilization.
6265
bool pad_buckets_for_high_nccl_busbw = false;

infini_train/include/nn/parallel/distributed_optimizer.h renamed to infini_train/include/nn/parallel/ddp/distributed_optimizer.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,24 @@
55
#include <unordered_map>
66
#include <vector>
77

8-
#include "infini_train/include/nn/parallel/param_and_grad_buffer.h"
98
#include "infini_train/include/optimizer.h"
109

10+
namespace infini_train::nn {
11+
class Module;
12+
namespace parallel {
13+
class ParamAndGradBuffer;
14+
class ParamAndGradBucketGroup;
15+
} // namespace parallel
16+
} // namespace infini_train::nn
17+
1118
namespace infini_train::nn::parallel {
1219

1320
class DistributedOptimizer final : public infini_train::Optimizer {
1421
public:
15-
DistributedOptimizer(OptimizerCreator inner_optimizer_creator,
22+
DistributedOptimizer(OptimizerCreator base_optimizer_creator,
1623
const std::vector<std::shared_ptr<Tensor>> &full_params,
17-
const std::vector<std::shared_ptr<ParamAndGradBuffer>> &buffers,
18-
const std::vector<std::shared_ptr<ParamAndGradBucketGroup>> &bucket_groups,
19-
const ProcessGroup *dp_pg, size_t dp_world_size, size_t ddp_rank);
24+
const std::vector<std::shared_ptr<Module>> &model_chunks, size_t dp_world_size,
25+
size_t dp_rank);
2026

2127
void Step() override;
2228

@@ -37,15 +43,13 @@ class DistributedOptimizer final : public infini_train::Optimizer {
3743
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups_;
3844

3945
// DP info
40-
const ProcessGroup *dp_pg_;
4146
size_t dp_world_size_;
4247
size_t dp_rank_;
4348

4449
// shard params
4550
std::vector<std::shared_ptr<Tensor>> shard_params_;
4651

4752
// Base optimizer (SGD, Adam and etc.)
48-
OptimizerCreator creator_;
4953
std::shared_ptr<Optimizer> base_optimizer_;
5054
};
5155

infini_train/include/nn/parallel/param_and_grad_buffer.h renamed to infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
#include <vector>
99

1010
#include "infini_train/include/datatype.h"
11-
#include "infini_train/include/device.h"
12-
#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h"
11+
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h"
1312

1413
namespace infini_train {
1514
class Tensor;
@@ -135,14 +134,17 @@ class ParamAndGradBuffer {
135134

136135
void ScaleGradients(float scaling_factor);
137136

138-
void Reset();
137+
void Reset(bool need_rebind = true);
138+
139+
void RebindGradViews();
139140

140141
private:
141142
void BuildBuckets(DataType param_dtype, DataType grad_dtype);
142143

143144
private:
144145
DistributedDataParallelConfig ddp_config_;
145146
std::vector<std::shared_ptr<Tensor>> params_;
147+
std::vector<std::shared_ptr<Tensor>> grads_;
146148
std::shared_ptr<Tensor> param_buffer_;
147149
std::shared_ptr<Tensor> grad_buffer_;
148150

@@ -153,6 +155,8 @@ class ParamAndGradBuffer {
153155
size_t ddp_world_size_ = 1;
154156
std::vector<std::shared_ptr<ParamAndGradBucket>> buckets_;
155157

158+
bool need_rebind_grad_views_ = true;
159+
156160
std::vector<std::pair<size_t, size_t>> bucket_indices_;
157161
// Param to (start, end, bucket_id)
158162
std::unordered_map<Tensor *, std::tuple<size_t, size_t, size_t>> param_index_map_;

infini_train/include/nn/parallel/reducer.h renamed to infini_train/include/nn/parallel/ddp/reducer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <vector>
77

88
#include "infini_train/include/datatype.h"
9-
#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h"
9+
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel_config.h"
1010
#include "infini_train/include/nn/parallel/parallel_functional.h"
1111

1212
namespace infini_train {
@@ -55,7 +55,7 @@ class Reducer : public std::enable_shared_from_this<Reducer> {
5555
* @param ddp_config DDP related options, see definition of DistributedDataParallelConfig
5656
*/
5757
explicit Reducer(std::vector<std::shared_ptr<Tensor>> parameters, std::vector<std::vector<size_t>> bucket_indices,
58-
const DistributedDataParallelConfig ddp_config = DistributedDataParallelConfig());
58+
const DistributedDataParallelConfig ddp_config);
5959

6060
// Attach PostAllReduceHooks to params
6161
void AttachHooksToParameters();

infini_train/include/tensor.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
6363

6464
Tensor(const Tensor &tensor, size_t offset, const std::vector<int64_t> &dims);
6565

66-
void SetData(const Tensor &tensor, size_t offset, bool overwrite = false);
66+
void SetData(const Tensor &tensor, size_t offset, bool preserve_data = false);
6767

6868
Tensor(const float *data, const std::vector<int64_t> &dims, DataType dtype, const Device *device);
6969
Tensor(const float *data, const std::vector<int64_t> &dims, DataType dtype)
@@ -205,9 +205,6 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
205205
std::shared_ptr<Tensor> grad() const;
206206
void set_grad(const std::shared_ptr<Tensor> &grad);
207207

208-
std::shared_ptr<Tensor> main_grad() const;
209-
void set_main_grad(const std::shared_ptr<Tensor> &grad);
210-
211208
bool requires_grad() const;
212209
void set_requires_grad(bool requires_grad);
213210

@@ -236,8 +233,6 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
236233

237234
private:
238235
std::shared_ptr<Tensor> grad_ = nullptr;
239-
// Points to a view in flat buffer constantly
240-
std::shared_ptr<Tensor> main_grad_ = nullptr;
241236
bool requires_grad_ = false;
242237
bool is_leaf_ = true;
243238
std::shared_ptr<autograd::Function> grad_fn_ = nullptr;

infini_train/src/nn/parallel/distributed_data_parallel.cc renamed to infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
1+
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
22

33
#include <map>
44
#include <memory>
@@ -147,14 +147,6 @@ void DistributedDataParallel::OnGradReady(const std::shared_ptr<Tensor> &param)
147147
CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True";
148148
}
149149

150-
if (param->grad()) {
151-
// Add to main_grad(buffer)
152-
auto kernel = Dispatcher::Instance().GetKernel({param->GetDevice()->Type(), "AccumulateGrad"});
153-
kernel.Call<void>(param->grad(), 1.f, param->main_grad());
154-
}
155-
// Can safely set grad to null because grad has already been added to main_grad(buffer)
156-
param->set_grad(nullptr);
157-
158150
if (ddp_config_.overlap_grad_reduce) {
159151
it->second->RegisterGradReady(param);
160152
}
@@ -167,6 +159,9 @@ DistributedDataParallel::Forward(const std::vector<std::shared_ptr<Tensor>> &inp
167159
if (reducer_) {
168160
reducer_->PrepareForBackward();
169161
}
162+
if (ddp_config_.use_distributed_optimizer) {
163+
for (auto buffer : param_grad_buffers_) { buffer->RebindGradViews(); }
164+
}
170165
return outputs;
171166
}
172167
} // namespace infini_train::nn::parallel

infini_train/src/nn/parallel/distributed_optimizer.cc renamed to infini_train/src/nn/parallel/ddp/distributed_optimizer.cc

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,33 @@
1-
#include "infini_train/include/nn/parallel/distributed_optimizer.h"
1+
#include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h"
22

33
#include "glog/logging.h"
44

5-
#include "infini_train/include/device.h"
5+
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
66
#include "infini_train/include/tensor.h"
77

88
namespace infini_train::nn::parallel {
9-
10-
namespace {
11-
std::shared_ptr<Tensor> GetShardView(const std::shared_ptr<Tensor> &buffer, size_t world_size, size_t rank) {
12-
13-
CHECK(buffer);
14-
CHECK_GT(world_size, 0);
15-
CHECK_LT(rank, world_size);
16-
CHECK_EQ(buffer->NumElements() % world_size, 0);
17-
18-
const size_t shard_numel = buffer->NumElements() / world_size;
19-
const size_t offset_bytes = shard_numel * rank * kDataTypeToSize.at(buffer->Dtype());
20-
21-
return std::make_shared<Tensor>(*buffer, offset_bytes, std::vector<int64_t>{static_cast<int64_t>(shard_numel)});
22-
}
23-
24-
} // namespace
25-
269
DistributedOptimizer::DistributedOptimizer(OptimizerCreator creator,
2710
const std::vector<std::shared_ptr<Tensor>> &full_params,
28-
const std::vector<std::shared_ptr<ParamAndGradBuffer>> &buffers,
29-
const std::vector<std::shared_ptr<ParamAndGradBucketGroup>> &bucket_groups,
30-
const ProcessGroup *dp_pg, size_t dp_world_size, size_t dp_rank)
31-
: Optimizer(full_params), param_grad_buffers_(buffers), bucket_groups_(bucket_groups), dp_pg_(dp_pg),
32-
dp_world_size_(dp_world_size), dp_rank_(dp_rank), creator_(std::move(creator)) {
11+
const std::vector<std::shared_ptr<Module>> &model_chunks,
12+
size_t dp_world_size, size_t dp_rank)
13+
: Optimizer(full_params), dp_world_size_(dp_world_size), dp_rank_(dp_rank) {
3314

34-
CHECK(dp_pg_);
3515
CHECK(dp_world_size_ > 1) << "DistributedOptimizer: dp_world_size must be greater than 1.";
3616

17+
for (size_t i = 0; i < model_chunks.size(); ++i) {
18+
auto ddp_chunk = std::dynamic_pointer_cast<DistributedDataParallel>(model_chunks[i]);
19+
CHECK(ddp_chunk) << "DistributedOptimizer: model_chunks[" << i << "] is not a DDP model.";
20+
21+
param_grad_buffers_.insert(param_grad_buffers_.end(), ddp_chunk->param_grad_buffers().begin(),
22+
ddp_chunk->param_grad_buffers().end());
23+
bucket_groups_.insert(bucket_groups_.end(), ddp_chunk->bucket_groups().begin(),
24+
ddp_chunk->bucket_groups().end());
25+
}
26+
3727
BuildShardParamsAndBindGrads();
3828

3929
// Build base optimizer
40-
base_optimizer_ = creator_(shard_params_);
30+
base_optimizer_ = creator(shard_params_);
4131
CHECK(base_optimizer_) << "DistributedOptimizer: failed to create base optimizer.";
4232
}
4333

@@ -110,11 +100,18 @@ void DistributedOptimizer::FinishParamSync(bool skip_next_bucket_dispatch) {
110100
}
111101

112102
void DistributedOptimizer::ZeroGrad(bool set_to_none) {
113-
// Zero main_grad buffer and clear BucketGroup state
114-
for (auto &buffer : param_grad_buffers_) { buffer->Reset(); }
103+
// Clear BucketGroup state and reset buffer:
104+
// If set_to_none is true:
105+
// 1) buffers will not be zeroed,
106+
// 2) each of full_params's tensor->grad() will be set to nullptr
107+
// If set_to_none is false:
108+
// 1) buffers will be zeroed,
109+
// 2) do not perform Fill(0) for each param
110+
for (auto &buffer : param_grad_buffers_) { buffer->Reset(set_to_none); }
115111
for (auto &group : bucket_groups_) { group->Reset(); }
116-
// Call base class's method: Zero each param's grad to guarantee consistency
117-
infini_train::Optimizer::ZeroGrad(set_to_none);
112+
if (set_to_none) {
113+
for (auto param : params_) { param->ZeroGrad(set_to_none); }
114+
}
118115
}
119116

120117
void DistributedOptimizer::Step() {

0 commit comments

Comments
 (0)