Skip to content

Commit fa65e35

Browse files
[TRTLLM-12669][perf] Add torch.compile(max-autotune) to compute_probs_from_logits
Profiling on H200 shows +15% rejection sampling throughput (1135 → 1304 tok/s) at bs=16 with Qwen3-8B Eagle3 dynamic tree. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent b7d6987 commit fa65e35

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

tensorrt_llm/_torch/speculative/one_model_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def sampling_batch_spec_dec_one_model(
117117
return random_sampled
118118

119119

120+
@torch.compile(options={"max-autotune": True})
120121
def compute_probs_from_logits(
121122
logits: torch.Tensor,
122123
temperatures: torch.Tensor,

0 commit comments

Comments
 (0)