Skip to content

Commit 289139c

Browse files
committed
change to loop level logging
1 parent cdda167 commit 289139c

1 file changed

Lines changed: 47 additions & 44 deletions

File tree

recml/core/training/jax_trainer.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -558,50 +558,53 @@ def _write_marker_file(self):
558558
) as f:
559559
f.write("COMPLETED")
560560

561-
def _train_n_steps(
562-
self,
563-
train_iter: Iterator[PyTree],
564-
train_step: partitioning.StepFn,
565-
state: State,
566-
start_step: int,
567-
num_steps: int,
568-
summary_writer: metrics_tools.AsyncMultiWriter,
569-
) -> tuple[State, Mapping[str, Any]]:
570-
"""Performs a training loop and returns the updated state and metrics."""
571-
metrics_accum = metrics_tools.MetricAccumulator(summary_writer)
572-
for step in range(start_step, start_step + num_steps):
573-
with jax.profiler.StepTraceAnnotation("train", step_num=step):
574-
train_batch = next(train_iter)
575-
step_start = time.time()
576-
inputs = self._partitioner.shard_inputs(train_batch)
577-
state, metrics_update = train_step(inputs, state)
578-
579-
timing_metrics = {}
580-
if step - start_step > 10:
581-
jax.block_until_ready(metrics_update)
582-
step_duration = time.time() - step_start
583-
584-
timing_metrics = {
585-
"perf/step_time_ms": base_metrics.scalar(step_duration * 1000),
586-
"perf/steps_per_sec": base_metrics.scalar(
587-
1.0 / step_duration if step_duration > 0 else 0
588-
),
589-
}
590-
591-
if "common/batch_size" in metrics_update:
592-
bs = metrics_update["common/batch_size"].compute()
593-
timing_metrics["perf/throughput_ex_per_sec"] = (
594-
base_metrics.scalar(bs / step_duration)
595-
)
596-
597-
metrics_accum.accumulate({**metrics_update, **timing_metrics}, step)
598-
599-
self.report_progress(step)
600-
if (step != start_step + num_steps - 1) and self._enable_checkpointing:
601-
self._maybe_save_checkpoint(step, state)
602-
603-
metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1)
604-
return state, metrics
561+
def _train_n_steps(self, train_iter, train_step, state, start_step, num_steps, summary_writer):
562+
metrics_accum = metrics_tools.MetricAccumulator(summary_writer)
563+
564+
warmup_steps = 3
565+
total_examples_in_loop = 0
566+
valid_steps_in_loop = 0
567+
568+
569+
for step in range(start_step, start_step + num_steps):
570+
with jax.profiler.StepTraceAnnotation("train", step_num=step):
571+
if step == warmup_steps:
572+
loop_start_time = time.time()
573+
train_batch = next(train_iter)
574+
inputs = self._partitioner.shard_inputs(train_batch)
575+
576+
state, metrics_update = train_step(inputs, state)
577+
if step >= warmup_steps:
578+
if 'common/batch_size' in metrics_update:
579+
total_examples_in_loop += metrics_update['common/batch_size'].compute()
580+
valid_steps_in_loop += 1
581+
582+
metrics_accum.accumulate(metrics_update, step)
583+
self.report_progress(step)
584+
585+
if (step != start_step + num_steps - 1) and self._enable_checkpointing:
586+
self._maybe_save_checkpoint(step, state)
587+
588+
duration = time.time() - loop_start_time
589+
590+
metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1)
591+
592+
# Calculate and inject overall loop performance
593+
if valid_steps_in_loop > 0 and duration > 0:
594+
throughput = total_examples_in_loop / duration
595+
ms_per_step = (duration / valid_steps_in_loop) * 1000
596+
597+
metrics.update({
598+
'perf/loop_throughput_ex_per_sec': throughput,
599+
'perf/loop_ms_per_step': ms_per_step,
600+
})
601+
602+
summary_writer.write_scalars(start_step + num_steps - 1, {
603+
'perf/loop_throughput_ex_per_sec': throughput,
604+
'perf/loop_ms_per_step': ms_per_step,
605+
})
606+
607+
return state, metrics
605608

606609
def _evaluate_n_steps(
607610
self,

0 commit comments

Comments
 (0)