Skip to content

Commit 2fe89eb

Browse files
authored
perf(fsdp2): reduce per-step host/comm overhead (#181)
* perf(fsdp2): overlap host-to-device copy with non_blocking=True send_to_device defaulted to non_blocking=False, which made the H2D transfer a synchronous step even when the dataloader produced pinned tensors (the repo's default). With non_blocking=True the copy is submitted to the CUDA stream and overlaps the first kernels of the following training_step, eliminating the dedicated h2d wait. Adds a one-time warning at trainer init when dataloader_pin_memory is False, since the flag becomes a no-op on pageable memory. * perf(fsdp2): collapse training metrics collectives to a single all_gather calculate_training_metrics previously issued five independent all_reduce calls (mfu/sum/avg/min/max) on tiny scalar tensors, each paying full collective latency and an extra GPU->CPU sync via .item(). Replace them with one all_gather_into_tensor over a 2-element per-rank tensor [mfu_local, seq_len_sum_local], reduce locally (mean/sum/min/max), and do a single batched .tolist() to pull all scalars at once. Also drops the redundant torch.tensor(flops, device=...) wrapper since the callee now accepts a Python float directly, removing one host->device roundtrip per step.
1 parent 025fe08 commit 2fe89eb

1 file changed

Lines changed: 63 additions & 47 deletions

File tree

src/lmms_engine/train/fsdp2/fsdp2_trainer.py

Lines changed: 63 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ def __init__(
9595
# Optional EMA (fully opt-in)
9696
self.ema = EMAHelper(self.args)
9797

98+
# send_to_device uses non_blocking=True to overlap H2D with the next
99+
# training step. This requires pinned memory; pageable memory falls
100+
# back to a synchronous copy and the flag becomes a no-op.
101+
if not self.args.dataloader_pin_memory:
102+
logger.warning(
103+
"send_to_device uses non_blocking=True but dataloader_pin_memory "
104+
"is False; H2D copies will fall back to synchronous. Enable "
105+
"dataloader_pin_memory for best throughput."
106+
)
107+
98108
# Optional Eval Server Backend (only on rank 0)
99109
self.eval_backend = None
100110
if dist.get_rank() == 0 and self.args.eval_config is not None and self.args.eval_strategy != "no":
@@ -380,7 +390,7 @@ def train(self, resume_from_checkpoint: bool = False):
380390
break
381391
# send batch to device
382392
with self.cuda_event_profiler.record("host_to_device", self.global_step):
383-
batch = send_to_device(batch, self.fsdp2_model.device)
393+
batch = send_to_device(batch, self.fsdp2_model.device, non_blocking=True)
384394
self.memory_snapshot_profiler.step(self.global_step)
385395
start_time = time.perf_counter()
386396
try:
@@ -408,14 +418,13 @@ def train(self, resume_from_checkpoint: bool = False):
408418
)
409419
self.compute_tracker.accumulate_flops(raw_flops)
410420
device = self.fsdp2_model.device
411-
flops_tensor = torch.tensor(flops, device=device)
412421
sp_size = pgm.process_group_manager.cp_world_size
413422
tp_size = pgm.process_group_manager.tp_world_size
414423
parallel_size = sp_size * tp_size
415424

416425
# Calculate training metrics (MFU, token stats, throughput)
417426
perf_metrics, self.total_tokens = self.calculate_training_metrics(
418-
flops_tensor=flops_tensor,
427+
flops=flops,
419428
parallel_size=parallel_size,
420429
promised_flops=promised_flops,
421430
device=device,
@@ -654,7 +663,7 @@ def print_batch_input(self, batch):
654663

655664
@staticmethod
656665
def calculate_training_metrics(
657-
flops_tensor: torch.Tensor,
666+
flops: float,
658667
parallel_size: int,
659668
promised_flops: float,
660669
device: torch.device,
@@ -666,54 +675,61 @@ def calculate_training_metrics(
666675
"""
667676
Calculate training performance metrics including MFU, token statistics, and throughput.
668677
678+
Uses a single ``all_gather`` over a 2-element per-rank tensor
679+
``[mfu_local, seq_len_sum_local]`` and reduces (mean/sum/min/max)
680+
locally, replacing five independent ``all_reduce`` calls.
681+
669682
Args:
670-
flops_tensor: Tensor containing FLOPs count
671-
parallel_size: Product of sequence and tensor parallel sizes
672-
promised_flops: Promised FLOPs capacity
673-
device: Device to perform computations on
674-
seq_len: List of sequence lengths per batch
675-
total_tokens: Current total token count
676-
delta_time: Time taken for the training step
677-
world_size: Distributed training world size
683+
flops: Per-rank FLOPs count (Python float from ``estimate_flops``).
684+
parallel_size: Product of sequence and tensor parallel sizes.
685+
promised_flops: Promised FLOPs capacity.
686+
device: Device to perform computations on.
687+
seq_len: List of sequence lengths per batch (one entry per local sample).
688+
total_tokens: Current total token count.
689+
delta_time: Time taken for the training step.
690+
world_size: Distributed training world size.
678691
679692
Returns:
680693
tuple: (metrics_dict, updated_total_tokens)
681694
"""
682-
metrics = {}
683-
684-
# Calculate mfu per rank
685-
# Divide by parallel size because seq_len/flops are estimated before SP/TP sharding.
686-
mfu = flops_tensor.item() / parallel_size / promised_flops
687-
mfu = torch.tensor(mfu, device=device)
688-
torch.distributed.all_reduce(mfu, op=torch.distributed.ReduceOp.AVG)
689-
mfu = mfu.item()
690-
691-
# Calculating token stats
692-
seq_len = torch.tensor(seq_len, device=device, dtype=torch.float32)
693-
# Divide by parallel size to avoid counting replicated SP/TP batches multiple times.
694-
total_seq_len = seq_len.sum() / parallel_size
695-
torch.distributed.all_reduce(total_seq_len, op=torch.distributed.ReduceOp.SUM)
696-
# Avg seq len won't be effected by sp since we perform all reduce
697-
# across world size
698-
global_seq_len_avg = seq_len.sum()
699-
torch.distributed.all_reduce(global_seq_len_avg, op=torch.distributed.ReduceOp.AVG)
700-
metrics["perf/global_seq_len_avg"] = global_seq_len_avg.item()
701-
global_seq_len_min = seq_len.sum()
702-
torch.distributed.all_reduce(global_seq_len_min, op=torch.distributed.ReduceOp.MIN)
703-
metrics["perf/global_seq_len_min"] = global_seq_len_min.item()
704-
global_seq_len_max = seq_len.sum()
705-
torch.distributed.all_reduce(global_seq_len_max, op=torch.distributed.ReduceOp.MAX)
706-
metrics["perf/global_seq_len_max"] = global_seq_len_max.item()
707-
708-
metrics["train/mfu"] = round(mfu, 2)
709-
total_tokens += total_seq_len.item()
710-
711-
tokens_per_second = total_seq_len.item() / delta_time
695+
# Divide mfu by parallel size because seq_len/flops are estimated
696+
# before SP/TP sharding. seq_len comes in as a plain Python list so we
697+
# avoid an extra GPU sync and just sum it on the host.
698+
mfu_local = flops / parallel_size / promised_flops
699+
seq_len_sum_local = float(sum(seq_len))
700+
701+
local = torch.tensor([mfu_local, seq_len_sum_local], device=device, dtype=torch.float32)
702+
gathered = torch.empty(world_size * 2, device=device, dtype=torch.float32)
703+
torch.distributed.all_gather_into_tensor(gathered, local)
704+
gathered = gathered.view(world_size, 2)
705+
706+
mfu_all = gathered[:, 0]
707+
sl_all = gathered[:, 1]
708+
709+
# Reduce on-device, then do a single batched .tolist() sync to pull
710+
# all five scalars at once.
711+
reduced = torch.stack(
712+
[
713+
mfu_all.mean(),
714+
sl_all.sum() / parallel_size, # total_seq_len (deduped across SP/TP)
715+
sl_all.mean(), # global_seq_len_avg
716+
sl_all.amin(), # global_seq_len_min
717+
sl_all.amax(), # global_seq_len_max
718+
]
719+
).tolist()
720+
mfu, total_seq_len, global_seq_len_avg, global_seq_len_min, global_seq_len_max = reduced
721+
722+
total_tokens += total_seq_len
723+
tokens_per_second = total_seq_len / delta_time
712724
tokens_per_gpu = tokens_per_second / world_size
713725

714-
# Log total tokens and total tokens per second
715-
metrics["train/total_tokens"] = TrainUtilities.format_tokens(total_tokens)
716-
metrics["train/tokens_per_second"] = round(tokens_per_second)
717-
metrics["train/tokens_per_gpu"] = round(tokens_per_gpu)
718-
726+
metrics = {
727+
"train/mfu": round(mfu, 2),
728+
"perf/global_seq_len_avg": global_seq_len_avg,
729+
"perf/global_seq_len_min": global_seq_len_min,
730+
"perf/global_seq_len_max": global_seq_len_max,
731+
"train/total_tokens": TrainUtilities.format_tokens(total_tokens),
732+
"train/tokens_per_second": round(tokens_per_second),
733+
"train/tokens_per_gpu": round(tokens_per_gpu),
734+
}
719735
return metrics, total_tokens

0 commit comments

Comments
 (0)