@@ -769,30 +769,29 @@ def _sample_and_accept_dynamic_tree(
769769 self ._accepted_draft_indices_tensor [:batch_size ].fill_ (- 1 )
770770
771771 num_flat_tokens = logits .shape [0 ]
772- with nvtx_range ("dyn_tree.greedy.sample_target" , color = "blue" ):
773- if not spec_metadata .is_all_greedy_sample :
774- # Non-greedy: sample target tokens with per-request temperature/top_k/top_p.
775- # Lazily initialize RNG tensors for CUDA graph compatibility.
776- if self .seed is None :
777- self .seed = torch .tensor ([0 ], dtype = torch .int64 , device = logits .device )
778- self .offset = torch .tensor ([0 ], dtype = torch .int64 , device = logits .device )
779- self .seed .add_ (1 ).remainder_ (2 ** 31 )
780- top_ks = spec_metadata .top_ks [:num_flat_tokens ]
781- if self .use_flashinfer :
782- top_ks = top_ks .clamp (min = 1 , max = logits .shape [- 1 ] - 1 )
783- sampled = sampling_batch_spec_dec_one_model (
784- logits ,
785- spec_metadata .temperatures [:num_flat_tokens ],
786- top_ks ,
787- spec_metadata .top_ps [:num_flat_tokens ],
788- use_flashinfer = self .use_flashinfer ,
789- seed = self .seed ,
790- offset = self .offset ,
791- )
792- self ._target_tokens_buf [:num_flat_tokens ].copy_ (sampled )
793- else :
794- # Greedy fast path (CUDA graph key: is_all_greedy_sample=True).
795- torch .argmax (logits , dim = - 1 , out = self ._target_tokens_buf [:num_flat_tokens ])
772+ if not spec_metadata .is_all_greedy_sample :
773+ # Non-greedy: sample target tokens with per-request temperature/top_k/top_p.
774+ # Lazily initialize RNG tensors for CUDA graph compatibility.
775+ if self .seed is None :
776+ self .seed = torch .tensor ([0 ], dtype = torch .int64 , device = logits .device )
777+ self .offset = torch .tensor ([0 ], dtype = torch .int64 , device = logits .device )
778+ self .seed .add_ (1 ).remainder_ (2 ** 31 )
779+ top_ks = spec_metadata .top_ks [:num_flat_tokens ]
780+ if self .use_flashinfer :
781+ top_ks = top_ks .clamp (min = 1 , max = logits .shape [- 1 ] - 1 )
782+ sampled = sampling_batch_spec_dec_one_model (
783+ logits ,
784+ spec_metadata .temperatures [:num_flat_tokens ],
785+ top_ks ,
786+ spec_metadata .top_ps [:num_flat_tokens ],
787+ use_flashinfer = self .use_flashinfer ,
788+ seed = self .seed ,
789+ offset = self .offset ,
790+ )
791+ self ._target_tokens_buf [:num_flat_tokens ].copy_ (sampled )
792+ else :
793+ # Greedy fast path (CUDA graph key: is_all_greedy_sample=True).
794+ torch .argmax (logits , dim = - 1 , out = self ._target_tokens_buf [:num_flat_tokens ])
796795 target_tokens = self ._target_tokens_buf [:num_flat_tokens ]
797796
798797 # Context requests: accept sampled token
@@ -812,41 +811,37 @@ def _sample_and_accept_dynamic_tree(
812811 self ._accepted_draft_indices_tensor [num_contexts :batch_size ] = - 1
813812 return accepted_tokens , num_accepted_tokens
814813
815- with nvtx_range ("dyn_tree.greedy.build_candidates" , color = "blue" ):
816- candidates = self ._candidates_buf [:num_gens ]
817- candidates [:, 1 :] = spec_metadata .draft_tokens .reshape (num_gens , N - 1 )
818- candidates [:, 0 ] = target_predict [:, 0 ]
814+ candidates = self ._candidates_buf [:num_gens ]
815+ candidates [:, 1 :] = spec_metadata .draft_tokens .reshape (num_gens , N - 1 )
816+ candidates [:, 0 ] = target_predict [:, 0 ]
819817
820818 slot_storage = spec_tree_manager .slot_storage
821819 gen_slot_ids = slot_storage .all_ids_buf [num_contexts : num_contexts + num_gens ]
822820 tree_valid = slot_storage .has_tree [gen_slot_ids ]
823- with nvtx_range ("dyn_tree.greedy.pack_retrieve" , color = "blue" ):
824- retrieve_packed = slot_storage .pack_retrieve_from_slots (gen_slot_ids , num_gens )
825-
826- with nvtx_range ("dyn_tree.greedy.verify_greedy" , color = "cyan" ):
827- accept_index , accept_token_num , accept_token = (
828- self .tree_ops_converter .verify_dynamic_tree_greedy_out_packed (
829- candidates ,
830- retrieve_packed ,
831- target_predict ,
832- num_gens ,
833- self ._max_path_len ,
834- tree_valid = tree_valid ,
835- )
821+ retrieve_packed = slot_storage .pack_retrieve_from_slots (gen_slot_ids , num_gens )
822+
823+ accept_index , accept_token_num , accept_token = (
824+ self .tree_ops_converter .verify_dynamic_tree_greedy_out_packed (
825+ candidates ,
826+ retrieve_packed ,
827+ target_predict ,
828+ num_gens ,
829+ self ._max_path_len ,
830+ tree_valid = tree_valid ,
836831 )
832+ )
837833
838- with nvtx_range ("dyn_tree.greedy.finalize" , color = "blue" ):
839- self ._finalize_dynamic_tree_verify_outputs (
840- accept_index = accept_index ,
841- accept_token_num = accept_token_num ,
842- accept_token = accept_token ,
843- accepted_tokens = accepted_tokens ,
844- num_accepted_tokens = num_accepted_tokens ,
845- num_contexts = num_contexts ,
846- batch_size = batch_size ,
847- num_gens = num_gens ,
848- max_path_len = max_path_len ,
849- )
834+ self ._finalize_dynamic_tree_verify_outputs (
835+ accept_index = accept_index ,
836+ accept_token_num = accept_token_num ,
837+ accept_token = accept_token ,
838+ accepted_tokens = accepted_tokens ,
839+ num_accepted_tokens = num_accepted_tokens ,
840+ num_contexts = num_contexts ,
841+ batch_size = batch_size ,
842+ num_gens = num_gens ,
843+ max_path_len = max_path_len ,
844+ )
850845
851846 num_accepted_tokens = self ._apply_force_accepted_tokens (
852847 num_accepted_tokens , num_contexts , self .max_draft_len
@@ -888,21 +883,20 @@ def _sample_and_accept_dynamic_tree_rejection(
888883 self .seed .add_ (1 ).remainder_ (2 ** 31 )
889884
890885 # Context tokens bypass the rejection kernel — sample them directly.
891- with nvtx_range ("dyn_tree.rej.sample_ctx" , color = "orange" ):
892- if num_contexts > 0 :
893- top_ks_ctx = spec_metadata .top_ks [:num_contexts ]
894- if self .use_flashinfer :
895- top_ks_ctx = top_ks_ctx .clamp (min = 1 , max = vocab_size - 1 )
896- sampled_ctx = sampling_batch_spec_dec_one_model (
897- logits [:num_contexts ],
898- spec_metadata .temperatures [:num_contexts ],
899- top_ks_ctx ,
900- spec_metadata .top_ps [:num_contexts ],
901- use_flashinfer = self .use_flashinfer ,
902- seed = self .seed ,
903- offset = self .offset ,
904- )
905- accepted_tokens [:num_contexts , 0 ].copy_ (sampled_ctx )
886+ if num_contexts > 0 :
887+ top_ks_ctx = spec_metadata .top_ks [:num_contexts ]
888+ if self .use_flashinfer :
889+ top_ks_ctx = top_ks_ctx .clamp (min = 1 , max = vocab_size - 1 )
890+ sampled_ctx = sampling_batch_spec_dec_one_model (
891+ logits [:num_contexts ],
892+ spec_metadata .temperatures [:num_contexts ],
893+ top_ks_ctx ,
894+ spec_metadata .top_ps [:num_contexts ],
895+ use_flashinfer = self .use_flashinfer ,
896+ seed = self .seed ,
897+ offset = self .offset ,
898+ )
899+ accepted_tokens [:num_contexts , 0 ].copy_ (sampled_ctx )
906900
907901 if num_gens > 0 :
908902 spec_tree_manager = self .spec_tree_manager
@@ -928,28 +922,26 @@ def _sample_and_accept_dynamic_tree_rejection(
928922 gen_slot_ids = slot_storage .all_ids_buf [num_contexts : num_contexts + num_gens ]
929923 tree_valid = slot_storage .has_tree [gen_slot_ids ]
930924
931- with nvtx_range ("dyn_tree.rej.next_links" , color = "orange" ):
932- retrieve_next_token , retrieve_next_sibling = slot_storage .next_links_from_slots (
933- gen_slot_ids , num_gens
934- )
925+ retrieve_next_token , retrieve_next_sibling = slot_storage .next_links_from_slots (
926+ gen_slot_ids , num_gens
927+ )
935928
936- with nvtx_range ("dyn_tree.rej.verify_rejection" , color = "red" ):
937- accept_index , accept_token_num , accept_token = (
938- self .tree_ops_converter .verify_dynamic_tree_rejection_out (
939- spec_metadata .draft_tokens .reshape (num_gens , N - 1 ).long (),
940- target_logits_tree ,
941- retrieve_next_token ,
942- retrieve_next_sibling ,
943- tree_valid ,
944- temps ,
945- top_ks_rej ,
946- top_ps_rej ,
947- num_gens ,
948- self ._max_path_len ,
949- seed = self .seed ,
950- offset = self .offset ,
951- )
929+ accept_index , accept_token_num , accept_token = (
930+ self .tree_ops_converter .verify_dynamic_tree_rejection_out (
931+ spec_metadata .draft_tokens .reshape (num_gens , N - 1 ).long (),
932+ target_logits_tree ,
933+ retrieve_next_token ,
934+ retrieve_next_sibling ,
935+ tree_valid ,
936+ temps ,
937+ top_ks_rej ,
938+ top_ps_rej ,
939+ num_gens ,
940+ self ._max_path_len ,
941+ seed = self .seed ,
942+ offset = self .offset ,
952943 )
944+ )
953945
954946 if self .force_num_accepted_tokens != 0.0 :
955947 # Fill accept_token positions 1..max_path_len-1 with draft tokens so
@@ -967,18 +959,17 @@ def _sample_and_accept_dynamic_tree_rejection(
967959 )
968960 )
969961
970- with nvtx_range ("dyn_tree.rej.finalize" , color = "orange" ):
971- self ._finalize_dynamic_tree_verify_outputs (
972- accept_index = accept_index ,
973- accept_token_num = accept_token_num ,
974- accept_token = accept_token ,
975- accepted_tokens = accepted_tokens ,
976- num_accepted_tokens = num_accepted_tokens ,
977- num_contexts = num_contexts ,
978- batch_size = batch_size ,
979- num_gens = num_gens ,
980- max_path_len = max_path_len ,
981- )
962+ self ._finalize_dynamic_tree_verify_outputs (
963+ accept_index = accept_index ,
964+ accept_token_num = accept_token_num ,
965+ accept_token = accept_token ,
966+ accepted_tokens = accepted_tokens ,
967+ num_accepted_tokens = num_accepted_tokens ,
968+ num_contexts = num_contexts ,
969+ batch_size = batch_size ,
970+ num_gens = num_gens ,
971+ max_path_len = max_path_len ,
972+ )
982973
983974 num_accepted_tokens = self ._apply_force_accepted_tokens (
984975 num_accepted_tokens , num_contexts , self .max_draft_len
0 commit comments