Skip to content

Commit cdda167

Browse files
committed
add throughput related metrics
1 parent 5023bd5 commit cdda167

2 files changed

Lines changed: 46 additions & 22 deletions

File tree

recml/core/training/jax_trainer.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import pprint
2222
from typing import Any, Generic, Protocol, Self, TypeVar
23+
import time
2324

2425
from absl import logging
2526
from clu import data as clu_data
@@ -558,28 +559,49 @@ def _write_marker_file(self):
558559
f.write("COMPLETED")
559560

560561
def _train_n_steps(
561-
self,
562-
train_iter: Iterator[PyTree],
563-
train_step: partitioning.StepFn,
564-
state: State,
565-
start_step: int,
566-
num_steps: int,
567-
summary_writer: metrics_tools.AsyncMultiWriter,
568-
) -> tuple[State, Mapping[str, Any]]:
569-
"""Performs a training loop and returns the updated state and metrics."""
570-
metrics_accum = metrics_tools.MetricAccumulator(summary_writer)
571-
for step in range(start_step, start_step + num_steps):
572-
with jax.profiler.StepTraceAnnotation("train", step_num=step):
573-
train_batch = next(train_iter)
574-
inputs = self._partitioner.shard_inputs(train_batch)
575-
state, metrics_update = train_step(inputs, state)
576-
metrics_accum.accumulate(metrics_update, step)
577-
self.report_progress(step)
578-
if (step != start_step + num_steps - 1) and self._enable_checkpointing:
579-
self._maybe_save_checkpoint(step, state)
580-
581-
metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1)
582-
return state, metrics
562+
self,
563+
train_iter: Iterator[PyTree],
564+
train_step: partitioning.StepFn,
565+
state: State,
566+
start_step: int,
567+
num_steps: int,
568+
summary_writer: metrics_tools.AsyncMultiWriter,
569+
) -> tuple[State, Mapping[str, Any]]:
570+
"""Performs a training loop and returns the updated state and metrics."""
571+
metrics_accum = metrics_tools.MetricAccumulator(summary_writer)
572+
for step in range(start_step, start_step + num_steps):
573+
with jax.profiler.StepTraceAnnotation("train", step_num=step):
574+
train_batch = next(train_iter)
575+
step_start = time.time()
576+
inputs = self._partitioner.shard_inputs(train_batch)
577+
state, metrics_update = train_step(inputs, state)
578+
579+
timing_metrics = {}
580+
if step - start_step > 10:
581+
jax.block_until_ready(metrics_update)
582+
step_duration = time.time() - step_start
583+
584+
timing_metrics = {
585+
"perf/step_time_ms": base_metrics.scalar(step_duration * 1000),
586+
"perf/steps_per_sec": base_metrics.scalar(
587+
1.0 / step_duration if step_duration > 0 else 0
588+
),
589+
}
590+
591+
if "common/batch_size" in metrics_update:
592+
bs = metrics_update["common/batch_size"].compute()
593+
timing_metrics["perf/throughput_ex_per_sec"] = (
594+
base_metrics.scalar(bs / step_duration)
595+
)
596+
597+
metrics_accum.accumulate({**metrics_update, **timing_metrics}, step)
598+
599+
self.report_progress(step)
600+
if (step != start_step + num_steps - 1) and self._enable_checkpointing:
601+
self._maybe_save_checkpoint(step, state)
602+
603+
metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1)
604+
return state, metrics
583605

584606
def _evaluate_n_steps(
585607
self,

recml/examples/dlrm_experiment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def _loss_fn(params: jt.PyTree) -> tuple[jt.Scalar, jt.Array]:
276276
loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0)
277277
return loss, logits
278278

279+
global_batch_size = self.train_data.global_batch_size
279280
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True, allow_int=True)
280281
(loss, logits), grads = grad_fn(state.params)
281282
state = state.update(grads=grads)
@@ -287,6 +288,7 @@ def _loss_fn(params: jt.PyTree) -> tuple[jt.Scalar, jt.Array]:
287288
'aucroc': recml.metrics.aucroc(label, logits, from_logits=True),
288289
'label/mean': recml.metrics.mean(label),
289290
'prediction/mean': recml.metrics.mean(jax.nn.sigmoid(logits)),
291+
"common/batch_size": recml.metrics.scalar(global_batch_size),
290292
}
291293
return state, metrics
292294

0 commit comments

Comments
 (0)