Skip to content

Commit 1f95e29

Browse files
committed
style: apply clang-format to all legacy files
1 parent f7b3fcb commit 1f95e29

14 files changed

Lines changed: 282 additions & 379 deletions

example/gpt2/main.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
5757
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
5858
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
5959
// lr scheduler
60-
DEFINE_string(lr_scheduler, "none",
61-
"Learning rate scheduler type: none|constant|step|linear");
60+
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
6261
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
6362
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
6463
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
@@ -289,10 +288,10 @@ void Train(const nn::parallel::Rank &rank) {
289288
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
290289
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
291290
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
292-
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
291+
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
293292
sched_config.constant_total_iters = FLAGS_lr_total_iters;
294293
sched_config.linear_total_iters = FLAGS_lr_total_iters;
295-
auto scheduler = CreateLRScheduler(optimizer,sched_config);
294+
auto scheduler = CreateLRScheduler(optimizer, sched_config);
296295

297296
auto train_iter = train_loader.begin();
298297
std::shared_ptr<nn::Module> loss_fn
@@ -410,12 +409,11 @@ void Train(const nn::parallel::Rank &rank) {
410409
if (rank.IsLastRank()) {
411410
size_t used_mb = 0, reserved_mb = 0;
412411
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);
413-
const float current_lr = scheduler ? scheduler->GetLR()
414-
: static_cast<float>(FLAGS_learning_rate);
412+
const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
415413
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
416414
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
417-
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f,
418-
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
415+
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
416+
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
419417
pp_world_size);
420418

421419
if ((step + 1) % FLAGS_freq_generate_txt == 0) {

example/llama3/main.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
5656
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
5757
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
5858
// lr scheduler
59-
DEFINE_string(lr_scheduler, "none",
60-
"Learning rate scheduler type: none|constant|step|linear");
59+
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
6160
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
6261
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
6362
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
@@ -268,10 +267,10 @@ void Train(const nn::parallel::Rank &rank) {
268267
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
269268
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
270269
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
271-
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
270+
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
272271
sched_config.constant_total_iters = FLAGS_lr_total_iters;
273272
sched_config.linear_total_iters = FLAGS_lr_total_iters;
274-
auto scheduler = CreateLRScheduler(optimizer,sched_config);
273+
auto scheduler = CreateLRScheduler(optimizer, sched_config);
275274

276275
auto train_iter = train_loader.begin();
277276
std::shared_ptr<nn::Module> loss_fn
@@ -386,12 +385,11 @@ void Train(const nn::parallel::Rank &rank) {
386385
if (rank.IsLastRank()) {
387386
size_t used_mb = 0, reserved_mb = 0;
388387
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);
389-
const float current_lr = scheduler ? scheduler->GetLR()
390-
: static_cast<float>(FLAGS_learning_rate);
388+
const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
391389
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
392390
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
393-
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f,
394-
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
391+
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
392+
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
395393
pp_world_size);
396394

397395
if ((step + 1) % FLAGS_freq_generate_txt == 0) {

infini_train/include/lr_scheduler.h

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include <cstdint>
43
#include <cmath>
4+
#include <cstdint>
55
#include <functional>
66
#include <memory>
77
#include <string>
@@ -13,12 +13,11 @@ namespace infini_train {
1313

1414
class Optimizer;
1515

16-
using StateValue = std::variant<int64_t, float, double, std::string,
17-
std::vector<float>>;
16+
using StateValue = std::variant<int64_t, float, double, std::string, std::vector<float>>;
1817
using StateDict = std::unordered_map<std::string, StateValue>;
1918

2019
struct LRSchedulerConfig {
21-
std::string type = "none";
20+
std::string type = "none";
2221
// ConstantLR
2322
float constant_factor = 1.0f / 3.0f;
2423
int constant_total_iters = 5;
@@ -44,15 +43,13 @@ struct LRSchedulerConfig {
4443

4544
class LRScheduler {
4645
public:
47-
template<typename T, typename... Args>
48-
static std::shared_ptr<T> Create(Args&&... args) {
46+
template <typename T, typename... Args> static std::shared_ptr<T> Create(Args &&...args) {
4947
auto scheduler = std::make_shared<T>(std::forward<Args>(args)...);
5048
scheduler->InitialStep();
5149
return scheduler;
5250
}
5351

54-
explicit LRScheduler(std::shared_ptr<Optimizer> optimizer,
55-
int64_t last_step = -1);
52+
explicit LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step = -1);
5653
virtual ~LRScheduler() = default;
5754

5855
LRScheduler(const LRScheduler &) = delete;
@@ -82,17 +79,13 @@ class LRScheduler {
8279
bool is_initial_ = false;
8380
};
8481

85-
std::shared_ptr<LRScheduler> CreateLRScheduler(
86-
std::shared_ptr<Optimizer> optimizer,
87-
const LRSchedulerConfig& config);
82+
std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config);
8883

8984
namespace lr_schedulers {
9085

9186
class ConstantLR : public LRScheduler {
9287
public:
93-
ConstantLR(std::shared_ptr<Optimizer> optimizer,
94-
float factor = 1.0f / 3.0f,
95-
int total_iters = 5,
88+
ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor = 1.0f / 3.0f, int total_iters = 5,
9689
int64_t last_step = -1);
9790
~ConstantLR() override = default;
9891

@@ -107,10 +100,7 @@ class ConstantLR : public LRScheduler {
107100

108101
class StepLR : public LRScheduler {
109102
public:
110-
StepLR(std::shared_ptr<Optimizer> optimizer,
111-
int64_t step_size,
112-
float gamma = 0.1f,
113-
int64_t last_step = -1);
103+
StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1);
114104
~StepLR() override = default;
115105

116106
protected:
@@ -124,11 +114,8 @@ class StepLR : public LRScheduler {
124114

125115
class LinearLR : public LRScheduler {
126116
public:
127-
LinearLR(std::shared_ptr<Optimizer> optimizer,
128-
float start_factor = 1.0f / 3.0f,
129-
float end_factor = 1.0f,
130-
int64_t total_iters = 5,
131-
int64_t last_step = -1);
117+
LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f,
118+
int64_t total_iters = 5, int64_t last_step = -1);
132119
~LinearLR() override = default;
133120

134121
protected:
@@ -145,9 +132,7 @@ class LambdaLR : public LRScheduler {
145132
public:
146133
using LambdaFunc = std::function<float(int64_t)>;
147134

148-
LambdaLR(std::shared_ptr<Optimizer> optimizer,
149-
LambdaFunc lr_lambda,
150-
int64_t last_step = -1);
135+
LambdaLR(std::shared_ptr<Optimizer> optimizer, LambdaFunc lr_lambda, int64_t last_step = -1);
151136
~LambdaLR() override = default;
152137

153138
protected:
@@ -157,13 +142,10 @@ class LambdaLR : public LRScheduler {
157142
const LambdaFunc lr_lambda_;
158143
};
159144

160-
161145
class SequentialLR : public LRScheduler {
162146
public:
163-
SequentialLR(std::shared_ptr<Optimizer> optimizer,
164-
std::vector<std::shared_ptr<LRScheduler>> schedulers,
165-
std::vector<int64_t> milestones,
166-
int64_t last_step = -1);
147+
SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
148+
std::vector<int64_t> milestones, int64_t last_step = -1);
167149
~SequentialLR() override = default;
168150

169151
void Step() override;
@@ -183,16 +165,15 @@ class SequentialLR : public LRScheduler {
183165

184166
class ChainedScheduler : public LRScheduler {
185167
public:
186-
ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
187-
std::vector<std::shared_ptr<LRScheduler>> schedulers,
168+
ChainedScheduler(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
188169
int64_t last_step = -1);
189170
~ChainedScheduler() override = default;
190171

191172
void Step() override;
192173
void InitialStep() override;
193174

194175
StateDict State() const override;
195-
void LoadState(const StateDict& state) override;
176+
void LoadState(const StateDict &state) override;
196177

197178
protected:
198179
float GetClosedFormLR() const override { return current_lr_; }
@@ -201,6 +182,5 @@ class ChainedScheduler : public LRScheduler {
201182
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
202183
};
203184

204-
205-
} // namespace lr_schedulers
206-
} // namespace infini_train
185+
} // namespace lr_schedulers
186+
} // namespace infini_train

infini_train/include/optimizer.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Optimizer {
2626
virtual float GetLearningRate() const;
2727

2828
float GetInitialLearningRate() const;
29-
29+
3030
void SetInitialLearningRate(float lr);
3131

3232
protected:
@@ -48,7 +48,6 @@ class SGD : public Optimizer {
4848
return std::make_shared<SGD>(params, learning_rate);
4949
};
5050
}
51-
5251
};
5352

5453
class Adam : public Optimizer {

0 commit comments

Comments
 (0)