Skip to content

Commit 815d3a1

Browse files
authored
[Feature] Add OpenAI-compatible tool_choice support for chat completions (#7882)
* first commit * fix bug * fix unit test * fix * fix review * fix review * fix unit test * 修改条件 * fix review * fix unit test * add condition
1 parent 7f8ce7d commit 815d3a1

10 files changed

Lines changed: 538 additions & 38 deletions

File tree

fastdeploy/engine/common_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1922,7 +1922,11 @@ def _send_error_response(self, request_id, error_msg, error_code: int = 500, wor
19221922
def _decode_token(self, token_ids, req_id, is_end):
19231923
delta_text = ""
19241924
if envs.FD_ENABLE_RETURN_TEXT:
1925-
delta_text, cum_tokens, _ = self.data_processor.ids2tokens(token_ids, req_id)
1925+
delta_text, previous_token_ids, _ = self.data_processor.ids2tokens(token_ids, req_id)
1926+
# Reconstruct the post-extend cumulative list from the pre-delta
1927+
# snapshot + this call's input — ``ids2tokens`` only returns the
1928+
# snapshot to keep its return values aliasing-free.
1929+
cum_tokens = previous_token_ids + list(token_ids)
19261930
if delta_text != "":
19271931
prefix_offset = self.data_processor.decode_status[req_id][0]
19281932
read_offset = self.data_processor.decode_status[req_id][1]

fastdeploy/entrypoints/openai/response_processors.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def accumulate_token_ids(self, request_output):
7272
else:
7373
self._multipart_buffer.append({"decode_type": decode_type, "request_output": request_output})
7474

75-
async def process_response_chat(self, request_outputs, stream, include_stop_str_in_output, request):
75+
async def process_response_chat(
76+
self, request_outputs, stream, include_stop_str_in_output, request, prompt_tokens=None
77+
):
7678
"""
7779
Process a list of responses into a generator that yields each processed response as it's generated.
7880
Args:
@@ -101,6 +103,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
101103
audio_tokens=all_audio_tokens,
102104
tts=tts,
103105
request=request,
106+
prompt_tokens=prompt_tokens,
104107
)
105108
else:
106109
response = self.data_processor.process_response_dict(
@@ -110,6 +113,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
110113
audio_tokens=all_audio_tokens,
111114
tts=tts,
112115
request=request,
116+
prompt_tokens=prompt_tokens,
113117
)
114118
yield response
115119
elif decode_type == 2: # audio
@@ -128,13 +132,15 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
128132
stream=stream,
129133
include_stop_str_in_output=include_stop_str_in_output,
130134
request=request,
135+
prompt_tokens=prompt_tokens,
131136
)
132137
else:
133138
response = self.data_processor.process_response_dict(
134139
response_dict=request_output,
135140
stream=stream,
136141
include_stop_str_in_output=include_stop_str_in_output,
137142
request=request,
143+
prompt_tokens=prompt_tokens,
138144
)
139145
yield response
140146
elif stream:
@@ -168,13 +174,15 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
168174
stream=stream,
169175
include_stop_str_in_output=include_stop_str_in_output,
170176
request=request,
177+
prompt_tokens=prompt_tokens,
171178
)
172179
else:
173180
self.data_processor.process_response_dict(
174181
response_dict=request_output,
175182
stream=stream,
176183
include_stop_str_in_output=include_stop_str_in_output,
177184
request=request,
185+
prompt_tokens=prompt_tokens,
178186
)
179187
text = {"type": "text", "text": request_output["outputs"]["text"]}
180188
request_output["outputs"]["multipart"] = [text]
@@ -197,13 +205,15 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_
197205
stream=False,
198206
include_stop_str_in_output=include_stop_str_in_output,
199207
request=request,
208+
prompt_tokens=prompt_tokens,
200209
)
201210
else:
202211
self.data_processor.process_response_dict(
203212
response_dict=request_output,
204213
stream=stream,
205214
include_stop_str_in_output=include_stop_str_in_output,
206215
request=request,
216+
prompt_tokens=prompt_tokens,
207217
)
208218
text = {"type": "text", "text": part["request_output"]["outputs"]["text"]}
209219
multipart.append(text)

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ async def chat_completion_stream_generator(
317317
stream=True,
318318
include_stop_str_in_output=include_stop_str_in_output,
319319
request=request,
320+
prompt_tokens=prompt_tokens,
320321
)
321322

322323
async for res in generator:
@@ -650,6 +651,7 @@ async def chat_completion_full_generator(
650651
stream=False,
651652
include_stop_str_in_output=include_stop_str_in_output,
652653
request=request,
654+
prompt_tokens=prompt_tokens,
653655
)
654656
async for data in generator:
655657
idx = get_choice_index(data["request_id"])

fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ class ToolParser:
3434
derived classes.
3535
"""
3636

37+
# Subclasses should override these with the literal tool-call sentinel
38+
# tokens they recognize (e.g. ``"<tool_call>"`` / ``"</tool_call>"``).
39+
# Used by :meth:`detect_tool_prefix` to support forced tool-call prompt
40+
# prefix injection (named-tool ``tool_choice`` or
41+
# ``chat_template_kwargs.options.tool_choice.mode == "force"``). Empty
42+
# defaults make the detection a no-op for parsers that have not opted in.
43+
tool_call_start_token: str = ""
44+
tool_call_end_token: str = ""
45+
3746
def __init__(self, tokenizer):
3847
self.prev_tool_call_arr: list[dict] = []
3948
# the index of the tool call that is currently being parsed
@@ -43,6 +52,16 @@ def __init__(self, tokenizer):
4352

4453
self.model_tokenizer = tokenizer
4554

55+
# Per-request tool-prefix state populated by the serving layer when
56+
# the chat template injects a forced tool-call prefix into the prompt.
57+
self._tool_prefix: str = ""
58+
self._tool_prefix_token_ids: list[int] = []
59+
# Set after the prefix is computed once for this request.
60+
self._tool_prefix_computed: bool = False
61+
# Set after the prefix has been spliced into the streaming delta
62+
# (only the first chunk needs it).
63+
self._tool_prefix_injected_to_delta: bool = False
64+
4665
@cached_property
4766
def vocab(self) -> dict[str, int]:
4867
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
@@ -55,6 +74,36 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques
5574
"""
5675
return request
5776

77+
def detect_tool_prefix(self, prompt: str) -> str:
78+
"""Detect the tool-call prefix injected at the tail of the rendered
79+
prompt by a forced ``tool_choice``.
80+
81+
Finds the **last** :attr:`tool_call_start_token` in ``prompt`` that is
82+
not closed by a later :attr:`tool_call_end_token` and reaches the
83+
prompt end (modulo trailing whitespace). Returns ``""`` otherwise.
84+
Subclasses with non-paired tag formats may override.
85+
"""
86+
start = self.tool_call_start_token
87+
if not start or not prompt:
88+
return ""
89+
90+
last_start = prompt.rfind(start)
91+
if last_start == -1:
92+
return ""
93+
94+
end = self.tool_call_end_token
95+
if end and prompt.find(end, last_start + len(start)) != -1:
96+
# The last start token is closed — this is a historical, completed
97+
# tool-call (e.g. from a previous assistant turn), not an injected
98+
# forced prefix.
99+
return ""
100+
101+
# By construction, ``prompt[last_start:]`` reaches the end of the
102+
# prompt. We treat the whole tail as the injected prefix. Subclasses
103+
# whose chat templates place additional content after the prefix can
104+
# override this method to apply stricter validation.
105+
return prompt[last_start:]
106+
58107
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
59108
"""
60109
Static method that should be implemented for extracting tool calls from

fastdeploy/input/base_processor.py

Lines changed: 123 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,28 @@ def text2ids(self, text, max_model_len=None, **kwargs):
159159
)
160160
return tokens["input_ids"][0]
161161

162+
def _text_to_token_ids(self, text: str) -> list:
163+
"""Encode ``text`` to a ``list[int]``, shared by :meth:`messages2ids`
164+
and :meth:`_prepare_tool_prefix`.
165+
166+
``ernie4_5`` tokenizer hangs on long inputs via ``.encode()``, so it
167+
goes through ``tokenize`` + ``convert_tokens_to_ids``. Other tokenizers
168+
use ``.encode()`` and the result is normalized to a plain list.
169+
"""
170+
if self.tokenizer_type == "ernie4_5":
171+
# NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode()
172+
return self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
173+
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
174+
if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids):
175+
token_ids = token_ids["input_ids"]
176+
if hasattr(token_ids, "ndim") and token_ids.ndim > 1:
177+
token_ids = token_ids[0]
178+
if hasattr(token_ids, "tolist"):
179+
token_ids = token_ids.tolist()
180+
if not isinstance(token_ids, list):
181+
token_ids = list(token_ids)
182+
return token_ids
183+
162184
def messages2ids(self, request, **kwargs):
163185
"""Convert a chat-template request into a token-ID list.
164186
@@ -180,19 +202,7 @@ def messages2ids(self, request, **kwargs):
180202
)
181203
request["prompt_tokens"] = spliced_message
182204
req_id = request.get("request_id", None) if isinstance(request, dict) else None
183-
if self.tokenizer_type == "ernie4_5":
184-
# NOTE: ernie4_5 tokenizer will hang when meet long input when use .encode()
185-
token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(spliced_message))
186-
else:
187-
token_ids = self.tokenizer.encode(spliced_message, add_special_tokens=False)
188-
if hasattr(token_ids, "input_ids") or (isinstance(token_ids, dict) and "input_ids" in token_ids):
189-
token_ids = token_ids["input_ids"]
190-
if hasattr(token_ids, "ndim") and token_ids.ndim > 1:
191-
token_ids = token_ids[0]
192-
if hasattr(token_ids, "tolist"):
193-
token_ids = token_ids.tolist()
194-
if not isinstance(token_ids, list):
195-
token_ids = list(token_ids)
205+
token_ids = self._text_to_token_ids(spliced_message)
196206
log_request(
197207
level=1,
198208
message="req_id:{req_id}, token_ids: {token_ids}",
@@ -225,9 +235,16 @@ def ids2tokens(self, token_id, task_id):
225235
Returns:
226236
(delta_text, previous_token_ids, previous_texts)
227237
228-
Both the HF and the PaddleFormers/ERNIE tokeniser paths return the
229-
same tuple shape. The HF path sets ``previous_token_ids`` to ``[]``
230-
since it does not expose per-token ids during batch-decode.
238+
``previous_token_ids`` and ``previous_texts`` are **snapshots of the
239+
accumulated state BEFORE this call's tokens were appended** —
240+
symmetric pre-delta views of what the caller had decoded so far.
241+
Both are owned by the caller (no aliasing of internal state).
242+
243+
Callers that need the post-extend cumulative list should reconstruct
244+
it locally via ``previous_token_ids + token_id``.
245+
246+
The HF path returns ``[]`` for ``previous_token_ids`` since it does
247+
not expose per-token ids during batch-decode.
231248
"""
232249
if envs.FD_USE_HF_TOKENIZER:
233250
if task_id not in self.decode_status:
@@ -246,20 +263,25 @@ def ids2tokens(self, token_id, task_id):
246263
status[2] = decode_str[0]
247264
else:
248265
new_str = ""
249-
# Return consistent three-tuple; previous_token_ids not available.
266+
# NOTE: HF path historically returns the post-delta full string
267+
# here, inconsistent with the non-HF branch (which returns the
268+
# pre-delta snapshot). Preserved as-is to avoid behavior change.
250269
return new_str, [], status[2]
251270
else:
252271
if task_id not in self.decode_status:
253272
# [prefix_offset, read_offset, all_token_ids, accumulated_text]
254273
self.decode_status[task_id] = [0, 0, [], ""]
255274
status = self.decode_status[task_id]
256275
previous_texts = status[3]
276+
# Snapshot BEFORE extend so the returned list is owned by the
277+
# caller and symmetric with ``previous_texts``.
278+
previous_token_ids = list(status[2])
257279
status[2].extend(token_id)
258280
decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(status[2], status[0], status[1])
259281
status[0] = prefix_offset
260282
status[1] = read_offset
261283
status[3] += decode_str
262-
return decode_str, status[2], previous_texts
284+
return decode_str, previous_token_ids, previous_texts
263285

264286
# ------------------------------------------------------------------
265287
# Response processing
@@ -287,6 +309,53 @@ def process_response_dict(self, response_dict, **kwargs):
287309
else:
288310
return self.process_response_dict_normal(response_dict, **kwargs)
289311

312+
@staticmethod
313+
def _is_forced_tool_choice(request):
314+
"""Return True if tool_choice mode requires forced prefix injection."""
315+
if not request:
316+
return False
317+
chat_kwargs = getattr(request, "chat_template_kwargs", None) or {}
318+
options = chat_kwargs.get("options") or {}
319+
tool_choice = options.get("tool_choice") or {}
320+
mode = tool_choice.get("mode", "") if isinstance(tool_choice, dict) else ""
321+
return mode in ("required", "force")
322+
323+
def _prepare_tool_prefix(self, tool_parser, prompt_tokens, request=None):
324+
"""Detect and cache on ``tool_parser`` the tool-call prefix that the
325+
chat template injected at the tail of ``prompt_tokens`` (the rendered
326+
prompt string from the serving layer). Computed once per parser
327+
instance via the parser's :meth:`ToolParser.detect_tool_prefix`.
328+
329+
Only performs detection when ``tool_choice`` mode indicates a forced
330+
tool call (e.g. ``"required"`` or ``"force"``).
331+
"""
332+
if tool_parser._tool_prefix_computed:
333+
return
334+
tool_parser._tool_prefix_computed = True
335+
tool_parser._tool_prefix = ""
336+
tool_parser._tool_prefix_token_ids = []
337+
if not prompt_tokens or not isinstance(prompt_tokens, str):
338+
return
339+
if not self._is_forced_tool_choice(request):
340+
return
341+
try:
342+
prefix = tool_parser.detect_tool_prefix(prompt_tokens) or ""
343+
except Exception:
344+
data_processor_logger.exception("detect_tool_prefix failed; falling back to empty prefix")
345+
return
346+
tool_parser._tool_prefix = prefix
347+
if not prefix:
348+
return
349+
# Encode the prefix into token ids so the streaming path can also
350+
# splice ``previous/current/delta_token_ids`` — some parsers gate on
351+
# ``tool_call_start_token_id in current_token_ids`` rather than on
352+
# text (e.g. ``Ernie45VLThinkingToolParser``).
353+
try:
354+
tool_parser._tool_prefix_token_ids = self._text_to_token_ids(prefix)
355+
except Exception:
356+
data_processor_logger.exception("encode tool prefix to token ids failed; token-id splice disabled")
357+
tool_parser._tool_prefix_token_ids = []
358+
290359
def process_response_dict_normal(self, response_dict, **kwargs):
291360
"""Accumulate tokens and build the full completion text (non-streaming)."""
292361
token_ids = response_dict["outputs"]["token_ids"]
@@ -321,7 +390,11 @@ def process_response_dict_normal(self, response_dict, **kwargs):
321390

322391
if self.tool_parser_obj:
323392
tool_parser = self.tool_parser_obj(self.tokenizer)
324-
tool_call_info = tool_parser.extract_tool_calls(full_text, request)
393+
parser_input = full_text
394+
self._prepare_tool_prefix(tool_parser, kwargs.get("prompt_tokens"), request)
395+
if tool_parser._tool_prefix:
396+
parser_input = tool_parser._tool_prefix + full_text
397+
tool_call_info = tool_parser.extract_tool_calls(parser_input, request)
325398
if tool_call_info.tools_called:
326399
response_dict["outputs"]["tool_calls"] = tool_call_info.tool_calls
327400

@@ -375,13 +448,38 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
375448
if req_id not in self.tool_parser_dict:
376449
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
377450
tool_parser = self.tool_parser_dict[req_id]
451+
stream_previous = previous_texts
452+
stream_current = previous_texts + delta_text
453+
stream_delta = delta_text
454+
stream_previous_token_ids = previous_token_ids
455+
stream_current_token_ids = previous_token_ids + token_ids
456+
stream_delta_token_ids = token_ids
457+
self._prepare_tool_prefix(tool_parser, kwargs.get("prompt_tokens"), request)
458+
prefix = tool_parser._tool_prefix
459+
prefix_ids = tool_parser._tool_prefix_token_ids
460+
# Splice the injected prefix back into both text and token-id
461+
# streaming args so parsers that gate on either form (e.g.
462+
# ``Ernie45VLThinkingToolParser`` checks
463+
# ``tool_call_start_token_id in current_token_ids``) work
464+
# unchanged. ``delta_*`` only spliced on the first call.
465+
if prefix:
466+
stream_previous = prefix + stream_previous
467+
stream_current = prefix + stream_current
468+
if prefix_ids:
469+
stream_previous_token_ids = prefix_ids + stream_previous_token_ids
470+
stream_current_token_ids = prefix_ids + stream_current_token_ids
471+
if not tool_parser._tool_prefix_injected_to_delta:
472+
stream_delta = prefix + stream_delta
473+
if prefix_ids:
474+
stream_delta_token_ids = prefix_ids + stream_delta_token_ids
475+
tool_parser._tool_prefix_injected_to_delta = True
378476
tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
379-
previous_texts,
380-
previous_texts + delta_text,
381-
delta_text,
382-
previous_token_ids,
383-
previous_token_ids + token_ids,
384-
token_ids,
477+
stream_previous,
478+
stream_current,
479+
stream_delta,
480+
stream_previous_token_ids,
481+
stream_current_token_ids,
482+
stream_delta_token_ids,
385483
request,
386484
)
387485
if tool_call_delta_message:

0 commit comments

Comments
 (0)