3535 destroy_distributed ,
3636 get_dp_group ,
3737 get_draft_dp_group ,
38+ get_draft_sp_group ,
39+ get_sp_ring_group ,
3840 get_tp_group ,
3941 init_distributed ,
4042)
@@ -96,6 +98,12 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
9698 choices = ["sglang" , "hf" , "custom" ],
9799 help = "The backend of the target model" ,
98100 )
101+ model_group .add_argument (
102+ "--shard-target-logits" ,
103+ action = "store_true" ,
104+ default = False ,
105+ help = "Shard target logits across TP ranks instead of materializing full target logits on every rank." ,
106+ )
99107
100108 # dataset arguments
101109 dataset_group = parser .add_argument_group ("dataset" )
@@ -291,6 +299,7 @@ def build_target_model(
291299 torch_dtype = torch .bfloat16 ,
292300 device = "cuda" ,
293301 cache_dir = args .model_download_dir ,
302+ shard_target_logits = args .shard_target_logits ,
294303 ** target_model_kwargs ,
295304 trust_remote_code = args .trust_remote_code ,
296305 )
@@ -339,8 +348,16 @@ def sanity_check(args: Namespace) -> None:
339348 """
340349 args .dp_size = dist .get_world_size () // args .tp_size
341350 args .target_batch_size = args .tp_size * args .batch_size
351+ if args .shard_target_logits :
352+ assert (
353+ args .target_model_backend == "sglang"
354+ ), "--shard-target-logits is only supported for the SGLang backend"
355+
342356 if args .attention_backend == "usp" :
343357 sp_sanity_check (args )
358+ if args .train_data_path is not None and args .train_hidden_states_path is None :
359+ sp_size = args .sp_ring_size * args .sp_ulysses_size
360+ args .target_batch_size = (args .tp_size // sp_size ) * args .batch_size
344361
345362
346363def sp_sanity_check (args : Namespace ) -> None :
@@ -356,7 +373,16 @@ def sp_sanity_check(args: Namespace) -> None:
356373 f"Got sp_ring_size={ args .sp_ring_size } , sp_ulysses_size={ args .sp_ulysses_size } ."
357374 )
358375
359- assert args .train_hidden_states_path is not None , f"USP only support offline mode"
376+ is_online = args .train_data_path is not None and args .train_hidden_states_path is None
377+ if is_online :
378+ sp_size = args .sp_ring_size * args .sp_ulysses_size
379+ assert args .shard_target_logits , "Online USP requires --shard-target-logits"
380+ assert not args .is_vlm , "Online USP with sharded target logits does not support VLM yet"
381+ assert (
382+ args .tp_size % sp_size == 0
383+ ), f"Online USP with sharded target logits requires tp_size ({ args .tp_size } ) to be divisible by SP size ({ sp_size } )"
384+ else :
385+ assert args .train_hidden_states_path is not None , f"USP only support offline mode"
360386
361387 if args .eval_data_path is not None and args .eval_hidden_states_path is not None :
362388 raise ValueError (
@@ -610,6 +636,18 @@ def run_forward(
610636 image_grid_thw = None
611637 if is_online :
612638 # we generate the eagle3 using the target model in an online fashion
639+ tp_size = dist .get_world_size (get_tp_group ())
640+ tp_rank = dist .get_rank (get_tp_group ())
641+ sequence_parallel = args .attention_backend == "usp"
642+ sp_group = get_draft_sp_group () if sequence_parallel else None
643+ sp_rank = dist .get_rank (sp_group ) if sequence_parallel else 0
644+ sp_size = dist .get_world_size (sp_group ) if sequence_parallel else 1
645+ target_dp_rank = tp_rank // sp_size if sequence_parallel else tp_rank
646+ target_dp_size = tp_size // sp_size if sequence_parallel else tp_size
647+ ring_group = get_sp_ring_group () if sequence_parallel else None
648+ sp_ring_rank = dist .get_rank (ring_group ) if sequence_parallel else 0
649+ sp_ring_size = dist .get_world_size (ring_group ) if sequence_parallel else 1
650+
613651 # Handle VLM data: pixel_values and image_grid_thw are lists
614652 # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
615653 if args .is_vlm :
@@ -626,19 +664,43 @@ def run_forward(
626664 is_vlm = args .is_vlm ,
627665 pixel_values = pixel_values ,
628666 image_grid_thw = image_grid_thw ,
667+ dp_rank = target_dp_rank ,
668+ dp_size = target_dp_size ,
669+ sequence_parallel = sequence_parallel ,
670+ sp_rank = sp_rank ,
671+ sp_size = sp_size ,
672+ sp_ring_rank = sp_ring_rank ,
673+ sp_ring_size = sp_ring_size ,
674+ ttt_length = args .ttt_length ,
629675 )
630676 else :
631677 eagle3_data = target_model .generate_eagle3_data (
632678 input_ids = data ["input_ids" ].cuda (),
633679 attention_mask = data ["attention_mask" ].cuda (),
634680 loss_mask = data ["loss_mask" ].cuda (),
681+ dp_rank = target_dp_rank ,
682+ dp_size = target_dp_size ,
683+ sequence_parallel = sequence_parallel ,
684+ sp_rank = sp_rank ,
685+ sp_size = sp_size ,
686+ sp_ring_rank = sp_ring_rank ,
687+ sp_ring_size = sp_ring_size ,
688+ ttt_length = args .ttt_length ,
635689 )
636690
637- input_ids = get_dp_data_shard_from_tp (eagle3_data .input_ids )
638- attention_mask = get_dp_data_shard_from_tp (eagle3_data .attention_mask )
639- loss_mask = get_dp_data_shard_from_tp (eagle3_data .loss_mask )
640- target = get_dp_data_shard_from_tp (eagle3_data .target )
641- hidden_states = get_dp_data_shard_from_tp (eagle3_data .hidden_states )
691+ if sequence_parallel or args .shard_target_logits :
692+ input_ids = eagle3_data .input_ids
693+ attention_mask = eagle3_data .attention_mask
694+ loss_mask = eagle3_data .loss_mask
695+ target = eagle3_data .target
696+ hidden_states = eagle3_data .hidden_states
697+ else :
698+ input_ids = get_dp_data_shard_from_tp (eagle3_data .input_ids )
699+ attention_mask = get_dp_data_shard_from_tp (eagle3_data .attention_mask )
700+ loss_mask = get_dp_data_shard_from_tp (eagle3_data .loss_mask )
701+ target = get_dp_data_shard_from_tp (eagle3_data .target )
702+ hidden_states = get_dp_data_shard_from_tp (eagle3_data .hidden_states )
703+ position_ids = eagle3_data .position_ids
642704 else :
643705 # we generate the logits using the hidden states loaded from disk
644706 attention_mask = data ["attention_mask" ].cuda ()
@@ -651,15 +713,14 @@ def run_forward(
651713 target .cuda ()
652714 ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU.
653715 loss_mask = loss_mask .cuda ()
716+ position_ids = data ["position_ids" ].cuda () if "position_ids" in data else None
654717 plosses , _ , acces = eagle3_model (
655718 input_ids = input_ids ,
656719 attention_mask = attention_mask ,
657720 loss_mask = loss_mask ,
658721 target = target ,
659722 hidden_states = hidden_states ,
660- position_ids = (
661- data ["position_ids" ].cuda () if "position_ids" in data else None
662- ),
723+ position_ids = position_ids ,
663724 image_grid_thw = image_grid_thw ,
664725 is_vlm = args .is_vlm ,
665726 )
0 commit comments