Skip to content

Commit d0e09ad

Browse files
[TRTLLM-12669][cleanup] Remove nvtx_range context markers from dynamic tree rejection paths
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent fa65e35 commit d0e09ad

1 file changed

Lines changed: 91 additions & 100 deletions

File tree

tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py

Lines changed: 91 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -769,30 +769,29 @@ def _sample_and_accept_dynamic_tree(
769769
self._accepted_draft_indices_tensor[:batch_size].fill_(-1)
770770

771771
num_flat_tokens = logits.shape[0]
772-
with nvtx_range("dyn_tree.greedy.sample_target", color="blue"):
773-
if not spec_metadata.is_all_greedy_sample:
774-
# Non-greedy: sample target tokens with per-request temperature/top_k/top_p.
775-
# Lazily initialize RNG tensors for CUDA graph compatibility.
776-
if self.seed is None:
777-
self.seed = torch.tensor([0], dtype=torch.int64, device=logits.device)
778-
self.offset = torch.tensor([0], dtype=torch.int64, device=logits.device)
779-
self.seed.add_(1).remainder_(2**31)
780-
top_ks = spec_metadata.top_ks[:num_flat_tokens]
781-
if self.use_flashinfer:
782-
top_ks = top_ks.clamp(min=1, max=logits.shape[-1] - 1)
783-
sampled = sampling_batch_spec_dec_one_model(
784-
logits,
785-
spec_metadata.temperatures[:num_flat_tokens],
786-
top_ks,
787-
spec_metadata.top_ps[:num_flat_tokens],
788-
use_flashinfer=self.use_flashinfer,
789-
seed=self.seed,
790-
offset=self.offset,
791-
)
792-
self._target_tokens_buf[:num_flat_tokens].copy_(sampled)
793-
else:
794-
# Greedy fast path (CUDA graph key: is_all_greedy_sample=True).
795-
torch.argmax(logits, dim=-1, out=self._target_tokens_buf[:num_flat_tokens])
772+
if not spec_metadata.is_all_greedy_sample:
773+
# Non-greedy: sample target tokens with per-request temperature/top_k/top_p.
774+
# Lazily initialize RNG tensors for CUDA graph compatibility.
775+
if self.seed is None:
776+
self.seed = torch.tensor([0], dtype=torch.int64, device=logits.device)
777+
self.offset = torch.tensor([0], dtype=torch.int64, device=logits.device)
778+
self.seed.add_(1).remainder_(2**31)
779+
top_ks = spec_metadata.top_ks[:num_flat_tokens]
780+
if self.use_flashinfer:
781+
top_ks = top_ks.clamp(min=1, max=logits.shape[-1] - 1)
782+
sampled = sampling_batch_spec_dec_one_model(
783+
logits,
784+
spec_metadata.temperatures[:num_flat_tokens],
785+
top_ks,
786+
spec_metadata.top_ps[:num_flat_tokens],
787+
use_flashinfer=self.use_flashinfer,
788+
seed=self.seed,
789+
offset=self.offset,
790+
)
791+
self._target_tokens_buf[:num_flat_tokens].copy_(sampled)
792+
else:
793+
# Greedy fast path (CUDA graph key: is_all_greedy_sample=True).
794+
torch.argmax(logits, dim=-1, out=self._target_tokens_buf[:num_flat_tokens])
796795
target_tokens = self._target_tokens_buf[:num_flat_tokens]
797796

798797
# Context requests: accept sampled token
@@ -812,41 +811,37 @@ def _sample_and_accept_dynamic_tree(
812811
self._accepted_draft_indices_tensor[num_contexts:batch_size] = -1
813812
return accepted_tokens, num_accepted_tokens
814813

815-
with nvtx_range("dyn_tree.greedy.build_candidates", color="blue"):
816-
candidates = self._candidates_buf[:num_gens]
817-
candidates[:, 1:] = spec_metadata.draft_tokens.reshape(num_gens, N - 1)
818-
candidates[:, 0] = target_predict[:, 0]
814+
candidates = self._candidates_buf[:num_gens]
815+
candidates[:, 1:] = spec_metadata.draft_tokens.reshape(num_gens, N - 1)
816+
candidates[:, 0] = target_predict[:, 0]
819817

820818
slot_storage = spec_tree_manager.slot_storage
821819
gen_slot_ids = slot_storage.all_ids_buf[num_contexts : num_contexts + num_gens]
822820
tree_valid = slot_storage.has_tree[gen_slot_ids]
823-
with nvtx_range("dyn_tree.greedy.pack_retrieve", color="blue"):
824-
retrieve_packed = slot_storage.pack_retrieve_from_slots(gen_slot_ids, num_gens)
825-
826-
with nvtx_range("dyn_tree.greedy.verify_greedy", color="cyan"):
827-
accept_index, accept_token_num, accept_token = (
828-
self.tree_ops_converter.verify_dynamic_tree_greedy_out_packed(
829-
candidates,
830-
retrieve_packed,
831-
target_predict,
832-
num_gens,
833-
self._max_path_len,
834-
tree_valid=tree_valid,
835-
)
821+
retrieve_packed = slot_storage.pack_retrieve_from_slots(gen_slot_ids, num_gens)
822+
823+
accept_index, accept_token_num, accept_token = (
824+
self.tree_ops_converter.verify_dynamic_tree_greedy_out_packed(
825+
candidates,
826+
retrieve_packed,
827+
target_predict,
828+
num_gens,
829+
self._max_path_len,
830+
tree_valid=tree_valid,
836831
)
832+
)
837833

838-
with nvtx_range("dyn_tree.greedy.finalize", color="blue"):
839-
self._finalize_dynamic_tree_verify_outputs(
840-
accept_index=accept_index,
841-
accept_token_num=accept_token_num,
842-
accept_token=accept_token,
843-
accepted_tokens=accepted_tokens,
844-
num_accepted_tokens=num_accepted_tokens,
845-
num_contexts=num_contexts,
846-
batch_size=batch_size,
847-
num_gens=num_gens,
848-
max_path_len=max_path_len,
849-
)
834+
self._finalize_dynamic_tree_verify_outputs(
835+
accept_index=accept_index,
836+
accept_token_num=accept_token_num,
837+
accept_token=accept_token,
838+
accepted_tokens=accepted_tokens,
839+
num_accepted_tokens=num_accepted_tokens,
840+
num_contexts=num_contexts,
841+
batch_size=batch_size,
842+
num_gens=num_gens,
843+
max_path_len=max_path_len,
844+
)
850845

851846
num_accepted_tokens = self._apply_force_accepted_tokens(
852847
num_accepted_tokens, num_contexts, self.max_draft_len
@@ -888,21 +883,20 @@ def _sample_and_accept_dynamic_tree_rejection(
888883
self.seed.add_(1).remainder_(2**31)
889884

890885
# Context tokens bypass the rejection kernel — sample them directly.
891-
with nvtx_range("dyn_tree.rej.sample_ctx", color="orange"):
892-
if num_contexts > 0:
893-
top_ks_ctx = spec_metadata.top_ks[:num_contexts]
894-
if self.use_flashinfer:
895-
top_ks_ctx = top_ks_ctx.clamp(min=1, max=vocab_size - 1)
896-
sampled_ctx = sampling_batch_spec_dec_one_model(
897-
logits[:num_contexts],
898-
spec_metadata.temperatures[:num_contexts],
899-
top_ks_ctx,
900-
spec_metadata.top_ps[:num_contexts],
901-
use_flashinfer=self.use_flashinfer,
902-
seed=self.seed,
903-
offset=self.offset,
904-
)
905-
accepted_tokens[:num_contexts, 0].copy_(sampled_ctx)
886+
if num_contexts > 0:
887+
top_ks_ctx = spec_metadata.top_ks[:num_contexts]
888+
if self.use_flashinfer:
889+
top_ks_ctx = top_ks_ctx.clamp(min=1, max=vocab_size - 1)
890+
sampled_ctx = sampling_batch_spec_dec_one_model(
891+
logits[:num_contexts],
892+
spec_metadata.temperatures[:num_contexts],
893+
top_ks_ctx,
894+
spec_metadata.top_ps[:num_contexts],
895+
use_flashinfer=self.use_flashinfer,
896+
seed=self.seed,
897+
offset=self.offset,
898+
)
899+
accepted_tokens[:num_contexts, 0].copy_(sampled_ctx)
906900

907901
if num_gens > 0:
908902
spec_tree_manager = self.spec_tree_manager
@@ -928,28 +922,26 @@ def _sample_and_accept_dynamic_tree_rejection(
928922
gen_slot_ids = slot_storage.all_ids_buf[num_contexts : num_contexts + num_gens]
929923
tree_valid = slot_storage.has_tree[gen_slot_ids]
930924

931-
with nvtx_range("dyn_tree.rej.next_links", color="orange"):
932-
retrieve_next_token, retrieve_next_sibling = slot_storage.next_links_from_slots(
933-
gen_slot_ids, num_gens
934-
)
925+
retrieve_next_token, retrieve_next_sibling = slot_storage.next_links_from_slots(
926+
gen_slot_ids, num_gens
927+
)
935928

936-
with nvtx_range("dyn_tree.rej.verify_rejection", color="red"):
937-
accept_index, accept_token_num, accept_token = (
938-
self.tree_ops_converter.verify_dynamic_tree_rejection_out(
939-
spec_metadata.draft_tokens.reshape(num_gens, N - 1).long(),
940-
target_logits_tree,
941-
retrieve_next_token,
942-
retrieve_next_sibling,
943-
tree_valid,
944-
temps,
945-
top_ks_rej,
946-
top_ps_rej,
947-
num_gens,
948-
self._max_path_len,
949-
seed=self.seed,
950-
offset=self.offset,
951-
)
929+
accept_index, accept_token_num, accept_token = (
930+
self.tree_ops_converter.verify_dynamic_tree_rejection_out(
931+
spec_metadata.draft_tokens.reshape(num_gens, N - 1).long(),
932+
target_logits_tree,
933+
retrieve_next_token,
934+
retrieve_next_sibling,
935+
tree_valid,
936+
temps,
937+
top_ks_rej,
938+
top_ps_rej,
939+
num_gens,
940+
self._max_path_len,
941+
seed=self.seed,
942+
offset=self.offset,
952943
)
944+
)
953945

954946
if self.force_num_accepted_tokens != 0.0:
955947
# Fill accept_token positions 1..max_path_len-1 with draft tokens so
@@ -967,18 +959,17 @@ def _sample_and_accept_dynamic_tree_rejection(
967959
)
968960
)
969961

970-
with nvtx_range("dyn_tree.rej.finalize", color="orange"):
971-
self._finalize_dynamic_tree_verify_outputs(
972-
accept_index=accept_index,
973-
accept_token_num=accept_token_num,
974-
accept_token=accept_token,
975-
accepted_tokens=accepted_tokens,
976-
num_accepted_tokens=num_accepted_tokens,
977-
num_contexts=num_contexts,
978-
batch_size=batch_size,
979-
num_gens=num_gens,
980-
max_path_len=max_path_len,
981-
)
962+
self._finalize_dynamic_tree_verify_outputs(
963+
accept_index=accept_index,
964+
accept_token_num=accept_token_num,
965+
accept_token=accept_token,
966+
accepted_tokens=accepted_tokens,
967+
num_accepted_tokens=num_accepted_tokens,
968+
num_contexts=num_contexts,
969+
batch_size=batch_size,
970+
num_gens=num_gens,
971+
max_path_len=max_path_len,
972+
)
982973

983974
num_accepted_tokens = self._apply_force_accepted_tokens(
984975
num_accepted_tokens, num_contexts, self.max_draft_len

0 commit comments

Comments
 (0)