Skip to content

Commit dc407f2

Browse files
committed
Log step_time
Signed-off-by: Jared Wilber <jwilber@nvidia.com>
1 parent a29272f commit dc407f2

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

  • recipes/geneformer_native_te_nvfsdp_fp8

recipes/geneformer_native_te_nvfsdp_fp8/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
import logging
4848
import os
49+
import time
4950
from dataclasses import dataclass, field
5051

5152
import hydra
@@ -217,7 +218,7 @@ def main(cfg: DictConfig) -> None:
217218
logger=logger,
218219
start_step=start_step,
219220
)
220-
221+
previous_step_time = time.perf_counter()
221222
for step in range(start_step, cfg.training.num_train_steps):
222223
# Get batch
223224
batch = next(dataloader)
@@ -252,6 +253,9 @@ def main(cfg: DictConfig) -> None:
252253

253254
# Log metrics to wandb on main process
254255
if dist_config.is_main_process():
256+
current_time = time.perf_counter()
257+
step_time = current_time - previous_step_time
258+
previous_step_time = current_time
255259
logger.info(
256260
f"Step {step} loss: {loss.item()}, grad_norm: {total_norm}, lr: {optimizer.param_groups[0]['lr']}"
257261
)
@@ -262,6 +266,7 @@ def main(cfg: DictConfig) -> None:
262266
"train/learning_rate": optimizer.param_groups[0]["lr"],
263267
"train/grad_norm": total_norm,
264268
"train/epoch": step / dataloader_length,
269+
"train/step_time": step_time,
265270
}
266271
)
267272

0 commit comments

Comments
 (0)