diff --git a/docs/en/advance/spec_decoding.md b/docs/en/advance/spec_decoding.md index e5366f0194..c57b4c2b52 100644 --- a/docs/en/advance/spec_decoding.md +++ b/docs/en/advance/spec_decoding.md @@ -2,8 +2,9 @@ Speculative decoding is an optimization technique that introcude a lightweight draft model to propose multiple next tokens and then, the main model verify and choose the longest matched tokens in a forward pass. Compared with standard auto-regressive decoding, this methold lets the system generate multiple tokens at once. -> \[!NOTE\] -> This is an experimental feature in lmdeploy. +:::{note} +This is an experimental feature in lmdeploy. +::: ## Examples @@ -104,3 +105,93 @@ deepseek-ai/DeepSeek-V3 \ --max-batch-size 128 \ --enable-metrics ``` + +## Guided Decoding with Speculative Decoding + +Speculative decoding (MTP) can be combined with [structured output](./structed_output.md) so that the draft tokens proposed by the spec model also respect the grammar constraints (e.g. JSON schema, regex). This significantly improves the acceptance rate compared to running spec decoding without grammar masks. + +:::{note} +This feature is supported for spec methods that inherit from `DeepseekMTP`, including `deepseek_mtp`, `qwen3_5_mtp`, and `eagle3`. Only the PyTorch backend is supported. +::: + +### How it works + +The grammar mask is applied at two stages: + +1. **Draft model** — forked grammar matchers are used to mask each draft position serially. Each position's mask depends on the token accepted at the previous position, ensuring the draft model proposes grammatically valid tokens. +2. **Target model verification** — position-serial grammar masking is applied to the target model's logits. After rejection sampling, only the accepted tokens are fed back to the original (un-forked) grammar matchers, keeping them in sync for the next step. + +When the draft model uses a different vocabulary from the target model (e.g. Eagle 3 with a compressed draft vocabulary), the target-vocab bitmask produced by xgrammar is translated to a draft-vocab bitmask via an efficient scatter-add kernel before being applied to the draft logits. + +### pipeline + +```python +from lmdeploy import PytorchEngineConfig, pipeline +from lmdeploy.messages import GenerationConfig, SpeculativeConfig + +if __name__ == '__main__': + + model_path = 'deepseek-ai/DeepSeek-V3' + spec_cfg = SpeculativeConfig(method='deepseek_mtp', num_speculative_tokens=3) + pipe = pipeline( + model_path, + backend_config=PytorchEngineConfig(tp=16, max_batch_size=128), + speculative_config=spec_cfg, + ) + + schema = { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'age': {'type': 'integer'}, + }, + 'required': ['name', 'age'], + } + gen_config = GenerationConfig( + response_format=dict(type='json_schema', json_schema=dict(name='person', schema=schema)), + max_new_tokens=256, + ) + + response = pipe(['Introduce yourself as JSON.'], gen_config=gen_config) + print(response) +``` + +### api_server + +```shell +lmdeploy serve api_server \ +deepseek-ai/DeepSeek-V3 \ +--backend pytorch \ +--server-port 24545 \ +--tp 16 \ +--speculative-algorithm deepseek_mtp \ +--speculative-num-draft-tokens 3 \ +--max-batch-size 128 +``` + +The client can then use `response_format` as described in the [structured output](./structed_output.md) documentation: + +```python +from openai import OpenAI + +if __name__ == '__main__': + + schema = { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'age': {'type': 'integer'}, + }, + 'required': ['name', 'age'], + } + response_format = dict(type='json_schema', json_schema=dict(name='person', schema=schema)) + + client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:24545/v1') + model_name = client.models.list().data[0].id + response = client.chat.completions.create( + model=model_name, + messages=[{'role': 'user', 'content': 'Introduce yourself as JSON.'}], + response_format=response_format, + ) + print(response) +``` diff --git a/docs/zh_cn/advance/spec_decoding.md b/docs/zh_cn/advance/spec_decoding.md index f011f61cb9..6497fd7858 100644 --- a/docs/zh_cn/advance/spec_decoding.md +++ b/docs/zh_cn/advance/spec_decoding.md @@ -2,8 +2,9 @@ 投机解码是一种优化技术,它通过引入轻量级草稿模型来预测多个后续token,再由主模型在前向推理过程中验证并选择匹配度最高的长token序列。与标准的自回归解码相比,这种方法可使系统一次性生成多个token。 -> \[!NOTE\] -> 请注意,这是lmdeploy中的实验性功能。 +:::{note} +请注意,这是lmdeploy中的实验性功能。 +::: ## 示例 @@ -103,3 +104,93 @@ deepseek-ai/DeepSeek-V3 \ --max-batch-size 128 \ --enable-metrics ``` + +## 投机解码与结构化输出 + +投机解码(MTP)可以与[结构化输出](./structed_output.md)结合使用,使草稿模型提出的 token 也遵循语法约束(如 JSON Schema、正则表达式),从而显著提高接受率。 + +:::{note} +该功能支持继承自 `DeepseekMTP` 的投机方法,包括 `deepseek_mtp`、`qwen3_5_mtp` 和 `eagle3`。仅支持 PyTorch 后端。 +::: + +### 工作原理 + +语法掩码在两个阶段分别施加: + +1. **草稿模型** — 使用 fork 出的语法匹配器,逐位置串行施加掩码。每个位置的掩码依赖于前一位置接受的 token,确保草稿模型提出符合语法的 token。 +2. **主模型验证** — 对主模型的 logits 进行逐位置串行的语法掩码处理。拒绝采样后,仅将接受的 token 反馈给原始(未 fork 的)语法匹配器,使其为下一步保持正确的状态。 + +当草稿模型使用与主模型不同的词表时(例如 Eagle 3 使用压缩的草稿词表),xgrammar 生成的目标词表位掩码会通过高效的 scatter-add 内核转换为草稿词表位掩码,然后再应用于草稿 logits。 + +### pipeline + +```python +from lmdeploy import PytorchEngineConfig, pipeline +from lmdeploy.messages import GenerationConfig, SpeculativeConfig + +if __name__ == '__main__': + + model_path = 'deepseek-ai/DeepSeek-V3' + spec_cfg = SpeculativeConfig(method='deepseek_mtp', num_speculative_tokens=3) + pipe = pipeline( + model_path, + backend_config=PytorchEngineConfig(tp=16, max_batch_size=128), + speculative_config=spec_cfg, + ) + + schema = { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'age': {'type': 'integer'}, + }, + 'required': ['name', 'age'], + } + gen_config = GenerationConfig( + response_format=dict(type='json_schema', json_schema=dict(name='person', schema=schema)), + max_new_tokens=256, + ) + + response = pipe(['请用 JSON 格式做自我介绍。'], gen_config=gen_config) + print(response) +``` + +### api_server + +```shell +lmdeploy serve api_server \ +deepseek-ai/DeepSeek-V3 \ +--backend pytorch \ +--server-port 24545 \ +--tp 16 \ +--speculative-algorithm deepseek_mtp \ +--speculative-num-draft-tokens 3 \ +--max-batch-size 128 +``` + +客户端可以按照[结构化输出](./structed_output.md)文档中的方式使用 `response_format`: + +```python +from openai import OpenAI + +if __name__ == '__main__': + + schema = { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'age': {'type': 'integer'}, + }, + 'required': ['name', 'age'], + } + response_format = dict(type='json_schema', json_schema=dict(name='person', schema=schema)) + + client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:24545/v1') + model_name = client.models.list().data[0].id + response = client.chat.completions.create( + model=model_name, + messages=[{'role': 'user', 'content': '请用 JSON 格式做自我介绍。'}], + response_format=response_format, + ) + print(response) +``` diff --git a/lmdeploy/pytorch/engine/guided_process.py b/lmdeploy/pytorch/engine/guided_process.py index 506ebc74a9..2f303820d9 100644 --- a/lmdeploy/pytorch/engine/guided_process.py +++ b/lmdeploy/pytorch/engine/guided_process.py @@ -11,7 +11,6 @@ class GuidedDecodingManager: - processors = {} def __init__(self, tokenizer: PreTrainedTokenizerBase, vocab_size: int | None): if vocab_size is None: @@ -20,6 +19,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, vocab_size: int | None): tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size) self.compiler = xgr.GrammarCompiler(tokenizer_info) self.vocab_size = vocab_size + self.processors: dict[int, dict[int, xgr.GrammarMatcher]] = {} def get_processors(self, session_ctx: list[dict[str, Any]], response_formats: tuple[dict]) -> dict[int, xgr.GrammarMatcher]: @@ -32,7 +32,8 @@ def get_processors(self, session_ctx: list[dict[str, Any]], if isinstance(schema, dict): for key in ['json_schema', 'schema']: if key in schema: - schema = json.dumps(schema[key], ensure_ascii=False) + val = schema[key] + schema = val if isinstance(val, str) else json.dumps(val, ensure_ascii=False) if not isinstance(schema, str): raise ValueError(f'Cannot parse schema {schema}. The schema must be ' diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index ff38daafc0..b6e6336d45 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -307,6 +307,11 @@ def __init__( self.agent_strategy, misc_config=misc_config, device=device) + if self.spec_agent.is_enabled(): + from lmdeploy.pytorch.spec_decode.guided_spec_helper import GuidedSpecHelper + helper = GuidedSpecHelper(self.guided_decoding_manager) + self.spec_agent.guided_helper = helper + self.spec_agent.proposer.guided_helper = helper # sleep wakeup state self.state: SleepWakeupState = SleepWakeupState() diff --git a/lmdeploy/pytorch/spec_decode/guided_spec_helper.py b/lmdeploy/pytorch/spec_decode/guided_spec_helper.py new file mode 100644 index 0000000000..9a2ecaa255 --- /dev/null +++ b/lmdeploy/pytorch/spec_decode/guided_spec_helper.py @@ -0,0 +1,219 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + import xgrammar as xgr + + from ..engine.guided_process import GuidedDecodingManager + + +class GuidedSpecHelper: + """Guided-decoding support for speculative decoding. + + Wraps a :class:`GuidedDecodingManager` and provides spec-decoding-specific + operations that cannot be handled by :class:`FusedLogitsProcessor` because + speculative decoding needs: + + * Position-serial bitmasking across N+1 positions (not 1). + * Forked matchers to preserve originals for target-side verification. + * Rejection-sampling-driven token acceptance (not direct argmax). + * Draft-vocab bitmask translation (Eagle3). + + Instead of passing ``guided_decoding_manager`` into ``FusedLogitsProcessor``, + the spec-decoding path constructs a ``GuidedSpecHelper`` and calls its + methods at the appropriate points. + + All public methods are no-ops when constructed with ``guided_manager=None`` + or when no guided processors are active, so callers never need to guard + with ``if guided_helper:`` or ``if processors:``. + """ + + def __init__(self, guided_manager: GuidedDecodingManager | None = None): + self._mgr = guided_manager + + @property + def manager(self) -> GuidedDecodingManager | None: + """Access the underlying :class:`GuidedDecodingManager`.""" + return self._mgr + + # ------------------------------------------------------------------ + # Session lifecycle + # ------------------------------------------------------------------ + + def cleanup_sessions(self, session_ids: list[int] | None): + """Remove grammar processors for ended sessions.""" + if self._mgr is None or not session_ids: + return + for session_id in session_ids: + self._mgr.remove_processor(session_id) + + def get_processors(self, session_ctx, response_formats) -> dict[int, xgr.GrammarMatcher]: + """Get grammar processors for active guided sessions. + + Returns an empty dict when no manager is set or no sessions are + guided, so callers can use ``if processors:`` uniformly. + """ + if self._mgr is None or session_ctx is None: + return {} + return self._mgr.get_processors(session_ctx, response_formats) + + # ------------------------------------------------------------------ + # Draft side (called from proposer.get_outputs) + # ------------------------------------------------------------------ + + async def prepare_bitmask(self, logits: torch.Tensor, + processors: dict[int, xgr.GrammarMatcher] | None) -> torch.Tensor | None: + """Allocate and fill a guided-decoding bitmask for draft logits. + + Returns the filled bitmask tensor (or ``None`` if no guided processors + are active). The caller is responsible for applying the bitmask — + some proposers (e.g. Eagle3) may need to translate the bitmask to + their draft vocabulary first. + """ + if not processors or self._mgr is None: + return None + bitmask = self._mgr.allocate_batched_bitmap(logits.size(0)) + + def _fill(): + for idx, proc in processors.items(): + self._mgr.fill_bitmap(proc, bitmask, idx) + + await asyncio.to_thread(_fill) + return bitmask + + def apply_bitmask(self, logits: torch.Tensor, bitmask: torch.Tensor | None): + """Apply a guided bitmask to logits. + + No-op when *bitmask* is ``None``. + """ + if bitmask is None or self._mgr is None: + return + self._mgr.apply_batched_bitmap(logits, bitmask) + + async def accept_draft_tokens(self, draft_token_ids: torch.Tensor, + processors: dict[int, xgr.GrammarMatcher] | None): + """Accept draft tokens on the provided (forked) grammar matchers. + + In speculative decoding the matchers are typically forked from the + originals (created in :meth:`SpecModelAgent._async_model_forward`), + so this method accepts on whichever matchers are passed in. + """ + if not processors or self._mgr is None: + return + + def _accept(): + cpu_ids = draft_token_ids[:, 0].cpu() + for idx, proc in processors.items(): + self._mgr.accept_token(proc, cpu_ids[idx].item()) + + await asyncio.to_thread(_accept) + + # ------------------------------------------------------------------ + # Target side: position-serial bitmask with forked matchers + # ------------------------------------------------------------------ + + async def apply_serial_bitmask( + self, + scores_3d: torch.Tensor, + processors: dict[int, xgr.GrammarMatcher], + draft_token_ids: torch.LongTensor, + num_spec_tokens: int, + ): + """Apply position-serial grammar mask to target logits. + + Forks the provided processors, applies bitmask at each speculative + position, and advances the forks using the draft tokens. The original + processors are **not** modified. + + No-op when *processors* is empty. + + Args: + scores_3d: ``[batch_size, num_expand, vocab_size]`` logits tensor + (modified in-place). + processors: Original grammar matchers indexed by batch position. + draft_token_ids: ``[batch_size, num_spec_tokens]`` draft tokens + from the proposer. Forks are advanced using these (not + argmax) because target logits are conditioned on the draft + tokens. + num_spec_tokens: Number of speculative tokens per step. + """ + if not processors or self._mgr is None: + return + forked = {idx: proc.fork() for idx, proc in processors.items()} + cpu_draft = await asyncio.to_thread(draft_token_ids.cpu) + batch_size = scores_3d.size(0) + num_expand = scores_3d.size(1) + bitmask = self._mgr.allocate_batched_bitmap(batch_size) + + for pos in range(num_expand): + await asyncio.to_thread(self._fill_bitmask, forked, bitmask) + pos_logits = scores_3d[:, pos, :] + self._mgr.apply_batched_bitmap(pos_logits, bitmask) + scores_3d[:, pos, :] = pos_logits + + # Advance fork using draft tokens for draft positions. + if pos < num_spec_tokens: + await asyncio.to_thread(self._accept_forked_at_pos, forked, cpu_draft, pos) + + # ------------------------------------------------------------------ + # Token acceptance (rejection-sampling-aware) + # ------------------------------------------------------------------ + + async def accept_rejection_sampled_tokens( + self, + processors: dict[int, xgr.GrammarMatcher], + num_rejected: torch.Tensor, + output_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + num_spec_tokens: int, + ): + """Accept rejection-sampled tokens on original grammar matchers. + + After rejection sampling, the original matchers must be advanced to + reflect the accepted tokens. For each sequence, ``num_spec_tokens - + num_rejected`` draft tokens are accepted, followed by the bonus token. + + No-op when *processors* is empty. + + Args: + processors: Original (non-forked) grammar matchers. + num_rejected: Per-sequence rejection counts (GPU or CPU tensor). + output_token_ids: Accepted output tokens ``[batch, num_spec]`` + (GPU or CPU tensor). + next_token_ids: Bonus tokens ``[batch]`` (GPU or CPU tensor). + num_spec_tokens: Number of speculative tokens per step. + """ + if not processors or self._mgr is None: + return + + def _accept(): + cpu_num_rejected = num_rejected.cpu() if num_rejected.is_cuda else num_rejected + cpu_output_token_ids = output_token_ids.cpu() if output_token_ids.is_cuda else output_token_ids + cpu_next_token_ids = next_token_ids.cpu() if next_token_ids.is_cuda else next_token_ids + for idx, processor in processors.items(): + n_rejected = cpu_num_rejected[idx].item() + n_valid_draft = num_spec_tokens - n_rejected + for pos in range(n_valid_draft): + tid = cpu_output_token_ids[idx, pos].item() + if tid >= 0: + self._mgr.accept_token(processor, tid) + self._mgr.accept_token(processor, cpu_next_token_ids[idx].item()) + + await asyncio.to_thread(_accept) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _fill_bitmask(self, processors: dict, bitmask: torch.Tensor): + for idx, proc in processors.items(): + self._mgr.fill_bitmap(proc, bitmask, idx) + + def _accept_forked_at_pos(self, forked: dict, cpu_draft: torch.Tensor, pos: int): + for idx, fork_proc in forked.items(): + self._mgr.accept_token(fork_proc, cpu_draft[idx, pos].item()) diff --git a/lmdeploy/pytorch/spec_decode/proposers/base.py b/lmdeploy/pytorch/spec_decode/proposers/base.py index c8ece28c54..2ba1e36b78 100644 --- a/lmdeploy/pytorch/spec_decode/proposers/base.py +++ b/lmdeploy/pytorch/spec_decode/proposers/base.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + from typing import Any import torch @@ -13,6 +15,7 @@ from ...models.patch import build_patched_model, update_custom_module_map from ...strategies.base.model_agent import ExtraInputs from ...weight_loader.model_weight_loader import load_model_weights +from ..guided_spec_helper import GuidedSpecHelper SPEC_PROPOSERS = Registry('spec_proposers') @@ -64,6 +67,8 @@ def __init__(self, specdecode_config: SpecDecodeConfig, device: torch.device = N self.lm_head = None self.num_speculative_tokens = specdecode_config.num_speculative_tokens self.target_model = None + # Set by SpecModelAgent after construction + self.guided_helper = GuidedSpecHelper() def build_model(self, empty_init: bool, target_model: torch.nn.Module = None, build_model_ctx=None): if self.specdecode_config is None: @@ -85,10 +90,11 @@ def build_model(self, empty_init: bool, target_model: torch.nn.Module = None, bu self.model = patched_model self.target_model = target_model - def get_outputs(self, + async def get_outputs(self, model_outputs: dict[str, torch.Tensor], model_inputs: ModelInputs, - extra_inputs: ExtraInputs = None): + extra_inputs: ExtraInputs = None, + guided_processors: dict | None = None): """Get outputs.""" raise NotImplementedError() diff --git a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py index 9a434ff2b2..2912758928 100644 --- a/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py +++ b/lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py @@ -14,10 +14,11 @@ @SPEC_PROPOSERS.register_module(name='deepseek_mtp') class DeepseekMTP(BaseSpecProposer): - def get_outputs(self, + async def get_outputs(self, model_outputs: dict[str, torch.Tensor], model_inputs: ModelInputs, - extra_inputs: ARSpecExtraInputs = None): + extra_inputs: ARSpecExtraInputs = None, + guided_processors: dict | None = None): """Get outputs.""" hidden_states = model_outputs['hidden_states'] model_metas = model_outputs['model_metas'] @@ -30,5 +31,12 @@ def get_outputs(self, target_hidden_states = hidden_states logits = self.get_logits(hidden_states)[0] + + guided_bitmask = await self.guided_helper.prepare_bitmask(logits, guided_processors) + if guided_bitmask is not None: + self.guided_helper.apply_bitmask(logits, guided_bitmask) + draft_token_ids = logits.argmax(dim=-1, keepdim=True) + await self.guided_helper.accept_draft_tokens(draft_token_ids, guided_processors) + return draft_token_ids, model_metas, target_hidden_states diff --git a/lmdeploy/pytorch/spec_decode/proposers/eagle3.py b/lmdeploy/pytorch/spec_decode/proposers/eagle3.py index db1011727e..7486ec964b 100644 --- a/lmdeploy/pytorch/spec_decode/proposers/eagle3.py +++ b/lmdeploy/pytorch/spec_decode/proposers/eagle3.py @@ -19,21 +19,84 @@ class Eagle3(DeepseekMTP): def build_model(self, empty_init: bool, target_model: torch.nn.Module = None, build_model_ctx=None): super().build_model(empty_init, target_model=target_model, build_model_ctx=build_model_ctx) self.draft_id_to_target_id = self.model.draft_id_to_target_id + self._init_bitmask_translate_constants() if not self.model.include_embed_tokens: logger.info('Using embed_tokens from target model.') del self.model.model.embed_tokens self.model.model.embed_tokens = target_model.get_input_embeddings() + def _init_bitmask_translate_constants(self): + d2t = self.draft_id_to_target_id + self._d2t_words = d2t // 32 + self._d2t_bits = d2t % 32 + draft_vocab_size = d2t.size(0) + draft_indices = torch.arange(draft_vocab_size, dtype=torch.int32) + self._draft_words = draft_indices // 32 + self._draft_bits = draft_indices % 32 + self._n_draft_words = (draft_vocab_size + 31) // 32 + # Precompute max word index (avoids GPU→CPU sync in _translate_bitmask) + self._max_d2t_word = self._d2t_words.max().item() + # Cache device-specific constants; keyed by device. + self._bitmask_cache: dict[torch.device, dict] = {} + + def _get_bitmask_constants(self, device: torch.device): + """Return bitmask-translate constants on *device*, caching the + transfer.""" + if device not in self._bitmask_cache: + self._bitmask_cache[device] = dict( + d2t_words=self._d2t_words.to(device), + d2t_bits=self._d2t_bits.to(device), + draft_words=self._draft_words.to(device), + draft_bits=self._draft_bits.to(device), + ) + return self._bitmask_cache[device] + + def _translate_bitmask(self, target_bitmask: torch.Tensor) -> torch.Tensor: + """Translate target-vocab bitmask to draft-vocab bitmask. + + Args: + target_bitmask: [batch, ceil(target_vocab/32)] int32 bitmask + produced by xgr.GrammarMatcher.fill_next_token_bitmask. + + Returns: + draft_bitmask: [batch, ceil(draft_vocab/32)] int32 bitmask + compatible with apply_batched_bitmap. + """ + c = self._get_bitmask_constants(target_bitmask.device) + d2t_words = c['d2t_words'] + d2t_bits = c['d2t_bits'] + draft_words = c['draft_words'] + draft_bits = c['draft_bits'] + + max_word_idx = self._max_d2t_word + if max_word_idx >= target_bitmask.size(1): + raise ValueError( + f'd2t mapping references word index {max_word_idx} but target_bitmask ' + f'only has {target_bitmask.size(1)} words. ' + f'The draft-to-target mapping may be out of bounds for the current vocab.') + + word_vals = target_bitmask[:, d2t_words] + draft_valid = ((word_vals >> d2t_bits.unsqueeze(0)) & 1).to(torch.int32) + + # scatter_add_ is correct because bit positions within the same word + # never overlap, so addition ≡ bitwise OR. + bits_to_set = draft_valid << draft_bits + out = target_bitmask.new_zeros(target_bitmask.size(0), self._n_draft_words) + out.scatter_add_(1, draft_words.to(torch.int64).unsqueeze(0).expand(target_bitmask.size(0), -1), + bits_to_set) + return out + def get_target_hidden_size(self, model_config: ModelConfig): """Get target hidden size.""" hf_config = self.specdecode_config.model_config.hf_config hidden_size = getattr(hf_config, 'target_hidden_size', hf_config.hidden_size) return hidden_size * 3 - def get_outputs(self, + async def get_outputs(self, model_outputs: dict[str, torch.Tensor], model_inputs: ModelInputs, - extra_inputs: ExtraInputs = None): + extra_inputs: ExtraInputs = None, + guided_processors: dict | None = None): """Get outputs.""" hidden_states = model_outputs['hidden_states'] hidden_states_prenorm = model_outputs['hidden_states_prenorm'] @@ -49,7 +112,15 @@ def get_outputs(self, hidden_states_prenorm = hidden_states_prenorm[:, last_token_loc] logits = self.get_logits(hidden_states)[0] + + guided_bitmask = await self.guided_helper.prepare_bitmask(logits, guided_processors) + if guided_bitmask is not None: + draft_bitmask = self._translate_bitmask(guided_bitmask) + self.guided_helper.apply_bitmask(logits, draft_bitmask) + draft_token_ids = logits.argmax(dim=-1, keepdim=True) - # token mapping draft_token_ids = self.draft_id_to_target_id[draft_token_ids] + + await self.guided_helper.accept_draft_tokens(draft_token_ids, guided_processors) + return draft_token_ids, model_metas, hidden_states_prenorm diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 1894328169..77156ca299 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -1,7 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + from contextlib import contextmanager +from typing import TYPE_CHECKING +import numpy as np import torch from torch.profiler import record_function @@ -17,10 +21,25 @@ from ..strategies.ar_spec.model_agent import ARSpecExtraInputs from ..strategies.base.model_agent import ExtraInputs from .base import BaseSpecModelAgent +from .guided_spec_helper import GuidedSpecHelper from .proposers.base import build_specdecode_proposer +if TYPE_CHECKING: + pass + logger = get_logger('lmdeploy') +# Fields that hold a single scalar value shared across the expanded batch. +_SCALAR_FIELDS = frozenset({ + 'max_top_k', 'min_top_p', 'max_num_logprobs', + 'max_repetition_ngram_size', +}) +# Fields that are global (not per-batch-element) and should not be +# repeated when expanding sampling inputs. +_GLOBAL_FIELDS = frozenset({ + 'session_to_cleanup', +}) + def _expand_sampling_inputs(sampling_inputs: SamplingInputs, num_tokens: int) -> SamplingInputs: """Expand per-batch SamplingInputs to per-token by repeating each batch @@ -48,6 +67,12 @@ def _expand_sampling_inputs(sampling_inputs: SamplingInputs, num_tokens: int) -> # reproducible but distinct random sampling arange = torch.arange(num_tokens, device=v.device) v = v + arange.repeat(sampling_inputs.batch_size) + elif k in _SCALAR_FIELDS or k in _GLOBAL_FIELDS: + pass + elif isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] == sampling_inputs.batch_size: + v = np.repeat(v, num_tokens, axis=0) + elif isinstance(v, (list, tuple)) and len(v) == sampling_inputs.batch_size: + v = type(v)(_item for elem in v for _item in [elem] * num_tokens) out_dict[k] = v out_dict['batch_size'] = sampling_inputs.batch_size * num_tokens @@ -90,6 +115,27 @@ def _slice_sampling_inputs(sampling_inputs: SamplingInputs, num_tokens: int, is_ shape = v.shape v = v.view(batch_size, num_tokens, *shape[1:]) v = v[:, :-1].reshape(batch_size * (num_tokens - 1), *shape[1:]) + elif k in _SCALAR_FIELDS or k in _GLOBAL_FIELDS: + pass + elif isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] == sampling_inputs.batch_size: + if is_last: + v = v[num_tokens - 1::num_tokens] + else: + v = v.reshape(batch_size, num_tokens, *v.shape[1:])[:, :-1].reshape( + batch_size * (num_tokens - 1), *v.shape[1:]) + elif isinstance(v, (list, tuple)): + # Skip if length doesn't match the expanded batch size (e.g. + # empty defaults or fields that were not per-batch). + if len(v) == sampling_inputs.batch_size: + if is_last: + indices = list(range(num_tokens - 1, len(v), num_tokens)) + v = type(v)(v[i] for i in indices) + else: + indices = [] + for b in range(batch_size): + start = b * num_tokens + indices.extend(range(start, start + num_tokens - 1)) + v = type(v)(v[i] for i in indices) out_dict[k] = v if is_last: @@ -122,6 +168,10 @@ def __init__( ) self.proposer = build_specdecode_proposer(specdecode_config, device=device) + + # Guided decoding — set by ModelAgent after construction + self.guided_helper = GuidedSpecHelper() + # make dummy meta self.make_dummy_meta = self.inputs_strategy.create_make_dummy_meta(self.model_config) # for long context carry-over in chunked decoding @@ -344,7 +394,7 @@ def _prepare_long_context_chunk_prepend_saved(self, key, tensor, save_last=True) self._prev_chunk_last.pop(key, None) return torch.cat([saved, tensor], dim=1) - async def _rejection_sampling(self, model_inputs: 'ModelInputs', extra_inputs: ARSpecExtraInputs, + async def _rejection_sampling(self, model_inputs: ModelInputs, extra_inputs: ARSpecExtraInputs, sampling_inputs: SamplingInputs): """Do rejection sampling.""" @@ -367,7 +417,6 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor, ) return output_logprobs - # Process target_logits via FusedLogitsProcessor for BOTH prefill and decoding target_logits = extra_inputs.target_logits batch_size = model_inputs.seq_length.size(0) num_rejected_tokens = torch.zeros_like(model_inputs.seq_length) @@ -390,23 +439,38 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor, num_expand_sampling = 1 if not model_inputs.is_decoding else self.num_spec_tokens + 1 expanded_sampling_inputs = _expand_sampling_inputs(sampling_inputs, num_expand_sampling) - logits_processor = FusedLogitsProcessor( - expanded_sampling_inputs, - logprobs_mode=self.misc_config.logprobs_mode, - ) + guided_helper = self.guided_helper + guided_helper.cleanup_sessions(sampling_inputs.session_to_cleanup) + guided_processors = guided_helper.get_processors( + sampling_inputs.session_ctx, sampling_inputs.response_formats) if model_inputs.is_decoding: - # TODO: guided decoding not supported yet for spec decoding - processed_logits, raw_logprobs = await logits_processor(target_logits) - # Slice bonus (last) position logits for each batch element + if guided_processors: + # Position-serial grammar mask via forked matchers; + # original matchers are NOT modified. + processed_logits, raw_logprobs = await self._guided_spec_logits_process( + target_logits, expanded_sampling_inputs, guided_helper, + guided_processors, batch_size, num_expand_sampling, + draft_token_ids=extra_inputs.output_draft_token_ids) + else: + logits_processor = FusedLogitsProcessor( + expanded_sampling_inputs, + logprobs_mode=self.misc_config.logprobs_mode, + ) + processed_logits, raw_logprobs = await logits_processor(target_logits) + + # Bonus logits already have grammar mask applied in guided path bonus_logits = processed_logits[num_expand_sampling - 1::num_expand_sampling] # [batch_size, vocab] - # Create a per-batch processor for bonus token sampling - # by slicing the expanded sampling_inputs back to batch_size + bonus_sampling_inputs = _slice_sampling_inputs(expanded_sampling_inputs, num_expand_sampling) - logits_processor.sampling_inputs = bonus_sampling_inputs - # Sample next token from bonus position + + logits_processor = FusedLogitsProcessor( + bonus_sampling_inputs, + logprobs_mode=self.misc_config.logprobs_mode, + ) + next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size] - # Reshape back to 3D + processed_logits = processed_logits.view(batch_size, num_expand_sampling, -1) # Rejection sampling on processed logits (exclude bonus position) target_draft_logits = processed_logits[:, :-1].contiguous() # [batch, num_spec, vocab] @@ -417,9 +481,28 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor, next_token_ids, sampling_inputs=draft_sampling_inputs, ) - # update last token indices last_token_indices = last_token_indices - num_rejected_tokens + + # Guided: accept final tokens on original matchers. + # Forked matchers were used during processing, so originals are still + # at pre-step state. Accept rejection-sampled output + bonus token + # to bring originals to the correct state for the next step. + await guided_helper.accept_rejection_sampled_tokens( + guided_processors, + num_rejected_tokens, + output_token_ids, + next_token_ids, + self.num_spec_tokens, + ) else: + # Prefill path — handle guided decoding manually (same pattern as + # the decode path) to keep accept_token in asyncio.to_thread and + # avoid the double-accept bug that occurs when + # FusedLogitsProcessor.sampling() also calls accept_token. + logits_processor = FusedLogitsProcessor( + expanded_sampling_inputs, + logprobs_mode=self.misc_config.logprobs_mode, + ) if model_inputs.is_chunk and not model_inputs.is_last_chunk: # dummy output, no need to sampling or compute logprobs for non-last chunk next_token_ids = num_rejected_tokens @@ -427,9 +510,20 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor, raw_logprobs = None else: bonus_logits, raw_logprobs = await logits_processor(target_logits) - # Sample next token from bonus position + # Apply guided bitmask (no fork needed — single position) + guided_bitmask = await guided_helper.prepare_bitmask(bonus_logits, guided_processors) + guided_helper.apply_bitmask(bonus_logits, guided_bitmask) next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size] output_token_ids = next_token_ids.unsqueeze(-1) + # Accept the sampled token on original grammar matchers + await guided_helper.accept_rejection_sampled_tokens( + guided_processors, + torch.zeros(next_token_ids.shape, + dtype=next_token_ids.dtype), + output_token_ids, + next_token_ids, + 0, # num_spec_tokens=0, only bonus accepted + ) logprobs = __compute_logprobs(raw_logprobs, output_token_ids, sampling_inputs.max_num_logprobs) @@ -443,13 +537,47 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor, ) return new_extra_inputs + async def _guided_spec_logits_process( + self, + target_logits: torch.Tensor, + expanded_sampling_inputs: SamplingInputs, + guided_helper: GuidedSpecHelper, + guided_processors: dict, + batch_size: int, + num_expand: int, + draft_token_ids: torch.LongTensor, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply position-serial grammar mask to target logits for spec decode. + + Uses forked GrammarMatchers so that the original matchers are NOT + modified. The caller is responsible for accepting the final tokens + on the original matchers after rejection sampling. + + All ``num_expand`` positions (including the bonus position) are masked. + """ + logits_processor = FusedLogitsProcessor( + expanded_sampling_inputs, + logprobs_mode=self.misc_config.logprobs_mode, + ) + scores, raw_logprobs = await logits_processor(target_logits) + + if not guided_processors: + return scores, raw_logprobs + + scores_3d = scores.view(batch_size, num_expand, -1) + await guided_helper.apply_serial_bitmask( + scores_3d, guided_processors, draft_token_ids, self.num_spec_tokens) + + scores = scores_3d.view(batch_size * num_expand, -1) + return scores, raw_logprobs + def _forward_impl(self, inputs: ModelInputs): """Forward impl.""" with self.draft_context(): output = self.proposer._forward(inputs, cache_engine=self.cache_engine) return output - async def async_sampling_logits(self, model_inputs: 'ModelInputs', extra_inputs: ARSpecExtraInputs, + async def async_sampling_logits(self, model_inputs: ModelInputs, extra_inputs: ARSpecExtraInputs, sampling_inputs: SamplingInputs): """Sample target logits and run rejection sampling.""" with record_function('spec_rejection_sampling'): @@ -492,9 +620,19 @@ def _update_dp_model_inputs(inputs: ModelInputs, dp_meta: DPMeta, padding_batch_ # remaining speculative forwards. output_draft_ids = inputs.input_ids.new_zeros(inputs.seq_length.size(0), self.num_spec_tokens) else: + # Fork guided processors for draft model. + draft_guided_processors = None + orig_processors = self.guided_helper.get_processors( + sampling_inputs.session_ctx if sampling_inputs else None, + sampling_inputs.response_formats if sampling_inputs else None) + if orig_processors: + draft_guided_processors = {idx: proc.fork() + for idx, proc in orig_processors.items()} + loop_count = self.num_spec_tokens - 1 - draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs( - outputs, inputs, extra_inputs) + draft_token_ids, model_metas, target_hidden_states = await self.proposer.get_outputs( + outputs, inputs, extra_inputs, + guided_processors=draft_guided_processors) draft_tokens_li = [draft_token_ids] if loop_count > 0: inputs = self.proposer.update_inputs_decoding(inputs, extra_inputs, draft_token_ids.transpose(0, 1), @@ -507,7 +645,9 @@ def _update_dp_model_inputs(inputs: ModelInputs, dp_meta: DPMeta, padding_batch_ for loop_idx in range(loop_count): inputs = _update_dp_model_inputs(inputs, dp_meta, padding_batch_size) outputs = self._forward_impl(inputs) - draft_token_ids, model_metas, target_hidden_states = self.proposer.get_outputs(outputs, inputs) + draft_token_ids, model_metas, target_hidden_states = await self.proposer.get_outputs( + outputs, inputs, + guided_processors=draft_guided_processors) draft_tokens_li.append(draft_token_ids) if loop_idx < loop_count - 1: step_seqlens = inputs.seq_length.new_ones(inputs.seq_length.size(0)) diff --git a/requirements/common.txt b/requirements/common.txt index b8de9a8c7c..2f97cc963c 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -20,4 +20,4 @@ shortuuid tiktoken transformers >= 4.56.0, != 5.0.*, != 5.1.*, != 5.2.*, != 5.3.*, != 5.4.*, != 5.5.0, !=5.7.*, !=5.8.*, !=5.9.* uvicorn -xgrammar +xgrammar >= 0.1.33 diff --git a/tests/pytorch/spec_decode/test_guided_spec_decode.py b/tests/pytorch/spec_decode/test_guided_spec_decode.py new file mode 100644 index 0000000000..e5f638ead7 --- /dev/null +++ b/tests/pytorch/spec_decode/test_guided_spec_decode.py @@ -0,0 +1,713 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Unit tests for MTP (speculative decoding) + Guided Decoding integration. + +1. _expand/_slice_sampling_inputs non-tensor field handling + - response_formats, session_ctx must be repeated/sliced alongside tensor + fields when spec decode expands or slices the batch dimension. + - Boundary cases: num_tokens=1, empty response_formats, None session_ctx. + +2. Grammar state management (fork / rollback / accept_string) + - fork() produces an independent GrammarMatcher snapshot. + - accept_string() advances state; rollback(n) reverts it. + - Fork-based strategy: fork before draft generation, accept final tokens + on original matcher only. + +3. Positional-serial grammar mask + - Prove that different spec positions need different grammar masks. + +4. Draft model grammar masking + - Masked argmax picks valid tokens; unmasked may pick invalid. + +5. Grammar state after rejection sampling + - After rejection, grammar state must reflect exactly the accepted tokens. + +NOTE: With BPE tokenizers (like Qwen), accept_token() may not visibly +advance grammar state because individual BPE tokens can be multi-character +partial bytes. Tests that need to observe state transitions use +accept_string() which reliably advances the grammar character-by-character. +""" +import pytest +import torch +import xgrammar as xgr + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.spec_decode.spec_agent import ( + _expand_sampling_inputs, + _slice_sampling_inputs, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_QWEN_MODEL = 'Qwen/Qwen3.5-0.8B' + + +@pytest.fixture(scope='module') +def tokenizer_info(): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(_QWEN_MODEL, trust_remote_code=True) + return xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=tokenizer.vocab_size) + + +@pytest.fixture(scope='module') +def compiler(tokenizer_info): + return xgr.GrammarCompiler(tokenizer_info) + + +def _json_matcher(compiler, schema): + compiled = compiler.compile_json_schema(schema) + return xgr.GrammarMatcher(compiled, terminate_without_stop_token=True) + + +def _allowed_ids(bitmask, row=0): + """Extract allowed token IDs from an xgrammar bitmask. + + Use ``xgr.allocate_token_bitmask(batch_size, vocab_size)`` or + ``xgr.get_bitmask_shape(batch_size, vocab_size)`` to obtain / query the + bitmask shape — do NOT hard-code ``ceil(vocab_size / 32)`` yourself. + + The internal bit-packing format may change across xgrammar versions. + Current format: int32 words, each bit maps to one token. + We decode bit-by-bit so the helper stays format-agnostic. + """ + bm_np = bitmask.numpy() + ids = set() + for word_idx in range(bm_np.shape[1]): + word = int(bm_np[row, word_idx]) & 0xFFFFFFFF # treat as unsigned + if word != 0: + for bit in range(32): + if word & (1 << bit): + ids.add(word_idx * 32 + bit) + return ids + + +# =========================================================================== +# 1. _expand_sampling_inputs — non-tensor field expansion +# =========================================================================== + + +class TestExpandSamplingInputsNonTensor: + """_expand_sampling_inputs must repeat non-tensor fields (response_formats, + session_ctx, logits_processors, session_to_cleanup) so that every expanded + position carries the same guided-decoding context as its source batch + element.""" + + # ---- response_formats (tuple) ---- + + def test_response_formats_repeated(self): + fmt = {'type': 'json_schema', 'json_schema': {'name': 't', 'schema': {'type': 'object'}}} + si = SamplingInputs( + max_top_k=1, + batch_size=2, + response_formats=(fmt, None), + ) + expanded = _expand_sampling_inputs(si, num_tokens=3) + # batch_size = 2 × 3 = 6 + assert expanded.batch_size == 6 + assert len(expanded.response_formats) == 6 + # [fmt, fmt, fmt, None, None, None] + assert expanded.response_formats[:3] == (fmt, fmt, fmt) + assert expanded.response_formats[3:] == (None, None, None) + + def test_response_formats_mixed_batch(self): + guided = {'type': 'json_schema', 'json_schema': {'name': 't', 'schema': {'type': 'object'}}} + si = SamplingInputs( + max_top_k=1, + batch_size=3, + response_formats=(guided, None, guided), + ) + expanded = _expand_sampling_inputs(si, num_tokens=2) + assert len(expanded.response_formats) == 6 + # [guided, guided, None, None, guided, guided] + assert expanded.response_formats[0] == guided + assert expanded.response_formats[1] == guided + assert expanded.response_formats[2] is None + assert expanded.response_formats[3] is None + assert expanded.response_formats[4] == guided + assert expanded.response_formats[5] == guided + + def test_response_formats_empty(self): + si = SamplingInputs(max_top_k=1, batch_size=2, response_formats=()) + expanded = _expand_sampling_inputs(si, num_tokens=3) + assert expanded.response_formats == () + + def test_num_tokens_1_identity(self): + fmt = {'type': 'json_schema', 'json_schema': {'name': 't', 'schema': {'type': 'object'}}} + si = SamplingInputs(max_top_k=1, batch_size=2, response_formats=(fmt, None)) + result = _expand_sampling_inputs(si, num_tokens=1) + assert result is si + + def test_session_ctx_none(self): + si = SamplingInputs(max_top_k=1, batch_size=2, session_ctx=None) + expanded = _expand_sampling_inputs(si, num_tokens=3) + assert expanded.session_ctx is None + + # ---- tensor fields still correct ---- + + def test_tensor_fields_still_expanded(self): + si = SamplingInputs( + max_top_k=1, + batch_size=2, + temperature=torch.tensor([0.5, 1.0]), + top_k=torch.tensor([1, 10]), + random_offsets=torch.tensor([100, 200]), + response_formats=({'type': 'json_schema'}, None), + session_ctx=[{'session_id': 1, 'seq_id': 10}, {'session_id': 2, 'seq_id': 20}], + ) + expanded = _expand_sampling_inputs(si, num_tokens=3) + # Tensor fields + assert expanded.temperature.shape[0] == 6 + torch.testing.assert_close(expanded.temperature, torch.tensor([0.5, 0.5, 0.5, 1.0, 1.0, 1.0])) + torch.testing.assert_close(expanded.random_offsets, torch.tensor([100, 101, 102, 200, 201, 202])) + # Non-tensor fields + assert len(expanded.response_formats) == 6 + assert len(expanded.session_ctx) == 6 + + +# =========================================================================== +# 2. _slice_sampling_inputs — non-tensor field slicing +# =========================================================================== + + +class TestSliceSamplingInputsNonTensor: + """After expansion, _slice_sampling_inputs must also slice non-tensor + fields back to the expected size.""" + + def _make_expanded(self, num_tokens=3, batch_size=2): + """Create an already-expanded SamplingInputs (as if expansion handled + non-tensor fields correctly).""" + total = batch_size * num_tokens + fmt = {'type': 'json_schema', 'json_schema': {'name': 't', 'schema': {'type': 'object'}}} + return SamplingInputs( + max_top_k=1, + batch_size=total, + temperature=torch.ones(total), + response_formats=tuple([fmt] * num_tokens + [None] * num_tokens), + session_ctx=[{'session_id': 1, 'seq_id': 10}] * num_tokens + + [{'session_id': 2, 'seq_id': 20}] * num_tokens, + ) + + def test_slice_is_last_true(self): + """is_last=True → one element per original batch (the last token).""" + si = self._make_expanded(num_tokens=3, batch_size=2) + sliced = _slice_sampling_inputs(si, num_tokens=3, is_last=True) + assert sliced.batch_size == 2 + assert len(sliced.response_formats) == 2 + assert len(sliced.session_ctx) == 2 + # Last token per batch: indices 2 and 5 + assert sliced.response_formats[0] == si.response_formats[2] + assert sliced.response_formats[1] is None + assert sliced.session_ctx[0] == {'session_id': 1, 'seq_id': 10} + assert sliced.session_ctx[1] == {'session_id': 2, 'seq_id': 20} + + def test_slice_is_last_false(self): + """is_last=False → num_tokens-1 elements per original batch.""" + si = self._make_expanded(num_tokens=3, batch_size=2) + sliced = _slice_sampling_inputs(si, num_tokens=3, is_last=False) + assert sliced.batch_size == 4 # 2 * (3-1) + assert len(sliced.response_formats) == 4 + assert len(sliced.session_ctx) == 4 + # First 2 tokens per batch: [fmt, fmt, None, None] + assert sliced.response_formats[0] == si.response_formats[0] + assert sliced.response_formats[1] == si.response_formats[1] + assert sliced.response_formats[2] is None + assert sliced.response_formats[3] is None + + def test_slice_num_tokens_1_identity(self): + si = SamplingInputs( + max_top_k=1, + batch_size=2, + temperature=torch.ones(2), + response_formats=({'type': 'json_schema'}, None), + ) + result = _slice_sampling_inputs(si, num_tokens=1) + assert result is si + + +# =========================================================================== +# 3. Grammar state management — fork +# =========================================================================== + + +class TestGrammarFork: + """Fork() creates an independent GrammarMatcher snapshot.""" + + def test_fork_is_independent_object(self, compiler): + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + forked = matcher.fork() + assert forked is not matcher + + def test_accept_string_on_fork_does_not_affect_original(self, compiler, tokenizer_info): + """accept_string on a forked matcher does not change the original.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # Record original's allowed tokens + bm_orig = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_orig, 0) + orig_allowed = _allowed_ids(bm_orig) + + # Fork and advance fork via accept_string (reliably changes state) + forked = original.fork() + forked.accept_string('{"') + + # Original should still be at initial state + bm_check = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_check, 0) + check_allowed = _allowed_ids(bm_check) + assert check_allowed == orig_allowed + + def test_fork_chain_for_spec_positions(self, compiler, tokenizer_info): + """Simulate the spec-decode fork pattern: for each speculative + position, fork from current state, then advance current state + via accept_string. + + Each fork captures the grammar state at position i. + The original matcher stays at the initial state. + """ + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + num_spec = 4 + forks = [] + current = original.fork() + + # Advance through JSON construction using accept_string + advance_strings = ['{"', 'name', '"', ':', '"'] + for i in range(min(num_spec, len(advance_strings))): + pos_fork = current.fork() + forks.append(pos_fork) + current.accept_string(advance_strings[i]) + + # Original unchanged + bm_orig = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_orig, 0) + orig_allowed = _allowed_ids(bm_orig) + + bm_fresh = xgr.allocate_token_bitmask(1, vocab_size) + _json_matcher(compiler, schema).fill_next_token_bitmask(bm_fresh, 0) + fresh_allowed = _allowed_ids(bm_fresh) + + assert orig_allowed == fresh_allowed + + # Forks should have progressively different states + fork_allowed_sets = [] + for fk in forks: + bm = xgr.allocate_token_bitmask(1, vocab_size) + fk.fill_next_token_bitmask(bm, 0) + fork_allowed_sets.append(_allowed_ids(bm)) + + # First and last must differ (grammar state advanced via accept_string) + assert fork_allowed_sets[0] != fork_allowed_sets[-1], ( + 'Fork chain must capture progressively different grammar states' + ) + + +# =========================================================================== +# 4. Grammar state management — rollback +# =========================================================================== + + +class TestGrammarRollback: + """Rollback(n) reverts the last n accept_string / accept_token calls. + + Note: With BPE tokenizers, rollback counts the number of *accept* calls + (not characters). accept_string('abc') counts as 1 accept step. + """ + + def test_rollback_one_accept_string_step(self, compiler, tokenizer_info): + """Rollback 1 accept_string step reverts to previous state.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # Record initial state + bm0 = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm0, 0) + initial = _allowed_ids(bm0) + + # Advance with accept_string + matcher.accept_string('{"') + + # Rollback 1 step + matcher.rollback(1) + + # Should be back to initial + bm1 = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm1, 0) + assert _allowed_ids(bm1) == initial + + def test_rollback_partial(self, compiler, tokenizer_info): + """Accept 3 steps, rollback 1 → state equals state after 2 steps.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # Accept 2 steps, record state + matcher.accept_string('{"') + matcher.accept_string('name') + + bm_after_2 = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm_after_2, 0) + expected = _allowed_ids(bm_after_2) + + # Accept 1 more + matcher.accept_string('"') + + # Rollback 1 + matcher.rollback(1) + + bm_check = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm_check, 0) + assert _allowed_ids(bm_check) == expected + + def test_rollback_after_partial_rejection(self, compiler, tokenizer_info): + """Simulate rejection sampling: advance N steps, rollback to K < N.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # Advance 4 steps (simulating draft generation) + advance_steps = ['{"', 'name', '"', ':'] + for s in advance_steps: + matcher.accept_string(s) + + # Rejection: only 1 step accepted → rollback 3 + num_accepted = 1 + rollback_count = len(advance_steps) - num_accepted + matcher.rollback(rollback_count) + + # State should match a matcher that accepted only 1 step + reference = _json_matcher(compiler, schema) + reference.accept_string(advance_steps[0]) + + bm_ref = xgr.allocate_token_bitmask(1, vocab_size) + reference.fill_next_token_bitmask(bm_ref, 0) + bm_actual = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm_actual, 0) + + assert _allowed_ids(bm_actual) == _allowed_ids(bm_ref) + + +# =========================================================================== +# 5. Grammar state management — fork-based strategy for spec decode +# =========================================================================== + + +class TestGrammarForkStrategy: + """The recommended approach: use fork() during draft generation and + target verification, then accept only the final output tokens on the + original matcher.""" + + def test_fork_strategy_target_verification(self, compiler, tokenizer_info): + """Target model verification uses forked matchers to apply position- + dependent grammar masks without mutating the original. + + After rejection sampling, accept the final output on the original. + """ + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + num_spec = 3 + + # Record original's initial state + bm_orig = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_orig, 0) + orig_allowed = _allowed_ids(bm_orig) + + # Phase 1: Fork-based verification masks (don't mutate original) + advance_strings = ['{"', 'name', '"'] + forks = [] + current = original.fork() + for i in range(num_spec): + pos_fork = current.fork() + forks.append(pos_fork) + if i < len(advance_strings): + current.accept_string(advance_strings[i]) + + # Original must be unchanged + bm_check = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_check, 0) + assert _allowed_ids(bm_check) == orig_allowed + + # Phase 2: After rejection sampling, accept output on original + # Simulate: all spec tokens accepted + bonus + output_strings = ['{"', 'name', '"', ':'] + for s in output_strings: + original.accept_string(s) + + # Original should have advanced + bm_final = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_final, 0) + final_allowed = _allowed_ids(bm_final) + assert final_allowed != orig_allowed + + +# =========================================================================== +# 6. Positional-serial grammar mask +# =========================================================================== + + +class TestPositionalSerialGrammarMask: + """In spec decode, the grammar mask for position i depends on what tokens + were accepted at positions 0..i-1. + + Applying the same mask to all positions (parallel mask) is INCORRECT for position 1+. + """ + + def test_mask_changes_after_accept_string(self, compiler, tokenizer_info): + """After accept_string, the allowed token set changes.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + bm0 = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm0, 0) + before = _allowed_ids(bm0) + + matcher.accept_string('{"') + + bm1 = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm1, 0) + after = _allowed_ids(bm1) + + assert before != after, 'Grammar mask must change after accept_string' + + def test_parallel_mask_incorrect_for_later_positions(self, compiler, tokenizer_info): + """Parallel (same mask for all positions) differs from serial + (position-dependent mask). + + This proves spec decode MUST use serial mask application. + """ + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + num_spec = 3 + + # Parallel: same initial mask for all positions + parallel_bm = xgr.allocate_token_bitmask(num_spec, vocab_size) + matcher.fill_next_token_bitmask(parallel_bm, 0) + for pos in range(1, num_spec): + parallel_bm.numpy()[pos] = parallel_bm.numpy()[0] + + # Serial: advance grammar state per position using accept_string + serial_bm = xgr.allocate_token_bitmask(num_spec, vocab_size) + temp = _json_matcher(compiler, schema) + advance_strings = ['{"', 'name', '"'] + for pos in range(num_spec): + temp.fill_next_token_bitmask(serial_bm, pos) + if pos < len(advance_strings): + # Pick an allowed token from the mask, then accept_string + # to advance grammar state for next position + temp.accept_string(advance_strings[pos]) + + # At position 1+, the masks must differ from the parallel (initial) mask + pos0_parallel = _allowed_ids(parallel_bm, row=0) + pos0_serial = _allowed_ids(serial_bm, row=0) + assert pos0_parallel == pos0_serial, 'Position 0 masks should match' + + pos1_parallel = _allowed_ids(parallel_bm, row=1) + pos1_serial = _allowed_ids(serial_bm, row=1) + assert pos1_parallel != pos1_serial, ( + 'Parallel mask (same for all positions) is incorrect for pos 1+' + ) + + def test_serial_mask_with_fork_per_position(self, compiler, tokenizer_info): + """Each speculative position uses a fork to get the correct mask + without mutating the original matcher.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + num_spec = 4 + + advance_strings = ['{"', 'name', '"', ':'] + position_forks = [] + current = original.fork() # work on a fork, not original directly + for pos in range(num_spec): + fork_at_pos = current.fork() + position_forks.append(fork_at_pos) + if pos < len(advance_strings): + current.accept_string(advance_strings[pos]) + + # Verify: each fork captures grammar state at position i + allowed_per_pos = [] + for fk in position_forks: + bm = xgr.allocate_token_bitmask(1, vocab_size) + fk.fill_next_token_bitmask(bm, 0) + allowed_per_pos.append(_allowed_ids(bm)) + + # First and last must differ (grammar state advanced via accept_string) + assert allowed_per_pos[0] != allowed_per_pos[-1], ( + 'Position forks must capture different grammar states' + ) + + # Original matcher should be unchanged + bm_orig = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_orig, 0) + bm_fresh = xgr.allocate_token_bitmask(1, vocab_size) + _json_matcher(compiler, schema).fill_next_token_bitmask(bm_fresh, 0) + assert _allowed_ids(bm_orig) == _allowed_ids(bm_fresh) + + +# =========================================================================== +# 7. Draft model grammar masking (logic validation) +# =========================================================================== + + +class TestDraftModelGrammarMasking: + """Draft model must apply grammar mask before argmax to produce + grammatically valid draft tokens.""" + + def test_unmasked_argmax_may_pick_invalid_token(self, compiler, tokenizer_info): + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + bm = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm, 0) + allowed = _allowed_ids(bm) + disallowed = set(range(vocab_size)) - allowed + + if not disallowed: + pytest.skip('all tokens allowed (unlikely with real tokenizer)') + + # Create logits that strongly prefer a disallowed token + logits = torch.full((1, vocab_size), -100.0) + bad_token = int(list(disallowed)[0]) + logits[0, bad_token] = 100.0 + + unmasked_choice = logits.argmax(dim=-1).item() + assert unmasked_choice not in allowed + + def test_masked_argmax_picks_valid_token(self, compiler, tokenizer_info): + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + bm = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm, 0) + allowed = _allowed_ids(bm) + disallowed = set(range(vocab_size)) - allowed + + logits = torch.full((1, vocab_size), -100.0) + if disallowed: + logits[0, int(list(disallowed)[0])] = 100.0 + + xgr.apply_token_bitmask_inplace(logits, bm) + masked_choice = logits.argmax(dim=-1).item() + assert masked_choice in allowed + + def test_draft_chain_all_valid(self, compiler, tokenizer_info): + """Generate a chain of draft tokens with grammar mask at each step. + + Every token should be valid at its position. + """ + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + num_spec = 4 + + for step in range(num_spec): + bm = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm, 0) + allowed = _allowed_ids(bm) + if len(allowed) == 0: + break # grammar terminated + + logits = torch.randn(1, vocab_size) + xgr.apply_token_bitmask_inplace(logits, bm) + token = logits.argmax(dim=-1).item() + assert token in allowed, f'Step {step}: token {token} not in allowed set' + matcher.accept_token(token) + + +# =========================================================================== +# 8. Grammar state after rejection sampling +# =========================================================================== + + +class TestGrammarStateAfterRejection: + """After rejection sampling, the grammar matcher's state must reflect + exactly the accepted tokens.""" + + def test_fork_strategy_rejection_output(self, compiler, tokenizer_info): + """Production code accepts rejection-sampled output + bonus token on + the original (un-forked) matcher — no rollback needed because forks are + used for spec positions and originals stay at the pre-step state. + + This test models the same logic: accept exactly the rejection-sampled + draft tokens (n_valid_draft) plus the bonus token on the original + matcher. + """ + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # Simulate: 3 spec steps, only 1 accepted + 1 bonus + # Production code: accept the 1 accepted draft + 1 bonus on original + accepted_strings = ['{"', ':'] # 1 accepted draft + 1 bonus + for s in accepted_strings: + original.accept_string(s) + + # Verify + reference = _json_matcher(compiler, schema) + for s in accepted_strings: + reference.accept_string(s) + + bm_orig = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_orig, 0) + bm_ref = xgr.allocate_token_bitmask(1, vocab_size) + reference.fill_next_token_bitmask(bm_ref, 0) + + assert _allowed_ids(bm_orig) == _allowed_ids(bm_ref) + + def test_fork_strategy_all_accepted(self, compiler, tokenizer_info): + """All draft tokens accepted → accept all + bonus on original.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # All spec tokens accepted + bonus + accepted_strings = ['{"', 'name', '"', ':'] # 3 spec + 1 bonus + for s in accepted_strings: + original.accept_string(s) + + # Verify state matches reference + reference = _json_matcher(compiler, schema) + for s in accepted_strings: + reference.accept_string(s) + + bm_orig = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_orig, 0) + bm_ref = xgr.allocate_token_bitmask(1, vocab_size) + reference.fill_next_token_bitmask(bm_ref, 0) + + assert _allowed_ids(bm_orig) == _allowed_ids(bm_ref) + + def test_fork_strategy_partial_rejection(self, compiler, tokenizer_info): + """Partial rejection: only some draft tokens accepted. + Original matcher accepts exactly the final output tokens.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # Simulate: 3 spec steps, only 1 accepted + 1 bonus + final_strings = ['{"', ':'] # 1 accepted draft + 1 bonus + for s in final_strings: + original.accept_string(s) + + # Verify + reference = _json_matcher(compiler, schema) + for s in final_strings: + reference.accept_string(s) + + bm_orig = xgr.allocate_token_bitmask(1, vocab_size) + original.fill_next_token_bitmask(bm_orig, 0) + bm_ref = xgr.allocate_token_bitmask(1, vocab_size) + reference.fill_next_token_bitmask(bm_ref, 0) + + assert _allowed_ids(bm_orig) == _allowed_ids(bm_ref) diff --git a/tests/pytorch/spec_decode/test_guided_spec_integration.py b/tests/pytorch/spec_decode/test_guided_spec_integration.py new file mode 100644 index 0000000000..323309b58b --- /dev/null +++ b/tests/pytorch/spec_decode/test_guided_spec_integration.py @@ -0,0 +1,1177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Integration tests for MTP (speculative decoding) + Guided Decoding. + +These tests exercise the interaction between guided decoding and speculative +decoding at a higher level than the unit tests in test_guided_spec_decode.py. +They test the *pipeline* logic — the position-serial grammar masking, +fork-based matcher management, and grammar state consistency after rejection +sampling — without requiring a GPU or actual model weights. + +Key scenarios: +1. Position-serial grammar mask via GuidedDecodingManager (matches direct xgr). +2. Fork-based target verification: multiple forks are independent. +3. Simulated _guided_spec_logits_process: all positions masked, mixed batch. +4. Grammar state after rejection sampling: original matchers accept exactly + the rejection-sampled output tokens. +5. End-to-end: draft → target verification → rejection → grammar state. +6. Batch-level grammar mask: mixed guided/unguided sequences. +""" +import asyncio + +import pytest +import torch +import xgrammar as xgr + +from lmdeploy.pytorch.engine.guided_process import GuidedDecodingManager +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.spec_decode.guided_spec_helper import GuidedSpecHelper + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_QWEN_MODEL = 'Qwen/Qwen2.5-7B-Instruct' + + +@pytest.fixture(scope='module') +def tokenizer(): + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained(_QWEN_MODEL, trust_remote_code=True) + + +@pytest.fixture(scope='module') +def tokenizer_info(tokenizer): + return xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=tokenizer.vocab_size) + + +@pytest.fixture(scope='module') +def compiler(tokenizer_info): + return xgr.GrammarCompiler(tokenizer_info) + + +@pytest.fixture(scope='module') +def guided_manager(tokenizer): + return GuidedDecodingManager(tokenizer, vocab_size=tokenizer.vocab_size) + + +def _json_matcher(compiler, schema): + compiled = compiler.compile_json_schema(schema) + return xgr.GrammarMatcher(compiled, terminate_without_stop_token=True) + + +def _regex_matcher(compiler, pattern): + compiled = compiler.compile_regex(pattern) + return xgr.GrammarMatcher(compiled, terminate_without_stop_token=True) + + +def _allowed_ids(bitmask, row=0): + """Extract allowed token IDs from an xgrammar bitmask.""" + bm_np = bitmask.numpy() + ids = set() + for word_idx in range(bm_np.shape[1]): + word = int(bm_np[row, word_idx]) & 0xFFFFFFFF + if word != 0: + for bit in range(32): + if word & (1 << bit): + ids.add(word_idx * 32 + bit) + return ids + + +def _make_sampling_inputs(batch_size, response_formats=None, session_ctx=None): + """Create a minimal SamplingInputs for testing.""" + return SamplingInputs( + max_top_k=1, + batch_size=batch_size, + response_formats=response_formats or (), + session_ctx=session_ctx, + ) + + +# =========================================================================== +# 1. Position-serial grammar mask via GuidedDecodingManager +# =========================================================================== + + +class TestPositionSerialGrammarMaskViaManager: + """Verify that applying grammar masks position-by-position using + GuidedDecodingManager methods produces correct per-position masks.""" + + def test_manager_methods_match_direct_xgr(self, compiler, tokenizer_info, guided_manager): + """GuidedDecodingManager methods produce the same results as direct + xgrammar calls.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # Via manager + bm_manager = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(matcher, bm_manager, 0) + allowed_manager = _allowed_ids(bm_manager) + + # Direct + bm_direct = xgr.allocate_token_bitmask(1, vocab_size) + matcher.fill_next_token_bitmask(bm_direct, 0) + allowed_direct = _allowed_ids(bm_direct) + + assert allowed_manager == allowed_direct + + +# =========================================================================== +# 2. Fork-based target verification +# =========================================================================== + + +class TestForkBasedTargetVerification: + """Verify that _guided_spec_logits_process uses forked matchers, leaving + originals untouched.""" + + def test_multiple_forks_independent(self, compiler, tokenizer_info, guided_manager): + """Multiple forks from the same original are independent.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + fork1 = original.fork() + fork2 = original.fork() + + # Advance fork1 by 2 steps + for _ in range(2): + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(fork1, bm, 0) + logits = torch.randn(1, vocab_size) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + guided_manager.accept_token(fork1, token) + + # fork2 should still be at original state + bm_fork2 = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(fork2, bm_fork2, 0) + bm_orig = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_orig, 0) + + assert _allowed_ids(bm_fork2) == _allowed_ids(bm_orig) + + +# =========================================================================== +# 3. Simulated _guided_spec_logits_process +# =========================================================================== + + +class TestSimulatedGuidedSpecLogitsProcess: + """Simulate the _guided_spec_logits_process method's logic to verify + correctness without requiring a full SpecModelAgent instance.""" + + def _guided_spec_logits_process_sim( + self, + target_logits: torch.Tensor, + guided_manager: GuidedDecodingManager, + guided_processors: dict, + batch_size: int, + num_expand: int, + vocab_size: int, + ): + """Simplified version of _guided_spec_logits_process that applies + position-serial grammar mask.""" + # Reshape to [batch_size, num_expand, vocab_size] + scores_3d = target_logits.clone().view(batch_size, num_expand, -1) + + # Fork matchers + forked = {idx: proc.fork() for idx, proc in guided_processors.items()} + + for pos in range(num_expand): + guided_bitmask = guided_manager.allocate_batched_bitmap(batch_size) + for idx, fork_proc in forked.items(): + guided_manager.fill_bitmap(fork_proc, guided_bitmask, idx) + pos_logits = scores_3d[:, pos, :] + guided_manager.apply_batched_bitmap(pos_logits, guided_bitmask) + scores_3d[:, pos, :] = pos_logits + + # NOTE: The production code advances forks with draft tokens (not + # argmax). This simulation uses argmax as a stand-in because there + # is no draft model. The mask→apply→accept loop structure and + # per-position grammar constraint are what matter here. + pos_token_ids = pos_logits.argmax(dim=-1) + for idx, fork_proc in forked.items(): + guided_manager.accept_token(fork_proc, pos_token_ids[idx].item()) + + return scores_3d.view(batch_size * num_expand, -1) + + def test_all_positions_masked(self, compiler, tokenizer_info, guided_manager): + """All positions (including bonus) must have grammar mask applied.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + batch_size = 1 + num_expand = 4 # 3 spec + 1 bonus + + target_logits = torch.randn(batch_size * num_expand, vocab_size) + guided_processors = {0: matcher} + + processed = self._guided_spec_logits_process_sim( + target_logits, guided_manager, guided_processors, + batch_size, num_expand, vocab_size, + ) + + # Verify each position's chosen token is in the allowed set + scores_3d = processed.view(batch_size, num_expand, -1) + reference = _json_matcher(compiler, schema) + for pos in range(num_expand): + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(reference, bm, 0) + allowed = _allowed_ids(bm) + token = scores_3d[0, pos].argmax().item() + assert token in allowed, f'Position {pos}: token {token} not grammar-valid' + guided_manager.accept_token(reference, token) + + def test_mixed_batch_only_guided_masked(self, compiler, tokenizer_info, guided_manager): + """In a mixed batch, only guided sequences should have their logits + masked; unguided sequences should be unaffected.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + batch_size = 2 + num_expand = 3 # 2 spec + 1 bonus + + # Only batch element 0 is guided; element 1 is not + target_logits = torch.randn(batch_size * num_expand, vocab_size) + # Make element 1's logits have a strong signal that would be masked + # Element 1's positions in the flat layout: indices 1, 3, 5 + target_logits[1, 0] = 100.0 + target_logits[3, 0] = 100.0 + target_logits[5, 0] = 100.0 + + guided_processors = {0: matcher} # Only idx 0 + + processed = self._guided_spec_logits_process_sim( + target_logits, guided_manager, guided_processors, + batch_size, num_expand, vocab_size, + ) + + # Element 1's logits should be unchanged (no grammar mask applied) + # Check that the strong signal is preserved + scores_3d = processed.view(batch_size, num_expand, -1) + # Element 1, position 0 — no masking should have been applied + assert scores_3d[1, 0, 0] == 100.0, 'Unguided sequence should not be masked' + + +# =========================================================================== +# 4. Simulated rejection sampling + grammar state +# =========================================================================== + + +class TestSimulatedRejectionSamplingGrammarState: + """Simulate the rejection sampling + grammar state management logic from + _rejection_sampling to verify grammar state consistency.""" + + def _simulate_rejection_greedy( + self, + target_logits_3d: torch.Tensor, # [batch, num_spec, vocab] + draft_token_ids: torch.Tensor, # [batch, num_spec] + num_spec_tokens: int, + batch_size: int, + ): + """Simplified greedy rejection sampling.""" + target_argmax = target_logits_3d.argmax(dim=-1) # [batch, num_spec] + masks = draft_token_ids == target_argmax + range_data = torch.arange(num_spec_tokens, device=draft_token_ids.device)[None, :] + equals = (masks.cumsum(dim=1) - 1) == range_data + num_rejected_tokens = num_spec_tokens - equals.sum(dim=1) + first_diff_indices = torch.argmin(equals.int(), dim=1, keepdim=True) + keeps = range_data.repeat(batch_size, 1) <= first_diff_indices + keeps = keeps | equals + output_token_ids = torch.where(keeps, target_argmax, -1) + # bonus (not relevant for grammar state here) + return output_token_ids, num_rejected_tokens + + def test_grammar_state_all_accepted_greedy(self, compiler, tokenizer_info, guided_manager): + """All draft tokens accepted (greedy) → grammar state reflects all + accepted tokens + bonus.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + batch_size = 1 + num_spec_tokens = 3 + + # Generate draft tokens with grammar mask + fork = original.fork() + draft_tokens = [] + for _ in range(num_spec_tokens): + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(fork, bm, 0) + logits = torch.randn(1, vocab_size) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + guided_manager.accept_token(fork, token) + draft_tokens.append(token) + + draft_token_ids = torch.tensor([draft_tokens], dtype=torch.long) + + # Target model agrees with all draft tokens (greedy match) + target_logits_3d = torch.zeros(batch_size, num_spec_tokens, vocab_size) + for i, tid in enumerate(draft_tokens): + target_logits_3d[0, i, tid] = 100.0 + + output_token_ids, num_rejected = self._simulate_rejection_greedy( + target_logits_3d, draft_token_ids, num_spec_tokens, batch_size, + ) + + assert num_rejected[0].item() == 0, 'All draft tokens should be accepted' + + # Accept output tokens on original + for pos in range(num_spec_tokens): + tid = output_token_ids[0, pos].item() + if tid >= 0: + guided_manager.accept_token(original, tid) + # Accept bonus token (simulate) + bonus_token = draft_tokens[0] # placeholder + guided_manager.accept_token(original, bonus_token) + + # Original should have advanced + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm, 0) + allowed = _allowed_ids(bm) + # At minimum, the allowed set should be non-empty (grammar hasn't terminated) + # and should differ from the initial state + bm_initial = guided_manager.allocate_batched_bitmap(1) + fresh = _json_matcher(compiler, schema) + guided_manager.fill_bitmap(fresh, bm_initial, 0) + allowed_initial = _allowed_ids(bm_initial) + assert allowed != allowed_initial, 'Grammar should have advanced from initial state' + + def test_grammar_state_partial_rejection_greedy(self, compiler, tokenizer_info, guided_manager): + """Partial rejection (greedy) → grammar state reflects only accepted + tokens + replacement + bonus.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + batch_size = 1 + num_spec_tokens = 3 + + # Generate draft tokens + fork = original.fork() + draft_tokens = [] + for _ in range(num_spec_tokens): + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(fork, bm, 0) + logits = torch.randn(1, vocab_size) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + guided_manager.accept_token(fork, token) + draft_tokens.append(token) + + draft_token_ids = torch.tensor([draft_tokens], dtype=torch.long) + + # Target model disagrees at position 1 + target_logits_3d = torch.zeros(batch_size, num_spec_tokens, vocab_size) + target_logits_3d[0, 0, draft_tokens[0]] = 100.0 # agree at pos 0 + # Position 1: target picks a different (but grammar-valid) token + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm, 0) + allowed = _allowed_ids(bm) + # Find a valid token that differs from draft + replacement_token = None + for t in allowed: + if t != draft_tokens[1]: + replacement_token = t + break + assert replacement_token is not None, 'Need a replacement token for the test' + target_logits_3d[0, 1, replacement_token] = 100.0 + # Position 2 doesn't matter (rejected) + target_logits_3d[0, 2, 0] = 100.0 + + output_token_ids, num_rejected = self._simulate_rejection_greedy( + target_logits_3d, draft_token_ids, num_spec_tokens, batch_size, + ) + + assert num_rejected[0].item() == 2, '2 tokens should be rejected (pos 1 and 2)' + + # Accept only the valid output tokens on original + n_valid_draft = num_spec_tokens - num_rejected[0].item() + for pos in range(n_valid_draft): + tid = output_token_ids[0, pos].item() + if tid >= 0: + guided_manager.accept_token(original, tid) + # Accept replacement/bonus token + guided_manager.accept_token(original, replacement_token) + + # Verify grammar state: should have accepted 2 tokens total + # (draft[0] + replacement), not the rejected tokens + # Build reference + reference = _json_matcher(compiler, schema) + guided_manager.accept_token(reference, draft_tokens[0]) + guided_manager.accept_token(reference, replacement_token) + + bm_actual = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_actual, 0) + bm_ref = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(reference, bm_ref, 0) + + assert _allowed_ids(bm_actual) == _allowed_ids(bm_ref) + + +# =========================================================================== +# 5. End-to-end: draft → target verification → rejection → grammar state +# =========================================================================== + + +class TestEndToEndGuidedSpecDecode: + """End-to-end simulation of the guided + spec decode pipeline. + + This test exercises the complete flow: + 1. Draft model generates tokens with forked grammar mask + 2. Target model verifies with position-serial grammar mask + 3. Rejection sampling determines accepted tokens + 4. Original matchers accept the final output tokens + 5. Grammar state is consistent for the next step + """ + + def test_two_step_consistency(self, compiler, tokenizer_info, guided_manager): + """Two consecutive decode steps: grammar state from step 1 carries + over correctly to step 2.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + num_spec_tokens = 2 + num_expand = num_spec_tokens + 1 + + def _one_step(original_matcher): + """Simulate one decode step.""" + # --- Draft phase --- + draft_fork = original_matcher.fork() + draft_tokens = [] + for _ in range(num_spec_tokens): + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(draft_fork, bm, 0) + logits = torch.randn(1, vocab_size) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + guided_manager.accept_token(draft_fork, token) + draft_tokens.append(token) + _draft_token_ids = torch.tensor([draft_tokens], dtype=torch.long) + + # --- Target verification with position-serial mask --- + # (Simulate _guided_spec_logits_process) + target_fork = original_matcher.fork() + target_tokens_per_pos = [] + for pos in range(num_expand): + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(target_fork, bm, 0) + logits = torch.randn(1, vocab_size) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + guided_manager.accept_token(target_fork, token) + target_tokens_per_pos.append(token) + + # --- Greedy rejection sampling --- + # For simplicity, assume all draft tokens accepted + output_tokens = draft_tokens + [target_tokens_per_pos[-1]] + num_rejected = 0 + + # --- Accept on original --- + for tid in output_tokens: + guided_manager.accept_token(original_matcher, tid) + + return output_tokens, num_rejected + + # Step 1 + step1_tokens, _ = _one_step(original) + + # Step 2: original should be at correct state + step2_tokens, _ = _one_step(original) + + # Verify original has advanced by verifying it's different from initial + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm, 0) + allowed = _allowed_ids(bm) + assert len(allowed) > 0, 'Grammar should still be active' + + def test_rejection_restores_correct_state(self, compiler, tokenizer_info, guided_manager): + """After rejection, accepting the correct tokens on the original + matcher should bring it to the same state as if we had only generated + those tokens.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + num_spec_tokens = 3 + + # Generate draft tokens with grammar mask + draft_fork = original.fork() + draft_tokens = [] + for _ in range(num_spec_tokens): + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(draft_fork, bm, 0) + logits = torch.randn(1, vocab_size) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + guided_manager.accept_token(draft_fork, token) + draft_tokens.append(token) + + # Simulate partial rejection: only 1 draft token accepted + n_accepted = 1 + accepted_tokens = draft_tokens[:n_accepted] + + # Replacement token: find a grammar-valid token at the rejection point + ref_for_replacement = original.fork() + for tid in accepted_tokens: + guided_manager.accept_token(ref_for_replacement, tid) + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(ref_for_replacement, bm, 0) + allowed = _allowed_ids(bm) + replacement = list(allowed)[0] # pick any valid token + + # Accept on original: accepted draft + replacement + final_tokens = accepted_tokens + [replacement] + for tid in final_tokens: + guided_manager.accept_token(original, tid) + + # Build reference from scratch + reference = _json_matcher(compiler, schema) + for tid in final_tokens: + guided_manager.accept_token(reference, tid) + + # Compare states + bm_actual = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_actual, 0) + bm_ref = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(reference, bm_ref, 0) + + assert _allowed_ids(bm_actual) == _allowed_ids(bm_ref), ( + 'Grammar state after rejection must match the reference' + ) + + +# =========================================================================== +# 7. Batch-level grammar mask application +# =========================================================================== + + +class TestBatchLevelGrammarMask: + """Test that grammar masks are applied correctly at the batch level, + including mixed guided/unguided sequences.""" + + def test_mixed_batch_bitmap(self, compiler, tokenizer_info, guided_manager): + """A batch with some guided and some unguided sequences: only guided + ones should be affected by the bitmask.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + batch_size = 3 + + # Only element 0 is guided + bitmask = guided_manager.allocate_batched_bitmap(batch_size) + guided_manager.fill_bitmap(matcher, bitmask, 0) + # Elements 1 and 2 are not filled → their bitmask rows are all-ones + + logits = torch.zeros(batch_size, vocab_size) + # Give each element a strong signal on token 0 + logits[:, 0] = 100.0 + + guided_manager.apply_batched_bitmap(logits, bitmask) + + # Element 0: token 0 may or may not be valid + # Elements 1, 2: token 0 should survive (no mask applied → all-ones bitmask) + assert logits[1, 0] == 100.0, 'Unguided element should not be masked' + assert logits[2, 0] == 100.0, 'Unguided element should not be masked' + + +# =========================================================================== +# 8. Eagle3 proposer grammar masking integration +# =========================================================================== + + +def _build_eagle3_proposer(tokenizer_info, vocab_size): + """Build a minimal Eagle3 proposer for testing grammar mask integration. + + Creates an Eagle3 instance with real SpecDecodeConfig and patches draft_id_to_target_id to be an identity mapping + (same vocab). + """ + from lmdeploy.pytorch.config import SpecDecodeConfig + from lmdeploy.pytorch.spec_decode.proposers.eagle3 import Eagle3 + + spec_cfg = SpecDecodeConfig(model='dummy', method='eagle3', num_speculative_tokens=2) + proposer = Eagle3(spec_cfg, device='cpu') + # Identity mapping: draft vocab == target vocab + proposer.draft_id_to_target_id = torch.arange(vocab_size) + proposer.guided_helper = GuidedSpecHelper() # set by caller if needed + return proposer + + +class TestEagle3GrammarMask: + """Test that Eagle3.get_outputs() applies grammar masking correctly. + + These tests verify the grammar mask logic in the Eagle3 proposer's + get_outputs() method — the same pattern as DeepseekMTP but with + draft_id_to_target_id mapping. + + Key invariants: + - Grammar mask is applied BEFORE argmax (constrains draft tokens) + - accept_token uses the MAPPED (target-space) token ID + - draft_id_to_target_id is applied after argmax + - Without guided_processors, behavior is unchanged + """ + + def test_eagle3_applies_grammar_mask_before_argmax(self, compiler, tokenizer_info, + guided_manager): + """Eagle3.get_outputs with grammar mask: argmax picks a valid token.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + proposer = _build_eagle3_proposer(tokenizer_info, vocab_size) + proposer.guided_helper = GuidedSpecHelper(guided_manager) + + # Test grammar mask flow directly (same pattern as + # TestDraftModelGrammarMasking in unit tests). + logits = torch.randn(1, vocab_size) + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(matcher, bm, 0) + allowed_before = _allowed_ids(bm) + + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + assert token in allowed_before, 'Masked argmax must pick a grammar-valid token' + + # After accept_token, the matcher should advance + # (token is in target space since draft_id_to_target_id is identity) + guided_manager.accept_token(matcher, token) + + def test_eagle3_accepts_target_space_token(self, compiler, tokenizer_info, + guided_manager): + """accept_token on forked processor uses the target-space token ID.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + proposer = _build_eagle3_proposer(tokenizer_info, vocab_size) + proposer.guided_helper = GuidedSpecHelper(guided_manager) + + # Fork the matcher (simulates draft generation in _async_model_forward) + draft_fork = original.fork() + + # Generate a masked token + logits = torch.randn(1, vocab_size) + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(draft_fork, bm, 0) + allowed = _allowed_ids(bm) + guided_manager.apply_batched_bitmap(logits, bm) + draft_token = logits.argmax(dim=-1).item() + assert draft_token in allowed + + # Apply draft_id_to_target_id mapping (identity in this case) + target_token = proposer.draft_id_to_target_id[draft_token].item() + + # Accept the TARGET-space token on the forked processor + guided_manager.accept_token(draft_fork, target_token) + + # Verify fork advanced but original did not + bm_fork = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(draft_fork, bm_fork, 0) + bm_orig = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_orig, 0) + + # Original should still allow the same set of tokens + assert _allowed_ids(bm_orig) == allowed, 'Original should be unchanged' + # Fork may have different allowed set (it advanced by one token) + # Not necessarily different (BPE tokenization), but should not error + + def test_eagle3_without_guided_processors_unchanged(self, compiler, tokenizer_info, + guided_manager): + """Without guided_processors, Eagle3.get_outputs behaves normally.""" + vocab_size = tokenizer_info.vocab_size + proposer = _build_eagle3_proposer(tokenizer_info, vocab_size) + # No guided_helper set → should not crash + + # Simulate the argmax path without grammar mask + logits = torch.randn(1, vocab_size) + token = logits.argmax(dim=-1).item() + mapped = proposer.draft_id_to_target_id[token].item() + assert mapped == token, 'Identity mapping should preserve token' + + def test_eagle3_draft_chain_with_grammar_mask(self, compiler, tokenizer_info, + guided_manager): + """Multi-step draft generation with grammar mask at each step. + + Simulates the loop in _async_model_forward: each step forks, + masks, argmax, accept_token, then uses the fork for the next step. + """ + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + proposer = _build_eagle3_proposer(tokenizer_info, vocab_size) + proposer.guided_helper = GuidedSpecHelper(guided_manager) + + num_spec_tokens = 3 + draft_fork = original.fork() + draft_tokens = [] + + for _ in range(num_spec_tokens): + logits = torch.randn(1, vocab_size) + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(draft_fork, bm, 0) + allowed = _allowed_ids(bm) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + assert token in allowed, 'Each draft token must be grammar-valid' + + # Map to target space and accept + target_token = proposer.draft_id_to_target_id[token].item() + guided_manager.accept_token(draft_fork, target_token) + draft_tokens.append(target_token) + + # Original matcher should be unchanged + bm_orig = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_orig, 0) + allowed_orig = _allowed_ids(bm_orig) + + # Build reference: accept all draft tokens on a fresh fork + ref = original.fork() + for tid in draft_tokens: + guided_manager.accept_token(ref, tid) + bm_ref = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(ref, bm_ref, 0) + allowed_ref = _allowed_ids(bm_ref) + + assert allowed_orig != allowed_ref or len(draft_tokens) == 0, ( + 'Original should not have advanced; reference should have' + ) + + +class TestDeepseekMTPGrammarMask: + """Test DeepseekMTP.get_outputs() grammar masking. + + Verifies that the existing DeepseekMTP proposer correctly applies grammar mask and accepts tokens on forked + processors. + """ + + def test_mtp_applies_grammar_mask(self, compiler, tokenizer_info, guided_manager): + """DeepseekMTP with grammar mask: argmax picks a valid token.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + logits = torch.randn(1, vocab_size) + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(matcher, bm, 0) + allowed = _allowed_ids(bm) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + assert token in allowed, 'Masked argmax must pick a grammar-valid token' + + # accept_token on the matcher + guided_manager.accept_token(matcher, token) + + def test_mtp_accepts_on_forked_processor(self, compiler, tokenizer_info, guided_manager): + """accept_token on a forked processor does not affect original.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + fork = original.fork() + logits = torch.randn(1, vocab_size) + bm = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(fork, bm, 0) + allowed = _allowed_ids(bm) + guided_manager.apply_batched_bitmap(logits, bm) + token = logits.argmax(dim=-1).item() + guided_manager.accept_token(fork, token) + + # Original unchanged + bm_orig = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_orig, 0) + assert _allowed_ids(bm_orig) == allowed + + +# =========================================================================== +# Bitmask translation tests (target vocab ≠ draft vocab) +# =========================================================================== + + +class TestBitmaskTranslation: + """Test Eagle3._translate_bitmask with non-identity d2t mapping. + + This is the core fix for the vocab-mismatch bug: when draft_vocab != + target_vocab, the target-space bitmask must be translated into a + draft-space bitmask before applying it to draft logits. + + These tests create an Eagle3 proposer with a random (non-identity) + draft_id_to_target_id mapping and verify that _translate_bitmask + produces a correct draft-space bitmask. + """ + + @staticmethod + def _build_eagle3_with_d2t(d2t: torch.Tensor): + """Build an Eagle3 proposer with a custom d2t mapping and pre-computed + bitmask-translation constants.""" + from lmdeploy.pytorch.config import SpecDecodeConfig + from lmdeploy.pytorch.spec_decode.proposers.eagle3 import Eagle3 + + spec_cfg = SpecDecodeConfig(model='dummy', method='eagle3', num_speculative_tokens=2) + proposer = Eagle3(spec_cfg, device='cpu') + proposer.draft_id_to_target_id = d2t + proposer._init_bitmask_translate_constants() + return proposer + + def test_translate_all_valid_identity(self): + """Identity d2t with all-valid target bitmask → all-valid draft + bitmask.""" + draft_vocab = 1024 + d2t = torch.arange(draft_vocab) + proposer = self._build_eagle3_with_d2t(d2t) + + target_vocab = 2048 + target_n_words = (target_vocab + 31) // 32 + target_bitmask = torch.zeros(1, target_n_words, dtype=torch.int32).bitwise_not() + + draft_bitmask = proposer._translate_bitmask(target_bitmask) + assert draft_bitmask.dtype == torch.int32 + assert draft_bitmask.shape[0] == 1 + assert draft_bitmask.shape[1] == (draft_vocab + 31) // 32 + + logits = torch.zeros(1, draft_vocab) + import xgrammar as xgr + xgr.apply_token_bitmask_inplace(logits, draft_bitmask) + assert (logits > -1e10).all().item(), 'All draft tokens should be valid' + + def test_translate_none_valid(self): + """All-zero target bitmask → no draft tokens valid.""" + draft_vocab = 512 + target_vocab = 2048 + d2t = torch.randint(0, target_vocab, (draft_vocab,)) + proposer = self._build_eagle3_with_d2t(d2t) + + target_n_words = (target_vocab + 31) // 32 + target_bitmask = torch.zeros(2, target_n_words, dtype=torch.int32) + + draft_bitmask = proposer._translate_bitmask(target_bitmask) + assert draft_bitmask.shape[0] == 2 + + logits = torch.zeros(2, draft_vocab) + import xgrammar as xgr + xgr.apply_token_bitmask_inplace(logits, draft_bitmask) + assert (logits > -1e10).sum().item() == 0, 'No draft tokens should be valid' + + def test_translate_sparse_target_ids(self): + """Specific allowed target IDs → only draft tokens mapping to those IDs + are valid.""" + torch.manual_seed(42) + draft_vocab = 256 + target_vocab = 1024 + d2t = torch.randint(0, target_vocab, (draft_vocab,)) + proposer = self._build_eagle3_with_d2t(d2t) + + # Only allow target tokens 10, 50, 100 + allowed_target_ids = {10, 50, 100} + target_n_words = (target_vocab + 31) // 32 + target_bitmask = torch.zeros(1, target_n_words, dtype=torch.int32) + for tid in allowed_target_ids: + word, bit = tid // 32, tid % 32 + target_bitmask[0, word] |= (1 << bit) + + draft_bitmask = proposer._translate_bitmask(target_bitmask) + + logits = torch.zeros(1, draft_vocab) + import xgrammar as xgr + xgr.apply_token_bitmask_inplace(logits, draft_bitmask) + valid_draft = (logits > -1e10).squeeze() + + # Compute expected valid draft tokens + expected_valid = torch.zeros(draft_vocab, dtype=torch.bool) + for i in range(draft_vocab): + if d2t[i].item() in allowed_target_ids: + expected_valid[i] = True + + assert torch.equal(valid_draft, expected_valid) + + def test_translate_batch_independent(self): + """Different batches have different allowed sets → independent + results.""" + torch.manual_seed(99) + draft_vocab = 256 + target_vocab = 1024 + d2t = torch.randint(0, target_vocab, (draft_vocab,)) + proposer = self._build_eagle3_with_d2t(d2t) + + target_n_words = (target_vocab + 31) // 32 + # Batch 0: allow target token 10; Batch 1: allow target token 50 + target_bitmask = torch.zeros(2, target_n_words, dtype=torch.int32) + target_bitmask[0, 10 // 32] |= (1 << (10 % 32)) + target_bitmask[1, 50 // 32] |= (1 << (50 % 32)) + + draft_bitmask = proposer._translate_bitmask(target_bitmask) + + for b in range(2): + logits = torch.zeros(1, draft_vocab) + import xgrammar as xgr + xgr.apply_token_bitmask_inplace(logits, draft_bitmask[b:b + 1]) + valid = (logits > -1e10).squeeze() + + allowed_tid = [10, 50][b] + expected = torch.tensor([d2t[i].item() == allowed_tid for i in range(draft_vocab)]) + assert torch.equal(valid, expected) + + def test_translate_produces_int32_bitmask(self): + """Output is int32 bitmask with correct shape for + apply_batched_bitmap.""" + draft_vocab = 32768 + target_vocab = 128256 + d2t = torch.randint(0, target_vocab, (draft_vocab,)) + proposer = self._build_eagle3_with_d2t(d2t) + + target_n_words = (target_vocab + 31) // 32 + target_bitmask = torch.zeros(1, target_n_words, dtype=torch.int32) + target_bitmask[0, 0] = 0xFF # tokens 0-7 + + draft_bitmask = proposer._translate_bitmask(target_bitmask) + assert draft_bitmask.dtype == torch.int32 + n_draft_words = (draft_vocab + 31) // 32 + assert draft_bitmask.shape == (1, n_draft_words) + + # Can be applied without error + logits = torch.zeros(1, draft_vocab) + import xgrammar as xgr + xgr.apply_token_bitmask_inplace(logits, draft_bitmask) + + def test_translate_matches_bool_reference(self): + """_translate_bitmask result matches the bool-masked_fill reference.""" + torch.manual_seed(7) + draft_vocab = 512 + target_vocab = 2048 + d2t = torch.randint(0, target_vocab, (draft_vocab,)) + proposer = self._build_eagle3_with_d2t(d2t) + + target_n_words = (target_vocab + 31) // 32 + # Random target bitmask + target_bitmask = torch.randint(0, 2**31, (3, target_n_words), dtype=torch.int32) + + draft_bitmask = proposer._translate_bitmask(target_bitmask) + + # Reference: extract bool mask, then masked_fill + d2t_words = d2t // 32 + d2t_bits = d2t % 32 + word_vals = target_bitmask[:, d2t_words] + draft_valid = ((word_vals >> d2t_bits.unsqueeze(0)) & 1).bool() + + for b in range(3): + logits_translate = torch.randn(draft_vocab) + logits_reference = logits_translate.clone() + + import xgrammar as xgr + xgr.apply_token_bitmask_inplace(logits_translate.unsqueeze(0), + draft_bitmask[b:b + 1]) + logits_reference.masked_fill_(~draft_valid[b], float('-inf')) + + # Same set of valid positions + valid_t = (logits_translate > -1e10).squeeze() + valid_r = (logits_reference > -1e10) + assert torch.equal(valid_t, valid_r), f'Mismatch at batch {b}' + + +# =========================================================================== +# Eagle3 get_outputs() integration with non-identity d2t + grammar mask +# =========================================================================== + + +class _MinimalDraftModel(torch.nn.Module): + """Minimal draft model with just an lm_head for testing get_outputs().""" + + def __init__(self, hidden_size: int, draft_vocab_size: int): + super().__init__() + self.lm_head = torch.nn.Linear(hidden_size, draft_vocab_size, bias=False) + torch.nn.init.normal_(self.lm_head.weight) + + def get_logits(self, hidden_states: torch.Tensor): + return self.lm_head(hidden_states) + + +def _build_eagle3_with_model(d2t: torch.Tensor, hidden_size: int = 64): + """Build an Eagle3 proposer with a real working draft model. + + Creates an Eagle3 instance whose self.model has a functional lm_head, so get_outputs() can compute real logits from + hidden_states. + """ + from lmdeploy.pytorch.config import SpecDecodeConfig + from lmdeploy.pytorch.spec_decode.proposers.eagle3 import Eagle3 + + draft_vocab_size = d2t.size(0) + spec_cfg = SpecDecodeConfig(model='dummy', method='eagle3', num_speculative_tokens=2) + proposer = Eagle3(spec_cfg, device='cpu') + proposer.draft_id_to_target_id = d2t + proposer._init_bitmask_translate_constants() + proposer.model = _MinimalDraftModel(hidden_size, draft_vocab_size) + return proposer + + +class TestEagle3GetOutputs: + """Test Eagle3.get_outputs() end-to-end with grammar masking. + + These tests call get_outputs() directly — the REAL code path — rather + than simulating the pattern. A minimal draft model with a real lm_head + provides actual logits from hidden_states, so we exercise: + + allocate_bitmask → fill → _translate_bitmask → apply_batched_bitmap + → argmax → draft_id_to_target_id → accept_token + + Key invariant verified: accept_token receives a TARGET-space token ID + (after d2t mapping), and the chosen draft token maps to a grammar-valid + target token. + """ + + def test_get_outputs_grammar_mask_non_identity_d2t(self, compiler, tokenizer_info, + guided_manager): + """get_outputs with non-identity d2t: chosen token is grammar-valid + after mapping to target space.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + matcher = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + # Snapshot the allowed set BEFORE get_outputs advances the matcher. + bm_before = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(matcher, bm_before, 0) + allowed_before = _allowed_ids(bm_before) + + # Non-identity d2t: draft vocab smaller than target vocab. + # Ensure at least one draft token maps to a grammar-valid target token, + # otherwise the grammar mask becomes vacuous (all logits → -inf). + draft_vocab = 512 + torch.manual_seed(123) + d2t = torch.randint(0, vocab_size, (draft_vocab,)) + if allowed_before: + # Force draft token 0 to map to a known-allowed target token + d2t[0] = min(allowed_before) + proposer = _build_eagle3_with_model(d2t, hidden_size=64) + proposer.guided_helper = GuidedSpecHelper(guided_manager) + + hidden_size = 64 + hidden_states = torch.randn(1, 1, hidden_size) + model_outputs = { + 'hidden_states': hidden_states, + 'hidden_states_prenorm': hidden_states, + 'model_metas': [None], + } + # Minimal ModelInputs (not used by get_outputs for this path) + model_inputs = type('M', (), {'is_decoding': True, 'seq_length': torch.tensor([1])})() + + draft_token_ids, _, _ = asyncio.run(proposer.get_outputs( + model_outputs, model_inputs, guided_processors={0: matcher})) + + target_token = draft_token_ids[0, 0].item() + assert 0 <= target_token < vocab_size + + # The selected token must have been grammar-valid at the time of selection. + assert target_token in allowed_before, ( + f'Target token {target_token} not in grammar-allowed set ' + f'{sorted(allowed_before)[:20]}...') + + def test_get_outputs_without_guided_processors(self, compiler, tokenizer_info, + guided_manager): + """get_outputs without guided_processors: normal argmax + d2t.""" + vocab_size = tokenizer_info.vocab_size + draft_vocab = 512 + torch.manual_seed(456) + d2t = torch.randint(0, vocab_size, (draft_vocab,)) + proposer = _build_eagle3_with_model(d2t, hidden_size=64) + + hidden_size = 64 + hidden_states = torch.randn(1, 1, hidden_size) + model_outputs = { + 'hidden_states': hidden_states, + 'hidden_states_prenorm': hidden_states, + 'model_metas': [None], + } + model_inputs = type('M', (), {'is_decoding': True, 'seq_length': torch.tensor([1])})() + + draft_token_ids, _, _ = asyncio.run(proposer.get_outputs(model_outputs, model_inputs)) + target_token = draft_token_ids[0, 0].item() + assert 0 <= target_token < vocab_size + + def test_get_outputs_accept_token_advances_fork(self, compiler, tokenizer_info, + guided_manager): + """accept_token in get_outputs advances the forked matcher, not the + original.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + draft_vocab = 512 + torch.manual_seed(789) + d2t = torch.randint(0, vocab_size, (draft_vocab,)) + proposer = _build_eagle3_with_model(d2t, hidden_size=64) + proposer.guided_helper = GuidedSpecHelper(guided_manager) + + # Fork the matcher (same as _async_model_forward) + draft_fork = original.fork() + + hidden_size = 64 + hidden_states = torch.randn(1, 1, hidden_size) + model_outputs = { + 'hidden_states': hidden_states, + 'hidden_states_prenorm': hidden_states, + 'model_metas': [None], + } + model_inputs = type('M', (), {'is_decoding': True, 'seq_length': torch.tensor([1])})() + + draft_token_ids, _, _ = asyncio.run(proposer.get_outputs( + model_outputs, model_inputs, guided_processors={0: draft_fork})) + + # Original matcher should be unchanged, while the fork should reflect + # acceptance of the emitted token. Since immediate allowed-id sets may + # legitimately coincide for some JSON states, replay the same token on + # an independent fork and compare the resulting state to `draft_fork`. + accepted_token = int(draft_token_ids.reshape(-1)[0].item()) + replay_fork = original.fork() + replay_fork.accept_token(accepted_token) + bm_orig = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_orig, 0) + bm_fork = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(draft_fork, bm_fork, 0) + bm_replay = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(replay_fork, bm_replay, 0) + assert _allowed_ids(bm_fork) == _allowed_ids(bm_replay) + # Both original and advanced fork must remain usable for subsequent + # mask production even if their immediate masks happen to coincide. + + def test_get_outputs_multi_step_with_fork(self, compiler, tokenizer_info, + guided_manager): + """Multi-step draft loop calling get_outputs repeatedly with the same + fork — same pattern as _async_model_forward.""" + schema = {'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']} + original = _json_matcher(compiler, schema) + vocab_size = tokenizer_info.vocab_size + + draft_vocab = 512 + torch.manual_seed(101) + d2t = torch.randint(0, vocab_size, (draft_vocab,)) + proposer = _build_eagle3_with_model(d2t, hidden_size=64) + proposer.guided_helper = GuidedSpecHelper(guided_manager) + + draft_fork = original.fork() + num_steps = 3 + target_tokens = [] + + for _ in range(num_steps): + hidden_states = torch.randn(1, 1, 64) + model_outputs = { + 'hidden_states': hidden_states, + 'hidden_states_prenorm': hidden_states, + 'model_metas': [None], + } + model_inputs = type('M', (), {'is_decoding': True, 'seq_length': torch.tensor([1])})() + draft_token_ids, _, _ = asyncio.run(proposer.get_outputs( + model_outputs, model_inputs, guided_processors={0: draft_fork})) + target_tokens.append(draft_token_ids[0, 0].item()) + + # Verify original matcher is still at initial state + bm_orig = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_orig, 0) + allowed_orig = _allowed_ids(bm_orig) + + # Verify: accepting the same tokens on a fresh fork produces the + # same final state as the draft_fork + ref_fork = original.fork() + for tid in target_tokens: + guided_manager.accept_token(ref_fork, tid) + bm_ref = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(ref_fork, bm_ref, 0) + + # Original should NOT have advanced + bm_orig2 = guided_manager.allocate_batched_bitmap(1) + guided_manager.fill_bitmap(original, bm_orig2, 0) + assert _allowed_ids(bm_orig2) == allowed_orig diff --git a/tests/pytorch/spec_decode/test_spec_agent.py b/tests/pytorch/spec_decode/test_spec_agent.py index d121bed2a3..85bc1faae6 100644 --- a/tests/pytorch/spec_decode/test_spec_agent.py +++ b/tests/pytorch/spec_decode/test_spec_agent.py @@ -2,6 +2,7 @@ import torch +from lmdeploy.pytorch.spec_decode.guided_spec_helper import GuidedSpecHelper from lmdeploy.pytorch.spec_decode.spec_agent import _expand_sampling_inputs device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -60,7 +61,7 @@ def __init__(self): self.update_inputs_decoding_calls = 0 self.model = _DummyDraftModel() - def get_outputs(self, outputs, inputs, extra_inputs=None): + async def get_outputs(self, outputs, inputs, extra_inputs=None, guided_processors=None): batch_size = inputs.seq_length.size(0) draft_token_ids = inputs.input_ids.new_full((batch_size, 1), self.get_outputs_calls) self.get_outputs_calls += 1 @@ -208,6 +209,7 @@ def test_async_model_forward_dp1_non_last_chunk_skips_remaining_spec_forwards(): agent.num_spec_tokens = 3 agent.rank = 0 agent.proposer = _DummyProposer() + agent.guided_helper = GuidedSpecHelper() forward_calls = 0 def _forward_impl(_inputs): @@ -239,6 +241,7 @@ def test_async_model_forward_dp_non_last_chunk_runs_all_spec_forwards(monkeypatc agent.num_spec_tokens = 3 agent.rank = 0 agent.proposer = _DummyProposer() + agent.guided_helper = GuidedSpecHelper() forward_calls = 0 def _forward_impl(_inputs): diff --git a/tests/test_lmdeploy/test_mtp_guided_decoding.py b/tests/test_lmdeploy/test_mtp_guided_decoding.py new file mode 100644 index 0000000000..9a60b3e937 --- /dev/null +++ b/tests/test_lmdeploy/test_mtp_guided_decoding.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Integration tests for MTP (speculative decoding) + Guided Decoding. + +Plan section 6.2 — Integration Tests (require GPU): + +1. JSON Schema + Spec Decode + - pipeline with speculative_config=SpeculativeConfig(method='qwen3_5_mtp') + - response_format=json_schema → output must conform to the schema + +2. Regex + Spec Decode + - response_format=regex_schema → output must match the regex + +3. JSON Object + Spec Decode + - response_format=json_object → output must be a valid JSON object + +4. Mixed Batch + - Some sequences with guided decoding, some without + - Both paths produce correct results + +5. Spec Decode without Guided Decoding + - Baseline: spec decode works when no grammar is applied + +6. Streaming + Guided Spec Decode + - Streaming inference still produces grammar-conformant output + +NOTE: These tests require a GPU and will be skipped if CUDA is unavailable. +The model used is Qwen/Qwen3.5-0.8B (smallest Qwen3.5 with MTP support). +Qwen3.5 is a VLM that supports text-only inference via the PyTorch backend. +The spec method is 'qwen3_5_mtp'. +""" +import json +import re + +import pytest +import torch +from jsonschema import validate + +from lmdeploy import pipeline +from lmdeploy.messages import ( + GenerationConfig, + PytorchEngineConfig, + SpeculativeConfig, +) +from lmdeploy.pytorch.backends.cuda.attention import use_fa3 + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +# Use the smallest available Qwen3.5 MTP model. +MTP_MODEL_ID = 'Qwen/Qwen3.5-0.8B' + +SCHEMA_MAP = { + 'json_schema': { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'skills': { + 'type': 'array', + 'items': {'type': 'string', 'maxLength': 10}, + 'minItems': 1, + 'maxItems': 5, + }, + }, + 'required': ['name', 'skills'], + }, + 'regex_schema': 'call me [A-Za-z]{1,10}', + 'json_object': None, +} + +PROMPT = 'Make a self introduction please.' + + +def _make_spec_config(): + """Create a SpeculativeConfig for qwen3_5_mtp.""" + return SpeculativeConfig(method='qwen3_5_mtp', num_speculative_tokens=1) + + +def _make_engine_config(): + """PytorchEngineConfig suitable for MTP + guided decoding tests.""" + return PytorchEngineConfig( + max_batch_size=2, + session_len=1024, + cache_max_entry_count=0.1, + ) + + +# Skip entire module if no GPU or FA3 is unavailable. +# Speculative decoding (ar_spec) requires FlashAttention-3 for the decode +# kernel; without it the engine will fail at runtime. +pytestmark = pytest.mark.skipif( + not (torch.cuda.is_available() and use_fa3), + reason='GPU + FlashAttention-3 required for MTP + guided decoding integration tests', +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope='module') +def pipe(): + """Shared pipeline instance for all tests in this module.""" + p = pipeline( + MTP_MODEL_ID, + backend_config=_make_engine_config(), + speculative_config=_make_spec_config(), + log_level='INFO', + ) + yield p + p.close() + + +# --------------------------------------------------------------------------- +# 1. JSON Schema + Spec Decode +# --------------------------------------------------------------------------- + + +class TestJSONSchemaSpecDecode: + """response_format=json_schema with MTP speculative decoding.""" + + def test_json_schema_conformance(self, pipe): + schema = SCHEMA_MAP['json_schema'] + response_format = { + 'type': 'json_schema', + 'json_schema': {'name': 'test', 'schema': schema}, + } + gen_config = GenerationConfig( + response_format=response_format, + max_new_tokens=200, + ) + response = pipe(PROMPT, gen_config=gen_config) + assert response.text, 'Response should not be empty' + + data = json.loads(response.text) + validate(instance=data, schema=schema) + + def test_json_schema_batch(self, pipe): + """Batch of identical prompts all produce schema-conformant output.""" + schema = SCHEMA_MAP['json_schema'] + response_format = { + 'type': 'json_schema', + 'json_schema': {'name': 'test', 'schema': schema}, + } + gen_config = GenerationConfig( + response_format=response_format, + max_new_tokens=200, + ) + # send one at a time to avoid queue issues + responses = pipe([PROMPT], gen_config=gen_config) + for resp in responses: + data = json.loads(resp.text) + validate(instance=data, schema=schema) + + +# --------------------------------------------------------------------------- +# 2. Regex + Spec Decode +# --------------------------------------------------------------------------- + + +class TestRegexSpecDecode: + """response_format=regex_schema with MTP speculative decoding.""" + + def test_regex_conformance(self, pipe): + pattern = SCHEMA_MAP['regex_schema'] + response_format = { + 'type': 'regex_schema', + 'regex_schema': pattern, + } + gen_config = GenerationConfig( + response_format=response_format, + max_new_tokens=50, + ) + response = pipe(PROMPT, gen_config=gen_config) + assert response.text, 'Response should not be empty' + assert re.fullmatch(pattern, response.text), ( + f"Output '{response.text}' does not match regex '{pattern}'" + ) + + +# --------------------------------------------------------------------------- +# 3. JSON Object + Spec Decode +# --------------------------------------------------------------------------- + + +class TestJSONObjectSpecDecode: + """response_format=json_object with MTP speculative decoding.""" + + def test_json_object_conformance(self, pipe): + response_format = {'type': 'json_object'} + gen_config = GenerationConfig( + response_format=response_format, + max_new_tokens=512, + ) + # Use a structured prompt to guide the model toward a short, complete JSON + json_prompt = 'Return a JSON object with exactly two keys: name (string) and age (integer).' + response = pipe(json_prompt, gen_config=gen_config) + assert response.text, 'Response should not be empty' + + data = json.loads(response.text) + assert isinstance(data, dict), 'json_object must produce a JSON object (dict)' + + +# --------------------------------------------------------------------------- +# 4. Mixed Batch — guided + unguided +# --------------------------------------------------------------------------- + + +class TestMixedBatchSpecDecode: + """Some sequences with guided decoding, some without.""" + + def test_mixed_json_and_free(self, pipe): + schema = SCHEMA_MAP['json_schema'] + response_format = { + 'type': 'json_schema', + 'json_schema': {'name': 'test', 'schema': schema}, + } + guided_config = GenerationConfig( + response_format=response_format, + max_new_tokens=200, + ) + free_config = GenerationConfig(max_new_tokens=50) + + # test one guided + one unguided sequentially + guided_resp = pipe(PROMPT, gen_config=guided_config) + free_resp = pipe('Tell me a short joke.', gen_config=free_config) + + # Guided must conform + data = json.loads(guided_resp.text) + validate(instance=data, schema=schema) + + # Free must produce text (no grammar constraint) + assert free_resp.text, 'Free generation should produce text' + + +# --------------------------------------------------------------------------- +# 5. Spec Decode without Guided Decoding (baseline) +# --------------------------------------------------------------------------- + + +class TestSpecDecodeNoGuided: + """MTP speculative decoding without guided decoding — baseline sanity.""" + + def test_free_generation(self, pipe): + gen_config = GenerationConfig(max_new_tokens=50) + response = pipe(PROMPT, gen_config=gen_config) + assert response.text, 'Free generation should produce text' + assert response.generate_token_len > 0 + + +# --------------------------------------------------------------------------- +# 6. Streaming + Guided Spec Decode +# --------------------------------------------------------------------------- + + +class TestStreamingGuidedSpecDecode: + """Streaming inference with guided decoding + MTP speculative decoding.""" + + def test_streaming_json_schema(self, pipe): + schema = SCHEMA_MAP['json_schema'] + response_format = { + 'type': 'json_schema', + 'json_schema': {'name': 'test', 'schema': schema}, + } + gen_config = GenerationConfig( + response_format=response_format, + max_new_tokens=200, + ) + chunks = [] + for chunk in pipe.stream_infer(PROMPT, gen_config=gen_config): + chunks.append(chunk.text) + + full_text = ''.join(chunks) + assert full_text, 'Streaming should produce text' + + data = json.loads(full_text) + validate(instance=data, schema=schema) + + def test_streaming_regex(self, pipe): + pattern = SCHEMA_MAP['regex_schema'] + response_format = { + 'type': 'regex_schema', + 'regex_schema': pattern, + } + gen_config = GenerationConfig( + response_format=response_format, + max_new_tokens=50, + ) + chunks = [] + for chunk in pipe.stream_infer(PROMPT, gen_config=gen_config): + chunks.append(chunk.text) + + full_text = ''.join(chunks) + assert full_text, 'Streaming should produce text' + assert re.fullmatch(pattern, full_text), ( + f"Streaming output '{full_text}' does not match regex '{pattern}'" + )