@@ -95,6 +95,16 @@ def __init__(
9595 # Optional EMA (fully opt-in)
9696 self .ema = EMAHelper (self .args )
9797
98+ # send_to_device uses non_blocking=True to overlap H2D with the next
99+ # training step. This requires pinned memory; pageable memory falls
100+ # back to a synchronous copy and the flag becomes a no-op.
101+ if not self .args .dataloader_pin_memory :
102+ logger .warning (
103+ "send_to_device uses non_blocking=True but dataloader_pin_memory "
104+ "is False; H2D copies will fall back to synchronous. Enable "
105+ "dataloader_pin_memory for best throughput."
106+ )
107+
98108 # Optional Eval Server Backend (only on rank 0)
99109 self .eval_backend = None
100110 if dist .get_rank () == 0 and self .args .eval_config is not None and self .args .eval_strategy != "no" :
@@ -380,7 +390,7 @@ def train(self, resume_from_checkpoint: bool = False):
380390 break
381391 # send batch to device
382392 with self .cuda_event_profiler .record ("host_to_device" , self .global_step ):
383- batch = send_to_device (batch , self .fsdp2_model .device )
393+ batch = send_to_device (batch , self .fsdp2_model .device , non_blocking = True )
384394 self .memory_snapshot_profiler .step (self .global_step )
385395 start_time = time .perf_counter ()
386396 try :
@@ -408,14 +418,13 @@ def train(self, resume_from_checkpoint: bool = False):
408418 )
409419 self .compute_tracker .accumulate_flops (raw_flops )
410420 device = self .fsdp2_model .device
411- flops_tensor = torch .tensor (flops , device = device )
412421 sp_size = pgm .process_group_manager .cp_world_size
413422 tp_size = pgm .process_group_manager .tp_world_size
414423 parallel_size = sp_size * tp_size
415424
416425 # Calculate training metrics (MFU, token stats, throughput)
417426 perf_metrics , self .total_tokens = self .calculate_training_metrics (
418- flops_tensor = flops_tensor ,
427+ flops = flops ,
419428 parallel_size = parallel_size ,
420429 promised_flops = promised_flops ,
421430 device = device ,
@@ -654,7 +663,7 @@ def print_batch_input(self, batch):
654663
655664 @staticmethod
656665 def calculate_training_metrics (
657- flops_tensor : torch . Tensor ,
666+ flops : float ,
658667 parallel_size : int ,
659668 promised_flops : float ,
660669 device : torch .device ,
@@ -666,54 +675,61 @@ def calculate_training_metrics(
666675 """
667676 Calculate training performance metrics including MFU, token statistics, and throughput.
668677
678+ Uses a single ``all_gather`` over a 2-element per-rank tensor
679+ ``[mfu_local, seq_len_sum_local]`` and reduces (mean/sum/min/max)
680+ locally, replacing five independent ``all_reduce`` calls.
681+
669682 Args:
670- flops_tensor: Tensor containing FLOPs count
671- parallel_size: Product of sequence and tensor parallel sizes
672- promised_flops: Promised FLOPs capacity
673- device: Device to perform computations on
674- seq_len: List of sequence lengths per batch
675- total_tokens: Current total token count
676- delta_time: Time taken for the training step
677- world_size: Distributed training world size
683+ flops: Per-rank FLOPs count (Python float from ``estimate_flops``).
684+ parallel_size: Product of sequence and tensor parallel sizes.
685+ promised_flops: Promised FLOPs capacity.
686+ device: Device to perform computations on.
687+ seq_len: List of sequence lengths per batch (one entry per local sample).
688+ total_tokens: Current total token count.
689+ delta_time: Time taken for the training step.
690+ world_size: Distributed training world size.
678691
679692 Returns:
680693 tuple: (metrics_dict, updated_total_tokens)
681694 """
682- metrics = {}
683-
684- # Calculate mfu per rank
685- # Divide by parallel size because seq_len/flops are estimated before SP/TP sharding.
686- mfu = flops_tensor .item () / parallel_size / promised_flops
687- mfu = torch .tensor (mfu , device = device )
688- torch .distributed .all_reduce (mfu , op = torch .distributed .ReduceOp .AVG )
689- mfu = mfu .item ()
690-
691- # Calculating token stats
692- seq_len = torch .tensor (seq_len , device = device , dtype = torch .float32 )
693- # Divide by parallel size to avoid counting replicated SP/TP batches multiple times.
694- total_seq_len = seq_len .sum () / parallel_size
695- torch .distributed .all_reduce (total_seq_len , op = torch .distributed .ReduceOp .SUM )
696- # Avg seq len won't be effected by sp since we perform all reduce
697- # across world size
698- global_seq_len_avg = seq_len .sum ()
699- torch .distributed .all_reduce (global_seq_len_avg , op = torch .distributed .ReduceOp .AVG )
700- metrics ["perf/global_seq_len_avg" ] = global_seq_len_avg .item ()
701- global_seq_len_min = seq_len .sum ()
702- torch .distributed .all_reduce (global_seq_len_min , op = torch .distributed .ReduceOp .MIN )
703- metrics ["perf/global_seq_len_min" ] = global_seq_len_min .item ()
704- global_seq_len_max = seq_len .sum ()
705- torch .distributed .all_reduce (global_seq_len_max , op = torch .distributed .ReduceOp .MAX )
706- metrics ["perf/global_seq_len_max" ] = global_seq_len_max .item ()
707-
708- metrics ["train/mfu" ] = round (mfu , 2 )
709- total_tokens += total_seq_len .item ()
710-
711- tokens_per_second = total_seq_len .item () / delta_time
695+ # Divide mfu by parallel size because seq_len/flops are estimated
696+ # before SP/TP sharding. seq_len comes in as a plain Python list so we
697+ # avoid an extra GPU sync and just sum it on the host.
698+ mfu_local = flops / parallel_size / promised_flops
699+ seq_len_sum_local = float (sum (seq_len ))
700+
701+ local = torch .tensor ([mfu_local , seq_len_sum_local ], device = device , dtype = torch .float32 )
702+ gathered = torch .empty (world_size * 2 , device = device , dtype = torch .float32 )
703+ torch .distributed .all_gather_into_tensor (gathered , local )
704+ gathered = gathered .view (world_size , 2 )
705+
706+ mfu_all = gathered [:, 0 ]
707+ sl_all = gathered [:, 1 ]
708+
709+ # Reduce on-device, then do a single batched .tolist() sync to pull
710+ # all five scalars at once.
711+ reduced = torch .stack (
712+ [
713+ mfu_all .mean (),
714+ sl_all .sum () / parallel_size , # total_seq_len (deduped across SP/TP)
715+ sl_all .mean (), # global_seq_len_avg
716+ sl_all .amin (), # global_seq_len_min
717+ sl_all .amax (), # global_seq_len_max
718+ ]
719+ ).tolist ()
720+ mfu , total_seq_len , global_seq_len_avg , global_seq_len_min , global_seq_len_max = reduced
721+
722+ total_tokens += total_seq_len
723+ tokens_per_second = total_seq_len / delta_time
712724 tokens_per_gpu = tokens_per_second / world_size
713725
714- # Log total tokens and total tokens per second
715- metrics ["train/total_tokens" ] = TrainUtilities .format_tokens (total_tokens )
716- metrics ["train/tokens_per_second" ] = round (tokens_per_second )
717- metrics ["train/tokens_per_gpu" ] = round (tokens_per_gpu )
718-
726+ metrics = {
727+ "train/mfu" : round (mfu , 2 ),
728+ "perf/global_seq_len_avg" : global_seq_len_avg ,
729+ "perf/global_seq_len_min" : global_seq_len_min ,
730+ "perf/global_seq_len_max" : global_seq_len_max ,
731+ "train/total_tokens" : TrainUtilities .format_tokens (total_tokens ),
732+ "train/tokens_per_second" : round (tokens_per_second ),
733+ "train/tokens_per_gpu" : round (tokens_per_gpu ),
734+ }
719735 return metrics , total_tokens
0 commit comments