diff --git a/mlx_lm/tool_parsers/minimax_m2.py b/mlx_lm/tool_parsers/minimax_m2.py
index 04a8acadb..a0b7efe56 100644
--- a/mlx_lm/tool_parsers/minimax_m2.py
+++ b/mlx_lm/tool_parsers/minimax_m2.py
@@ -157,43 +157,44 @@ def _get_param_types_from_config(param_name: str, param_config: dict) -> list[st
def parse_tool_call(text: str, tools: list | None = None):
- invoke_match = _invoke_complete_regex.findall(text)
- if not invoke_match:
+ invoke_matches = _invoke_complete_regex.findall(text)
+ if not invoke_matches:
raise ValueError("No tool call found")
- invoke_text = invoke_match[0]
- name_match = re.search(r"^([^>]+)", invoke_text)
- if not name_match:
- return None
-
- function_name = _extract_name(name_match.group(1))
-
- # Get parameter configuration
- param_config = {}
+ param_config_for = {}
if tools:
for tool in tools:
if func := tool.get("function", False):
- if func["name"] != function_name:
- continue
if params := func.get("parameters", False):
- param_config = params.get("properties", {})
-
- # Extract parameters
- param_dict = {}
- for match in _parameter_complete_regex.findall(invoke_text):
- param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
- if param_match:
- param_name = _extract_name(param_match.group(1))
- param_value = param_match.group(2).strip()
- if param_value.startswith("\n"):
- param_value = param_value[1:]
- if param_value.endswith("\n"):
- param_value = param_value[:-1]
-
- param_type = _get_param_types_from_config(param_name, param_config)
-
- param_dict[param_name] = _convert_param_value_with_types(
- param_value, param_type
- )
-
- return dict(name=function_name, arguments=param_dict)
+ param_config_for[func["name"]] = params.get("properties", {})
+
+ calls = []
+ for invoke_text in invoke_matches:
+ name_match = re.search(r"^([^>]+)", invoke_text)
+ if not name_match:
+ continue
+ function_name = _extract_name(name_match.group(1))
+ param_config = param_config_for.get(function_name, {})
+
+ param_dict = {}
+ for match in _parameter_complete_regex.findall(invoke_text):
+ param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
+ if param_match:
+ param_name = _extract_name(param_match.group(1))
+ param_value = param_match.group(2).strip()
+ if param_value.startswith("\n"):
+ param_value = param_value[1:]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ param_type = _get_param_types_from_config(param_name, param_config)
+
+ param_dict[param_name] = _convert_param_value_with_types(
+ param_value, param_type
+ )
+
+ calls.append(dict(name=function_name, arguments=param_dict))
+
+ if len(calls) == 1:
+ return calls[0]
+ return calls
diff --git a/tests/test_tool_parsing.py b/tests/test_tool_parsing.py
index 5ba79d1b2..b19e1d59f 100644
--- a/tests/test_tool_parsing.py
+++ b/tests/test_tool_parsing.py
@@ -292,6 +292,22 @@ def test_kimi_k2(self):
]
self.assertEqual(tool_calls, expected)
+ def test_minimax_m2(self):
+ test_case = (
+ '\n'
+ 'weather\n'
+ "\n"
+ '\n'
+ '/tmp/test.txt\n'
+ ""
+ )
+ expected = [
+ {"name": "search", "arguments": {"query": "weather"}},
+ {"name": "read_file", "arguments": {"path": "/tmp/test.txt"}},
+ ]
+ tool_calls = minimax_m2.parse_tool_call(test_case, None)
+ self.assertEqual(expected, tool_calls)
+
if __name__ == "__main__":
unittest.main()