Skip to content
21 changes: 18 additions & 3 deletions lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ async def process_single_prompt(prompt: Union[str, List[int]], prompt_index: int
prompt, individual_sampling_params, multimodal_params, request=raw_request
)

return await _collect_generation_results(generator, request, prompt_str, prompt_index)
return await _collect_generation_results(
generator, request, prompt_str, prompt_index, individual_sampling_params
)

tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)]

Expand Down Expand Up @@ -485,7 +487,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)


async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int):
async def _collect_generation_results(
generator, request: CompletionRequest, prompt: str, prompt_index: int, sampling_params: SamplingParams
):
final_output = []
count_output_tokens = 0
finish_reason = None
Expand Down Expand Up @@ -516,9 +520,20 @@ async def _collect_generation_results(generator, request: CompletionRequest, pro
finish_reason = finish_status.get_finish_reason()
prompt_tokens = metadata["prompt_tokens"]

# 处理停止序列剔除
final_text = "".join(final_output)
if finish_reason == "stop" and sampling_params.stop_sequences.size > 0:
valid_stop_strings = sampling_params.stop_sequences.to_strings()
for stop_str in valid_stop_strings:
stop_index = final_text.rfind(stop_str, max(0, len(final_text) - len(stop_str) - 20), len(final_text))
if stop_index != -1:
logger.debug(f"removed stop sequence in tail: '{final_text[stop_index:]}'")
final_text = final_text[:stop_index]
break

return {
"index": prompt_index,
"text": "".join(final_output),
"text": final_text,
"finish_reason": finish_reason,
"prompt_tokens": prompt_tokens,
"completion_tokens": count_output_tokens,
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/io_objs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd
from .group_req import GroupReqIndexes, GroupReqObjs, AbortedReqCmd, StopStrMatchedReqCmd
5 changes: 5 additions & 0 deletions lightllm/server/core/objs/io_objs/group_req.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ def to_group_req_index(self):
@dataclass
class AbortedReqCmd:
req_id: int


@dataclass
class StopStrMatchedReqCmd:
req_id: int
20 changes: 15 additions & 5 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def get_status(self):
def is_finished(self):
return self.FINISHED_STOP <= self.status <= self.FINISHED_LENGTH

def is_stopped(self):
return self.status == self.FINISHED_STOP

def get_finish_reason(self):
if self.status == self.FINISHED_STOP:
return "stop"
Expand Down Expand Up @@ -74,10 +77,8 @@ class Req(ctypes.Structure):
("prompt_cache_len", ctypes.c_int), # 用于记录prompt cache 的命中长度,用于统计
("is_paused", ctypes.c_bool), # 标记一个Req因为显存资源管理的原因被临时暂停了。
("finish_status", FinishStatus),
# 这个标记变量是http_server 写入,其他进程读取,用于标记该请求是否因为断网被aborted。
("is_aborted", ctypes.c_bool),
# 这个标记变量是router进程读取到is_aborted信息后,router 进程标记该请求已经被abort处理
# 等待推理进程处理,防止router进程反复给推理进程发送abort指令。
("router_aborted", ctypes.c_bool),
# 当FinishStatus 是正常结束状态时,finish_token_index 用于标识结束的
# token 的index位置
("finish_token_index", ctypes.c_int),
Expand All @@ -97,6 +98,12 @@ class Req(ctypes.Structure):
("mtp_accepted_token_num", ctypes.c_int),
# mtp_step 保存一个mtp使用的常量参数,用于快速访问,不会被外部输入初始化
("_mtp_step", ctypes.c_int),
# stop_str_matched 用于判断停止字符串是否匹配成功, detokenization 进程写入,router 进程读取
# 然后router发停止命令给推理进程,推理进程停止输出
("stop_str_matched", ctypes.c_bool),
# 当 stop_str_matched 条件满足的时候,对应的最后一个生成 token 所在的index位置。
# 该变量为 detokenization 进程写入,http_server 读取
("stop_str_matched_token_index", ctypes.c_int),
]

def get_str(self):
Expand Down Expand Up @@ -124,7 +131,6 @@ def init(
self.is_paused = False
self.finish_status = FinishStatus()
self.is_aborted = False
self.router_aborted = False
self.shm_infer_released = False
self.shm_cur_kv_len = 0
self.shm_cur_output_len = 0
Expand All @@ -150,6 +156,8 @@ def init(
self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids
self.mtp_accepted_token_num = 0
self._mtp_step = get_env_start_args().mtp_step
self.stop_str_matched = False
self.stop_str_matched_token_index = -1

self.post_init()

Expand Down Expand Up @@ -210,7 +218,9 @@ def can_release(self):
if self.is_aborted and can_released_mark and ref_count_ok:
return True

if self.finish_status.is_finished() and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
ok_finished_gen_req = self.finish_status.is_finished() or self.stop_str_matched

if ok_finished_gen_req and can_released_mark and ref_count_ok and self.out_tokens_queue.is_empty():
return True

return False
Expand Down
82 changes: 54 additions & 28 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import ctypes
from typing import List, Tuple, Union
from typing import Optional, List, Tuple, Union
from transformers import GenerationConfig
from lightllm.server.req_id_generator import MAX_BEST_OF

Expand All @@ -10,6 +10,7 @@

# 从环境变量获取最大长度限制
STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256))
STOP_SEQUENCE_STR_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH", 256))
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
Expand All @@ -22,17 +23,30 @@ class StopSequence(ctypes.Structure):
_fields_ = [
("sequence", ctypes.c_int * STOP_SEQUENCE_MAX_LENGTH),
("size", ctypes.c_int),
("sequence_str", ctypes.c_char * STOP_SEQUENCE_STR_MAX_LENGTH),
("sequence_str_len", ctypes.c_int),
]

def initialize(self, sequence: List[int]):
def initialize(self, sequence: List[int], sequence_str: Optional[str] = None):
self.size = len(sequence)
assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long."
assert all(isinstance(e, int) for e in sequence), "all must be int"
self.sequence[: self.size] = sequence[:]

def to_list(self):
if sequence_str is not None:
sequence_str_bytes = sequence_str.encode("utf-8")
assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long."
self.sequence_str = sequence_str_bytes
self.sequence_str_len = len(sequence_str_bytes)
else:
self.sequence_str_len = 0

def to_list(self) -> List[int]:
return list(self.sequence[0 : self.size])

def to_string(self) -> str:
return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8")


class StopSequenceGroups(ctypes.Structure):
_pack_ = 4
Expand All @@ -41,40 +55,52 @@ class StopSequenceGroups(ctypes.Structure):
("size", ctypes.c_int),
]

def initialize(self, stop_sequences: Union[str, List], tokenizer):
def initialize(self, stop_sequences: Union[str, List[Union[List[int], str]]], tokenizer):
if stop_sequences is None:
stop_sequences = []
elif isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]

groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer)
self.size = len(groups)
assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups."
for group_idx in range(self.size):
self.groups[group_idx].initialize(groups[group_idx])

def stop_sentences_to_token_ids(self, stop_sequences, tokenizer):
if stop_sequences is None:
stop_sequences = []
else:
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]

new_stop_sequences = []
for stop_info in stop_sequences:
if isinstance(stop_info, str):
stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer)
if stop_str_ids is not None and len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
if isinstance(stop_info, list):
if all(isinstance(x, int) for x in stop_info):
if len(stop_info) > 0:
new_stop_sequences.append(stop_info)
stop_sequences = new_stop_sequences
return stop_sequences

def _stop_str_to_token_ids(self, stop_str: str, tokenizer):
for group_idx in range(self.size):
if isinstance(stop_sequences[group_idx], str):
self.groups[group_idx].initialize(groups[group_idx], sequence_str=stop_sequences[group_idx])
else:
self.groups[group_idx].initialize(groups[group_idx])

def stop_sentences_to_token_ids(self, stop_sequences: List[Union[List[int], str]], tokenizer) -> List[List[int]]:
new_stop_sequences = []
for stop_info in stop_sequences:
if isinstance(stop_info, str):
stop_str_ids = self._stop_str_to_token_ids(stop_info, tokenizer)
if stop_str_ids is not None and len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
if isinstance(stop_info, list):
if all(isinstance(x, int) for x in stop_info):
if len(stop_info) > 0:
new_stop_sequences.append(stop_info)
else:
assert False, "stop_sequences item must be type List[int] when it is a list."
return new_stop_sequences

def _stop_str_to_token_ids(self, stop_str: str, tokenizer) -> List[int]:
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
return stop_str_ids

def to_list(self):
def to_list(self) -> List[List[int]]:
return [self.groups[i].to_list() for i in range(self.size)]

def to_strings(self) -> List[str]:
# 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
return sorted(
[self.groups[i].to_string() for i in range(self.size) if self.groups[i].sequence_str_len > 0],
key=len,
reverse=True,
)


class RegularConstraint(ctypes.Structure):
_pack_ = 4
Expand Down
34 changes: 33 additions & 1 deletion lightllm/server/detokenization/decode_req.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
from typing import List, Dict
from lightllm.server.core.objs import Req
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


LIGHTLLM_DECODE_PREFIX_LENGTH = int(os.getenv("LIGHTLLM_DECODE_PREFIX_LENGTH", 5))

Expand All @@ -15,6 +19,7 @@ def __init__(
self.group_req_id = req.group_req_id
self.prompt_ids = req.shm_prompt_ids.arr[0 : req.input_len].tolist()
self.output_ids = []
self.output_strs = []
self.prefix_offset = max(len(self.prompt_ids) - LIGHTLLM_DECODE_PREFIX_LENGTH, 0)

if is_pd_decode_mode:
Expand All @@ -26,6 +31,9 @@ def __init__(
self.req = req
self.input_len = self.req.input_len
self.prefix_str = ""
self.stop_strs: List[str] = self.req.sample_params.stop_sequences.to_strings()
# to_strings()已经做了倒序排列,第一个元素就是最长字符串
self.stop_str_max_len = len(self.stop_strs[0]) if self.stop_strs else 0

def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], tokenizer):
tokens = [token_id_to_token[token_id] for token_id in self.req.prefix_token_ids.get_token_ids()]
Expand All @@ -35,8 +43,30 @@ def init_token_healing_prefix_str(self, token_id_to_token: Dict[int, str], token
self.prefix_str = ""
return

def stop_sequences_str_match(self) -> bool:
stop_strs = self.stop_strs
if not stop_strs or self.stop_str_max_len == 0:
return False

tail_token_len = self.stop_str_max_len + 10 # 10 for safety
tail_token_strs = self.output_strs[-tail_token_len:]
tail_str = "".join(tail_token_strs)

for stop_str in stop_strs:
if stop_str in tail_str:
logger.debug(
f"req_id {self.request_id} Found stop sequence in tail: stop_str='{stop_str}', "
f"tail_str='{tail_str}'"
)
return True
return False

def need_detoken(self):
if (not self.req.is_aborted) and len(self.output_ids) < self.req.candetoken_out_len:
if (
(not self.req.is_aborted)
and (not self.req.stop_str_matched)
and len(self.output_ids) < self.req.candetoken_out_len
):
return True
return False

Expand All @@ -55,6 +85,8 @@ def get_decode_tokens(self):
def can_set_release_mark(self):
if self.req.is_aborted:
return True
if self.req.stop_str_matched:
return True
if (
self.req.finish_status.is_finished()
and self.req.candetoken_out_len == len(self.output_ids)
Expand Down
12 changes: 12 additions & 0 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def gen_token_out(self):
exist_need_detoken = False
exist_decode = False
for decode_req in self.req_id_to_out.values():
# 已经满足停止字符串停止条件,则不再处理后续生成 token
if decode_req.req.stop_str_matched:
continue

if decode_req.need_detoken() and not decode_req.out_queue_is_full():
new_token_id, src_index = decode_req.get_next_token_id_and_index()
decode_req.output_ids.append(new_token_id)
Expand All @@ -131,6 +135,14 @@ def gen_token_out(self):
logger.error(
f"error token healing state, prefix_str {decode_req.prefix_str} new_text {new_text}"
)

decode_req.output_strs.append(new_text)

# 停止字符串匹配
if not decode_req.req.finish_status.is_stopped() and decode_req.stop_sequences_str_match():
decode_req.req.stop_str_matched_token_index = src_index
decode_req.req.stop_str_matched = True

decode_req.req.out_tokens_queue.push(new_text, src_index, special, count_output_tokens)

if decode_req.need_detoken():
Expand Down
12 changes: 10 additions & 2 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,10 +679,18 @@ async def handle_loop(self):

req.out_tokens_queue.pop_no_ret()

if req.finish_token_index != src_index:
finished_token_index = (
req.stop_str_matched_token_index if req.stop_str_matched else req.finish_token_index
)

if finished_token_index != src_index:
token_list.append((req_id, text, metadata, FinishStatus()))
else:
finish_status = FinishStatus(req.finish_status.status)
if req.stop_str_matched:
finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
else:
finish_status = FinishStatus(req.finish_status.status)

token_list.append((req_id, text, metadata, finish_status))
else:
break
Expand Down
Loading