Skip to content

Commit c005aac

Browse files
authored
fix: normalize nested args in DeepSeek DSML (#1654)
1 parent 74714e2 commit c005aac

1 file changed

Lines changed: 82 additions & 13 deletions

File tree

aphrodite/tool_parsers/deepseekv32_tool_parser.py

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,17 @@ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
5858
self.current_tool_index: int = 0
5959
self._sent_content_idx: int = 0
6060

61-
# Regex patterns for complete parsing
61+
# Regex patterns for complete parsing.
62+
#
63+
# The wrapper tokens are class attributes so subclasses such as
64+
# DeepSeekV4ToolParser can override them.
6265
self.tool_call_complete_regex = re.compile(
6366
re.escape(self.tool_call_start_token) + r"(.*?)" + re.escape(self.tool_call_end_token),
6467
re.DOTALL,
6568
)
6669
self.invoke_complete_regex = re.compile(
67-
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>', re.DOTALL
70+
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>',
71+
re.DOTALL,
6872
)
6973
self.parameter_complete_regex = re.compile(
7074
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</|DSML|parameter>',
@@ -83,7 +87,7 @@ def adjust_request(
8387
if request.tools and request.tool_choice != "none":
8488
# Ensure tool call tokens
8589
# (e.g. <|DSML|function_calls>, </|DSML|function_calls>)
86-
# are not skippedduring decoding.
90+
# are not skipped during decoding.
8791
# Even though they are not marked as special tokens,
8892
# setting skip_special_tokens=False ensures proper handling in
8993
# transformers 5.x where decoding behavior may have changed.
@@ -94,8 +98,8 @@ def _generate_tool_call_id(self) -> str:
9498
"""Generate a unique tool call ID."""
9599
return f"call_{uuid.uuid4().hex[:24]}"
96100

97-
def _parse_invoke_params(self, invoke_str: str) -> dict:
98-
param_dict = dict()
101+
def _parse_invoke_params(self, invoke_str: str) -> dict[str, str]:
102+
param_dict: dict[str, str] = {}
99103
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
100104
param_dict[param_name] = param_val
101105
return param_dict
@@ -123,8 +127,11 @@ def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
123127
else:
124128
return json.loads(value)
125129

126-
def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
130+
def _convert_param_value(self, value: Any, param_type: str | list[str]) -> Any:
127131
"""Convert parameter value to the correct type."""
132+
if not isinstance(value, str):
133+
return value
134+
128135
if not isinstance(param_type, list):
129136
param_type = [param_type]
130137
for current_type in param_type:
@@ -135,6 +142,45 @@ def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
135142
# return value as fallback
136143
return value
137144

145+
def _normalize_arguments_wrapper(
146+
self,
147+
converted: dict[str, Any],
148+
) -> dict[str, Any]:
149+
"""Normalize model-generated nested arguments wrapper.
150+
151+
DeepSeek V4 Flash may generate DSML parameters like:
152+
153+
<|DSML|parameter name="arguments" string="false">
154+
{"path": "/tmp/a", "content": "hello"}
155+
</|DSML|parameter>
156+
157+
The parser would otherwise produce:
158+
159+
{"arguments": {"path": "/tmp/a", "content": "hello"}}
160+
161+
OpenAI-compatible function.arguments should be:
162+
163+
{"path": "/tmp/a", "content": "hello"}
164+
"""
165+
if set(converted.keys()) != {"arguments"}:
166+
return converted
167+
168+
wrapped = converted.get("arguments")
169+
170+
if isinstance(wrapped, dict):
171+
return wrapped
172+
173+
if isinstance(wrapped, str):
174+
try:
175+
parsed = json.loads(wrapped)
176+
except Exception:
177+
return converted
178+
179+
if isinstance(parsed, dict):
180+
return parsed
181+
182+
return converted
183+
138184
def _convert_params_with_schema(
139185
self,
140186
function_name: str,
@@ -160,7 +206,8 @@ def _convert_params_with_schema(
160206
if name in param_config and isinstance(param_config[name], dict):
161207
param_type = param_config[name].get("type", "string")
162208
converted[name] = self._convert_param_value(value, param_type)
163-
return converted
209+
210+
return self._normalize_arguments_wrapper(converted)
164211

165212
def extract_tool_calls(
166213
self,
@@ -180,12 +227,16 @@ def extract_tool_calls(
180227
# Find all invokes within this tool_call
181228
for invoke_name, invoke_content in self.invoke_complete_regex.findall(tool_call_match):
182229
param_dict = self._parse_invoke_params(invoke_content)
230+
converted = self._convert_params_with_schema(
231+
invoke_name,
232+
param_dict,
233+
)
183234
tool_calls.append(
184235
ToolCall(
185236
type="function",
186237
function=FunctionCall(
187238
name=invoke_name,
188-
arguments=json.dumps(param_dict, ensure_ascii=False),
239+
arguments=json.dumps(converted, ensure_ascii=False),
189240
),
190241
)
191242
)
@@ -249,15 +300,29 @@ def _extract_delta_tool_calls(
249300

250301
return delta_tool_calls
251302

252-
def _extract_content(self, current_text: str) -> str | None:
303+
def _extract_content(
304+
self,
305+
current_text: str,
306+
*,
307+
is_final: bool = False,
308+
) -> str | None:
253309
"""Return unsent non-tool-call text, or None.
254310
255311
Holds back any suffix that could be a partial start marker
256312
so that split markers are never leaked as content.
313+
314+
On final streaming step, flush the held-back suffix because it
315+
cannot form a complete tool-call start marker anymore.
257316
"""
258317
if self.tool_call_start_token not in current_text:
259-
overlap = partial_tag_overlap(current_text, self.tool_call_start_token)
260-
sendable_idx = len(current_text) - overlap
318+
if is_final:
319+
sendable_idx = len(current_text)
320+
else:
321+
overlap = partial_tag_overlap(
322+
current_text,
323+
self.tool_call_start_token,
324+
)
325+
sendable_idx = len(current_text) - overlap
261326
else:
262327
sendable_idx = current_text.index(self.tool_call_start_token)
263328

@@ -288,15 +353,19 @@ def extract_tool_calls_streaming(
288353
if not previous_text:
289354
self._reset_streaming_state()
290355

291-
content = self._extract_content(current_text)
356+
# Empty delta with token ids means EOS or a skipped/closing token.
357+
# Treat it as final for content flushing purposes.
358+
is_final = not delta_text and bool(delta_token_ids)
359+
360+
content = self._extract_content(current_text, is_final=is_final)
292361
delta_tool_calls = self._extract_delta_tool_calls(current_text, request)
293362

294363
if delta_tool_calls or content:
295364
return DeltaMessage(content=content, tool_calls=delta_tool_calls)
296365

297366
# Empty delta with token ids means EOS or closing tag; return
298367
# non-None so the serving framework can finalize finish_reason.
299-
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
368+
if is_final and self.prev_tool_call_arr:
300369
return DeltaMessage(content="")
301370

302371
return None

0 commit comments

Comments
 (0)