|
20 | 20 | import os |
21 | 21 | import pprint |
22 | 22 | from typing import Any, Generic, Protocol, Self, TypeVar |
| 23 | +import time |
23 | 24 |
|
24 | 25 | from absl import logging |
25 | 26 | from clu import data as clu_data |
@@ -558,28 +559,49 @@ def _write_marker_file(self): |
558 | 559 | f.write("COMPLETED") |
559 | 560 |
|
560 | 561 | def _train_n_steps( |
561 | | - self, |
562 | | - train_iter: Iterator[PyTree], |
563 | | - train_step: partitioning.StepFn, |
564 | | - state: State, |
565 | | - start_step: int, |
566 | | - num_steps: int, |
567 | | - summary_writer: metrics_tools.AsyncMultiWriter, |
568 | | - ) -> tuple[State, Mapping[str, Any]]: |
569 | | - """Performs a training loop and returns the updated state and metrics.""" |
570 | | - metrics_accum = metrics_tools.MetricAccumulator(summary_writer) |
571 | | - for step in range(start_step, start_step + num_steps): |
572 | | - with jax.profiler.StepTraceAnnotation("train", step_num=step): |
573 | | - train_batch = next(train_iter) |
574 | | - inputs = self._partitioner.shard_inputs(train_batch) |
575 | | - state, metrics_update = train_step(inputs, state) |
576 | | - metrics_accum.accumulate(metrics_update, step) |
577 | | - self.report_progress(step) |
578 | | - if (step != start_step + num_steps - 1) and self._enable_checkpointing: |
579 | | - self._maybe_save_checkpoint(step, state) |
580 | | - |
581 | | - metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1) |
582 | | - return state, metrics |
| 562 | + self, |
| 563 | + train_iter: Iterator[PyTree], |
| 564 | + train_step: partitioning.StepFn, |
| 565 | + state: State, |
| 566 | + start_step: int, |
| 567 | + num_steps: int, |
| 568 | + summary_writer: metrics_tools.AsyncMultiWriter, |
| 569 | + ) -> tuple[State, Mapping[str, Any]]: |
| 570 | + """Performs a training loop and returns the updated state and metrics.""" |
| 571 | + metrics_accum = metrics_tools.MetricAccumulator(summary_writer) |
| 572 | + for step in range(start_step, start_step + num_steps): |
| 573 | + with jax.profiler.StepTraceAnnotation("train", step_num=step): |
| 574 | + train_batch = next(train_iter) |
| 575 | + step_start = time.time() |
| 576 | + inputs = self._partitioner.shard_inputs(train_batch) |
| 577 | + state, metrics_update = train_step(inputs, state) |
| 578 | + |
| 579 | + timing_metrics = {} |
| 580 | + if step - start_step > 10: |
| 581 | + jax.block_until_ready(metrics_update) |
| 582 | + step_duration = time.time() - step_start |
| 583 | + |
| 584 | + timing_metrics = { |
| 585 | + "perf/step_time_ms": base_metrics.scalar(step_duration * 1000), |
| 586 | + "perf/steps_per_sec": base_metrics.scalar( |
| 587 | + 1.0 / step_duration if step_duration > 0 else 0 |
| 588 | + ), |
| 589 | + } |
| 590 | + |
| 591 | + if "common/batch_size" in metrics_update: |
| 592 | + bs = metrics_update["common/batch_size"].compute() |
| 593 | + timing_metrics["perf/throughput_ex_per_sec"] = ( |
| 594 | + base_metrics.scalar(bs / step_duration) |
| 595 | + ) |
| 596 | + |
| 597 | + metrics_accum.accumulate({**metrics_update, **timing_metrics}, step) |
| 598 | + |
| 599 | + self.report_progress(step) |
| 600 | + if (step != start_step + num_steps - 1) and self._enable_checkpointing: |
| 601 | + self._maybe_save_checkpoint(step, state) |
| 602 | + |
| 603 | + metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1) |
| 604 | + return state, metrics |
583 | 605 |
|
584 | 606 | def _evaluate_n_steps( |
585 | 607 | self, |
|
0 commit comments