Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 82 additions & 13 deletions aphrodite/tool_parsers/deepseekv32_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,17 @@ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
self.current_tool_index: int = 0
self._sent_content_idx: int = 0

# Regex patterns for complete parsing
# Regex patterns for complete parsing.
#
# The wrapper tokens are class attributes so subclasses such as
# DeepSeekV4ToolParser can override them.
self.tool_call_complete_regex = re.compile(
re.escape(self.tool_call_start_token) + r"(.*?)" + re.escape(self.tool_call_end_token),
re.DOTALL,
)
self.invoke_complete_regex = re.compile(
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>', re.DOTALL
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>',
re.DOTALL,
)
self.parameter_complete_regex = re.compile(
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</|DSML|parameter>',
Expand All @@ -83,7 +87,7 @@ def adjust_request(
if request.tools and request.tool_choice != "none":
# Ensure tool call tokens
# (e.g. <|DSML|function_calls>, </|DSML|function_calls>)
# are not skippedduring decoding.
# are not skipped during decoding.
# Even though they are not marked as special tokens,
# setting skip_special_tokens=False ensures proper handling in
# transformers 5.x where decoding behavior may have changed.
Expand All @@ -94,8 +98,8 @@ def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"

def _parse_invoke_params(self, invoke_str: str) -> dict:
param_dict = dict()
def _parse_invoke_params(self, invoke_str: str) -> dict[str, str]:
param_dict: dict[str, str] = {}
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
param_dict[param_name] = param_val
return param_dict
Expand Down Expand Up @@ -123,8 +127,11 @@ def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
else:
return json.loads(value)

def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
def _convert_param_value(self, value: Any, param_type: str | list[str]) -> Any:
"""Convert parameter value to the correct type."""
if not isinstance(value, str):
return value

if not isinstance(param_type, list):
param_type = [param_type]
for current_type in param_type:
Expand All @@ -135,6 +142,45 @@ def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
# return value as fallback
return value

def _normalize_arguments_wrapper(
self,
converted: dict[str, Any],
) -> dict[str, Any]:
"""Normalize model-generated nested arguments wrapper.

DeepSeek V4 Flash may generate DSML parameters like:

<|DSML|parameter name="arguments" string="false">
{"path": "/tmp/a", "content": "hello"}
</|DSML|parameter>

The parser would otherwise produce:

{"arguments": {"path": "/tmp/a", "content": "hello"}}

OpenAI-compatible function.arguments should be:

{"path": "/tmp/a", "content": "hello"}
"""
if set(converted.keys()) != {"arguments"}:
return converted
Comment on lines +165 to +166

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve explicit arguments parameter key

Do not unwrap every single-key {"arguments": ...} payload here: this helper now strips the arguments key for any tool that legitimately defines one top-level parameter named arguments (for example, properties.arguments in user-provided tool schemas). Because _convert_params_with_schema() calls this in both streaming and non-streaming extraction paths, those tools will receive function.arguments with the wrong shape ({...} instead of {"arguments": {...}}), which breaks downstream argument binding.

Useful? React with 👍 / 👎.


wrapped = converted.get("arguments")

if isinstance(wrapped, dict):
return wrapped

if isinstance(wrapped, str):
try:
parsed = json.loads(wrapped)
except Exception:
return converted

if isinstance(parsed, dict):
return parsed

return converted

def _convert_params_with_schema(
self,
function_name: str,
Expand All @@ -160,7 +206,8 @@ def _convert_params_with_schema(
if name in param_config and isinstance(param_config[name], dict):
param_type = param_config[name].get("type", "string")
converted[name] = self._convert_param_value(value, param_type)
return converted

return self._normalize_arguments_wrapper(converted)

def extract_tool_calls(
self,
Expand All @@ -180,12 +227,16 @@ def extract_tool_calls(
# Find all invokes within this tool_call
for invoke_name, invoke_content in self.invoke_complete_regex.findall(tool_call_match):
param_dict = self._parse_invoke_params(invoke_content)
converted = self._convert_params_with_schema(
invoke_name,
param_dict,
)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=invoke_name,
arguments=json.dumps(param_dict, ensure_ascii=False),
arguments=json.dumps(converted, ensure_ascii=False),
),
)
)
Expand Down Expand Up @@ -249,15 +300,29 @@ def _extract_delta_tool_calls(

return delta_tool_calls

def _extract_content(self, current_text: str) -> str | None:
def _extract_content(
self,
current_text: str,
*,
is_final: bool = False,
) -> str | None:
"""Return unsent non-tool-call text, or None.

Holds back any suffix that could be a partial start marker
so that split markers are never leaked as content.

On final streaming step, flush the held-back suffix because it
cannot form a complete tool-call start marker anymore.
"""
if self.tool_call_start_token not in current_text:
overlap = partial_tag_overlap(current_text, self.tool_call_start_token)
sendable_idx = len(current_text) - overlap
if is_final:
sendable_idx = len(current_text)
else:
overlap = partial_tag_overlap(
current_text,
self.tool_call_start_token,
)
sendable_idx = len(current_text) - overlap
else:
sendable_idx = current_text.index(self.tool_call_start_token)

Expand Down Expand Up @@ -288,15 +353,19 @@ def extract_tool_calls_streaming(
if not previous_text:
self._reset_streaming_state()

content = self._extract_content(current_text)
# Empty delta with token ids means EOS or a skipped/closing token.
# Treat it as final for content flushing purposes.
is_final = not delta_text and bool(delta_token_ids)

content = self._extract_content(current_text, is_final=is_final)
delta_tool_calls = self._extract_delta_tool_calls(current_text, request)

if delta_tool_calls or content:
return DeltaMessage(content=content, tool_calls=delta_tool_calls)

# Empty delta with token ids means EOS or closing tag; return
# non-None so the serving framework can finalize finish_reason.
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
if is_final and self.prev_tool_call_arr:
return DeltaMessage(content="")

return None
Loading