@@ -942,6 +942,13 @@ def log_loss_valid(_task_key="Default"):
942942 >= self .disp_freq # skip first disp_freq steps
943943 ):
944944 self .total_train_time += train_time
945+ if display_step_id == 1 :
946+ self .timed_steps += 1
947+ else :
948+ self .timed_steps += min (
949+ self .disp_freq , _step_id - self .start_step
950+ )
951+ print (f"{ self .timed_steps = } " )
945952
946953 if fout :
947954 if self .lcurve_should_print_header :
@@ -986,6 +993,7 @@ def log_loss_valid(_task_key="Default"):
986993 self .wrapper .train ()
987994 self .t0 = time .time ()
988995 self .total_train_time = 0.0
996+ self .timed_steps = 0
989997 for step_id in range (self .start_step , self .num_steps ):
990998 step (step_id )
991999 if JIT :
@@ -1025,16 +1033,12 @@ def log_loss_valid(_task_key="Default"):
10251033 with open ("checkpoint" , "w" ) as f :
10261034 f .write (str (self .latest_model ))
10271035
1028- elapsed_steps = self .num_steps - self .start_step
10291036 if self .timing_in_training :
1030- if elapsed_steps <= 2 * self .disp_freq :
1031- log .info (
1032- f"average training time: { self .total_train_time / elapsed_steps :.4f} s/batch"
1033- )
1034- else :
1035- log .info (
1036- f"average training time: { self .total_train_time / (elapsed_steps - self .disp_freq - elapsed_steps % self .disp_freq ):.4f} s/batch (first { self .disp_freq } batches excluded)" ,
1037- )
1037+ msg = f"average training time: { self .total_train_time / self .timed_steps :.4f} s/batch"
1038+ excluded_steps = self .num_steps - self .start_step - self .timed_steps
1039+ if excluded_steps > 0 :
1040+ msg += f" ({ excluded_steps } batches excluded)"
1041+ log .info (msg )
10381042
10391043 if JIT :
10401044 pth_model_path = (
0 commit comments