@@ -558,50 +558,53 @@ def _write_marker_file(self):
558558 ) as f :
559559 f .write ("COMPLETED" )
560560
561- def _train_n_steps (
562- self ,
563- train_iter : Iterator [PyTree ],
564- train_step : partitioning .StepFn ,
565- state : State ,
566- start_step : int ,
567- num_steps : int ,
568- summary_writer : metrics_tools .AsyncMultiWriter ,
569- ) -> tuple [State , Mapping [str , Any ]]:
570- """Performs a training loop and returns the updated state and metrics."""
571- metrics_accum = metrics_tools .MetricAccumulator (summary_writer )
572- for step in range (start_step , start_step + num_steps ):
573- with jax .profiler .StepTraceAnnotation ("train" , step_num = step ):
574- train_batch = next (train_iter )
575- step_start = time .time ()
576- inputs = self ._partitioner .shard_inputs (train_batch )
577- state , metrics_update = train_step (inputs , state )
578-
579- timing_metrics = {}
580- if step - start_step > 10 :
581- jax .block_until_ready (metrics_update )
582- step_duration = time .time () - step_start
583-
584- timing_metrics = {
585- "perf/step_time_ms" : base_metrics .scalar (step_duration * 1000 ),
586- "perf/steps_per_sec" : base_metrics .scalar (
587- 1.0 / step_duration if step_duration > 0 else 0
588- ),
589- }
590-
591- if "common/batch_size" in metrics_update :
592- bs = metrics_update ["common/batch_size" ].compute ()
593- timing_metrics ["perf/throughput_ex_per_sec" ] = (
594- base_metrics .scalar (bs / step_duration )
595- )
596-
597- metrics_accum .accumulate ({** metrics_update , ** timing_metrics }, step )
598-
599- self .report_progress (step )
600- if (step != start_step + num_steps - 1 ) and self ._enable_checkpointing :
601- self ._maybe_save_checkpoint (step , state )
602-
603- metrics = metrics_accum .compute_and_log_scalars (start_step + num_steps - 1 )
604- return state , metrics
561+ def _train_n_steps (self , train_iter , train_step , state , start_step , num_steps , summary_writer ):
562+ metrics_accum = metrics_tools .MetricAccumulator (summary_writer )
563+
564+ warmup_steps = 3
565+ total_examples_in_loop = 0
566+ valid_steps_in_loop = 0
567+
568+
569+ for step in range (start_step , start_step + num_steps ):
570+ with jax .profiler .StepTraceAnnotation ("train" , step_num = step ):
571+ if step == warmup_steps :
572+ loop_start_time = time .time ()
573+ train_batch = next (train_iter )
574+ inputs = self ._partitioner .shard_inputs (train_batch )
575+
576+ state , metrics_update = train_step (inputs , state )
577+ if step >= warmup_steps :
578+ if 'common/batch_size' in metrics_update :
579+ total_examples_in_loop += metrics_update ['common/batch_size' ].compute ()
580+ valid_steps_in_loop += 1
581+
582+ metrics_accum .accumulate (metrics_update , step )
583+ self .report_progress (step )
584+
585+ if (step != start_step + num_steps - 1 ) and self ._enable_checkpointing :
586+ self ._maybe_save_checkpoint (step , state )
587+
588+ duration = time .time () - loop_start_time
589+
590+ metrics = metrics_accum .compute_and_log_scalars (start_step + num_steps - 1 )
591+
592+ # Calculate and inject overall loop performance
593+ if valid_steps_in_loop > 0 and duration > 0 :
594+ throughput = total_examples_in_loop / duration
595+ ms_per_step = (duration / valid_steps_in_loop ) * 1000
596+
597+ metrics .update ({
598+ 'perf/loop_throughput_ex_per_sec' : throughput ,
599+ 'perf/loop_ms_per_step' : ms_per_step ,
600+ })
601+
602+ summary_writer .write_scalars (start_step + num_steps - 1 , {
603+ 'perf/loop_throughput_ex_per_sec' : throughput ,
604+ 'perf/loop_ms_per_step' : ms_per_step ,
605+ })
606+
607+ return state , metrics
605608
606609 def _evaluate_n_steps (
607610 self ,
0 commit comments