Skip to content

Commit b97ad98

Browse files
authored
feat(train): show finish time in ETA logs (#5328)
Make long-running training progress easier to read by keeping the relative ETA and appending a concise absolute finish time across the pt, pd, tf, and pt_expt backends. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Training logs now include both remaining ETA and an estimated local finish time (YYYY-MM-DD HH:MM). * Timezone-aware local timestamps are shown across training frameworks for clearer cross-region monitoring and more consistent periodic timing output. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent b2923e8 commit b97ad98

5 files changed

Lines changed: 110 additions & 8 deletions

File tree

deepmd/loggers/training.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,65 @@
66
log = logging.getLogger(__name__)
77

88

9+
def _format_estimated_finish_time(
10+
eta_seconds: int,
11+
current_time: datetime.datetime | None = None,
12+
) -> str:
13+
"""Format the estimated local finish time.
14+
15+
Parameters
16+
----------
17+
eta_seconds : int
18+
Remaining time in seconds.
19+
current_time : datetime.datetime | None, optional
20+
Current local time used to estimate the finish timestamp. If ``None``,
21+
the current local time is used.
22+
23+
Returns
24+
-------
25+
str
26+
Estimated local finish time in ``YYYY-MM-DD HH:MM`` format.
27+
"""
28+
if current_time is None:
29+
current_time = datetime.datetime.now(datetime.timezone.utc).astimezone()
30+
elif current_time.tzinfo is not None:
31+
current_time = current_time.astimezone()
32+
finish_time = current_time + datetime.timedelta(seconds=eta_seconds)
33+
return finish_time.strftime("%Y-%m-%d %H:%M")
34+
35+
936
def format_training_message(
1037
batch: int,
1138
wall_time: float,
1239
eta: int | None = None,
40+
current_time: datetime.datetime | None = None,
1341
) -> str:
14-
"""Format a training message."""
42+
"""Format the summary message for one training interval.
43+
44+
Parameters
45+
----------
46+
batch : int
47+
The batch index.
48+
wall_time : float
49+
Wall-clock time shown in the progress message in seconds.
50+
eta : int | None, optional
51+
Remaining time in seconds.
52+
current_time : datetime.datetime | None, optional
53+
Current local time used to estimate the finish timestamp. This is only
54+
used when ``eta`` is provided.
55+
56+
Returns
57+
-------
58+
str
59+
The formatted training message.
60+
"""
1561
msg = f"Batch {batch:7d}: total wall time = {wall_time:.2f} s"
1662
if isinstance(eta, int):
17-
msg += f", eta = {datetime.timedelta(seconds=int(eta))!s}"
63+
eta_seconds = int(eta)
64+
msg += (
65+
f", eta = {datetime.timedelta(seconds=eta_seconds)!s} at "
66+
f"{_format_estimated_finish_time(eta_seconds, current_time=current_time)}"
67+
)
1868
return msg
1969

2070

@@ -39,6 +89,11 @@ def format_training_message_per_task(
3989
The learning rate
4090
check_total_rmse_nan : bool
4191
Whether to throw an error if the total RMSE is NaN
92+
93+
Returns
94+
-------
95+
str
96+
The formatted training message for the task.
4297
"""
4398
if task_name:
4499
task_name += ": "

deepmd/pd/train/training.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import contextlib
3+
import datetime
34
import functools
45
import logging
56
import time
@@ -982,6 +983,10 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
982983
batch=display_step_id,
983984
wall_time=train_time,
984985
eta=eta,
986+
current_time=datetime.datetime.fromtimestamp(
987+
current_time,
988+
tz=datetime.timezone.utc,
989+
).astimezone(),
985990
)
986991
)
987992
# the first training time is not accurate

deepmd/pt/train/training.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import datetime
23
import functools
34
import json
45
import logging
@@ -1334,6 +1335,10 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13341335
batch=display_step_id,
13351336
wall_time=train_time,
13361337
eta=eta,
1338+
current_time=datetime.datetime.fromtimestamp(
1339+
current_time,
1340+
tz=datetime.timezone.utc,
1341+
).astimezone(),
13371342
)
13381343
)
13391344
if (

deepmd/pt_expt/train/training.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
converted to torch tensors at the boundary.
77
"""
88

9+
import datetime
910
import functools
1011
import logging
1112
import time
@@ -30,6 +31,7 @@
3031
LearningRateExp,
3132
)
3233
from deepmd.loggers.training import (
34+
format_training_message,
3335
format_training_message_per_task,
3436
)
3537
from deepmd.pt_expt.loss import (
@@ -732,6 +734,7 @@ def run(self) -> None:
732734

733735
self.wrapper.train()
734736
wall_start = time.time()
737+
last_log_time = wall_start
735738

736739
for step_id in range(self.start_step, self.num_steps):
737740
cur_lr = float(self.lr_schedule.value(step_id))
@@ -792,17 +795,40 @@ def run(self) -> None:
792795
}
793796

794797
# wall-clock time
795-
wall_elapsed = time.time() - wall_start
798+
current_time = time.time()
799+
wall_elapsed = current_time - wall_start
800+
interval_wall_time = current_time - last_log_time
801+
last_log_time = current_time
796802
if self.timing_in_training:
797803
step_time = t_end - t_start
804+
steps_completed_since_restart = max(
805+
1,
806+
display_step_id - self.start_step,
807+
)
808+
eta = int(
809+
(self.num_steps - display_step_id)
810+
/ steps_completed_since_restart
811+
* wall_elapsed
812+
)
798813
log.info(
799-
"step=%d wall=%.2fs step_time=%.4fs",
800-
display_step_id,
801-
wall_elapsed,
802-
step_time,
814+
format_training_message(
815+
batch=display_step_id,
816+
wall_time=interval_wall_time,
817+
eta=eta,
818+
current_time=datetime.datetime.fromtimestamp(
819+
current_time,
820+
tz=datetime.timezone.utc,
821+
).astimezone(),
822+
)
803823
)
824+
log.info("step=%d step_time=%.4fs", display_step_id, step_time)
804825
else:
805-
log.info("step=%d wall=%.2fs", display_step_id, wall_elapsed)
826+
log.info(
827+
format_training_message(
828+
batch=display_step_id,
829+
wall_time=interval_wall_time,
830+
)
831+
)
806832

807833
# log
808834
log.info(

deepmd/tf/train/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
# SPDX-License-Identifier: LGPL-3.0-or-later
3+
import datetime
34
import logging
45
import os
56
import shutil
@@ -603,10 +604,20 @@ def train(
603604
toc = time.time()
604605
test_time = toc - tic
605606
wall_time = toc - wall_time_tic
607+
displayed_batches = max(
608+
1,
609+
min(self.disp_freq, int(cur_batch - start_batch)),
610+
)
611+
eta = int((stop_batch - cur_batch) / displayed_batches * wall_time)
606612
log.info(
607613
format_training_message(
608614
batch=cur_batch,
609615
wall_time=wall_time,
616+
eta=eta,
617+
current_time=datetime.datetime.fromtimestamp(
618+
toc,
619+
tz=datetime.timezone.utc,
620+
).astimezone(),
610621
)
611622
)
612623
# the first training time is not accurate

0 commit comments

Comments
 (0)