Skip to content

Commit d330445

Browse files
committed
perf(fsdp2): collapse training metrics collectives to a single all_gather
calculate_training_metrics previously issued five independent all_reduce calls (mfu/sum/avg/min/max) on tiny scalar tensors, each paying full collective latency and an extra GPU->CPU sync via .item(). Replace them with one all_gather_into_tensor over a 2-element per-rank tensor [mfu_local, seq_len_sum_local], reduce locally (mean/sum/min/max), and do a single batched .tolist() to pull all scalars at once. Also drops the redundant torch.tensor(flops, device=...) wrapper since the callee now accepts a Python float directly, removing one host->device roundtrip per step.
1 parent 1001b50 commit d330445

1 file changed

Lines changed: 52 additions & 46 deletions

File tree

src/lmms_engine/train/fsdp2/fsdp2_trainer.py

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,13 @@ def train(self, resume_from_checkpoint: bool = False):
418418
)
419419
self.compute_tracker.accumulate_flops(raw_flops)
420420
device = self.fsdp2_model.device
421-
flops_tensor = torch.tensor(flops, device=device)
422421
sp_size = pgm.process_group_manager.cp_world_size
423422
tp_size = pgm.process_group_manager.tp_world_size
424423
parallel_size = sp_size * tp_size
425424

426425
# Calculate training metrics (MFU, token stats, throughput)
427426
perf_metrics, self.total_tokens = self.calculate_training_metrics(
428-
flops_tensor=flops_tensor,
427+
flops=flops,
429428
parallel_size=parallel_size,
430429
promised_flops=promised_flops,
431430
device=device,
@@ -664,7 +663,7 @@ def print_batch_input(self, batch):
664663

665664
@staticmethod
666665
def calculate_training_metrics(
667-
flops_tensor: torch.Tensor,
666+
flops: float,
668667
parallel_size: int,
669668
promised_flops: float,
670669
device: torch.device,
@@ -676,54 +675,61 @@ def calculate_training_metrics(
676675
"""
677676
Calculate training performance metrics including MFU, token statistics, and throughput.
678677
678+
Uses a single ``all_gather`` over a 2-element per-rank tensor
679+
``[mfu_local, seq_len_sum_local]`` and reduces (mean/sum/min/max)
680+
locally, replacing five independent ``all_reduce`` calls.
681+
679682
Args:
680-
flops_tensor: Tensor containing FLOPs count
681-
parallel_size: Product of sequence and tensor parallel sizes
682-
promised_flops: Promised FLOPs capacity
683-
device: Device to perform computations on
684-
seq_len: List of sequence lengths per batch
685-
total_tokens: Current total token count
686-
delta_time: Time taken for the training step
687-
world_size: Distributed training world size
683+
flops: Per-rank FLOPs count (Python float from ``estimate_flops``).
684+
parallel_size: Product of sequence and tensor parallel sizes.
685+
promised_flops: Promised FLOPs capacity.
686+
device: Device to perform computations on.
687+
seq_len: List of sequence lengths per batch (one entry per local sample).
688+
total_tokens: Current total token count.
689+
delta_time: Time taken for the training step.
690+
world_size: Distributed training world size.
688691
689692
Returns:
690693
tuple: (metrics_dict, updated_total_tokens)
691694
"""
692-
metrics = {}
693-
694-
# Calculate mfu per rank
695-
# Divide by parallel size because seq_len/flops are estimated before SP/TP sharding.
696-
mfu = flops_tensor.item() / parallel_size / promised_flops
697-
mfu = torch.tensor(mfu, device=device)
698-
torch.distributed.all_reduce(mfu, op=torch.distributed.ReduceOp.AVG)
699-
mfu = mfu.item()
700-
701-
# Calculating token stats
702-
seq_len = torch.tensor(seq_len, device=device, dtype=torch.float32)
703-
# Divide by parallel size to avoid counting replicated SP/TP batches multiple times.
704-
total_seq_len = seq_len.sum() / parallel_size
705-
torch.distributed.all_reduce(total_seq_len, op=torch.distributed.ReduceOp.SUM)
706-
# Avg seq len won't be effected by sp since we perform all reduce
707-
# across world size
708-
global_seq_len_avg = seq_len.sum()
709-
torch.distributed.all_reduce(global_seq_len_avg, op=torch.distributed.ReduceOp.AVG)
710-
metrics["perf/global_seq_len_avg"] = global_seq_len_avg.item()
711-
global_seq_len_min = seq_len.sum()
712-
torch.distributed.all_reduce(global_seq_len_min, op=torch.distributed.ReduceOp.MIN)
713-
metrics["perf/global_seq_len_min"] = global_seq_len_min.item()
714-
global_seq_len_max = seq_len.sum()
715-
torch.distributed.all_reduce(global_seq_len_max, op=torch.distributed.ReduceOp.MAX)
716-
metrics["perf/global_seq_len_max"] = global_seq_len_max.item()
717-
718-
metrics["train/mfu"] = round(mfu, 2)
719-
total_tokens += total_seq_len.item()
720-
721-
tokens_per_second = total_seq_len.item() / delta_time
695+
# Divide mfu by parallel size because seq_len/flops are estimated
696+
# before SP/TP sharding. seq_len comes in as a plain Python list so we
697+
# avoid an extra GPU sync and just sum it on the host.
698+
mfu_local = flops / parallel_size / promised_flops
699+
seq_len_sum_local = float(sum(seq_len))
700+
701+
local = torch.tensor([mfu_local, seq_len_sum_local], device=device, dtype=torch.float32)
702+
gathered = torch.empty(world_size * 2, device=device, dtype=torch.float32)
703+
torch.distributed.all_gather_into_tensor(gathered, local)
704+
gathered = gathered.view(world_size, 2)
705+
706+
mfu_all = gathered[:, 0]
707+
sl_all = gathered[:, 1]
708+
709+
# Reduce on-device, then do a single batched .tolist() sync to pull
710+
# all five scalars at once.
711+
reduced = torch.stack(
712+
[
713+
mfu_all.mean(),
714+
sl_all.sum() / parallel_size, # total_seq_len (deduped across SP/TP)
715+
sl_all.mean(), # global_seq_len_avg
716+
sl_all.amin(), # global_seq_len_min
717+
sl_all.amax(), # global_seq_len_max
718+
]
719+
).tolist()
720+
mfu, total_seq_len, global_seq_len_avg, global_seq_len_min, global_seq_len_max = reduced
721+
722+
total_tokens += total_seq_len
723+
tokens_per_second = total_seq_len / delta_time
722724
tokens_per_gpu = tokens_per_second / world_size
723725

724-
# Log total tokens and total tokens per second
725-
metrics["train/total_tokens"] = TrainUtilities.format_tokens(total_tokens)
726-
metrics["train/tokens_per_second"] = round(tokens_per_second)
727-
metrics["train/tokens_per_gpu"] = round(tokens_per_gpu)
728-
726+
metrics = {
727+
"train/mfu": round(mfu, 2),
728+
"perf/global_seq_len_avg": global_seq_len_avg,
729+
"perf/global_seq_len_min": global_seq_len_min,
730+
"perf/global_seq_len_max": global_seq_len_max,
731+
"train/total_tokens": TrainUtilities.format_tokens(total_tokens),
732+
"train/tokens_per_second": round(tokens_per_second),
733+
"train/tokens_per_gpu": round(tokens_per_gpu),
734+
}
729735
return metrics, total_tokens

0 commit comments

Comments
 (0)