|
20 | 20 | import os |
21 | 21 | import pprint |
22 | 22 | from typing import Any, Generic, Protocol, Self, TypeVar |
| 23 | +import time |
23 | 24 |
|
24 | 25 | from absl import logging |
25 | 26 | from clu import data as clu_data |
@@ -557,28 +558,52 @@ def _write_marker_file(self): |
557 | 558 | ) as f: |
558 | 559 | f.write("COMPLETED") |
559 | 560 |
|
560 | | - def _train_n_steps( |
561 | | - self, |
562 | | - train_iter: Iterator[PyTree], |
563 | | - train_step: partitioning.StepFn, |
564 | | - state: State, |
565 | | - start_step: int, |
566 | | - num_steps: int, |
567 | | - summary_writer: metrics_tools.AsyncMultiWriter, |
568 | | - ) -> tuple[State, Mapping[str, Any]]: |
569 | | - """Performs a training loop and returns the updated state and metrics.""" |
| 561 | + def _train_n_steps(self, train_iter, train_step, state, start_step, num_steps, summary_writer): |
570 | 562 | metrics_accum = metrics_tools.MetricAccumulator(summary_writer) |
| 563 | + |
| 564 | + warmup_steps = 3 |
| 565 | + total_examples_in_loop = 0 |
| 566 | + valid_steps_in_loop = 0 |
| 567 | + loop_start_time = time.time() |
| 568 | + |
571 | 569 | for step in range(start_step, start_step + num_steps): |
572 | 570 | with jax.profiler.StepTraceAnnotation("train", step_num=step): |
| 571 | + if step == warmup_steps: |
| 572 | + loop_start_time = time.time() |
573 | 573 | train_batch = next(train_iter) |
574 | 574 | inputs = self._partitioner.shard_inputs(train_batch) |
| 575 | + |
575 | 576 | state, metrics_update = train_step(inputs, state) |
| 577 | + if step >= warmup_steps: |
| 578 | + if 'common/batch_size' in metrics_update: |
| 579 | + total_examples_in_loop += metrics_update['common/batch_size'].compute() |
| 580 | + valid_steps_in_loop += 1 |
| 581 | + |
576 | 582 | metrics_accum.accumulate(metrics_update, step) |
577 | 583 | self.report_progress(step) |
| 584 | + |
578 | 585 | if (step != start_step + num_steps - 1) and self._enable_checkpointing: |
579 | 586 | self._maybe_save_checkpoint(step, state) |
580 | 587 |
|
| 588 | + duration = time.time() - loop_start_time |
| 589 | + |
581 | 590 | metrics = metrics_accum.compute_and_log_scalars(start_step + num_steps - 1) |
| 591 | + |
| 592 | + # Calculate and inject overall loop performance |
| 593 | + if valid_steps_in_loop > 0 and duration > 0: |
| 594 | + throughput = total_examples_in_loop / duration |
| 595 | + ms_per_step = (duration / valid_steps_in_loop) * 1000 |
| 596 | + |
| 597 | + metrics.update({ |
| 598 | + 'perf/loop_throughput_ex_per_sec': throughput, |
| 599 | + 'perf/loop_ms_per_step': ms_per_step, |
| 600 | + }) |
| 601 | + |
| 602 | + summary_writer.write_scalars(start_step + num_steps - 1, { |
| 603 | + 'perf/loop_throughput_ex_per_sec': throughput, |
| 604 | + 'perf/loop_ms_per_step': ms_per_step, |
| 605 | + }) |
| 606 | + |
582 | 607 | return state, metrics |
583 | 608 |
|
584 | 609 | def _evaluate_n_steps( |
|
0 commit comments