diff --git a/mlx_lm/tool_parsers/minimax_m2.py b/mlx_lm/tool_parsers/minimax_m2.py index 04a8acadb..a0b7efe56 100644 --- a/mlx_lm/tool_parsers/minimax_m2.py +++ b/mlx_lm/tool_parsers/minimax_m2.py @@ -157,43 +157,44 @@ def _get_param_types_from_config(param_name: str, param_config: dict) -> list[st def parse_tool_call(text: str, tools: list | None = None): - invoke_match = _invoke_complete_regex.findall(text) - if not invoke_match: + invoke_matches = _invoke_complete_regex.findall(text) + if not invoke_matches: raise ValueError("No tool call found") - invoke_text = invoke_match[0] - name_match = re.search(r"^([^>]+)", invoke_text) - if not name_match: - return None - - function_name = _extract_name(name_match.group(1)) - - # Get parameter configuration - param_config = {} + param_config_for = {} if tools: for tool in tools: if func := tool.get("function", False): - if func["name"] != function_name: - continue if params := func.get("parameters", False): - param_config = params.get("properties", {}) - - # Extract parameters - param_dict = {} - for match in _parameter_complete_regex.findall(invoke_text): - param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL) - if param_match: - param_name = _extract_name(param_match.group(1)) - param_value = param_match.group(2).strip() - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - param_type = _get_param_types_from_config(param_name, param_config) - - param_dict[param_name] = _convert_param_value_with_types( - param_value, param_type - ) - - return dict(name=function_name, arguments=param_dict) + param_config_for[func["name"]] = params.get("properties", {}) + + calls = [] + for invoke_text in invoke_matches: + name_match = re.search(r"^([^>]+)", invoke_text) + if not name_match: + continue + function_name = _extract_name(name_match.group(1)) + param_config = param_config_for.get(function_name, {}) + + param_dict = {} + for match in _parameter_complete_regex.findall(invoke_text): + param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL) + if param_match: + param_name = _extract_name(param_match.group(1)) + param_value = param_match.group(2).strip() + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_type = _get_param_types_from_config(param_name, param_config) + + param_dict[param_name] = _convert_param_value_with_types( + param_value, param_type + ) + + calls.append(dict(name=function_name, arguments=param_dict)) + + if len(calls) == 1: + return calls[0] + return calls diff --git a/tests/test_tool_parsing.py b/tests/test_tool_parsing.py index 5ba79d1b2..b19e1d59f 100644 --- a/tests/test_tool_parsing.py +++ b/tests/test_tool_parsing.py @@ -292,6 +292,22 @@ def test_kimi_k2(self): ] self.assertEqual(tool_calls, expected) + def test_minimax_m2(self): + test_case = ( + '\n' + 'weather\n' + "\n" + '\n' + '/tmp/test.txt\n' + "" + ) + expected = [ + {"name": "search", "arguments": {"query": "weather"}}, + {"name": "read_file", "arguments": {"path": "/tmp/test.txt"}}, + ] + tool_calls = minimax_m2.parse_tool_call(test_case, None) + self.assertEqual(expected, tool_calls) + if __name__ == "__main__": unittest.main()