Skip to content

Commit 6d11468

Browse files
authored
Fix MiniMax M2 parallel tool calling (#1171)
1 parent aa4f880 commit 6d11468

2 files changed

Lines changed: 51 additions & 34 deletions

File tree

mlx_lm/tool_parsers/minimax_m2.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -157,43 +157,44 @@ def _get_param_types_from_config(param_name: str, param_config: dict) -> list[st
157157

158158

159159
def parse_tool_call(text: str, tools: list | None = None):
160-
invoke_match = _invoke_complete_regex.findall(text)
161-
if not invoke_match:
160+
invoke_matches = _invoke_complete_regex.findall(text)
161+
if not invoke_matches:
162162
raise ValueError("No tool call found")
163-
invoke_text = invoke_match[0]
164163

165-
name_match = re.search(r"^([^>]+)", invoke_text)
166-
if not name_match:
167-
return None
168-
169-
function_name = _extract_name(name_match.group(1))
170-
171-
# Get parameter configuration
172-
param_config = {}
164+
param_config_for = {}
173165
if tools:
174166
for tool in tools:
175167
if func := tool.get("function", False):
176-
if func["name"] != function_name:
177-
continue
178168
if params := func.get("parameters", False):
179-
param_config = params.get("properties", {})
180-
181-
# Extract parameters
182-
param_dict = {}
183-
for match in _parameter_complete_regex.findall(invoke_text):
184-
param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
185-
if param_match:
186-
param_name = _extract_name(param_match.group(1))
187-
param_value = param_match.group(2).strip()
188-
if param_value.startswith("\n"):
189-
param_value = param_value[1:]
190-
if param_value.endswith("\n"):
191-
param_value = param_value[:-1]
192-
193-
param_type = _get_param_types_from_config(param_name, param_config)
194-
195-
param_dict[param_name] = _convert_param_value_with_types(
196-
param_value, param_type
197-
)
198-
199-
return dict(name=function_name, arguments=param_dict)
169+
param_config_for[func["name"]] = params.get("properties", {})
170+
171+
calls = []
172+
for invoke_text in invoke_matches:
173+
name_match = re.search(r"^([^>]+)", invoke_text)
174+
if not name_match:
175+
continue
176+
function_name = _extract_name(name_match.group(1))
177+
param_config = param_config_for.get(function_name, {})
178+
179+
param_dict = {}
180+
for match in _parameter_complete_regex.findall(invoke_text):
181+
param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
182+
if param_match:
183+
param_name = _extract_name(param_match.group(1))
184+
param_value = param_match.group(2).strip()
185+
if param_value.startswith("\n"):
186+
param_value = param_value[1:]
187+
if param_value.endswith("\n"):
188+
param_value = param_value[:-1]
189+
190+
param_type = _get_param_types_from_config(param_name, param_config)
191+
192+
param_dict[param_name] = _convert_param_value_with_types(
193+
param_value, param_type
194+
)
195+
196+
calls.append(dict(name=function_name, arguments=param_dict))
197+
198+
if len(calls) == 1:
199+
return calls[0]
200+
return calls

tests/test_tool_parsing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,22 @@ def test_kimi_k2(self):
292292
]
293293
self.assertEqual(tool_calls, expected)
294294

295+
def test_minimax_m2(self):
296+
test_case = (
297+
'<invoke name="search">\n'
298+
'<parameter name="query">weather</parameter>\n'
299+
"</invoke>\n"
300+
'<invoke name="read_file">\n'
301+
'<parameter name="path">/tmp/test.txt</parameter>\n'
302+
"</invoke>"
303+
)
304+
expected = [
305+
{"name": "search", "arguments": {"query": "weather"}},
306+
{"name": "read_file", "arguments": {"path": "/tmp/test.txt"}},
307+
]
308+
tool_calls = minimax_m2.parse_tool_call(test_case, None)
309+
self.assertEqual(expected, tool_calls)
310+
295311

296312
if __name__ == "__main__":
297313
unittest.main()

0 commit comments

Comments
 (0)