Skip to content

Commit df3dc51

Browse files
[TRTLLM-12669][chore] Apply CI yapf reformat to interface.py
CI yapf hook reformatted a few line wraps in interface.py — apply locally to keep CI green. No functional change. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent beb4b32 commit df3dc51

1 file changed

Lines changed: 7 additions & 8 deletions

File tree

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,8 +1057,8 @@ def _sample_and_accept_draft_tokens_base(
10571057
device=logits.device)
10581058

10591059
# Sample tokens using per-request sampling parameters
1060-
target_tokens = self._sample_tokens_for_batch(
1061-
logits, spec_metadata, num_contexts, batch_size)
1060+
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
1061+
num_contexts, batch_size)
10621062

10631063
# Context requests: only accept the sampled token (no draft tokens yet)
10641064
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
@@ -1072,8 +1072,7 @@ def _sample_and_accept_draft_tokens_base(
10721072
# Compare draft tokens with target tokens using cumulative product
10731073
# Counts consecutive matches from the start
10741074
num_accepted_tokens[num_contexts:] += torch.cumprod(
1075-
(draft_tokens
1076-
== gen_target_tokens[:, :runtime_draft_len]).int(),
1075+
(draft_tokens == gen_target_tokens[:, :runtime_draft_len]).int(),
10771076
dim=-1).sum(1)
10781077

10791078
# Apply force override if set
@@ -1182,8 +1181,9 @@ def _sample_and_accept_draft_tokens_rejection(
11821181

11831182
target_probs_flat = compute_probs_from_logits(
11841183
gen_logits, temperatures, top_ks, top_ps)
1185-
target_probs = target_probs_flat.reshape(
1186-
num_gens, runtime_draft_len + 1, vocab_size)
1184+
target_probs = target_probs_flat.reshape(num_gens,
1185+
runtime_draft_len + 1,
1186+
vocab_size)
11871187

11881188
draft_vocab_size = draft_probs.shape[-1]
11891189
assert draft_probs.shape[0] == num_gens, (
@@ -1200,8 +1200,7 @@ def _sample_and_accept_draft_tokens_rejection(
12001200
# configured, e.g. when use_rejection_sampling was off at
12011201
# prepare() time.
12021202
if spec_metadata.full_draft_probs is not None:
1203-
full_draft_probs = spec_metadata.full_draft_probs[:
1204-
num_gens]
1203+
full_draft_probs = spec_metadata.full_draft_probs[:num_gens]
12051204
else:
12061205
full_draft_probs = torch.zeros(
12071206
(num_gens, runtime_draft_len, vocab_size),

0 commit comments

Comments
 (0)