Skip to content

Commit afd98ff

Browse files
author
kinorw
committed
fix: get lr of current epoch before the scheduler steps.
1 parent 1f95e29 commit afd98ff

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

example/gpt2/main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ void Train(const nn::parallel::Rank &rank) {
336336
Profiler::Instance().SetTag("Step_" + std::to_string(step));
337337
#endif
338338

339+
const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
339340
float lossf = 0.0f;
340341
// model->Train();
341342
if (pp_world_size == 1) {
@@ -409,7 +410,6 @@ void Train(const nn::parallel::Rank &rank) {
409410
if (rank.IsLastRank()) {
410411
size_t used_mb = 0, reserved_mb = 0;
411412
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);
412-
const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
413413
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
414414
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
415415
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,

example/llama3/main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ void Train(const nn::parallel::Rank &rank) {
312312
Profiler::Instance().SetTag("Step_" + std::to_string(step));
313313
#endif
314314

315+
const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
315316
float lossf = 0.0f;
316317
if (pp_world_size == 1) {
317318
// model->Train();
@@ -385,7 +386,6 @@ void Train(const nn::parallel::Rank &rank) {
385386
if (rank.IsLastRank()) {
386387
size_t used_mb = 0, reserved_mb = 0;
387388
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);
388-
const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
389389
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
390390
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
391391
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,

0 commit comments

Comments
 (0)