Skip to content

Commit 5d1c73e

Browse files
authored
Merge pull request #151 from AI-Hypercomputer/yujiedeng/throughput
add throughput related metrics
2 parents 5023bd5 + ce11f31 commit 5d1c73e

2 files changed

Lines changed: 37 additions & 10 deletions

File tree

recml/core/training/jax_trainer.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import pprint
2222
from typing import Any, Generic, Protocol, Self, TypeVar
23+
import time
2324

2425
from absl import logging
2526
from clu import data as clu_data
@@ -557,28 +558,52 @@ def _write_marker_file(self):
557558
) as f:
558559
f.write("COMPLETED")
559560

560-
def _train_n_steps(
561-
self,
562-
train_iter: Iterator[PyTree],
563-
train_step: partitioning.StepFn,
564-
state: State,
565-
start_step: int,
566-
num_steps: int,
567-
summary_writer: metrics_tools.AsyncMultiWriter,
568-
) -> tuple[State, Mapping[str, Any]]:
569-
"""Performs a training loop and returns the updated state and metrics."""
561+
def _train_n_steps(self, train_iter, train_step, state, start_step, num_steps, summary_writer):
570562
metrics_accum = metrics_tools.MetricAccumulator(summary_writer)
563+
564+
warmup_steps = 3
565+
total_examples_in_loop = 0
566+
valid_steps_in_loop = 0
567+
loop_start_time = time.time()
568+
571569
for step in range(start_step, start_step + num_steps):
572570
with jax.profiler.StepTraceAnnotation("train", step_num=step):
571+
if step == warmup_steps:
572+
loop_start_time = time.time()
573573
train_batch = next(train_iter)
574574
inputs = self._partitioner.shard_inputs(train_batch)
575+
575576
state, metrics_update = train_step(inputs, state)
577+
if step >= warmup_steps:
578+
if 'common/batch_size' in metrics_update:
579+
total_examples_in_loop += metrics_update['common/batch_size'].compute()
580+
valid_steps_in_loop += 1
581+
576582
metrics_accum.accumulate(metrics_update, step)
577583
self.report_progress(step)
584+
578585
if (step != start_step + num_steps - 1) and self._enable_checkpointing:
579586
self._maybe_save_checkpoint(step, state)
580587

588+
duration = time.time() - loop_start_time
589+
581590
metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1)
591+
592+
# Calculate and inject overall loop performance
593+
if valid_steps_in_loop > 0 and duration > 0:
594+
throughput = total_examples_in_loop / duration
595+
ms_per_step = (duration / valid_steps_in_loop) * 1000
596+
597+
metrics.update({
598+
'perf/loop_throughput_ex_per_sec': throughput,
599+
'perf/loop_ms_per_step': ms_per_step,
600+
})
601+
602+
summary_writer.write_scalars(start_step + num_steps - 1, {
603+
'perf/loop_throughput_ex_per_sec': throughput,
604+
'perf/loop_ms_per_step': ms_per_step,
605+
})
606+
582607
return state, metrics
583608

584609
def _evaluate_n_steps(

recml/examples/dlrm_experiment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def _loss_fn(params: jt.PyTree) -> tuple[jt.Scalar, jt.Array]:
276276
loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0)
277277
return loss, logits
278278

279+
global_batch_size = self.train_data.global_batch_size
279280
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True, allow_int=True)
280281
(loss, logits), grads = grad_fn(state.params)
281282
state = state.update(grads=grads)
@@ -287,6 +288,7 @@ def _loss_fn(params: jt.PyTree) -> tuple[jt.Scalar, jt.Array]:
287288
'aucroc': recml.metrics.aucroc(label, logits, from_logits=True),
288289
'label/mean': recml.metrics.mean(label),
289290
'prediction/mean': recml.metrics.mean(jax.nn.sigmoid(logits)),
291+
"common/batch_size": recml.metrics.scalar(global_batch_size),
290292
}
291293
return state, metrics
292294

0 commit comments

Comments
 (0)