From f959a2f0f2a4f73ef8cae89050f8d465b57e2c21 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 23 Mar 2026 07:56:52 +0000 Subject: [PATCH 1/6] for turbomind --- lmdeploy/messages.py | 169 ++++++++++++++++++---------- lmdeploy/pytorch/messages.py | 8 +- lmdeploy/serve/core/async_engine.py | 43 ++++--- lmdeploy/tokenizer.py | 10 +- lmdeploy/turbomind/turbomind.py | 22 ++-- lmdeploy/utils.py | 35 +++--- 6 files changed, 186 insertions(+), 101 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index b029d98c26..4dba517209 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -2,7 +2,7 @@ import enum import time from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Literal, Optional +from typing import Any, Callable, Literal import torch from pydantic.dataclasses import dataclass as pydantic_dataclass @@ -109,16 +109,16 @@ class GenerationConfig: repetition_penalty: float = 1.0 ignore_eos: bool = False random_seed: int = None - stop_words: List[str] = None - bad_words: List[str] = None - stop_token_ids: List[int] = None - bad_token_ids: List[int] = None + stop_words: list[str] = None + bad_words: list[str] = None + stop_token_ids: list[int] | list[list[int]] = None + bad_token_ids: list[int] = None min_new_tokens: int = None skip_special_tokens: bool = True spaces_between_special_tokens: bool = True logprobs: int = None - response_format: Optional[Dict] = None - logits_processors: Optional[List[LogitsProcessor]] = None + response_format: dict | None = None + logits_processors: list[LogitsProcessor] | None = None output_logits: Literal['all', 'generation'] = None output_last_hidden_state: Literal['all', 'generation'] = None include_stop_str_in_output: bool = False @@ -126,7 +126,7 @@ class GenerationConfig: # for disaggregation with_cache: bool = False preserve_cache: bool = False - migration_request: Optional[MigrationRequest] = None + migration_request: MigrationRequest | None = None # router replay return_routed_experts: bool = False @@ -135,46 +135,99 @@ class GenerationConfig: repetition_ngram_size: int = 0 repetition_ngram_threshold: int = 0 + @staticmethod + def _normalize_stop_token_ids(ids: list[int] | list[list[int]] | None) -> list[list[int]]: + """Normalize stop_token_ids to list[list[int]].""" + if ids is None: + return [] + out: list[list[int]] = [] + for item in ids: + if isinstance(item, int): + out.append([item]) + else: + out.append(list(item)) + return out + def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): - """Convert stop_words/bad_sords to ids and append the ids to + """Convert stop_words/bad_words to ids and append the ids to stop_token_ids/bad_token_ids.""" - def special_word_token_ids(words): - if words is not None: - assert isinstance(words, List) and \ - all(isinstance(elem, str) for elem in words), \ - f'stop_words must be a list of str but got {type(words)}' - indexes = [] - for word in words: - indexes += tokenizer.indexes_containing_token(word) - return indexes - return None - - stop_token_ids = special_word_token_ids(self.stop_words) or [] - bad_token_ids = special_word_token_ids(self.bad_words) or [] - stop_token_ids.extend(self.stop_token_ids or []) - bad_token_ids.extend(self.bad_token_ids or []) - self.stop_token_ids = list(set(stop_token_ids)) or None - self.bad_token_ids = list(set(bad_token_ids)) or None + def words_to_token_seqs(words: list[str], prefer_exact: bool = False) -> list[list[int]]: + assert isinstance(words, list) and \ + all(isinstance(elem, str) for elem in words), \ + f'stop_words must be a list of str but got {type(words)}' + seqs: list[list[int]] = [] + for word in words: + # For stop_words, prefer exact tokenization so multi-word phrases + # are represented as an exact token-id sequence. + if prefer_exact: + encoded = tokenizer.encode(word, add_bos=False) + if encoded: + seqs.append(encoded) + continue + + single_matches = tokenizer.indexes_containing_token(word) + if single_matches: + for idx in single_matches: + seqs.append([idx]) + else: + encoded = tokenizer.encode(word, add_bos=False) + if encoded: + seqs.append(encoded) + return seqs + + stop_seqs = words_to_token_seqs(self.stop_words, prefer_exact=True) if self.stop_words else [] + bad_seqs = words_to_token_seqs(self.bad_words) if self.bad_words else [] + + stop_seqs.extend(self._normalize_stop_token_ids(self.stop_token_ids)) + bad_seqs.extend([[i] for i in (self.bad_token_ids or [])]) + + # deduplicate stop_token_ids and bad_token_ids + seen = set() + deduped: list[list[int]] = [] + for seq in stop_seqs: + key = tuple(seq) + if key not in seen: + seen.add(key) + deduped.append(seq) + self.stop_token_ids = deduped or None + + seen_bad = set() + deduped_bad: list[int] = [] + for seq in bad_seqs: + if len(seq) > 1: + logger.warning(f'Multi-token bad word {seq} is not supported and ' + 'will be ignored. Only single-token bad words can be ' + 'masked in logits processing.') + continue + if seq[0] not in seen_bad: + seen_bad.add(seq[0]) + deduped_bad.append(seq[0]) + self.bad_token_ids = deduped_bad or None def update_from_hf_gen_cfg(self, generation_config, tokenizer_eos_token_id): """Update the stop_token_ids.""" - stop_token_ids = set(self.stop_token_ids or []) + stop_seqs = self._normalize_stop_token_ids(self.stop_token_ids) + existing = {tuple(s) for s in stop_seqs} + + def _add_single(tok_id: int): + key = (tok_id, ) + if key not in existing: + existing.add(key) + stop_seqs.append([tok_id]) - # add tokenizer's eos_token_id if tokenizer_eos_token_id is not None: - stop_token_ids.add(tokenizer_eos_token_id) + _add_single(tokenizer_eos_token_id) - # add eos_token_id from model's generation_config.json file if there - # is any. eos_token_id = generation_config.get('eos_token_id') if eos_token_id is not None: if isinstance(eos_token_id, int): - stop_token_ids.add(eos_token_id) + _add_single(eos_token_id) else: - stop_token_ids.update(eos_token_id) + for eid in eos_token_id: + _add_single(eid) - self.stop_token_ids = list(stop_token_ids) + self.stop_token_ids = stop_seqs def __post_init__(self): """Check input validation.""" @@ -184,6 +237,8 @@ def __post_init__(self): assert self.temperature >= 0 and self.temperature <= 2 # [0,2] assert 0 <= self.min_p <= 1, \ f'min_p should be in range [0, 1], but found {self.min_p}' + if self.stop_token_ids is not None: + self.stop_token_ids = self._normalize_stop_token_ids(self.stop_token_ids) @pydantic_dataclass @@ -251,7 +306,7 @@ class TurbomindEngineConfig: """ dtype: str = 'auto' - model_format: Optional[str] = None + model_format: str | None = None tp: int = 1 dp: int = 1 cp: int = 1 @@ -264,9 +319,9 @@ class TurbomindEngineConfig: outer_dp_size: int = None nnodes: int = 1 node_rank: int = 0 - dist_init_addr: Optional[str] = None - devices: List[int] = None - session_len: Optional[int] = None + dist_init_addr: str | None = None + devices: list[int] = None + session_len: int | None = None max_batch_size: int = None cache_max_entry_count: float = 0.8 cache_chunk_size: int = -1 @@ -275,16 +330,16 @@ class TurbomindEngineConfig: quant_policy: int = 0 rope_scaling_factor: float = 0.0 use_logn_attn: bool = False - download_dir: Optional[str] = None - revision: Optional[str] = None + download_dir: str | None = None + revision: str | None = None max_prefill_token_num: int = 8192 num_tokens_per_iter: int = 0 max_prefill_iters: int = 1 async_: int = 1 - devices: Optional[List[int]] = None + devices: list[int] | None = None empty_init: bool = False communicator: str = 'nccl' - hf_overrides: Optional[Dict[str, Any]] = None + hf_overrides: dict[str, Any] | None = None enable_metrics: bool = True def __post_init__(self): @@ -388,13 +443,13 @@ class PytorchEngineConfig: block_size: int = 64 num_cpu_blocks: int = 0 num_gpu_blocks: int = 0 - adapters: Dict[str, str] = None + adapters: dict[str, str] = None max_prefill_token_num: int = 4096 thread_safe: bool = False enable_prefix_caching: bool = False device_type: str = 'cuda' eager_mode: bool = False - custom_module_map: Dict[str, str] = None + custom_module_map: dict[str, str] = None download_dir: str = None revision: str = None quant_policy: Literal[0, 4, 8] = 0 @@ -406,7 +461,7 @@ class PytorchEngineConfig: mp_engine_backend: str = 'mp' model_format: str = None enable_metrics: bool = True - hf_overrides: Optional[Dict[str, Any]] = None + hf_overrides: dict[str, Any] | None = None disable_vision_encoder: bool = False logprobs_mode: str = None # router replay @@ -488,9 +543,9 @@ class Response: text: str generate_token_len: int input_token_len: int - finish_reason: Optional[Literal['stop', 'length']] = None - token_ids: List[int] = field(default_factory=list) - logprobs: List[Dict[int, float]] = None + finish_reason: Literal['stop', 'length'] | None = None + token_ids: list[int] = field(default_factory=list) + logprobs: list[dict[int, float]] = None logits: torch.Tensor = None last_hidden_state: torch.Tensor = None index: int = 0 @@ -511,7 +566,7 @@ def _format_none_text_fields(self): fields.append(f'logprobs={self.logprobs}') # Helper function to format tensor information - def _format_tensor(name: str, tensor: Optional[torch.Tensor]) -> List[str]: + def _format_tensor(name: str, tensor: torch.Tensor | None) -> list[str]: if tensor is None: return [f'{name}=None'] try: @@ -580,7 +635,7 @@ class EngineEvent: timestamp: float @classmethod - def new_event(cls, event_type: EventType, timestamp: Optional[float] = None) -> 'EngineEvent': + def new_event(cls, event_type: EventType, timestamp: float | None = None) -> 'EngineEvent': # Timestamps MUST use wall-clock time (time.time()) to maintain consistency # between csrc(std::chrono::system_clock) and python timestamp = time.time() if timestamp is None else timestamp @@ -604,11 +659,11 @@ class RequestMetrics: Attributes: token_timestamp: A wall-clock time when a token is generated. - engine_events: List of engine events during inference. + engine_events: list of engine events during inference. """ token_timestamp: float = 0.0 - engine_events: List[EngineEvent] = field(default_factory=list) - spec_info: Optional[Dict[str, Any]] = None + engine_events: list[EngineEvent] = field(default_factory=list) + spec_info: dict[str, Any] | None = None @dataclass @@ -625,12 +680,12 @@ class EngineOutput: req_metrics: request metrics information """ status: ResponseType - token_ids: List[int] - logprobs: List[Dict[int, float]] = None + token_ids: list[int] + logprobs: list[dict[int, float]] = None logits: torch.Tensor = None last_hidden_state: torch.Tensor = None - cache_block_ids: Optional[List[int]] = None - req_metrics: Optional[RequestMetrics] = None + cache_block_ids: list[int] | None = None + req_metrics: RequestMetrics | None = None routed_experts: torch.Tensor = None diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 588f458bb6..dfd182bfde 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -52,7 +52,7 @@ class SamplingParam: repetition_penalty: float = 1.0 ignore_eos: bool = False random_seed: int = None - stop_words: List[int] = field(default_factory=list) + stop_words: List[List[int]] = field(default_factory=list) bad_words: List[int] = field(default_factory=list) max_new_tokens: int = 512 min_new_tokens: int = 0 @@ -75,7 +75,11 @@ def from_gen_config(cls, gen_config: GenerationConfig): stop_words = gen_config.stop_token_ids or [] bad_words = gen_config.bad_token_ids or [] if gen_config.ignore_eos: - bad_words += stop_words + if any(len(s) > 1 for s in stop_words): + logger.warning('Multi-token stop words are not supported and ' + 'will be ignored. Only single-token stop words can ' + 'be used to stop generation.') + bad_words += [s[0] for s in stop_words if len(s) == 1] stop_words = [] top_k = gen_config.top_k diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index 78357f1164..a650fbeab7 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -134,8 +134,6 @@ def __init__(self, # parameters for member functions self.stop_words = _stop_words(self.chat_template.stop_words, self.tokenizer) - if self.stop_words is not None: - self.stop_words = self.stop_words[0][0].tolist() self.backend = backend self.request_logger = RequestLogger(max_log_len) @@ -389,9 +387,14 @@ async def generate( def is_error(status): return status not in [ResponseType.SUCCESS, ResponseType.FINISH, ResponseType.CANCEL] - stop_ids = [] + single_stop_ids: set = set() + multi_stop_seqs: list = [] if not gen_config.ignore_eos: - stop_ids = gen_config.stop_token_ids or [] + for seq in (gen_config.stop_token_ids or []): + if len(seq) == 1: + single_stop_ids.add(seq[0]) + else: + multi_stop_seqs.append(seq) metrics_processor.increase_total_requests() async with session.request_handle() as handle: @@ -441,12 +444,22 @@ def is_error(status): output_len = len(outputs.token_ids) if hit_stop_token or output_len == 0: continue - - # This assumes the engine will stop when stop token is hit - if output_len and outputs.token_ids[-1] in stop_ids: + print(f'outputs.token_ids: {outputs.token_ids}') + # Check single-token stop + if output_len and outputs.token_ids[-1] in single_stop_ids: hit_stop_token = 1 token_ids += outputs.token_ids[:output_len - hit_stop_token] + + # Check multi-token stop sequences + if not hit_stop_token and multi_stop_seqs: + gen_ids = token_ids[input_len:] + for mseq in multi_stop_seqs: + slen = len(mseq) + if len(gen_ids) >= slen and gen_ids[-slen:] == mseq: + hit_stop_token = slen + token_ids = token_ids[:-slen] + break gen_len = len(token_ids) - input_len ids_offset = state.ids_offset @@ -480,7 +493,8 @@ def is_error(status): if outputs.status == ResponseType.CANCEL: finish_reason = 'abort' else: - finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length' + is_stop = (outputs.token_ids[-1] in single_stop_ids) or hit_stop_token > 0 + finish_reason = 'stop' if is_stop else 'length' # utf-8 char at the end means it's a potential unfinished byte sequence if not response.endswith('�'): @@ -488,13 +502,14 @@ def is_error(status): response = '' token_ids, logits, last_hidden_state, logprobs = [], None, None, None if gen_config.include_stop_str_in_output and finish_reason == 'stop': - # return the eos token id (MUST be in a list), eos string, eos token's logits and so on - token_ids = outputs.token_ids[-1:] + stop_len = max(hit_stop_token, 1) + token_ids = outputs.token_ids[-stop_len:] response = self.tokenizer.decode(token_ids, skip_special_tokens=False) - logits = outputs.logits[-1:] if outputs.logits is not None else None - last_hidden_state = outputs.last_hidden_state[-1:] if outputs.last_hidden_state else None - logprobs = outputs.logprobs[-1:] if outputs.logprobs else None - gen_len += 1 + logits = outputs.logits[-stop_len:] if outputs.logits is not None else None + last_hidden_state = (outputs.last_hidden_state[-stop_len:] + if outputs.last_hidden_state else None) + logprobs = outputs.logprobs[-stop_len:] if outputs.logprobs else None + gen_len += stop_len # router replay routed_experts = outputs.routed_experts diff --git a/lmdeploy/tokenizer.py b/lmdeploy/tokenizer.py index 2a4fc407f8..643a5f090e 100644 --- a/lmdeploy/tokenizer.py +++ b/lmdeploy/tokenizer.py @@ -182,9 +182,9 @@ def indexes_containing_token(self, token: str): # there might be token id that exceeds self.vocab_size if len(indexes) == 0: indexes = self.encode(token, False) - if len(indexes) != 1: - self.logger.warning(f'The token {token}, its length of indexes {indexes} is ' - 'not 1. Currently, it can not be used as stop words') + if len(indexes) > 1: + # Multi-token encoding: return empty so callers can handle + # the multi-token case via encode() directly. indexes = [] self._indexes_tokens_deque.append((token, indexes)) return indexes @@ -540,7 +540,7 @@ def indexes_containing_token(self, token): the input token.""" encoded = self.encode(token, add_bos=False) if len(encoded) > 1: - self.logger.warning(f'The token {token}, its length of indexes {encoded} is over ' - 'than 1. Currently, it can not be used as stop words') + # Multi-token encoding: return empty so callers can handle + # the multi-token case via encode() directly. return [] return self.model.indexes_containing_token(token) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index a4a37dc529..f5b3809104 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -41,12 +41,20 @@ MAX_LOGPROBS = 1024 -def _construct_stop_or_bad_words(words: List[int] = None): - if words is None or len(words) == 0: +def _construct_stop_or_bad_words(seqs: List[List[int]] = None): + """Build packed (token_ids, offsets) for TurboMind stop/bad words. + + Each inner list is a token ID sequence. The offset array stores cumulative lengths so the C++ kernel knows where + each sequence ends. + """ + if not seqs: return None - offsets = list(range(1, len(words) + 1)) - combined = [words, offsets] - return combined + tokens: List[int] = [] + offsets: List[int] = [] + for seq in seqs: + tokens.extend(seq) + offsets.append(len(tokens)) + return [tokens, offsets] def _np_dict_to_tm_dict(np_dict: dict): @@ -807,9 +815,9 @@ def _get_generation_config(self, cfg: GenerationConfig): c.min_p = cfg.min_p c.temperature = cfg.temperature if cfg.stop_token_ids: - c.eos_ids = cfg.stop_token_ids + c.eos_ids = [s[0] for s in cfg.stop_token_ids if len(s) == 1] if cfg.bad_token_ids: - c.bad_ids = _construct_stop_or_bad_words(cfg.bad_token_ids) + c.bad_ids = _construct_stop_or_bad_words([[tid] for tid in cfg.bad_token_ids]) if not cfg.ignore_eos and cfg.stop_token_ids: c.stop_ids = _construct_stop_or_bad_words(cfg.stop_token_ids) c.repetition_penalty = cfg.repetition_penalty diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 5e06ab5ae9..387f522760 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -195,28 +195,31 @@ def filter_suffix(response: str, suffixes: list[str] | None = None) -> str: return response -# TODO remove stop_word_offsets stuff and make it clean -def _stop_words(stop_words: list[int | str], tokenizer: object): - """Return list of stop-words to numpy.ndarray.""" - import numpy as np +def _stop_words(stop_words: list[int | str], tokenizer: object) -> list[list[int]] | None: + """Convert chat-template stop words to List[List[int]]. + + Each element is a token ID sequence representing one stop word. Single-token matches from vocab scan produce + length-1 lists. Multi-token words that require encoding produce longer lists. + """ if stop_words is None: return None assert isinstance(stop_words, list) and \ all(isinstance(elem, (str, int)) for elem in stop_words), \ f'stop_words must be a list but got {type(stop_words)}' - stop_indexes = [] + seqs: list[list[int]] = [] for stop_word in stop_words: - if isinstance(stop_word, str): - stop_indexes += tokenizer.indexes_containing_token(stop_word) - elif isinstance(stop_word, int): - stop_indexes.append(stop_word) - assert isinstance(stop_indexes, list) and all(isinstance(elem, int) for elem in stop_indexes), 'invalid stop_words' - # each id in stop_indexes represents a stop word - # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for - # detailed explanation about fastertransformer's stop_indexes - stop_word_offsets = range(1, len(stop_indexes) + 1) - stop_words = np.array([[stop_indexes, stop_word_offsets]]).astype(np.int32) - return stop_words + if isinstance(stop_word, int): + seqs.append([stop_word]) + elif isinstance(stop_word, str): + single_matches = tokenizer.indexes_containing_token(stop_word) + if single_matches: + for idx in single_matches: + seqs.append([idx]) + else: + encoded = tokenizer.encode(stop_word, add_bos=False) + if encoded: + seqs.append(encoded) + return seqs or None def get_hf_gen_cfg(path: str): From 9e5818553abeba376bf9c07c5755a3aa4b216638 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 23 Mar 2026 09:27:18 +0000 Subject: [PATCH 2/6] buffer multi-token stop phrases to prevent partial prefix leakage in streaming --- lmdeploy/serve/core/async_engine.py | 104 +++++++++++++++++++++++----- 1 file changed, 88 insertions(+), 16 deletions(-) diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index a650fbeab7..e5a115e004 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -29,6 +29,53 @@ logger = get_logger('lmdeploy') +def _commit_stream_tokens(all_ids: list[int], input_len: int, pending_ids: list[int], new_ids: list[int], + multi_stop_seqs: list[list[int]], holdback_len: int): + """Commit safe streamed tokens with multi-stop holdback. + + Args: + all_ids: full token_ids buffer (prompt + committed generated tokens). + input_len: length of the prompt prefix in all_ids; generated tokens start here. + pending_ids: tokens buffered but not yet committed (mutated in place). + new_ids: freshly arrived tokens to add this iteration. + multi_stop_seqs: list of multi-token stop sequences to match against. + holdback_len: number of tokens to withhold from the tail (max_stop_len - 1). + + Returns: + hit_stop_token: matched stop length (0 if none). + matched_stop_ids: matched stop token ids. + commit_ids: token ids safe to emit this round. + pending_ids: updated pending buffer. + """ + pending_ids.extend(new_ids) + hit_stop_token = 0 + matched_stop_ids: list[int] = [] + for mseq in multi_stop_seqs: + slen = len(mseq) + plen = len(pending_ids) + total = len(all_ids) - input_len + plen + if total < slen: + continue + if plen >= slen: + tail = pending_ids[-slen:] + else: + need = slen - plen + tail = all_ids[len(all_ids) - need:] + pending_ids + if tail == mseq: + hit_stop_token = slen + matched_stop_ids = list(mseq) + del pending_ids[-slen:] + break + if hit_stop_token: + commit_ids = list(pending_ids) + pending_ids.clear() + else: + commit_len = max(0, len(pending_ids) - holdback_len) + commit_ids = pending_ids[:commit_len] + del pending_ids[:commit_len] + return hit_stop_token, matched_stop_ids, commit_ids, pending_ids + + @dataclasses.dataclass class GenOut: """Pack all response information together.""" @@ -427,6 +474,11 @@ def is_error(status): step=history_len) as gen: logger.debug(f'[generate] session {session_id} started') hit_stop_token = 0 + stop_by_single = False + matched_stop_ids: list[int] = [] + max_multi_stop_len = max((len(s) for s in multi_stop_seqs), default=0) + holdback_len = max(0, max_multi_stop_len - 1) + pending_ids: list[int] = [] req_stats = RequestStats(prompt_tokens=input_len) # per-request stats # We use this as default outputs in case the async_stream_infer of the Engine yields empty generator. @@ -444,22 +496,32 @@ def is_error(status): output_len = len(outputs.token_ids) if hit_stop_token or output_len == 0: continue - print(f'outputs.token_ids: {outputs.token_ids}') + # print(f'outputs.token_ids: {outputs.token_ids}') # Check single-token stop if output_len and outputs.token_ids[-1] in single_stop_ids: hit_stop_token = 1 - - token_ids += outputs.token_ids[:output_len - hit_stop_token] - - # Check multi-token stop sequences - if not hit_stop_token and multi_stop_seqs: - gen_ids = token_ids[input_len:] - for mseq in multi_stop_seqs: - slen = len(mseq) - if len(gen_ids) >= slen and gen_ids[-slen:] == mseq: - hit_stop_token = slen - token_ids = token_ids[:-slen] - break + stop_by_single = True + matched_stop_ids = [outputs.token_ids[-1]] + + new_ids = outputs.token_ids[:output_len - hit_stop_token] + if not hit_stop_token: + if multi_stop_seqs: + hit_stop_token, matched_multi_ids, commit_ids, pending_ids = _commit_stream_tokens( + token_ids, + input_len, + pending_ids, + new_ids, + multi_stop_seqs, + holdback_len, + ) + if matched_multi_ids: + matched_stop_ids = matched_multi_ids + else: + commit_ids = new_ids + else: + commit_ids = pending_ids + new_ids + pending_ids = [] + token_ids.extend(commit_ids) gen_len = len(token_ids) - input_len ids_offset = state.ids_offset @@ -493,7 +555,17 @@ def is_error(status): if outputs.status == ResponseType.CANCEL: finish_reason = 'abort' else: - is_stop = (outputs.token_ids[-1] in single_stop_ids) or hit_stop_token > 0 + if not hit_stop_token and pending_ids: + token_ids.extend(pending_ids) + pending_ids = [] + gen_len = len(token_ids) - input_len + ids_offset = state.ids_offset + response, state = self.tokenizer.detokenize_incrementally( + token_ids, + state, + skip_special_tokens=gen_config.skip_special_tokens, + spaces_between_special_tokens=gen_config.spaces_between_special_tokens) + is_stop = stop_by_single or hit_stop_token > 0 finish_reason = 'stop' if is_stop else 'length' # utf-8 char at the end means it's a potential unfinished byte sequence @@ -502,8 +574,8 @@ def is_error(status): response = '' token_ids, logits, last_hidden_state, logprobs = [], None, None, None if gen_config.include_stop_str_in_output and finish_reason == 'stop': - stop_len = max(hit_stop_token, 1) - token_ids = outputs.token_ids[-stop_len:] + token_ids = matched_stop_ids if matched_stop_ids else outputs.token_ids[-1:] + stop_len = len(token_ids) response = self.tokenizer.decode(token_ids, skip_special_tokens=False) logits = outputs.logits[-stop_len:] if outputs.logits is not None else None last_hidden_state = (outputs.last_hidden_state[-stop_len:] From 51dd33c579b1013a7b3cd63009fc2107384fa37c Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 23 Mar 2026 12:46:27 +0000 Subject: [PATCH 3/6] for pytorch --- lmdeploy/pytorch/engine/logits_process.py | 14 +- lmdeploy/pytorch/engine/model_agent/agent.py | 5 + lmdeploy/pytorch/strategies/ar/model_agent.py | 122 ++++++++++++++++-- lmdeploy/pytorch/strategies/ar/sampling.py | 25 +++- lmdeploy/pytorch/strategies/dllm/sampling.py | 2 +- 5 files changed, 148 insertions(+), 20 deletions(-) diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index f6f290fc29..a355b061f9 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -210,7 +210,7 @@ def _filter_repetition_ngram_( return scores # use first stop words _, found = ngram(generated_ids, n, threshold, max_n, max_ngram_window_size) - stop_words = stop_words[:, 0] + stop_words = stop_words[:, 0, 0] # fill all scores -inf scores.masked_fill_(found[:, None], -float('inf')) # set stop words to 0 @@ -245,7 +245,7 @@ class SamplingInputs: bad_words: torch.LongTensor = None bad_mask: torch.BoolTensor = None stop_words: torch.LongTensor = None - stop_mask: torch.BoolTensor = None + stop_word_lens: torch.LongTensor = None repetition_penalty: torch.Tensor = None top_k: torch.LongTensor = None top_p: torch.Tensor = None @@ -428,11 +428,13 @@ async def __call__(self, scores: torch.Tensor) -> torch.Tensor: scores = _process_bad_words_(scores, bad_words, bad_mask) stop_words = sampling_inputs.stop_words - if stop_words is not None: + stop_word_lens = sampling_inputs.stop_word_lens + if stop_words is not None and stop_word_lens is not None: ignore_eos = sampling_inputs.num_ignore_eos > 0 - stop_mask = sampling_inputs.stop_mask - stop_mask = torch.where(ignore_eos[:, None], stop_mask, False) - scores = _process_bad_words_(scores, stop_words, stop_mask) + single_mask = (stop_word_lens == 1) & ignore_eos[:, None] + if single_mask.any(): + single_tokens = stop_words[:, :, 0] + scores = _process_bad_words_(scores, single_tokens, single_mask) return scores, logprobs diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f9c2919962..3e25f4962d 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -695,11 +695,16 @@ async def _step_postprocess_with_output(self, logger.debug(f' rank[{rank}]: synchronize token ids') # stopping criteria + # Use output_token_ids (all tokens accepted this step) so that multi-token + # stop sequences whose last token is not the final spec-decoded token are + # detected correctly. For non-spec AR, output_token_ids == next_token_ids. stopped, stop_pos, stopping_criteria = stopping_criteria.step( next_token_ids, sampling_inputs.stop_words, inputs=inputs, extra_inputs=extra_inputs, + stop_word_lens=sampling_inputs.stop_word_lens, + generated_ids=sampling_inputs.generated_ids, ) # send output diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 9c7abb5887..a6db3b9b8c 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -54,40 +54,142 @@ class ARExtraOutputs(ExtraOutputs): @dataclass class ARStoppingCriteria(StoppingCriteria): num_appendable_ids: torch.Tensor + # Tail of previously generated tokens, shape [batch, tail_len]. + # Maintained across steps so that multi-token stop sequences spanning two + # decode steps are detected without relying on the (pipelined) generated_ids + # from SamplingInputs, which lags one step behind. + stop_tail: Optional[torch.Tensor] = None def clone(self): """clone.""" - return ARStoppingCriteria(num_appendable_ids=self.num_appendable_ids) + tail = self.stop_tail.clone() if self.stop_tail is not None else None + return ARStoppingCriteria(num_appendable_ids=self.num_appendable_ids, stop_tail=tail) def merge(self, other: 'ARStoppingCriteria'): """Merge two stopping criteria.""" - new_num_appendable = torch.cat([self.num_appendable_ids, other.num_appendable_ids], dim=0) - return ARStoppingCriteria(num_appendable_ids=new_num_appendable) + new_num = torch.cat([self.num_appendable_ids, other.num_appendable_ids], dim=0) + t0, t1 = self.stop_tail, other.stop_tail + if t0 is None and t1 is None: + new_tail = None + else: + bs0 = self.num_appendable_ids.size(0) + bs1 = other.num_appendable_ids.size(0) + dev = (t0 if t0 is not None else t1).device + if t0 is None: + t0 = torch.zeros(bs0, t1.size(1), dtype=torch.long, device=dev) + if t1 is None: + t1 = torch.zeros(bs1, t0.size(1), dtype=torch.long, device=dev) + # Pad the shorter tail to the same length. + l0, l1 = t0.size(1), t1.size(1) + if l0 < l1: + t0 = torch.nn.functional.pad(t0, (l1 - l0, 0)) + elif l1 < l0: + t1 = torch.nn.functional.pad(t1, (l0 - l1, 0)) + new_tail = torch.cat([t0, t1], dim=0) + return ARStoppingCriteria(num_appendable_ids=new_num, stop_tail=new_tail) def update(self, delta: ModelInputsDelta): """Update stopping criteria.""" indices = delta.indices - new_num_appendable = self.num_appendable_ids[indices] - return ARStoppingCriteria(num_appendable_ids=new_num_appendable) + new_num = self.num_appendable_ids[indices] + new_tail = self.stop_tail[indices] if self.stop_tail is not None else None + return ARStoppingCriteria(num_appendable_ids=new_num, stop_tail=new_tail) @record_function('stopping_criteria') def step(self, token_ids: torch.Tensor, stop_words: torch.Tensor, inputs: Optional[ModelInputs] = None, - extra_inputs: Optional[ARExtraInputs] = None): + extra_inputs: Optional[ARExtraInputs] = None, + stop_word_lens: Optional[torch.Tensor] = None, + generated_ids: Optional[torch.Tensor] = None): """Check whether to stop generation.""" num_appendable_ids = self.num_appendable_ids - 1 stopped = num_appendable_ids <= 0 stop_pos = torch.zeros_like(num_appendable_ids) - if stop_words is not None: - sw_stopped = (token_ids[:, None] == stop_words).any(1) + + if stop_words is not None and stop_word_lens is not None: + max_slen = int(stop_word_lens.max().item()) + tail_len = max(0, max_slen - 1) + batch_size = stop_words.size(0) + num_seqs = stop_words.size(1) + sw_stopped = torch.zeros(batch_size, dtype=torch.bool, device=stop_words.device) + + # Normalise to [batch, step_len] as a view so in-place masking propagates. + token_ids_was_1d = (token_ids.ndim == 1) + if token_ids_was_1d: + token_ids = token_ids.unsqueeze(1) + + new_tail = torch.zeros( + (batch_size, tail_len), dtype=torch.long, device=token_ids.device) if tail_len > 0 else None + + for bidx in range(batch_size): + step_tokens = token_ids[bidx] + valid_tokens = step_tokens[step_tokens >= 0] + + # Retrieve the tail from the previous step for this batch item. + if self.stop_tail is not None and tail_len > 0: + prev_tail = self.stop_tail[bidx].to(token_ids.device) + # Trim or pad to the current tail_len. + if prev_tail.size(0) >= tail_len: + prev_tail = prev_tail[-tail_len:] + else: + prev_tail = torch.nn.functional.pad(prev_tail, (tail_len - prev_tail.size(0), 0)) + else: + prev_tail = token_ids.new_zeros(tail_len) + + if valid_tokens.numel() == 0: + # No new tokens this step; carry the tail forward unchanged. + if new_tail is not None: + new_tail[bidx] = prev_tail + continue + + # History = tail of previous steps + tokens from this step. + history = torch.cat([prev_tail, valid_tokens]) if tail_len > 0 else valid_tokens + hist_len = history.size(0) + stop_pos_bidx = valid_tokens.numel() - 1 # default: last valid token + + for si in range(num_seqs): + slen = int(stop_word_lens[bidx, si].item()) + if slen <= 0 or hist_len < slen: + continue + target = stop_words[bidx, si, :slen] + # Scan positions whose end falls within the current step tokens + # (end_pos >= tail_len ensures at least one new token is included). + for end_pos in range(max(slen - 1, tail_len), hist_len): + if (history[end_pos - slen + 1:end_pos + 1] == target).all(): + sw_stopped[bidx] = True + step_end_pos = end_pos - tail_len # 0-indexed within valid_tokens + stop_pos_bidx = min(stop_pos_bidx, step_end_pos) + break + if sw_stopped[bidx]: + break + + if sw_stopped[bidx]: + stop_pos[bidx] = stop_pos_bidx + # Mask tokens generated after the stop position in the same step. + if token_ids.size(1) > (stop_pos_bidx + 1): + token_ids[bidx, stop_pos_bidx + 1:] = -1 + effective_tokens = valid_tokens[:stop_pos_bidx + 1] + else: + effective_tokens = valid_tokens + + # Update tail: last tail_len tokens of [prev_tail, effective_tokens]. + if new_tail is not None: + combined = torch.cat([prev_tail, effective_tokens]) + new_tail[bidx] = combined[-tail_len:] + + if token_ids_was_1d and token_ids.size(1) == 1: + token_ids = token_ids.squeeze(1) + stopped = stopped | sw_stopped one_ids = torch.clamp_max(num_appendable_ids, 0) num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) - # I don't know why assign inplace does not works... - new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids) + else: + new_tail = None + + new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids, stop_tail=new_tail) return stopped, stop_pos, new_stopping diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index 8a8c45d032..95a566d680 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -97,7 +97,8 @@ def __gather_params(): bw = param.bad_words sw = param.stop_words if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens): - bw = bw + sw + # During min_new_tokens period suppress single-token stops as bad words. + bw = bw + [s[0] for s in sw if len(s) == 1] bad_words[idx] = bw stop_words[idx] = sw logits_processors[idx] = param.logits_processors @@ -143,6 +144,24 @@ def __get_bad_words(bad_words): mask = ret >= 0 return ret, mask + def __get_stop_words(stop_words_list): + """Build stop_words [batch, num_seqs, max_len] and stop_word_lens + [batch, num_seqs].""" + max_num_seqs = max(len(sw) for sw in stop_words_list) + if max_num_seqs == 0: + return None, None + max_len = max((len(s) for sw in stop_words_list for s in sw), default=0) + if max_len == 0: + return None, None + seqs = torch.zeros((batch_size, max_num_seqs, max_len), dtype=torch.long) + lens = torch.zeros((batch_size, max_num_seqs), dtype=torch.long) + for i, sw in enumerate(stop_words_list): + for j, seq in enumerate(sw): + slen = len(seq) + seqs[i, j, :slen] = torch.tensor(seq, dtype=torch.long) + lens[i, j] = slen + return seqs, lens + __gather_params() if all(rp == 1.0 for rp in repetition_penalty): @@ -156,7 +175,7 @@ def __get_bad_words(bad_words): temperature = None bad_words, bad_mask = __get_bad_words(bad_words) - stop_words, stop_mask = __get_bad_words(stop_words) + stop_words, stop_word_lens = __get_stop_words(stop_words) max_top_k = max(top_k) if min(top_k) <= 0: @@ -201,7 +220,7 @@ def __get_bad_words(bad_words): bad_words=bad_words, bad_mask=bad_mask, stop_words=stop_words, - stop_mask=stop_mask, + stop_word_lens=stop_word_lens, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index d7c8bc4716..569ac2bde9 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -34,7 +34,7 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: 'bad_words', 'bad_mask', 'stop_words', - 'stop_mask', + 'stop_word_lens', 'repetition_penalty', 'top_k', 'top_p', From 8874b4fbc557a678dff3c55f41304482ca456f05 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 23 Mar 2026 15:08:53 +0000 Subject: [PATCH 4/6] vectorize step function --- lmdeploy/pytorch/engine/model_agent/agent.py | 3 +- lmdeploy/pytorch/strategies/ar/model_agent.py | 178 +++++++++++------- 2 files changed, 107 insertions(+), 74 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 3e25f4962d..4ccbb19383 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -699,12 +699,11 @@ async def _step_postprocess_with_output(self, # stop sequences whose last token is not the final spec-decoded token are # detected correctly. For non-spec AR, output_token_ids == next_token_ids. stopped, stop_pos, stopping_criteria = stopping_criteria.step( - next_token_ids, + output_token_ids, sampling_inputs.stop_words, inputs=inputs, extra_inputs=extra_inputs, stop_word_lens=sampling_inputs.stop_word_lens, - generated_ids=sampling_inputs.generated_ids, ) # send output diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index a6db3b9b8c..dcc1a77bb9 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -101,83 +101,20 @@ def step(self, stop_words: torch.Tensor, inputs: Optional[ModelInputs] = None, extra_inputs: Optional[ARExtraInputs] = None, - stop_word_lens: Optional[torch.Tensor] = None, - generated_ids: Optional[torch.Tensor] = None): + stop_word_lens: Optional[torch.Tensor] = None): """Check whether to stop generation.""" num_appendable_ids = self.num_appendable_ids - 1 stopped = num_appendable_ids <= 0 stop_pos = torch.zeros_like(num_appendable_ids) - if stop_words is not None and stop_word_lens is not None: - max_slen = int(stop_word_lens.max().item()) - tail_len = max(0, max_slen - 1) - batch_size = stop_words.size(0) - num_seqs = stop_words.size(1) - sw_stopped = torch.zeros(batch_size, dtype=torch.bool, device=stop_words.device) - - # Normalise to [batch, step_len] as a view so in-place masking propagates. + if stop_words is None or stop_word_lens is None: + new_tail = None + else: token_ids_was_1d = (token_ids.ndim == 1) if token_ids_was_1d: token_ids = token_ids.unsqueeze(1) - new_tail = torch.zeros( - (batch_size, tail_len), dtype=torch.long, device=token_ids.device) if tail_len > 0 else None - - for bidx in range(batch_size): - step_tokens = token_ids[bidx] - valid_tokens = step_tokens[step_tokens >= 0] - - # Retrieve the tail from the previous step for this batch item. - if self.stop_tail is not None and tail_len > 0: - prev_tail = self.stop_tail[bidx].to(token_ids.device) - # Trim or pad to the current tail_len. - if prev_tail.size(0) >= tail_len: - prev_tail = prev_tail[-tail_len:] - else: - prev_tail = torch.nn.functional.pad(prev_tail, (tail_len - prev_tail.size(0), 0)) - else: - prev_tail = token_ids.new_zeros(tail_len) - - if valid_tokens.numel() == 0: - # No new tokens this step; carry the tail forward unchanged. - if new_tail is not None: - new_tail[bidx] = prev_tail - continue - - # History = tail of previous steps + tokens from this step. - history = torch.cat([prev_tail, valid_tokens]) if tail_len > 0 else valid_tokens - hist_len = history.size(0) - stop_pos_bidx = valid_tokens.numel() - 1 # default: last valid token - - for si in range(num_seqs): - slen = int(stop_word_lens[bidx, si].item()) - if slen <= 0 or hist_len < slen: - continue - target = stop_words[bidx, si, :slen] - # Scan positions whose end falls within the current step tokens - # (end_pos >= tail_len ensures at least one new token is included). - for end_pos in range(max(slen - 1, tail_len), hist_len): - if (history[end_pos - slen + 1:end_pos + 1] == target).all(): - sw_stopped[bidx] = True - step_end_pos = end_pos - tail_len # 0-indexed within valid_tokens - stop_pos_bidx = min(stop_pos_bidx, step_end_pos) - break - if sw_stopped[bidx]: - break - - if sw_stopped[bidx]: - stop_pos[bidx] = stop_pos_bidx - # Mask tokens generated after the stop position in the same step. - if token_ids.size(1) > (stop_pos_bidx + 1): - token_ids[bidx, stop_pos_bidx + 1:] = -1 - effective_tokens = valid_tokens[:stop_pos_bidx + 1] - else: - effective_tokens = valid_tokens - - # Update tail: last tail_len tokens of [prev_tail, effective_tokens]. - if new_tail is not None: - combined = torch.cat([prev_tail, effective_tokens]) - new_tail[bidx] = combined[-tail_len:] + sw_stopped, stop_pos, new_tail = self._check_stop_words(token_ids, stop_words, stop_word_lens) if token_ids_was_1d and token_ids.size(1) == 1: token_ids = token_ids.squeeze(1) @@ -186,11 +123,108 @@ def step(self, one_ids = torch.clamp_max(num_appendable_ids, 0) num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) + return (stopped, stop_pos, ARStoppingCriteria(num_appendable_ids=num_appendable_ids, stop_tail=new_tail)) + + def _check_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor): + """Vectorized multi-token stop word detection. + + Uses ``unfold`` for sliding-window matching so that no Python loops + over batch items, stop-word entries, or window positions are needed. + + Args: + token_ids: [batch, step_len], -1 for invalid positions. + Modified **in-place** (tokens after stop are set to -1). + stop_words: [batch, num_seqs, max_slen] + stop_word_lens: [batch, num_seqs] + + Returns: + sw_stopped: [batch] bool + stop_pos: [batch] long – step-relative index of the stop token + new_tail: [batch, tail_len] or None + """ + max_slen = int(stop_word_lens.max().item()) + tail_len = max(0, max_slen - 1) + batch_size = token_ids.size(0) + step_len = token_ids.size(1) + device = token_ids.device + + # -- 1. build history = [prev_tail | token_ids] -- + prev_tail = self._get_prev_tail(batch_size, tail_len, device) + if prev_tail is not None: + history = torch.cat([prev_tail, token_ids], dim=1) else: - new_tail = None - - new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids, stop_tail=new_tail) - return stopped, stop_pos, new_stopping + history = token_ids + hist_len = history.size(1) + + # -- 2. sliding-window matching via unfold -- + NO_MATCH = hist_len + best_end = history.new_full((batch_size, ), NO_MATCH) + for slen in stop_word_lens.unique().tolist(): + slen = int(slen) + if slen <= 0 or hist_len < slen: + continue + windows = history.unfold(1, slen, 1) # [B, W, slen] + targets = stop_words[:, :, :slen] # [B, S, slen] + len_mask = (stop_word_lens == slen) # [B, S] + + match = (windows.unsqueeze(2) == targets.unsqueeze(1)).all(-1) + match = match & len_mask.unsqueeze(1) + match_any = match.any(2) # [B, W] + + # discard windows that don't include any new token from this step + min_win = max(0, tail_len - slen + 1) + if min_win > 0: + match_any[:, :min_win] = False + + has_match = match_any.any(1) + first_win = match_any.int().argmax(1) + end_pos = first_win + slen - 1 + better = has_match & (end_pos < best_end) + best_end = torch.where(better, end_pos, best_end) + + sw_stopped = best_end < NO_MATCH + + # -- 3. compute stop_pos and mask trailing tokens -- + step_stop_pos = best_end - tail_len + stop_pos = torch.where(sw_stopped, step_stop_pos, sw_stopped.new_zeros(batch_size, dtype=torch.long)) + + col_idx = torch.arange(step_len, device=device) + after_stop = (col_idx > step_stop_pos.unsqueeze(1)) & sw_stopped.unsqueeze(1) + token_ids[after_stop] = -1 + + # -- 4. update tail -- + new_tail = self._build_new_tail(history, tail_len, sw_stopped, best_end, token_ids) + + return sw_stopped, stop_pos, new_tail + + def _get_prev_tail(self, batch_size: int, tail_len: int, device: torch.device) -> Optional[torch.Tensor]: + """Return the previous tail padded/trimmed to ``tail_len``.""" + if tail_len <= 0: + return None + if self.stop_tail is None: + return torch.zeros(batch_size, tail_len, dtype=torch.long, device=device) + prev = self.stop_tail.to(device) + pt_len = prev.size(1) + if pt_len < tail_len: + prev = torch.nn.functional.pad(prev, (tail_len - pt_len, 0)) + elif pt_len > tail_len: + prev = prev[:, -tail_len:] + return prev + + @staticmethod + def _build_new_tail(history: torch.Tensor, tail_len: int, sw_stopped: torch.Tensor, best_end: torch.Tensor, + token_ids: torch.Tensor) -> Optional[torch.Tensor]: + """Gather the last ``tail_len`` valid tokens from *history*.""" + if tail_len <= 0: + return None + valid_counts = (token_ids >= 0).sum(1) + effective_end = torch.where(sw_stopped, best_end, tail_len + valid_counts - 1) + effective_end = effective_end.clamp(min=tail_len - 1) + + offsets = torch.arange(tail_len, device=history.device) + indices = (effective_end - tail_len + 1).unsqueeze(1) + offsets.unsqueeze(0) + indices = indices.clamp(min=0, max=history.size(1) - 1) + return history.gather(1, indices) class ARModelAgentStrategy(ModelAgentStrategy): From 99856675fac093006bd1f58f8a7995260a66f368 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 24 Mar 2026 03:58:21 +0000 Subject: [PATCH 5/6] optimize --- lmdeploy/pytorch/strategies/ar/model_agent.py | 54 +++++++++++++++---- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index dcc1a77bb9..fcdd6c599d 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -128,9 +128,6 @@ def step(self, def _check_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor): """Vectorized multi-token stop word detection. - Uses ``unfold`` for sliding-window matching so that no Python loops - over batch items, stop-word entries, or window positions are needed. - Args: token_ids: [batch, step_len], -1 for invalid positions. Modified **in-place** (tokens after stop are set to -1). @@ -143,7 +140,46 @@ def _check_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, s new_tail: [batch, tail_len] or None """ max_slen = int(stop_word_lens.max().item()) - tail_len = max(0, max_slen - 1) + + if max_slen <= 1: + return self._check_stop_words_single(token_ids, stop_words, stop_word_lens) + + return self._check_stop_words_multi(token_ids, stop_words, stop_word_lens, max_slen) + + def _check_stop_words_single(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor): + """Fast path when every stop word is a single token. + + No tail, no unfold, no sliding window — just a broadcast compare. + """ + step_len = token_ids.size(1) + device = token_ids.device + + targets = stop_words[:, :, 0] # [B, S] + valid = (stop_word_lens == 1) # [B, S] + + # [B, L, 1] == [B, 1, S] -> [B, L, S]; mask invalid targets + match = (token_ids.unsqueeze(2) == targets.unsqueeze(1)) & valid.unsqueeze(1) + match_any = match.any(2) # [B, L] + + sw_stopped = match_any.any(1) # [B] + first_match = match_any.int().argmax(1) # [B] + stop_pos = torch.where(sw_stopped, first_match, torch.zeros_like(first_match)) + + col_idx = torch.arange(step_len, device=device) + after_stop = (col_idx > stop_pos.unsqueeze(1)) & sw_stopped.unsqueeze(1) + token_ids[after_stop] = -1 + + return sw_stopped, stop_pos, None + + def _check_stop_words_multi(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor, + max_slen: int): + """General path for multi-token stop words. + + Per-length unfold loop (each length needs its own window count), but + iterates ``range(1, max_slen+1)`` instead of calling the GPU-syncing + ``stop_word_lens.unique().tolist()``. + """ + tail_len = max_slen - 1 batch_size = token_ids.size(0) step_len = token_ids.size(1) device = token_ids.device @@ -156,12 +192,11 @@ def _check_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, s history = token_ids hist_len = history.size(1) - # -- 2. sliding-window matching via unfold -- + # -- 2. sliding-window matching per length -- NO_MATCH = hist_len best_end = history.new_full((batch_size, ), NO_MATCH) - for slen in stop_word_lens.unique().tolist(): - slen = int(slen) - if slen <= 0 or hist_len < slen: + for slen in range(1, max_slen + 1): + if hist_len < slen: continue windows = history.unfold(1, slen, 1) # [B, W, slen] targets = stop_words[:, :, :slen] # [B, S, slen] @@ -171,7 +206,6 @@ def _check_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, s match = match & len_mask.unsqueeze(1) match_any = match.any(2) # [B, W] - # discard windows that don't include any new token from this step min_win = max(0, tail_len - slen + 1) if min_win > 0: match_any[:, :min_win] = False @@ -206,7 +240,7 @@ def _get_prev_tail(self, batch_size: int, tail_len: int, device: torch.device) - prev = self.stop_tail.to(device) pt_len = prev.size(1) if pt_len < tail_len: - prev = torch.nn.functional.pad(prev, (tail_len - pt_len, 0)) + prev = torch.nn.functional.pad(prev, (tail_len - pt_len, 0), value=-1) elif pt_len > tail_len: prev = prev[:, -tail_len:] return prev From 27bdb55ff4a88d9e03f499d9029d734d77218b65 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 24 Mar 2026 06:36:53 +0000 Subject: [PATCH 6/6] update _check_single_stop_words --- lmdeploy/messages.py | 14 ++----- lmdeploy/pytorch/strategies/ar/model_agent.py | 42 ++++++------------- lmdeploy/serve/core/async_engine.py | 4 +- 3 files changed, 18 insertions(+), 42 deletions(-) diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 4dba517209..0ff90f9341 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -100,7 +100,7 @@ class GenerationConfig: """ n: int = 1 - max_new_tokens: int = 512 + max_new_tokens: int = None do_sample: bool = False top_p: float = 1.0 top_k: int = 50 @@ -152,20 +152,12 @@ def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_words to ids and append the ids to stop_token_ids/bad_token_ids.""" - def words_to_token_seqs(words: list[str], prefer_exact: bool = False) -> list[list[int]]: + def words_to_token_seqs(words: list[str]) -> list[list[int]]: assert isinstance(words, list) and \ all(isinstance(elem, str) for elem in words), \ f'stop_words must be a list of str but got {type(words)}' seqs: list[list[int]] = [] for word in words: - # For stop_words, prefer exact tokenization so multi-word phrases - # are represented as an exact token-id sequence. - if prefer_exact: - encoded = tokenizer.encode(word, add_bos=False) - if encoded: - seqs.append(encoded) - continue - single_matches = tokenizer.indexes_containing_token(word) if single_matches: for idx in single_matches: @@ -176,7 +168,7 @@ def words_to_token_seqs(words: list[str], prefer_exact: bool = False) -> list[li seqs.append(encoded) return seqs - stop_seqs = words_to_token_seqs(self.stop_words, prefer_exact=True) if self.stop_words else [] + stop_seqs = words_to_token_seqs(self.stop_words) if self.stop_words else [] bad_seqs = words_to_token_seqs(self.bad_words) if self.bad_words else [] stop_seqs.extend(self._normalize_stop_token_ids(self.stop_token_ids)) diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index fcdd6c599d..614a7f6d10 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -110,15 +110,11 @@ def step(self, if stop_words is None or stop_word_lens is None: new_tail = None else: - token_ids_was_1d = (token_ids.ndim == 1) - if token_ids_was_1d: - token_ids = token_ids.unsqueeze(1) + # Set a uniform shape for token_ids for both single and multi-token stop words + token_ids = token_ids.unsqueeze(1) if token_ids.ndim == 1 else token_ids sw_stopped, stop_pos, new_tail = self._check_stop_words(token_ids, stop_words, stop_word_lens) - if token_ids_was_1d and token_ids.size(1) == 1: - token_ids = token_ids.squeeze(1) - stopped = stopped | sw_stopped one_ids = torch.clamp_max(num_appendable_ids, 0) num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) @@ -142,36 +138,24 @@ def _check_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, s max_slen = int(stop_word_lens.max().item()) if max_slen <= 1: - return self._check_stop_words_single(token_ids, stop_words, stop_word_lens) - - return self._check_stop_words_multi(token_ids, stop_words, stop_word_lens, max_slen) + # Fast path when every stop word is a single token + return self._check_single_stop_words(token_ids, stop_words, stop_word_lens) - def _check_stop_words_single(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor): - """Fast path when every stop word is a single token. + # General path for multi-token stop words + return self._check_multi_stop_words(token_ids, stop_words, stop_word_lens, max_slen) - No tail, no unfold, no sliding window — just a broadcast compare. - """ - step_len = token_ids.size(1) + def _check_single_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor): + """Fast path: every stop word is a single token, AR always has L==1.""" + batch_size = token_ids.size(0) device = token_ids.device - targets = stop_words[:, :, 0] # [B, S] valid = (stop_word_lens == 1) # [B, S] - - # [B, L, 1] == [B, 1, S] -> [B, L, S]; mask invalid targets - match = (token_ids.unsqueeze(2) == targets.unsqueeze(1)) & valid.unsqueeze(1) - match_any = match.any(2) # [B, L] - - sw_stopped = match_any.any(1) # [B] - first_match = match_any.int().argmax(1) # [B] - stop_pos = torch.where(sw_stopped, first_match, torch.zeros_like(first_match)) - - col_idx = torch.arange(step_len, device=device) - after_stop = (col_idx > stop_pos.unsqueeze(1)) & sw_stopped.unsqueeze(1) - token_ids[after_stop] = -1 - + # token_ids [B, 1] broadcasts against targets [B, S] + sw_stopped = ((token_ids == targets) & valid).any(1) # [B] + stop_pos = torch.zeros(batch_size, dtype=torch.long, device=device) return sw_stopped, stop_pos, None - def _check_stop_words_multi(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor, + def _check_multi_stop_words(self, token_ids: torch.Tensor, stop_words: torch.Tensor, stop_word_lens: torch.Tensor, max_slen: int): """General path for multi-token stop words. diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index e5a115e004..b25b245927 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -374,6 +374,8 @@ async def generate( else: logger.warning('chat_template_kwargs["enable_thinking"] is already set, ' 'the value will not be overwritten by enable_thinking') + + gen_config = self._determine_gen_config(session, input_ids, gen_config=gen_config) if messages: prompt = messages self.request_logger.log_prompt(session, prompt=prompt) @@ -399,8 +401,6 @@ async def generate( # Figure out a graceful way to handle the invalid input prompt_input = dict(input_ids=input_ids) - gen_config = self._determine_gen_config(session, input_ids, gen_config=gen_config) - if gen_config.max_new_tokens == 0: logger.info(f'run out of tokens. session={session_id}.') yield GenOut(response='',