@@ -418,14 +418,13 @@ def train(self, resume_from_checkpoint: bool = False):
418418 )
419419 self .compute_tracker .accumulate_flops (raw_flops )
420420 device = self .fsdp2_model .device
421- flops_tensor = torch .tensor (flops , device = device )
422421 sp_size = pgm .process_group_manager .cp_world_size
423422 tp_size = pgm .process_group_manager .tp_world_size
424423 parallel_size = sp_size * tp_size
425424
426425 # Calculate training metrics (MFU, token stats, throughput)
427426 perf_metrics , self .total_tokens = self .calculate_training_metrics (
428- flops_tensor = flops_tensor ,
427+ flops = flops ,
429428 parallel_size = parallel_size ,
430429 promised_flops = promised_flops ,
431430 device = device ,
@@ -664,7 +663,7 @@ def print_batch_input(self, batch):
664663
665664 @staticmethod
666665 def calculate_training_metrics (
667- flops_tensor : torch . Tensor ,
666+ flops : float ,
668667 parallel_size : int ,
669668 promised_flops : float ,
670669 device : torch .device ,
@@ -676,54 +675,61 @@ def calculate_training_metrics(
676675 """
677676 Calculate training performance metrics including MFU, token statistics, and throughput.
678677
678+ Uses a single ``all_gather`` over a 2-element per-rank tensor
679+ ``[mfu_local, seq_len_sum_local]`` and reduces (mean/sum/min/max)
680+ locally, replacing five independent ``all_reduce`` calls.
681+
679682 Args:
680- flops_tensor: Tensor containing FLOPs count
681- parallel_size: Product of sequence and tensor parallel sizes
682- promised_flops: Promised FLOPs capacity
683- device: Device to perform computations on
684- seq_len: List of sequence lengths per batch
685- total_tokens: Current total token count
686- delta_time: Time taken for the training step
687- world_size: Distributed training world size
683+ flops: Per-rank FLOPs count (Python float from ``estimate_flops``).
684+ parallel_size: Product of sequence and tensor parallel sizes.
685+ promised_flops: Promised FLOPs capacity.
686+ device: Device to perform computations on.
687+ seq_len: List of sequence lengths per batch (one entry per local sample).
688+ total_tokens: Current total token count.
689+ delta_time: Time taken for the training step.
690+ world_size: Distributed training world size.
688691
689692 Returns:
690693 tuple: (metrics_dict, updated_total_tokens)
691694 """
692- metrics = {}
693-
694- # Calculate mfu per rank
695- # Divide by parallel size because seq_len/flops are estimated before SP/TP sharding.
696- mfu = flops_tensor .item () / parallel_size / promised_flops
697- mfu = torch .tensor (mfu , device = device )
698- torch .distributed .all_reduce (mfu , op = torch .distributed .ReduceOp .AVG )
699- mfu = mfu .item ()
700-
701- # Calculating token stats
702- seq_len = torch .tensor (seq_len , device = device , dtype = torch .float32 )
703- # Divide by parallel size to avoid counting replicated SP/TP batches multiple times.
704- total_seq_len = seq_len .sum () / parallel_size
705- torch .distributed .all_reduce (total_seq_len , op = torch .distributed .ReduceOp .SUM )
706- # Avg seq len won't be effected by sp since we perform all reduce
707- # across world size
708- global_seq_len_avg = seq_len .sum ()
709- torch .distributed .all_reduce (global_seq_len_avg , op = torch .distributed .ReduceOp .AVG )
710- metrics ["perf/global_seq_len_avg" ] = global_seq_len_avg .item ()
711- global_seq_len_min = seq_len .sum ()
712- torch .distributed .all_reduce (global_seq_len_min , op = torch .distributed .ReduceOp .MIN )
713- metrics ["perf/global_seq_len_min" ] = global_seq_len_min .item ()
714- global_seq_len_max = seq_len .sum ()
715- torch .distributed .all_reduce (global_seq_len_max , op = torch .distributed .ReduceOp .MAX )
716- metrics ["perf/global_seq_len_max" ] = global_seq_len_max .item ()
717-
718- metrics ["train/mfu" ] = round (mfu , 2 )
719- total_tokens += total_seq_len .item ()
720-
721- tokens_per_second = total_seq_len .item () / delta_time
695+ # Divide mfu by parallel size because seq_len/flops are estimated
696+ # before SP/TP sharding. seq_len comes in as a plain Python list so we
697+ # avoid an extra GPU sync and just sum it on the host.
698+ mfu_local = flops / parallel_size / promised_flops
699+ seq_len_sum_local = float (sum (seq_len ))
700+
701+ local = torch .tensor ([mfu_local , seq_len_sum_local ], device = device , dtype = torch .float32 )
702+ gathered = torch .empty (world_size * 2 , device = device , dtype = torch .float32 )
703+ torch .distributed .all_gather_into_tensor (gathered , local )
704+ gathered = gathered .view (world_size , 2 )
705+
706+ mfu_all = gathered [:, 0 ]
707+ sl_all = gathered [:, 1 ]
708+
709+ # Reduce on-device, then do a single batched .tolist() sync to pull
710+ # all five scalars at once.
711+ reduced = torch .stack (
712+ [
713+ mfu_all .mean (),
714+ sl_all .sum () / parallel_size , # total_seq_len (deduped across SP/TP)
715+ sl_all .mean (), # global_seq_len_avg
716+ sl_all .amin (), # global_seq_len_min
717+ sl_all .amax (), # global_seq_len_max
718+ ]
719+ ).tolist ()
720+ mfu , total_seq_len , global_seq_len_avg , global_seq_len_min , global_seq_len_max = reduced
721+
722+ total_tokens += total_seq_len
723+ tokens_per_second = total_seq_len / delta_time
722724 tokens_per_gpu = tokens_per_second / world_size
723725
724- # Log total tokens and total tokens per second
725- metrics ["train/total_tokens" ] = TrainUtilities .format_tokens (total_tokens )
726- metrics ["train/tokens_per_second" ] = round (tokens_per_second )
727- metrics ["train/tokens_per_gpu" ] = round (tokens_per_gpu )
728-
726+ metrics = {
727+ "train/mfu" : round (mfu , 2 ),
728+ "perf/global_seq_len_avg" : global_seq_len_avg ,
729+ "perf/global_seq_len_min" : global_seq_len_min ,
730+ "perf/global_seq_len_max" : global_seq_len_max ,
731+ "train/total_tokens" : TrainUtilities .format_tokens (total_tokens ),
732+ "train/tokens_per_second" : round (tokens_per_second ),
733+ "train/tokens_per_gpu" : round (tokens_per_gpu ),
734+ }
729735 return metrics , total_tokens
0 commit comments