Skip to content

Commit 6f4b612

Browse files
committed
fix: address Copilot review feedback — CPU tensor, redundant assignment, precompute max word idx
1 parent 9faeefa commit 6f4b612

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

lmdeploy/pytorch/spec_decode/proposers/eagle3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def _init_bitmask_translate_constants(self):
3434
self._draft_words = draft_indices // 32
3535
self._draft_bits = draft_indices % 32
3636
self._n_draft_words = (draft_vocab_size + 31) // 32
37+
# Precompute max word index (avoids GPU→CPU sync in _translate_bitmask)
38+
self._max_d2t_word = self._d2t_words.max().item()
3739
# Cache device-specific constants; keyed by device.
3840
self._bitmask_cache: dict[torch.device, dict] = {}
3941

@@ -66,7 +68,7 @@ def _translate_bitmask(self, target_bitmask: torch.Tensor) -> torch.Tensor:
6668
draft_words = c['draft_words']
6769
draft_bits = c['draft_bits']
6870

69-
max_word_idx = d2t_words.max().item()
71+
max_word_idx = self._max_d2t_word
7072
if max_word_idx >= target_bitmask.size(1):
7173
raise ValueError(
7274
f'd2t mapping references word index {max_word_idx} but target_bitmask '

lmdeploy/pytorch/spec_decode/spec_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor,
497497
bonus_sampling_inputs,
498498
logprobs_mode=self.misc_config.logprobs_mode,
499499
)
500-
logits_processor.sampling_inputs = bonus_sampling_inputs
501500

502501
next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size]
503502

@@ -559,7 +558,8 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor,
559558
_accept_spec_rejection_tokens,
560559
guided_manager,
561560
guided_processors,
562-
torch.zeros_like(next_token_ids), # 0 rejected
561+
torch.zeros(next_token_ids.shape,
562+
dtype=next_token_ids.dtype), # 0 rejected, CPU
563563
output_token_ids.cpu(),
564564
cpu_next_token_ids,
565565
0, # num_spec_tokens=0, only bonus accepted

0 commit comments

Comments
 (0)