Skip to content

Commit ce11f31

Browse files
committed
minor fix
1 parent 289139c commit ce11f31

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

recml/core/training/jax_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def _train_n_steps(self, train_iter, train_step, state, start_step, num_steps, s
564564
warmup_steps = 3
565565
total_examples_in_loop = 0
566566
valid_steps_in_loop = 0
567-
567+
loop_start_time = time.time()
568568

569569
for step in range(start_step, start_step + num_steps):
570570
with jax.profiler.StepTraceAnnotation("train", step_num=step):

0 commit comments

Comments
 (0)