Skip to content

Commit 6f3afc2

Browse files
committed
feat(gemma): add tool call parser
1 parent 26722cd commit 6f3afc2

8 files changed

Lines changed: 202 additions & 5 deletions

File tree

xinference/model/llm/mlx/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from ..llm_family import LLMFamilyV2, LLMSpecV1
5454
from ..utils import (
5555
DEEPSEEK_TOOL_CALL_FAMILY,
56+
GEMMA_TOOL_CALL_FAMILY,
5657
QWEN_TOOL_CALL_FAMILY,
5758
ChatModelMixin,
5859
generate_completion_chunk,
@@ -1186,6 +1187,7 @@ async def async_chat(
11861187
if tools:
11871188
if (
11881189
model_family in QWEN_TOOL_CALL_FAMILY
1190+
or model_family in GEMMA_TOOL_CALL_FAMILY
11891191
or model_family in DEEPSEEK_TOOL_CALL_FAMILY
11901192
):
11911193
full_context_kwargs["tools"] = tools
@@ -1547,7 +1549,10 @@ def chat(
15471549
)
15481550
chat_context_var.set(chat_template_kwargs)
15491551
full_context_kwargs = chat_template_kwargs.copy()
1550-
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
1552+
if tools and (
1553+
model_family in QWEN_TOOL_CALL_FAMILY
1554+
or model_family in GEMMA_TOOL_CALL_FAMILY
1555+
):
15511556
full_context_kwargs["tools"] = tools
15521557
chat_template = self.model_family.chat_template
15531558
tokenizer = None

xinference/model/llm/sglang/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ..core import chat_context_var
3939
from ..utils import (
4040
DEEPSEEK_TOOL_CALL_FAMILY,
41+
GEMMA_TOOL_CALL_FAMILY,
4142
QWEN_TOOL_CALL_FAMILY,
4243
QWEN_TOOL_CALL_SYMBOLS,
4344
ChatModelMixin,
@@ -730,6 +731,7 @@ async def async_chat(
730731
if tools:
731732
if (
732733
model_family in QWEN_TOOL_CALL_FAMILY
734+
or model_family in GEMMA_TOOL_CALL_FAMILY
733735
or model_family in DEEPSEEK_TOOL_CALL_FAMILY
734736
):
735737
full_context_kwargs["tools"] = tools

xinference/model/llm/tool_parsers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def decorator(cls: Type[Any]) -> Type[Any]:
5353
deepseek_r1_tool_parser,
5454
deepseek_v3_1_tool_parser,
5555
deepseek_v3_tool_parser,
56+
gemma_tool_parser,
5657
glm4_tool_parser,
5758
llama3_tool_parser,
5859
minimax_tool_parser,
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import json
2+
import logging
3+
import re
4+
from typing import Any, Dict, List, Optional, Tuple
5+
6+
from . import register_tool_parser
7+
from .abstract_tool_parser import ToolParser
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
@register_tool_parser("gemma")
13+
class GemmaToolParser(ToolParser):
14+
"""
15+
Tool parser for Gemma-4 style tool call blocks.
16+
17+
Gemma emits tool invocations using tokens like:
18+
<|tool_call>call:get_weather{location:<|"|>Shanghai<|"|>}<tool_call|>
19+
where strings are wrapped with <|"|> ... <|"|>.
20+
"""
21+
22+
def __init__(self):
23+
self.tool_call_start_token = "<|tool_call>"
24+
self.tool_call_end_token = "<tool_call|>"
25+
self.tool_call_regex = re.compile(
26+
r"(<\|tool_call\>.*?<tool_call\|>)", re.DOTALL
27+
)
28+
self.call_header_regex = re.compile(r"call\s*:\s*([^{\s]+)", re.IGNORECASE)
29+
30+
@staticmethod
31+
def _replace_quotes(text: str) -> str:
32+
return text.replace('<|"|>', '"')
33+
34+
@staticmethod
35+
def _quote_keys(text: str) -> str:
36+
pattern = re.compile(r"(?P<prefix>[{,])\s*(?P<key>[A-Za-z0-9_\-]+)\s*:")
37+
38+
def repl(match: re.Match) -> str:
39+
prefix = match.group("prefix")
40+
key = match.group("key")
41+
return f'{prefix}"{key}":'
42+
43+
while True:
44+
new_text, count = pattern.subn(repl, text)
45+
text = new_text
46+
if count == 0:
47+
break
48+
return text
49+
50+
def _parse_arguments(self, arg_block: str) -> Dict[str, Any]:
51+
cleaned = self._replace_quotes(arg_block.strip())
52+
if not cleaned:
53+
return {}
54+
normalized = self._quote_keys(cleaned)
55+
return json.loads(normalized)
56+
57+
def _parse_tool_call_block(
58+
self, block: str
59+
) -> Tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]:
60+
content = block.strip()
61+
try:
62+
# Remove wrapper tokens
63+
if content.startswith(self.tool_call_start_token):
64+
content = content[len(self.tool_call_start_token) :]
65+
if content.endswith(self.tool_call_end_token):
66+
content = content[: -len(self.tool_call_end_token)]
67+
content = content.strip()
68+
69+
match = self.call_header_regex.search(content)
70+
if not match:
71+
raise ValueError("Missing call header")
72+
func_name = match.group(1).strip()
73+
74+
brace_start = content.find("{", match.end())
75+
brace_end = content.rfind("}")
76+
if brace_start == -1 or brace_end == -1 or brace_end < brace_start:
77+
args = {}
78+
else:
79+
args_str = content[brace_start : brace_end + 1]
80+
args = self._parse_arguments(args_str)
81+
return (None, func_name, args)
82+
except Exception as exc:
83+
logger.warning("Failed to parse Gemma tool call: %s, error: %s", block, exc)
84+
return (block, None, None)
85+
86+
def extract_tool_calls(
87+
self, model_output: str
88+
) -> List[Tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]]:
89+
if self.tool_call_start_token not in model_output:
90+
return [(model_output, None, None)]
91+
92+
results: List[Tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]] = (
93+
[]
94+
)
95+
last_end = 0
96+
for match in self.tool_call_regex.finditer(model_output):
97+
if match.start() > last_end:
98+
content = model_output[last_end : match.start()]
99+
if content:
100+
results.append((content, None, None))
101+
block = match.group(0)
102+
results.append(self._parse_tool_call_block(block))
103+
last_end = match.end()
104+
105+
if last_end < len(model_output):
106+
remainder = model_output[last_end:]
107+
if remainder:
108+
results.append((remainder, None, None))
109+
110+
return results or [(model_output, None, None)]
111+
112+
def extract_tool_calls_streaming(
113+
self,
114+
previous_texts: List[str],
115+
current_text: str,
116+
delta_text: str,
117+
) -> Optional[Tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]]:
118+
if self.tool_call_start_token not in current_text:
119+
return (delta_text, None, None)
120+
121+
matches = list(self.tool_call_regex.finditer(current_text))
122+
if not matches:
123+
return None
124+
125+
prev_text = previous_texts[-1] if previous_texts else ""
126+
last_match = matches[-1]
127+
if last_match.end() <= len(prev_text):
128+
# The latest complete tool call was already processed, return delta as text
129+
return (delta_text, None, None)
130+
131+
block = last_match.group(0)
132+
return self._parse_tool_call_block(block)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
3+
from ..gemma_tool_parser import GemmaToolParser
4+
5+
6+
@pytest.fixture
7+
def parser():
8+
return GemmaToolParser()
9+
10+
11+
def test_extract_tool_calls(parser):
12+
output = (
13+
"<|tool_call>call:get_weather"
14+
'{location:<|"|>上海<|"|>,unit:<|"|>celsius<|"|>}'
15+
"<tool_call|>"
16+
)
17+
result = parser.extract_tool_calls(output)
18+
assert result == [(None, "get_weather", {"location": "上海", "unit": "celsius"})]
19+
20+
21+
def test_extract_tool_calls_with_surrounding_text(parser):
22+
output = (
23+
"Thought...\n"
24+
"<|tool_call>call:get_weather"
25+
'{location:<|"|>上海<|"|>}'
26+
"<tool_call|>\nThanks"
27+
)
28+
result = parser.extract_tool_calls(output)
29+
assert result == [
30+
("Thought...\n", None, None),
31+
(None, "get_weather", {"location": "上海"}),
32+
("\nThanks", None, None),
33+
]
34+
35+
36+
def test_extract_tool_calls_streaming(parser):
37+
previous = [""]
38+
block = "<|tool_call>call:get_weather" '{location:<|"|>上海<|"|>}' "<tool_call|>"
39+
result = parser.extract_tool_calls_streaming(previous, block, block)
40+
assert result == (None, "get_weather", {"location": "上海"})
41+
42+
43+
def test_streaming_ignores_processed_block(parser):
44+
block = "<|tool_call>call:get_weather" '{location:<|"|>上海<|"|>}' "<tool_call|>"
45+
previous = [block]
46+
current = block + " more text"
47+
result = parser.extract_tool_calls_streaming(previous, current, " more text")
48+
assert result == (" more text", None, None)

xinference/model/llm/transformers/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from ..llm_family import LLMFamilyV2, LLMSpecV1
4242
from ..utils import (
4343
DEEPSEEK_TOOL_CALL_FAMILY,
44+
GEMMA_TOOL_CALL_FAMILY,
4445
LLAMA3_TOOL_CALL_FAMILY,
4546
QWEN_TOOL_CALL_FAMILY,
4647
ChatModelMixin,
@@ -1079,9 +1080,9 @@ def _get_full_prompt(self, messages: List[Dict], tools, generate_config: dict):
10791080
)
10801081
chat_context_var.set(chat_template_kwargs)
10811082
full_context_kwargs = chat_template_kwargs.copy()
1082-
if (
1083-
tools
1084-
and model_family in QWEN_TOOL_CALL_FAMILY
1083+
if tools and (
1084+
model_family in QWEN_TOOL_CALL_FAMILY
1085+
or model_family in GEMMA_TOOL_CALL_FAMILY
10851086
or model_family in LLAMA3_TOOL_CALL_FAMILY
10861087
or model_family in DEEPSEEK_TOOL_CALL_FAMILY
10871088
):

xinference/model/llm/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def get_context_length_from_config(
123123
"qwen3.5",
124124
]
125125

126+
GEMMA_TOOL_CALL_FAMILY = ["gemma-4"]
127+
126128
GLM4_TOOL_CALL_FAMILY = [
127129
"glm4-chat",
128130
"glm4-chat-1m",
@@ -142,6 +144,7 @@ def get_context_length_from_config(
142144

143145
TOOL_CALL_FAMILY = (
144146
QWEN_TOOL_CALL_FAMILY
147+
+ GEMMA_TOOL_CALL_FAMILY
145148
+ GLM4_TOOL_CALL_FAMILY
146149
+ LLAMA3_TOOL_CALL_FAMILY
147150
+ DEEPSEEK_TOOL_CALL_FAMILY

xinference/model/llm/vllm/core.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from ..llm_family import cache_model_tokenizer_and_config
6161
from ..utils import (
6262
DEEPSEEK_TOOL_CALL_FAMILY,
63+
GEMMA_TOOL_CALL_FAMILY,
6364
QWEN_TOOL_CALL_FAMILY,
6465
QWEN_TOOL_CALL_SYMBOLS,
6566
ChatModelMixin,
@@ -1670,6 +1671,7 @@ async def async_chat(
16701671
if tools:
16711672
if (
16721673
model_family in QWEN_TOOL_CALL_FAMILY
1674+
or model_family in GEMMA_TOOL_CALL_FAMILY
16731675
or model_family in DEEPSEEK_TOOL_CALL_FAMILY
16741676
):
16751677
full_context_kwargs["tools"] = tools
@@ -1963,7 +1965,10 @@ async def async_chat(
19631965
)
19641966
chat_context_var.set(chat_template_kwargs)
19651967
full_context_kwargs = chat_template_kwargs.copy()
1966-
if tools and model_family in QWEN_TOOL_CALL_FAMILY:
1968+
if tools and (
1969+
model_family in QWEN_TOOL_CALL_FAMILY
1970+
or model_family in GEMMA_TOOL_CALL_FAMILY
1971+
):
19671972
full_context_kwargs["tools"] = tools
19681973
assert self.model_family.chat_template is not None
19691974
if "omni" in self.model_family.model_ability:

0 commit comments

Comments
 (0)