Skip to content

Commit 1431df3

Browse files
committed
Add new PTL callback to measure the wall-clock time per optmizer step to match native recipe
1 parent 0db38e1 commit 1431df3

3 files changed

Lines changed: 63 additions & 0 deletions

File tree

bionemo-recipes/recipes/codonfm_ptl_te/src/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from src.tokenizer import Tokenizer
3636
from src.utils.fsdp_config import get_fsdp_strategy
3737
from src.utils.grad_norm_callback import GradientNormLogger
38+
from src.utils.interval_step_timing import IntervalStepTimingCallback
3839
from src.utils.pred_writer import PredWriter
3940
from src.utils.scheduler import linear_scheduler_with_warmup_lr_lambda
4041
from src.utils.throughput_logger import ThroughputLogger
@@ -136,6 +137,7 @@ def get_callbacks_config(args: Any) -> Dict[str, fdl.Config]:
136137
"lr_monitor": fdl.Config(LearningRateMonitor, logging_interval="step", log_weight_decay=True),
137138
"grad_norm_callback": fdl.Config(GradientNormLogger, log_every_n_steps=args.log_every_n_steps),
138139
"timer_callback": fdl.Config(StepTimingCallback, log_every_n_steps=args.log_every_n_steps, mode="train"),
140+
"interval_timer_callback": fdl.Config(IntervalStepTimingCallback, log_every_n_steps=args.log_every_n_steps),
139141
"throughput_callback": fdl.Config(ThroughputLogger, log_every_n_steps=args.log_every_n_steps, warmup_steps=40),
140142
}
141143
if args.mode == "eval":

bionemo-recipes/recipes/codonfm_ptl_te/src/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616

1717
from src.utils.grad_norm_callback import GradientNormLogger
18+
from src.utils.interval_step_timing import IntervalStepTimingCallback
1819
from src.utils.pred_writer import PredWriter
1920
from src.utils.pylogger import RankedLogger
2021
from src.utils.throughput_logger import ThroughputLogger
2122

2223

2324
__all__ = [
2425
"GradientNormLogger",
26+
"IntervalStepTimingCallback",
2527
"PredWriter",
2628
"RankedLogger",
2729
"ThroughputLogger",
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import time
17+
18+
import torch
19+
from lightning.pytorch.callbacks import Callback
20+
21+
22+
class IntervalStepTimingCallback(Callback):
23+
"""Logs mean wall-clock time per optimizer step over a fixed logging interval.
24+
25+
Mirrors the semantics of `train/step_time` in the native_te recipe's `PerfLogger`:
26+
samples `time.perf_counter()` only at log boundaries and divides by
27+
`log_every_n_steps`, yielding the average optimizer-step wall time over the
28+
last interval rather than a per-step measurement.
29+
"""
30+
31+
def __init__(self, log_every_n_steps: int = 10): # noqa: D107
32+
self.log_every_n_steps = log_every_n_steps
33+
self.previous_log_time: float | None = None
34+
35+
def on_train_start(self, trainer, pl_module): # noqa: D102
36+
self.previous_log_time = time.perf_counter()
37+
38+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # noqa: D102
39+
if (batch_idx + 1) % trainer.accumulate_grad_batches != 0:
40+
return
41+
42+
step = trainer.global_step
43+
if step == 0 or step % self.log_every_n_steps != 0:
44+
return
45+
46+
if torch.cuda.is_available():
47+
torch.cuda.synchronize()
48+
now = time.perf_counter()
49+
step_time = (now - self.previous_log_time) / self.log_every_n_steps
50+
self.previous_log_time = now
51+
52+
pl_module.log(
53+
"timing_train/step_time",
54+
step_time,
55+
prog_bar=True,
56+
on_step=True,
57+
on_epoch=False,
58+
sync_dist=True,
59+
)

0 commit comments

Comments
 (0)