Skip to content

Commit 226a1da

Browse files
committed
Support sharded target logits for EAGLE3 online training
1 parent 0a7006f commit 226a1da

3 files changed

Lines changed: 653 additions & 100 deletions

File tree

scripts/train_eagle3.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
destroy_distributed,
3636
get_dp_group,
3737
get_draft_dp_group,
38+
get_draft_sp_group,
39+
get_sp_ring_group,
3840
get_tp_group,
3941
init_distributed,
4042
)
@@ -96,6 +98,12 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
9698
choices=["sglang", "hf", "custom"],
9799
help="The backend of the target model",
98100
)
101+
model_group.add_argument(
102+
"--shard-target-logits",
103+
action="store_true",
104+
default=False,
105+
help="Shard target logits across TP ranks instead of materializing full target logits on every rank.",
106+
)
99107

100108
# dataset arguments
101109
dataset_group = parser.add_argument_group("dataset")
@@ -291,6 +299,7 @@ def build_target_model(
291299
torch_dtype=torch.bfloat16,
292300
device="cuda",
293301
cache_dir=args.model_download_dir,
302+
shard_target_logits=args.shard_target_logits,
294303
**target_model_kwargs,
295304
trust_remote_code=args.trust_remote_code,
296305
)
@@ -339,8 +348,16 @@ def sanity_check(args: Namespace) -> None:
339348
"""
340349
args.dp_size = dist.get_world_size() // args.tp_size
341350
args.target_batch_size = args.tp_size * args.batch_size
351+
if args.shard_target_logits:
352+
assert (
353+
args.target_model_backend == "sglang"
354+
), "--shard-target-logits is only supported for the SGLang backend"
355+
342356
if args.attention_backend == "usp":
343357
sp_sanity_check(args)
358+
if args.train_data_path is not None and args.train_hidden_states_path is None:
359+
sp_size = args.sp_ring_size * args.sp_ulysses_size
360+
args.target_batch_size = (args.tp_size // sp_size) * args.batch_size
344361

345362

346363
def sp_sanity_check(args: Namespace) -> None:
@@ -356,7 +373,16 @@ def sp_sanity_check(args: Namespace) -> None:
356373
f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}."
357374
)
358375

359-
assert args.train_hidden_states_path is not None, f"USP only support offline mode"
376+
is_online = args.train_data_path is not None and args.train_hidden_states_path is None
377+
if is_online:
378+
sp_size = args.sp_ring_size * args.sp_ulysses_size
379+
assert args.shard_target_logits, "Online USP requires --shard-target-logits"
380+
assert not args.is_vlm, "Online USP with sharded target logits does not support VLM yet"
381+
assert (
382+
args.tp_size % sp_size == 0
383+
), f"Online USP with sharded target logits requires tp_size ({args.tp_size}) to be divisible by SP size ({sp_size})"
384+
else:
385+
assert args.train_hidden_states_path is not None, f"USP only support offline mode"
360386

361387
if args.eval_data_path is not None and args.eval_hidden_states_path is not None:
362388
raise ValueError(
@@ -610,6 +636,18 @@ def run_forward(
610636
image_grid_thw = None
611637
if is_online:
612638
# we generate the eagle3 using the target model in an online fashion
639+
tp_size = dist.get_world_size(get_tp_group())
640+
tp_rank = dist.get_rank(get_tp_group())
641+
sequence_parallel = args.attention_backend == "usp"
642+
sp_group = get_draft_sp_group() if sequence_parallel else None
643+
sp_rank = dist.get_rank(sp_group) if sequence_parallel else 0
644+
sp_size = dist.get_world_size(sp_group) if sequence_parallel else 1
645+
target_dp_rank = tp_rank // sp_size if sequence_parallel else tp_rank
646+
target_dp_size = tp_size // sp_size if sequence_parallel else tp_size
647+
ring_group = get_sp_ring_group() if sequence_parallel else None
648+
sp_ring_rank = dist.get_rank(ring_group) if sequence_parallel else 0
649+
sp_ring_size = dist.get_world_size(ring_group) if sequence_parallel else 1
650+
613651
# Handle VLM data: pixel_values and image_grid_thw are lists
614652
# pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
615653
if args.is_vlm:
@@ -626,19 +664,43 @@ def run_forward(
626664
is_vlm=args.is_vlm,
627665
pixel_values=pixel_values,
628666
image_grid_thw=image_grid_thw,
667+
dp_rank=target_dp_rank,
668+
dp_size=target_dp_size,
669+
sequence_parallel=sequence_parallel,
670+
sp_rank=sp_rank,
671+
sp_size=sp_size,
672+
sp_ring_rank=sp_ring_rank,
673+
sp_ring_size=sp_ring_size,
674+
ttt_length=args.ttt_length,
629675
)
630676
else:
631677
eagle3_data = target_model.generate_eagle3_data(
632678
input_ids=data["input_ids"].cuda(),
633679
attention_mask=data["attention_mask"].cuda(),
634680
loss_mask=data["loss_mask"].cuda(),
681+
dp_rank=target_dp_rank,
682+
dp_size=target_dp_size,
683+
sequence_parallel=sequence_parallel,
684+
sp_rank=sp_rank,
685+
sp_size=sp_size,
686+
sp_ring_rank=sp_ring_rank,
687+
sp_ring_size=sp_ring_size,
688+
ttt_length=args.ttt_length,
635689
)
636690

637-
input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
638-
attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask)
639-
loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask)
640-
target = get_dp_data_shard_from_tp(eagle3_data.target)
641-
hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states)
691+
if sequence_parallel or args.shard_target_logits:
692+
input_ids = eagle3_data.input_ids
693+
attention_mask = eagle3_data.attention_mask
694+
loss_mask = eagle3_data.loss_mask
695+
target = eagle3_data.target
696+
hidden_states = eagle3_data.hidden_states
697+
else:
698+
input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
699+
attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask)
700+
loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask)
701+
target = get_dp_data_shard_from_tp(eagle3_data.target)
702+
hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states)
703+
position_ids = eagle3_data.position_ids
642704
else:
643705
# we generate the logits using the hidden states loaded from disk
644706
attention_mask = data["attention_mask"].cuda()
@@ -651,15 +713,14 @@ def run_forward(
651713
target.cuda()
652714
) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU.
653715
loss_mask = loss_mask.cuda()
716+
position_ids = data["position_ids"].cuda() if "position_ids" in data else None
654717
plosses, _, acces = eagle3_model(
655718
input_ids=input_ids,
656719
attention_mask=attention_mask,
657720
loss_mask=loss_mask,
658721
target=target,
659722
hidden_states=hidden_states,
660-
position_ids=(
661-
data["position_ids"].cuda() if "position_ids" in data else None
662-
),
723+
position_ids=position_ids,
663724
image_grid_thw=image_grid_thw,
664725
is_vlm=args.is_vlm,
665726
)

0 commit comments

Comments
 (0)