Skip to content

Commit f0012be

Browse files
committed
style: apply clang-format to all legacy files
1 parent 831f55e commit f0012be

File tree

15 files changed

+287
-384
lines changed

15 files changed

+287
-384
lines changed

CMakeLists.txt

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,25 +204,23 @@ link_infini_train_exe(test_precision_check)
204204
add_executable(test_lora test/lora/test_lora.cc)
205205
link_infini_train_exe(test_lora)
206206

207-
target_link_libraries(test_precision_check infini_train)
208-
209207
add_executable(test_lr_scheduler test/lr_scheduler/test_lr_scheduler.cc)
210-
target_link_libraries(test_lr_scheduler infini_train)
208+
link_infini_train_exe(test_lr_scheduler)
211209

212210
add_executable(test_constant_lr test/lr_scheduler/test_constant_lr.cc)
213-
target_link_libraries(test_constant_lr infini_train)
211+
link_infini_train_exe(test_constant_lr)
214212

215213
add_executable(test_step_lr test/lr_scheduler/test_step_lr.cc)
216-
target_link_libraries(test_step_lr infini_train)
214+
link_infini_train_exe(test_step_lr)
217215

218216
add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc)
219-
target_link_libraries(test_linear_lr infini_train)
217+
link_infini_train_exe(test_linear_lr)
220218

221219
add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc)
222-
target_link_libraries(test_lambda_lr infini_train)
220+
link_infini_train_exe(test_lambda_lr)
223221

224222
add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc)
225-
target_link_libraries(test_sequential_lr infini_train)
223+
link_infini_train_exe(test_sequential_lr)
226224

227225
add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc)
228-
target_link_libraries(test_chained_lr infini_train)
226+
link_infini_train_exe(test_chained_lr)

example/gpt2/main.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
5858
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
5959
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
6060
// lr scheduler
61-
DEFINE_string(lr_scheduler, "none",
62-
"Learning rate scheduler type: none|constant|step|linear");
61+
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
6362
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
6463
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
6564
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
@@ -331,10 +330,10 @@ void Train(const nn::parallel::Rank &rank) {
331330
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
332331
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
333332
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
334-
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
333+
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
335334
sched_config.constant_total_iters = FLAGS_lr_total_iters;
336335
sched_config.linear_total_iters = FLAGS_lr_total_iters;
337-
auto scheduler = CreateLRScheduler(optimizer,sched_config);
336+
auto scheduler = CreateLRScheduler(optimizer, sched_config);
338337

339338
auto train_iter = train_loader.begin();
340339
std::shared_ptr<nn::Module> loss_fn
@@ -455,8 +454,8 @@ void Train(const nn::parallel::Rank &rank) {
455454
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);
456455
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
457456
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
458-
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f,
459-
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
457+
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
458+
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
460459
pp_world_size);
461460

462461
if ((step + 1) % FLAGS_freq_generate_txt == 0) {

example/llama3/main.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
5757
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
5858
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
5959
// lr scheduler
60-
DEFINE_string(lr_scheduler, "none",
61-
"Learning rate scheduler type: none|constant|step|linear");
60+
DEFINE_string(lr_scheduler, "none", "Learning rate scheduler type: none|constant|step|linear");
6261
DEFINE_int64(warmup_steps, 0, "Number of linear warmup steps (0 = no warmup)");
6362
DEFINE_double(warmup_start_factor, 0.333333, "Starting learning rate factor for linear warmup (multiplied by base LR)");
6463
DEFINE_double(warmup_end_factor, 1.0, "Ending learning rate factor for linear warmup (multiplied by base LR)");
@@ -303,10 +302,10 @@ void Train(const nn::parallel::Rank &rank) {
303302
sched_config.step_gamma = static_cast<float>(FLAGS_gamma);
304303
sched_config.linear_start_factor = static_cast<float>(FLAGS_start_factor);
305304
sched_config.linear_end_factor = static_cast<float>(FLAGS_end_factor);
306-
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
305+
sched_config.constant_factor = static_cast<float>(FLAGS_start_factor); // 复用
307306
sched_config.constant_total_iters = FLAGS_lr_total_iters;
308307
sched_config.linear_total_iters = FLAGS_lr_total_iters;
309-
auto scheduler = CreateLRScheduler(optimizer,sched_config);
308+
auto scheduler = CreateLRScheduler(optimizer, sched_config);
310309

311310
auto train_iter = train_loader.begin();
312311
std::shared_ptr<nn::Module> loss_fn
@@ -424,8 +423,8 @@ void Train(const nn::parallel::Rank &rank) {
424423
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);
425424
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
426425
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
427-
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f,
428-
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
426+
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
427+
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
429428
pp_world_size);
430429

431430
if ((step + 1) % FLAGS_freq_generate_txt == 0) {

infini_train/include/lr_scheduler.h

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include <cstdint>
43
#include <cmath>
4+
#include <cstdint>
55
#include <functional>
66
#include <memory>
77
#include <string>
@@ -13,12 +13,11 @@ namespace infini_train {
1313

1414
class Optimizer;
1515

16-
using StateValue = std::variant<int64_t, float, double, std::string,
17-
std::vector<float>>;
16+
using StateValue = std::variant<int64_t, float, double, std::string, std::vector<float>>;
1817
using StateDict = std::unordered_map<std::string, StateValue>;
1918

2019
struct LRSchedulerConfig {
21-
std::string type = "none";
20+
std::string type = "none";
2221
// ConstantLR
2322
float constant_factor = 1.0f / 3.0f;
2423
int constant_total_iters = 5;
@@ -44,15 +43,13 @@ struct LRSchedulerConfig {
4443

4544
class LRScheduler {
4645
public:
47-
template<typename T, typename... Args>
48-
static std::shared_ptr<T> Create(Args&&... args) {
46+
template <typename T, typename... Args> static std::shared_ptr<T> Create(Args &&...args) {
4947
auto scheduler = std::make_shared<T>(std::forward<Args>(args)...);
5048
scheduler->InitialStep();
5149
return scheduler;
5250
}
5351

54-
explicit LRScheduler(std::shared_ptr<Optimizer> optimizer,
55-
int64_t last_step = -1);
52+
explicit LRScheduler(std::shared_ptr<Optimizer> optimizer, int64_t last_step = -1);
5653
virtual ~LRScheduler() = default;
5754

5855
LRScheduler(const LRScheduler &) = delete;
@@ -82,17 +79,13 @@ class LRScheduler {
8279
bool is_initial_ = false;
8380
};
8481

85-
std::shared_ptr<LRScheduler> CreateLRScheduler(
86-
std::shared_ptr<Optimizer> optimizer,
87-
const LRSchedulerConfig& config);
82+
std::shared_ptr<LRScheduler> CreateLRScheduler(std::shared_ptr<Optimizer> optimizer, const LRSchedulerConfig &config);
8883

8984
namespace lr_schedulers {
9085

9186
class ConstantLR : public LRScheduler {
9287
public:
93-
ConstantLR(std::shared_ptr<Optimizer> optimizer,
94-
float factor = 1.0f / 3.0f,
95-
int total_iters = 5,
88+
ConstantLR(std::shared_ptr<Optimizer> optimizer, float factor = 1.0f / 3.0f, int total_iters = 5,
9689
int64_t last_step = -1);
9790
~ConstantLR() override = default;
9891

@@ -107,10 +100,7 @@ class ConstantLR : public LRScheduler {
107100

108101
class StepLR : public LRScheduler {
109102
public:
110-
StepLR(std::shared_ptr<Optimizer> optimizer,
111-
int64_t step_size,
112-
float gamma = 0.1f,
113-
int64_t last_step = -1);
103+
StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma = 0.1f, int64_t last_step = -1);
114104
~StepLR() override = default;
115105

116106
protected:
@@ -124,11 +114,8 @@ class StepLR : public LRScheduler {
124114

125115
class LinearLR : public LRScheduler {
126116
public:
127-
LinearLR(std::shared_ptr<Optimizer> optimizer,
128-
float start_factor = 1.0f / 3.0f,
129-
float end_factor = 1.0f,
130-
int64_t total_iters = 5,
131-
int64_t last_step = -1);
117+
LinearLR(std::shared_ptr<Optimizer> optimizer, float start_factor = 1.0f / 3.0f, float end_factor = 1.0f,
118+
int64_t total_iters = 5, int64_t last_step = -1);
132119
~LinearLR() override = default;
133120

134121
protected:
@@ -145,9 +132,7 @@ class LambdaLR : public LRScheduler {
145132
public:
146133
using LambdaFunc = std::function<float(int64_t)>;
147134

148-
LambdaLR(std::shared_ptr<Optimizer> optimizer,
149-
LambdaFunc lr_lambda,
150-
int64_t last_step = -1);
135+
LambdaLR(std::shared_ptr<Optimizer> optimizer, LambdaFunc lr_lambda, int64_t last_step = -1);
151136
~LambdaLR() override = default;
152137

153138
protected:
@@ -157,13 +142,10 @@ class LambdaLR : public LRScheduler {
157142
const LambdaFunc lr_lambda_;
158143
};
159144

160-
161145
class SequentialLR : public LRScheduler {
162146
public:
163-
SequentialLR(std::shared_ptr<Optimizer> optimizer,
164-
std::vector<std::shared_ptr<LRScheduler>> schedulers,
165-
std::vector<int64_t> milestones,
166-
int64_t last_step = -1);
147+
SequentialLR(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
148+
std::vector<int64_t> milestones, int64_t last_step = -1);
167149
~SequentialLR() override = default;
168150

169151
void Step() override;
@@ -183,16 +165,15 @@ class SequentialLR : public LRScheduler {
183165

184166
class ChainedScheduler : public LRScheduler {
185167
public:
186-
ChainedScheduler(std::shared_ptr<Optimizer> optimizer,
187-
std::vector<std::shared_ptr<LRScheduler>> schedulers,
168+
ChainedScheduler(std::shared_ptr<Optimizer> optimizer, std::vector<std::shared_ptr<LRScheduler>> schedulers,
188169
int64_t last_step = -1);
189170
~ChainedScheduler() override = default;
190171

191172
void Step() override;
192173
void InitialStep() override;
193174

194175
StateDict State() const override;
195-
void LoadState(const StateDict& state) override;
176+
void LoadState(const StateDict &state) override;
196177

197178
protected:
198179
float GetClosedFormLR() const override { return current_lr_; }
@@ -201,6 +182,5 @@ class ChainedScheduler : public LRScheduler {
201182
std::vector<std::shared_ptr<LRScheduler>> schedulers_;
202183
};
203184

204-
205-
} // namespace lr_schedulers
206-
} // namespace infini_train
185+
} // namespace lr_schedulers
186+
} // namespace infini_train

infini_train/include/optimizer.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Optimizer {
2626
virtual float GetLearningRate() const;
2727

2828
float GetInitialLearningRate() const;
29-
29+
3030
void SetInitialLearningRate(float lr);
3131

3232
protected:
@@ -48,7 +48,6 @@ class SGD : public Optimizer {
4848
return std::make_shared<SGD>(params, learning_rate);
4949
};
5050
}
51-
5251
};
5352

5453
class Adam : public Optimizer {

0 commit comments

Comments
 (0)