Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions mlx_lm/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,38 @@ def _infer_thinking(tokenizer):
return (None, None, None, None)


def _infer_markers_from_config(tokenizer):
"""Discover tool-call markers from tokenizer config fields.

Some models (e.g. Gemma 4) publish structured token fields in
``tokenizer_config.json`` that HuggingFace's ``AutoTokenizer`` exposes
as attributes (via ``_special_tokens_map``). This function checks for
those fields and returns any markers found.

Currently recognises:

* ``stc_token`` / ``etc_token`` – start / end of **tool call**
(Gemma 4 convention, see
https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4).

Returns:
dict with ``"tool_call_start"`` and ``"tool_call_end"`` (str or None).
"""
result = {
"tool_call_start": None,
"tool_call_end": None,
}

# stc_token / etc_token = start / end of tool call (Gemma 4 convention)
stc = getattr(tokenizer, "stc_token", None)
etc_tok = getattr(tokenizer, "etc_token", None)
if stc is not None and etc_tok is not None:
result["tool_call_start"] = stc
result["tool_call_end"] = etc_tok

return result


class TokenizerWrapper:
"""A wrapper that combines an HF tokenizer and a detokenizer.

Expand All @@ -300,6 +332,8 @@ def __init__(
tool_call_start=None,
tool_call_end=None,
tool_parser=None,
think_start=None,
think_end=None,
):
self._tokenizer = tokenizer
self._detokenizer_class = detokenizer_class
Expand All @@ -308,12 +342,22 @@ def __init__(
if eos_token_ids is not None
else {tokenizer.eos_token_id}
)
(
self._think_start,
self._think_end,
self._think_start_tokens,
self._think_end_tokens,
) = _infer_thinking(tokenizer)
if think_start is not None and think_end is not None:
self._think_start = think_start
self._think_end = think_end
self._think_start_tokens = tuple(
tokenizer.encode(think_start, add_special_tokens=False)
)
self._think_end_tokens = tuple(
tokenizer.encode(think_end, add_special_tokens=False)
)
else:
(
self._think_start,
self._think_end,
self._think_start_tokens,
self._think_end_tokens,
) = _infer_thinking(tokenizer)

self._chat_template = chat_template
self.has_chat_template = (
Expand Down Expand Up @@ -613,6 +657,9 @@ def load(

tokenizer_config = tokenizer.init_kwargs

# Auto-discover markers from tokenizer config fields (e.g. Gemma 4)
config_markers = _infer_markers_from_config(tokenizer)

if chat_template_type := tokenizer_config.get("chat_template_type", False):
chat_template = importlib.import_module(
f"mlx_lm.chat_templates.{chat_template_type}"
Expand All @@ -623,11 +670,18 @@ def load(
)

if tool_parser_type is not None:
# Parser module knows the exact markers it expects
tool_module = importlib.import_module(f"mlx_lm.tool_parsers.{tool_parser_type}")
tool_parser = tool_module.parse_tool_call
tool_call_start = tool_module.tool_call_start
tool_call_end = tool_module.tool_call_end
tokenizer_config["tool_parser_type"] = tool_parser_type
elif config_markers["tool_call_start"] is not None:
# Config provided tool markers but no parser was matched.
# Set markers for state-machine streaming; parser stays None.
tool_parser = None
tool_call_start = config_markers["tool_call_start"]
tool_call_end = config_markers["tool_call_end"]
else:
tool_parser = None
tool_call_start = None
Expand Down
100 changes: 100 additions & 0 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,105 @@ def test_thinking(self):
self.assertEqual(tokenizer.think_end, "</think>")


class _StubTokenizer:
"""Minimal tokenizer stub for testing marker-discovery in isolation.

Mirrors the HuggingFace behaviour where unset extra special tokens
(``boi_token``, ``stc_token`` etc.) return ``None`` via ``__getattr__``.
"""

def __init__(self, **named_tokens):
self._named_tokens = named_tokens
self.eos_token_id = 0
self.chat_template = None

def __getattr__(self, name):
# Instance dict is checked first; this only fires for unset names.
if name in self._named_tokens:
return self._named_tokens[name]
if name.endswith("_token"):
return None
raise AttributeError(name)

def get_vocab(self):
return {}

def encode(self, text, add_special_tokens=False):
# Deterministic fake tokenisation; two IDs per input string.
return [100, 101]


class TestMarkerDiscovery(unittest.TestCase):
"""Tests for _infer_markers_from_config and related wrapper plumbing."""

def test_config_discovers_tool_markers(self):
"""stc_token / etc_token → tool_call_start / tool_call_end."""
from mlx_lm.tokenizer_utils import _infer_markers_from_config

tok = _StubTokenizer(
stc_token="<|tool_call>",
etc_token="<tool_call|>",
)
result = _infer_markers_from_config(tok)
self.assertEqual(result["tool_call_start"], "<|tool_call>")
self.assertEqual(result["tool_call_end"], "<tool_call|>")

def test_config_no_markers_returns_none(self):
"""Tokenizer without config fields returns None markers."""
from mlx_lm.tokenizer_utils import _infer_markers_from_config

tok = _StubTokenizer()
result = _infer_markers_from_config(tok)
self.assertIsNone(result["tool_call_start"])
self.assertIsNone(result["tool_call_end"])

def test_config_partial_markers_ignored(self):
"""Only stc_token without etc_token → no markers set."""
from mlx_lm.tokenizer_utils import _infer_markers_from_config

tok = _StubTokenizer(stc_token="<|tool_call>")
result = _infer_markers_from_config(tok)
self.assertIsNone(result["tool_call_start"])
self.assertIsNone(result["tool_call_end"])

def test_config_markers_enable_tool_calling(self):
"""Markers passed to TokenizerWrapper should flip has_tool_calling."""
from mlx_lm.tokenizer_utils import TokenizerWrapper

wrapper = TokenizerWrapper(
_StubTokenizer(),
tool_call_start="<|tool_call>",
tool_call_end="<tool_call|>",
)
self.assertTrue(wrapper.has_tool_calling)
self.assertEqual(wrapper.tool_call_start, "<|tool_call>")
self.assertEqual(wrapper.tool_call_end, "<tool_call|>")

def test_think_start_end_params_override_inference(self):
"""Explicit think_start/think_end bypass _infer_thinking."""
from mlx_lm.tokenizer_utils import TokenizerWrapper

wrapper = TokenizerWrapper(
_StubTokenizer(),
think_start="<think>",
think_end="</think>",
)
self.assertTrue(wrapper.has_thinking)
self.assertEqual(wrapper.think_start, "<think>")
self.assertEqual(wrapper.think_end, "</think>")

def test_parser_markers_take_precedence(self):
"""Integration: when a parser module exists, its markers win.

Verifies that adding config-based discovery does not regress any
currently-supported model. Qwen3 is matched by _infer_tool_parser
and must keep using the parser module's markers.
"""
tokenizer = load_tokenizer("mlx-community/Qwen3-4B-4bit")
self.assertTrue(tokenizer.has_tool_calling)
self.assertEqual(tokenizer.tool_call_start, "<tool_call>")
self.assertEqual(tokenizer.tool_call_end, "</tool_call>")


if __name__ == "__main__":
unittest.main()