Skip to content

Commit 4e2c8a1

Browse files
committed
refactor for a more explicit logic
1 parent 5665c0f commit 4e2c8a1

1 file changed

Lines changed: 13 additions & 9 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,13 @@ def log_loss_valid(_task_key="Default"):
942942
>= self.disp_freq # skip first disp_freq steps
943943
):
944944
self.total_train_time += train_time
945+
if display_step_id == 1:
946+
self.timed_steps += 1
947+
else:
948+
self.timed_steps += min(
949+
self.disp_freq, _step_id - self.start_step
950+
)
951+
print(f"{self.timed_steps=}")
945952

946953
if fout:
947954
if self.lcurve_should_print_header:
@@ -986,6 +993,7 @@ def log_loss_valid(_task_key="Default"):
986993
self.wrapper.train()
987994
self.t0 = time.time()
988995
self.total_train_time = 0.0
996+
self.timed_steps = 0
989997
for step_id in range(self.start_step, self.num_steps):
990998
step(step_id)
991999
if JIT:
@@ -1025,16 +1033,12 @@ def log_loss_valid(_task_key="Default"):
10251033
with open("checkpoint", "w") as f:
10261034
f.write(str(self.latest_model))
10271035

1028-
elapsed_steps = self.num_steps - self.start_step
10291036
if self.timing_in_training:
1030-
if elapsed_steps <= 2 * self.disp_freq:
1031-
log.info(
1032-
f"average training time: {self.total_train_time / elapsed_steps:.4f} s/batch"
1033-
)
1034-
else:
1035-
log.info(
1036-
f"average training time: {self.total_train_time / (elapsed_steps - self.disp_freq - elapsed_steps % self.disp_freq):.4f} s/batch (first {self.disp_freq} batches excluded)",
1037-
)
1037+
msg = f"average training time: {self.total_train_time / self.timed_steps:.4f} s/batch"
1038+
excluded_steps = self.num_steps - self.start_step - self.timed_steps
1039+
if excluded_steps > 0:
1040+
msg += f" ({excluded_steps} batches excluded)"
1041+
log.info(msg)
10381042

10391043
if JIT:
10401044
pth_model_path = (

0 commit comments

Comments
 (0)