Skip to content

Commit 5ccf9a2

Browse files
committed
refactor: consolidate guided spec decoding logic into GuidedSpecHelper
Extract all guided-decoding logic specific to speculative decoding into a dedicated GuidedSpecHelper class. This replaces the scattered free functions (_accept_spec_rejection_tokens, _fill_spec_bitmask, _accept_spec_forked_tokens) and inline guided logic in spec_agent.py and BaseSpecProposer with a single, well-defined API. Key changes: - New GuidedSpecHelper class (guided_spec_helper.py) encapsulates: - Session lifecycle (cleanup_sessions, get_processors) - Draft-side bitmask (prepare_bitmask, apply_bitmask, accept_draft_tokens) - Target-side serial bitmask (apply_serial_bitmask with forked matchers) - Rejection-sampling-aware token acceptance (accept_rejection_sampled_tokens) - All public methods are null-safe: GuidedSpecHelper(manager=None) is a valid no-op instance, so callers never need to guard with if guided_helper: or if processors:. - Replaced guided_decoding_manager on SpecModelAgent/BaseSpecProposer with guided_helper (a GuidedSpecHelper instance, always set). - Removed _prepare_guided_bitmask and _accept_guided_tokens from BaseSpecProposer (subsumed by helper methods). - Simplified spec_agent.py: removed 3 free functions, removed all if guided_helper: / if guided_processors: guards, delegate to helper.
1 parent ecc143b commit 5ccf9a2

8 files changed

Lines changed: 279 additions & 178 deletions

File tree

lmdeploy/pytorch/engine/model_agent/agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,10 @@ def __init__(
308308
misc_config=misc_config,
309309
device=device)
310310
if self.spec_agent.is_enabled():
311-
self.spec_agent.guided_decoding_manager = self.guided_decoding_manager
312-
self.spec_agent.proposer.guided_decoding_manager = self.guided_decoding_manager
311+
from lmdeploy.pytorch.spec_decode.guided_spec_helper import GuidedSpecHelper
312+
helper = GuidedSpecHelper(self.guided_decoding_manager)
313+
self.spec_agent.guided_helper = helper
314+
self.spec_agent.proposer.guided_helper = helper
313315
# sleep wakeup state
314316
self.state: SleepWakeupState = SleepWakeupState()
315317

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from __future__ import annotations
3+
4+
import asyncio
5+
from typing import TYPE_CHECKING
6+
7+
import torch
8+
9+
if TYPE_CHECKING:
10+
import xgrammar as xgr
11+
12+
from ..engine.guided_process import GuidedDecodingManager
13+
14+
15+
class GuidedSpecHelper:
16+
"""Guided-decoding support for speculative decoding.
17+
18+
Wraps a :class:`GuidedDecodingManager` and provides spec-decoding-specific
19+
operations that cannot be handled by :class:`FusedLogitsProcessor` because
20+
speculative decoding needs:
21+
22+
* Position-serial bitmasking across N+1 positions (not 1).
23+
* Forked matchers to preserve originals for target-side verification.
24+
* Rejection-sampling-driven token acceptance (not direct argmax).
25+
* Draft-vocab bitmask translation (Eagle3).
26+
27+
Instead of passing ``guided_decoding_manager`` into ``FusedLogitsProcessor``,
28+
the spec-decoding path constructs a ``GuidedSpecHelper`` and calls its
29+
methods at the appropriate points.
30+
31+
All public methods are no-ops when constructed with ``guided_manager=None``
32+
or when no guided processors are active, so callers never need to guard
33+
with ``if guided_helper:`` or ``if processors:``.
34+
"""
35+
36+
def __init__(self, guided_manager: GuidedDecodingManager | None = None):
37+
self._mgr = guided_manager
38+
39+
@property
40+
def manager(self) -> GuidedDecodingManager | None:
41+
"""Access the underlying :class:`GuidedDecodingManager`."""
42+
return self._mgr
43+
44+
# ------------------------------------------------------------------
45+
# Session lifecycle
46+
# ------------------------------------------------------------------
47+
48+
def cleanup_sessions(self, session_ids: list[int] | None):
49+
"""Remove grammar processors for ended sessions."""
50+
if self._mgr is None or not session_ids:
51+
return
52+
for session_id in session_ids:
53+
self._mgr.remove_processor(session_id)
54+
55+
def get_processors(self, session_ctx, response_formats) -> dict[int, xgr.GrammarMatcher]:
56+
"""Get grammar processors for active guided sessions.
57+
58+
Returns an empty dict when no manager is set or no sessions are
59+
guided, so callers can use ``if processors:`` uniformly.
60+
"""
61+
if self._mgr is None or session_ctx is None:
62+
return {}
63+
return self._mgr.get_processors(session_ctx, response_formats)
64+
65+
# ------------------------------------------------------------------
66+
# Draft side (called from proposer.get_outputs)
67+
# ------------------------------------------------------------------
68+
69+
async def prepare_bitmask(self, logits: torch.Tensor,
70+
processors: dict[int, xgr.GrammarMatcher] | None) -> torch.Tensor | None:
71+
"""Allocate and fill a guided-decoding bitmask for draft logits.
72+
73+
Returns the filled bitmask tensor (or ``None`` if no guided processors
74+
are active). The caller is responsible for applying the bitmask —
75+
some proposers (e.g. Eagle3) may need to translate the bitmask to
76+
their draft vocabulary first.
77+
"""
78+
if not processors or self._mgr is None:
79+
return None
80+
bitmask = self._mgr.allocate_batched_bitmap(logits.size(0))
81+
82+
def _fill():
83+
for idx, proc in processors.items():
84+
self._mgr.fill_bitmap(proc, bitmask, idx)
85+
86+
await asyncio.to_thread(_fill)
87+
return bitmask
88+
89+
def apply_bitmask(self, logits: torch.Tensor, bitmask: torch.Tensor | None):
90+
"""Apply a guided bitmask to logits.
91+
92+
No-op when *bitmask* is ``None``.
93+
"""
94+
if bitmask is None or self._mgr is None:
95+
return
96+
self._mgr.apply_batched_bitmap(logits, bitmask)
97+
98+
async def accept_draft_tokens(self, draft_token_ids: torch.Tensor,
99+
processors: dict[int, xgr.GrammarMatcher] | None):
100+
"""Accept draft tokens on the provided (forked) grammar matchers.
101+
102+
In speculative decoding the matchers are typically forked from the
103+
originals (created in :meth:`SpecModelAgent._async_model_forward`),
104+
so this method accepts on whichever matchers are passed in.
105+
"""
106+
if not processors or self._mgr is None:
107+
return
108+
cpu_ids = draft_token_ids[:, 0].cpu()
109+
110+
def _accept():
111+
for idx, proc in processors.items():
112+
self._mgr.accept_token(proc, cpu_ids[idx].item())
113+
114+
await asyncio.to_thread(_accept)
115+
116+
# ------------------------------------------------------------------
117+
# Target side: position-serial bitmask with forked matchers
118+
# ------------------------------------------------------------------
119+
120+
async def apply_serial_bitmask(
121+
self,
122+
scores_3d: torch.Tensor,
123+
processors: dict[int, xgr.GrammarMatcher],
124+
draft_token_ids: torch.LongTensor,
125+
num_spec_tokens: int,
126+
):
127+
"""Apply position-serial grammar mask to target logits.
128+
129+
Forks the provided processors, applies bitmask at each speculative
130+
position, and advances the forks using the draft tokens. The original
131+
processors are **not** modified.
132+
133+
No-op when *processors* is empty.
134+
135+
Args:
136+
scores_3d: ``[batch_size, num_expand, vocab_size]`` logits tensor
137+
(modified in-place).
138+
processors: Original grammar matchers indexed by batch position.
139+
draft_token_ids: ``[batch_size, num_spec_tokens]`` draft tokens
140+
from the proposer. Forks are advanced using these (not
141+
argmax) because target logits are conditioned on the draft
142+
tokens.
143+
num_spec_tokens: Number of speculative tokens per step.
144+
"""
145+
if not processors or self._mgr is None:
146+
return
147+
forked = {idx: proc.fork() for idx, proc in processors.items()}
148+
cpu_draft = draft_token_ids.cpu()
149+
batch_size = scores_3d.size(0)
150+
num_expand = scores_3d.size(1)
151+
bitmask = self._mgr.allocate_batched_bitmap(batch_size)
152+
153+
for pos in range(num_expand):
154+
await asyncio.to_thread(self._fill_bitmask, forked, bitmask)
155+
pos_logits = scores_3d[:, pos, :]
156+
self._mgr.apply_batched_bitmap(pos_logits, bitmask)
157+
scores_3d[:, pos, :] = pos_logits
158+
159+
# Advance fork using draft tokens for draft positions.
160+
if pos < num_spec_tokens:
161+
await asyncio.to_thread(self._accept_forked_at_pos, forked, cpu_draft, pos)
162+
163+
# ------------------------------------------------------------------
164+
# Token acceptance (rejection-sampling-aware)
165+
# ------------------------------------------------------------------
166+
167+
async def accept_rejection_sampled_tokens(
168+
self,
169+
processors: dict[int, xgr.GrammarMatcher],
170+
num_rejected: torch.Tensor,
171+
output_token_ids: torch.Tensor,
172+
next_token_ids: torch.Tensor,
173+
num_spec_tokens: int,
174+
):
175+
"""Accept rejection-sampled tokens on original grammar matchers.
176+
177+
After rejection sampling, the original matchers must be advanced to
178+
reflect the accepted tokens. For each sequence, ``num_spec_tokens -
179+
num_rejected`` draft tokens are accepted, followed by the bonus token.
180+
181+
No-op when *processors* is empty.
182+
183+
Args:
184+
processors: Original (non-forked) grammar matchers.
185+
num_rejected: Per-sequence rejection counts (GPU or CPU tensor).
186+
output_token_ids: Accepted output tokens ``[batch, num_spec]``
187+
(GPU or CPU tensor).
188+
next_token_ids: Bonus tokens ``[batch]`` (GPU or CPU tensor).
189+
num_spec_tokens: Number of speculative tokens per step.
190+
"""
191+
if not processors or self._mgr is None:
192+
return
193+
cpu_num_rejected = num_rejected.cpu() if num_rejected.is_cuda else num_rejected
194+
cpu_output_token_ids = output_token_ids.cpu() if output_token_ids.is_cuda else output_token_ids
195+
cpu_next_token_ids = next_token_ids.cpu() if next_token_ids.is_cuda else next_token_ids
196+
197+
def _accept():
198+
for idx, processor in processors.items():
199+
n_rejected = cpu_num_rejected[idx].item()
200+
n_valid_draft = num_spec_tokens - n_rejected
201+
for pos in range(n_valid_draft):
202+
tid = cpu_output_token_ids[idx, pos].item()
203+
if tid >= 0:
204+
self._mgr.accept_token(processor, tid)
205+
self._mgr.accept_token(processor, cpu_next_token_ids[idx].item())
206+
207+
await asyncio.to_thread(_accept)
208+
209+
# ------------------------------------------------------------------
210+
# Private helpers
211+
# ------------------------------------------------------------------
212+
213+
def _fill_bitmask(self, processors: dict, bitmask: torch.Tensor):
214+
for idx, proc in processors.items():
215+
self._mgr.fill_bitmap(proc, bitmask, idx)
216+
217+
def _accept_forked_at_pos(self, forked: dict, cpu_draft: torch.Tensor, pos: int):
218+
for idx, fork_proc in forked.items():
219+
self._mgr.accept_token(fork_proc, cpu_draft[idx, pos].item())

lmdeploy/pytorch/spec_decode/proposers/base.py

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import asyncio
2+
from __future__ import annotations
3+
34
from typing import Any
45

56
import torch
@@ -14,6 +15,7 @@
1415
from ...models.patch import build_patched_model, update_custom_module_map
1516
from ...strategies.base.model_agent import ExtraInputs
1617
from ...weight_loader.model_weight_loader import load_model_weights
18+
from ..guided_spec_helper import GuidedSpecHelper
1719

1820
SPEC_PROPOSERS = Registry('spec_proposers')
1921

@@ -66,53 +68,7 @@ def __init__(self, specdecode_config: SpecDecodeConfig, device: torch.device = N
6668
self.num_speculative_tokens = specdecode_config.num_speculative_tokens
6769
self.target_model = None
6870
# Set by SpecModelAgent after construction
69-
self.guided_decoding_manager = None
70-
71-
async def _prepare_guided_bitmask(self, logits: torch.Tensor,
72-
guided_processors: dict | None) -> torch.Tensor | None:
73-
"""Allocate and fill a guided-decoding bitmask for draft logits.
74-
75-
Returns the filled bitmask tensor (or None if no guided processors are
76-
active). The caller is responsible for actually applying the bitmask to
77-
logits — some proposers (e.g. Eagle3) may need to translate the bitmask
78-
to their draft vocabulary first.
79-
80-
CPU-bound xgrammar ``fill_bitmap`` calls are offloaded to a thread
81-
so they don't block the asyncio event loop.
82-
"""
83-
if not guided_processors or self.guided_decoding_manager is None:
84-
return None
85-
guided_manager = self.guided_decoding_manager
86-
guided_bitmask = guided_manager.allocate_batched_bitmap(logits.size(0))
87-
88-
def _fill():
89-
for idx, processor in guided_processors.items():
90-
guided_manager.fill_bitmap(processor, guided_bitmask, idx)
91-
92-
await asyncio.to_thread(_fill)
93-
return guided_bitmask
94-
95-
async def _accept_guided_tokens(self, draft_token_ids: torch.Tensor,
96-
guided_processors: dict | None):
97-
"""Accept draft tokens on the provided grammar matchers.
98-
99-
In speculative decoding the matchers are typically forked from the
100-
originals (created in ``SpecModelAgent._async_model_forward``), so this
101-
method accepts on whichever matchers are passed in.
102-
103-
CPU-bound xgrammar ``accept_token`` calls are offloaded to a thread
104-
so they don't block the asyncio event loop.
105-
"""
106-
if not guided_processors or self.guided_decoding_manager is None:
107-
return
108-
guided_manager = self.guided_decoding_manager
109-
cpu_draft_token_ids = draft_token_ids[:, 0].cpu()
110-
111-
def _accept():
112-
for idx, processor in guided_processors.items():
113-
guided_manager.accept_token(processor, cpu_draft_token_ids[idx].item())
114-
115-
await asyncio.to_thread(_accept)
71+
self.guided_helper = GuidedSpecHelper()
11672

11773
def build_model(self, empty_init: bool, target_model: torch.nn.Module = None, build_model_ctx=None):
11874
if self.specdecode_config is None:

lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ async def get_outputs(self,
3232

3333
logits = self.get_logits(hidden_states)[0]
3434

35-
guided_bitmask = await self._prepare_guided_bitmask(logits, guided_processors)
35+
guided_bitmask = await self.guided_helper.prepare_bitmask(logits, guided_processors)
3636
if guided_bitmask is not None:
37-
self.guided_decoding_manager.apply_batched_bitmap(logits, guided_bitmask)
37+
self.guided_helper.apply_bitmask(logits, guided_bitmask)
3838

3939
draft_token_ids = logits.argmax(dim=-1, keepdim=True)
40-
await self._accept_guided_tokens(draft_token_ids, guided_processors)
40+
await self.guided_helper.accept_draft_tokens(draft_token_ids, guided_processors)
4141

4242
return draft_token_ids, model_metas, target_hidden_states

lmdeploy/pytorch/spec_decode/proposers/eagle3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ async def get_outputs(self,
113113

114114
logits = self.get_logits(hidden_states)[0]
115115

116-
guided_bitmask = await self._prepare_guided_bitmask(logits, guided_processors)
116+
guided_bitmask = await self.guided_helper.prepare_bitmask(logits, guided_processors)
117117
if guided_bitmask is not None:
118118
draft_bitmask = self._translate_bitmask(guided_bitmask)
119-
self.guided_decoding_manager.apply_batched_bitmap(logits, draft_bitmask)
119+
self.guided_helper.apply_bitmask(logits, draft_bitmask)
120120

121121
draft_token_ids = logits.argmax(dim=-1, keepdim=True)
122122
draft_token_ids = self.draft_id_to_target_id[draft_token_ids]
123123

124-
await self._accept_guided_tokens(draft_token_ids, guided_processors)
124+
await self.guided_helper.accept_draft_tokens(draft_token_ids, guided_processors)
125125

126126
return draft_token_ids, model_metas, hidden_states_prenorm

0 commit comments

Comments
 (0)