From f97077c31db7040852909416a17d4d9c3bad847d Mon Sep 17 00:00:00 2001 From: hallerite Date: Thu, 7 May 2026 17:38:44 +0000 Subject: [PATCH 1/3] feat: add Llama-3 renderer for Llama-3.2-1B/3B-Instruct MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hand-coded Llama3Renderer mirroring Meta's Llama-3.x chat template. Initial scope: Llama-3.2-1B-Instruct and Llama-3.2-3B-Instruct (and the unrestricted unsloth/... mirrors with byte-identical chat templates). MODEL_RENDERER_MAP routes the canonical meta-llama paths; tests load via the unsloth mirrors so CI doesn't need an HF_TOKEN with Meta license access. Implementation notes: * No / reasoning channel — preserve_*_thinking constructor flags raise NotImplementedError if set (matches DefaultRenderer's contract for the same case). * <|begin_of_text|> (BOS) is emitted at the start of every render. The system block is emitted UNCONDITIONALLY with a fixed "Cutting Knowledge Date / Today Date" preamble even when no system message is supplied. date_string is a constructor kwarg pinned at "26 Jul 2024" by default (matches the chat template's strftime fallback); override per instance for production runs that want today's date. * tools_in_user_message defaults to True. Tools + JSON signatures inject into the first user message; pass False at construction to flip to system-block mode. Both modes parity-tested. * Single tool call per assistant message (chat template raises otherwise). Tool calls render as a JSON blob inside the assistant body. Tool responses render under role ipython regardless of source role; mirrors the chat template's content|tojson branch including the Jinja quirk that strings are iterable so plain-string tool content gets JSON-quoted. * parse_llama_3 detects the JSON tool-call body shape with a strict check; malformed JSON falls through to content. 47 dedicated tests covering map shape, constructor contract, byte parity across 11 conversation shapes (including tool calls, multi-turn, custom date, tools-in-system mode), parse_response, and bridge contract. Full suite: 947 passed, 48 skipped, 1 xfailed. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/__init__.py | 2 + renderers/base.py | 15 +- renderers/llama_3.py | 401 ++++++++++++++++++++++++++++++++++++++++++ renderers/parsing.py | 44 +++++ tests/test_llama_3.py | 388 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 848 insertions(+), 2 deletions(-) create mode 100644 renderers/llama_3.py create mode 100644 tests/test_llama_3.py diff --git a/renderers/__init__.py b/renderers/__init__.py index 6b2f225..911ea74 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -26,6 +26,7 @@ from renderers.gpt_oss import GptOssRenderer from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer +from renderers.llama_3 import Llama3Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -43,6 +44,7 @@ "GptOssRenderer", "KimiK2Renderer", "KimiK25Renderer", + "Llama3Renderer", "Message", "MiniMaxM2Renderer", "Nemotron3Renderer", diff --git a/renderers/base.py b/renderers/base.py index 8afa1ff..16b08d1 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -293,6 +293,14 @@ def size(self) -> int: # Nemotron 3. "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": "nemotron3", "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "nemotron3", + # Llama 3.2 (Instruct). Tested against the gated meta-llama repos and + # the unrestricted unsloth/... mirror, which ships a byte-identical + # chat template. ``Llama3Renderer`` defaults ``date_string`` to + # "26 Jul 2024" — matching the chat template's strftime fallback — + # so the renderer is reproducible. Pass ``date_string=...`` at + # construction to pin a different date. + "meta-llama/Llama-3.2-1B-Instruct": "llama_3", + "meta-llama/Llama-3.2-3B-Instruct": "llama_3", # GPT-OSS. "openai/gpt-oss-20b": "gpt_oss", "openai/gpt-oss-120b": "gpt_oss", @@ -360,6 +368,7 @@ def _populate_registry(): from renderers.gpt_oss import GptOssRenderer from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer + from renderers.llama_3 import Llama3Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -381,6 +390,7 @@ def _populate_registry(): "deepseek_v3": DeepSeekV3Renderer, "kimi_k2": KimiK2Renderer, "kimi_k25": KimiK25Renderer, + "llama_3": Llama3Renderer, "nemotron3": Nemotron3Renderer, "gpt_oss": GptOssRenderer, } @@ -444,8 +454,9 @@ def create_renderer( Args: tokenizer: HuggingFace tokenizer instance. renderer: Renderer name ('qwen3', 'qwen3_vl', 'qwen3.5', 'glm5', 'glm4.5', - 'minimax-m2', 'deepseek_v3', 'kimi_k2', 'kimi_k25', 'nemotron3', - 'gpt_oss', 'default') or 'auto' to detect from model name. + 'minimax-m2', 'deepseek_v3', 'kimi_k2', 'kimi_k25', 'llama_3', + 'nemotron3', 'gpt_oss', 'default') or 'auto' to detect from + model name. tool_parser: Name of a tool parser registered in ``renderers.parsers``. Only consumed by DefaultRenderer. Model-specific renderers have their own parsing wired in. diff --git a/renderers/llama_3.py b/renderers/llama_3.py new file mode 100644 index 0000000..df0a508 --- /dev/null +++ b/renderers/llama_3.py @@ -0,0 +1,401 @@ +"""Llama-3 Renderer — hard-coded Python mirroring Meta's Llama-3 chat template. + +Initial scope: Llama-3.2-1B-Instruct and Llama-3.2-3B-Instruct (and the +unrestricted ``unsloth/Llama-3.2-{1B,3B}-Instruct`` mirror, which ships a +byte-identical chat template). Other Llama-3.x sizes ship slightly +different templates and are NOT covered by this renderer until parity is +verified. + +Notable differences from the Qwen / GLM family renderers: + +* No ```` / reasoning channel — Llama-3 doesn't ship a + reasoning-content concept, so ``preserve_*_thinking`` flags don't + apply. +* ``<|begin_of_text|>`` (BOS) is emitted at the very start of every + render. The chat template never omits it. +* The system block is emitted **unconditionally** with a fixed + ``Cutting Knowledge Date: December 2023\\nToday Date: \\n\\n`` + preamble — even when no system message is supplied. Empty system + message → block ends with ``\\n\\n<|eot_id|>``. +* Tools default to "first-user-message" mode (matching the chat + template's default ``tools_in_user_message=True``): tool descriptions + + JSON signatures are injected into the first user message rather + than the system block. Pass ``tools_in_user_message=False`` at + construction to flip to system-block mode. +* ``date_string`` is a constructor kwarg pinned at ``"26 Jul 2024"`` by + default to match the chat template's ``strftime`` fallback (and keep + output deterministic). Override per instance for production runs that + want today's date. +* Tool calls: a single ``{"name": "...", "parameters": ...}`` JSON blob + inside the assistant body. The chat template explicitly raises if + ``message.tool_calls | length != 1``; this renderer matches that. +* Tool responses: rendered with role ``ipython`` regardless of whether + the source message used ``role: "tool"`` or ``role: "ipython"``. The + chat template runs ``content | tojson`` on any mapping/iterable + content — and Jinja considers strings iterable, so plain string + contents get JSON-quoted. We mirror that exactly. +""" + +from __future__ import annotations + +import json +from typing import Any + +from transformers.tokenization_utils import PreTrainedTokenizer + +from renderers.base import ( + Message, + ParsedResponse, + RenderedTokens, + ToolSpec, + reject_assistant_in_extension, + trim_to_turn_close, +) +from renderers.parsing import parse_llama_3 + +# --------------------------------------------------------------------------- +# Constants — must match the Jinja chat template's literal strings exactly. +# --------------------------------------------------------------------------- + +_DEFAULT_DATE_STRING = "26 Jul 2024" + +_CUTTING_KNOWLEDGE_DATE = "December 2023" + +# Tools-in-system intro: emitted into the system block when tools is set +# AND tools_in_user_message=False. Note: the chat template puts these +# three string literals back-to-back with NO newline between the second +# and third, so there's no space before "Do not use variables.". +_TOOLS_IN_SYSTEM_INTRO = ( + "You have access to the following functions. To call a function, " + "please respond with JSON for a function call." + 'Respond in the format {"name": function name, "parameters": ' + "dictionary of argument name and its value}." + "Do not use variables.\n\n" +) + +# Tools-in-user intro: emitted into the first user message when tools is +# set AND tools_in_user_message=True (the default). +_TOOLS_IN_USER_INTRO = ( + "Given the following functions, please respond with a JSON for a " + "function call with its proper arguments that best answers the given " + "prompt.\n\n" + 'Respond in the format {"name": function name, "parameters": ' + "dictionary of argument name and its value}." + "Do not use variables.\n\n" +) + + +class Llama3Renderer: + """Deterministic message → token renderer for Llama-3.x Instruct models.""" + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + *, + date_string: str = _DEFAULT_DATE_STRING, + tools_in_user_message: bool = True, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, + ): + if preserve_all_thinking or preserve_thinking_between_tool_calls: + raise NotImplementedError( + "Llama-3 doesn't ship a reasoning_content channel — the chat " + "template has no block to preserve or drop. " + "preserve_*_thinking flags are not applicable." + ) + self._tokenizer = tokenizer + self._date_string = date_string + self._tools_in_user_message = tools_in_user_message + + self._bos = self._token_id("<|begin_of_text|>") + self._start_header = self._token_id("<|start_header_id|>") + self._end_header = self._token_id("<|end_header_id|>") + self._eot = self._token_id("<|eot_id|>") + self._end_of_text = self._token_id("<|end_of_text|>") + # ``<|eom_id|>`` shows up in some Llama-3 tool-calling traces (the + # "ipython" / python-tag flow) but the standard 3.2 chat template + # closes turns with ``<|eot_id|>``. We still treat eom as a stop + # token so models that emit it terminate cleanly. + self._eom = self._token_id("<|eom_id|>") + + def _token_id(self, token: str) -> int: + tid = self._tokenizer.convert_tokens_to_ids(token) + assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, ( + f"Special token {token!r} not found in tokenizer vocabulary" + ) + return tid + + def _encode(self, text: str) -> list[int]: + if not text: + return [] + return self._tokenizer.encode(text, add_special_tokens=False) + + @staticmethod + def _content_str(content: Any) -> str: + """Render content to a plain string. Handles ``str``, list-of-text-parts, + and ``None``. Matches the chat template's ``message.content | trim`` + callers, which expect a string in.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict) and "text" in item: + parts.append(item["text"]) + else: + raise ValueError(f"Unexpected content item: {item}") + return "".join(parts) + raise TypeError(f"Unexpected content type: {type(content)}") + + @staticmethod + def _tool_response_str(content: Any) -> str: + """Mirror the chat template's tool-response branch: + ``{% if message.content is mapping or message.content is iterable %} + {{ message.content | tojson }} {% else %} {{ message.content }}``. + + In Jinja, **strings are iterable** — so plain-string tool contents + also go through ``tojson`` (i.e. ``json.dumps``), wrapping them in + quotes and escaping. Non-iterable scalars (numbers, bools, None) + fall through to literal stringification. + """ + if content is None: + return "" + if isinstance(content, (dict, list, str)): + return json.dumps(content, ensure_ascii=False) + return str(content) + + # ------------------------------------------------------------------ + # render + # ------------------------------------------------------------------ + + def render( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> RenderedTokens: + if not messages: + raise ValueError("No messages provided.") + + tokens: list[int] = [] + indices: list[int] = [] + + def emit_ids(ids: list[int], msg_idx: int) -> None: + tokens.extend(ids) + indices.extend([msg_idx] * len(ids)) + + def emit_special(token_id: int, msg_idx: int) -> None: + tokens.append(token_id) + indices.append(msg_idx) + + def emit_text(text: str, msg_idx: int) -> None: + emit_ids(self._encode(text), msg_idx) + + # ── 0. BOS ────────────────────────────────────────────────── + emit_special(self._bos, -1) + + # ── 1. System block (always emitted) ──────────────────────── + first_is_system = messages[0].get("role") == "system" + sys_idx = 0 if first_is_system else -1 + sys_text = ( + self._content_str(messages[0].get("content")).strip() + if first_is_system + else "" + ) + + emit_special(self._start_header, sys_idx) + emit_text("system", sys_idx) + emit_special(self._end_header, sys_idx) + body = "\n\n" + if tools is not None: + body += "Environment: ipython\n" + body += f"Cutting Knowledge Date: {_CUTTING_KNOWLEDGE_DATE}\n" + body += f"Today Date: {self._date_string}\n\n" + if tools is not None and not self._tools_in_user_message: + body += _TOOLS_IN_SYSTEM_INTRO + for t in tools: + body += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" + body += sys_text + emit_text(body, sys_idx) + emit_special(self._eot, sys_idx) + + # ── 2. Body messages ──────────────────────────────────────── + body_messages = messages[1:] if first_is_system else messages + offset = 1 if first_is_system else 0 + + i = 0 + # 2a. tools_in_user_message mode pulls the first user message + # into a special block with the tools description prepended. + if tools is not None and self._tools_in_user_message: + if i >= len(body_messages): + raise ValueError( + "Cannot place tools in the first user message — no user " + "message was provided." + ) + first_user = body_messages[i] + if first_user.get("role") != "user": + raise ValueError( + "tools_in_user_message=True requires the first non-system " + f"message to be 'user'; got {first_user.get('role')!r}." + ) + user_idx = i + offset + emit_special(self._start_header, user_idx) + emit_text("user", user_idx) + emit_special(self._end_header, user_idx) + user_body = "\n\n" + _TOOLS_IN_USER_INTRO + for t in tools: + user_body += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" + user_body += self._content_str(first_user.get("content")).strip() + emit_text(user_body, user_idx) + emit_special(self._eot, user_idx) + i += 1 + + # 2b. Remaining messages — plain user/assistant/tool/assistant-with-tool-calls. + for j in range(i, len(body_messages)): + msg = body_messages[j] + msg_idx = j + offset + role = msg.get("role") + tool_calls = msg.get("tool_calls") + + if role in ("tool", "ipython"): + emit_special(self._start_header, msg_idx) + emit_text("ipython", msg_idx) + emit_special(self._end_header, msg_idx) + emit_text( + "\n\n" + self._tool_response_str(msg.get("content")), + msg_idx, + ) + emit_special(self._eot, msg_idx) + elif tool_calls: + if len(tool_calls) != 1: + raise ValueError( + "Llama-3 chat template only supports a single tool call " + "per assistant message." + ) + tc = tool_calls[0] + func = tc.get("function") or tc + name = func.get("name", "") + arguments = func.get("arguments", {}) + if isinstance(arguments, str): + args_str = arguments + else: + args_str = json.dumps(arguments, ensure_ascii=False) + emit_special(self._start_header, msg_idx) + emit_text("assistant", msg_idx) + emit_special(self._end_header, msg_idx) + emit_text( + '\n\n{"name": "' + name + '", "parameters": ' + args_str + "}", + msg_idx, + ) + emit_special(self._eot, msg_idx) + else: + content = self._content_str(msg.get("content")).strip() + emit_special(self._start_header, msg_idx) + emit_text(role or "", msg_idx) + emit_special(self._end_header, msg_idx) + emit_text("\n\n" + content, msg_idx) + emit_special(self._eot, msg_idx) + + # ── 3. Generation prompt ──────────────────────────────────── + if add_generation_prompt: + emit_special(self._start_header, -1) + emit_text("assistant", -1) + emit_special(self._end_header, -1) + emit_text("\n\n", -1) + + return RenderedTokens(token_ids=tokens, message_indices=indices) + + def render_ids( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> list[int]: + return self.render( + messages, + tools=tools, + add_generation_prompt=add_generation_prompt, + ).token_ids + + def parse_response(self, token_ids: list[int]) -> ParsedResponse: + return parse_llama_3( + self._tokenizer, + token_ids, + stop_ids={self._eot, self._end_of_text, self._eom}, + ) + + def get_stop_token_ids(self) -> list[int]: + return [self._eot, self._end_of_text, self._eom] + + # ------------------------------------------------------------------ + # bridge_to_next_turn + # ------------------------------------------------------------------ + + def bridge_to_next_turn( + self, + previous_prompt_ids: list[int], + previous_completion_ids: list[int], + new_messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + ) -> list[int] | None: + if ( + not previous_prompt_ids + or not new_messages + or reject_assistant_in_extension(new_messages) + ): + return None + + previous_ids = trim_to_turn_close( + previous_prompt_ids, + previous_completion_ids, + {self._eot, self._end_of_text, self._eom}, + synthesize_close=self._eot, + ) + if previous_ids is None: + return None + + ext: list[int] = [] + + def emit_special(token_id: int, _msg_idx: int = -1) -> None: + ext.append(token_id) + + def emit_text(text: str, _msg_idx: int = -1) -> None: + ext.extend(self._encode(text)) + + for i, msg in enumerate(new_messages): + role = msg.get("role") + if role == "system": + emit_special(self._start_header, i) + emit_text("system", i) + emit_special(self._end_header, i) + emit_text("\n\n" + self._content_str(msg.get("content")).strip(), i) + emit_special(self._eot, i) + elif role == "user": + emit_special(self._start_header, i) + emit_text("user", i) + emit_special(self._end_header, i) + emit_text("\n\n" + self._content_str(msg.get("content")).strip(), i) + emit_special(self._eot, i) + elif role in ("tool", "ipython"): + emit_special(self._start_header, i) + emit_text("ipython", i) + emit_special(self._end_header, i) + emit_text("\n\n" + self._tool_response_str(msg.get("content")), i) + emit_special(self._eot, i) + else: + return None + + # Generation prompt — matches the gen-prompt branch of ``render()``. + emit_special(self._start_header, -1) + emit_text("assistant", -1) + emit_special(self._end_header, -1) + emit_text("\n\n", -1) + + return previous_ids + ext diff --git a/renderers/parsing.py b/renderers/parsing.py index 6644103..53a1e99 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -782,3 +782,47 @@ def _gptoss_extract_after_token( after = _decode(tokenizer, header_ids[pos + 1 :]).strip() # Take first whitespace-delimited word (channel name) return after.split()[0] if after else None + + +# ── Llama-3: single JSON tool call {"name": "...", "parameters": {...}} ─ + + +def parse_llama_3( + tokenizer, + token_ids: list[int], + *, + stop_ids: set[int], +) -> ParsedResponse: + """Parse Llama-3 completion tokens. + + The Llama-3 chat template emits tool calls as a single JSON blob in + the assistant body — ``{"name": "...", "parameters": {...}}`` — with + no surrounding XML tags or special tokens. Plain replies are just + text. We detect the tool-call shape with a strict starts-with-``{`` + + parses-as-dict-with-name-key check; anything else is treated as + content. Llama-3 doesn't have a built-in reasoning channel, so + ``reasoning_content`` is always ``None``. + """ + ids = _strip_stop_tokens(token_ids, stop_ids) + text = _decode(tokenizer, ids).strip() + + if text.startswith("{") and text.endswith("}"): + try: + parsed = json.loads(text) + except json.JSONDecodeError: + parsed = None + if isinstance(parsed, dict) and "name" in parsed: + arguments = parsed.get("parameters", parsed.get("arguments", {})) + tool_call = { + "function": { + "name": parsed.get("name", ""), + "arguments": arguments, + } + } + return ParsedResponse( + content="", + reasoning_content=None, + tool_calls=[tool_call], + ) + + return ParsedResponse(content=text, reasoning_content=None, tool_calls=None) diff --git a/tests/test_llama_3.py b/tests/test_llama_3.py new file mode 100644 index 0000000..d9dbb09 --- /dev/null +++ b/tests/test_llama_3.py @@ -0,0 +1,388 @@ +"""Llama-3 renderer coverage. + +Covers ``Llama3Renderer`` and the ``meta-llama/Llama-3.2-{1B,3B}-Instruct`` +entries in ``MODEL_RENDERER_MAP``. Tokenizers are loaded via the +unrestricted ``unsloth/Llama-3.2-{1B,3B}-Instruct`` mirrors (verified +byte-identical chat templates) so CI doesn't need an HF token with Meta +license access. +""" + +from __future__ import annotations + +import pytest + +from renderers import Llama3Renderer, create_renderer +from renderers.base import MODEL_RENDERER_MAP, ParsedResponse, load_tokenizer + +# Pinned date for byte-parity tests. Matches the chat template's +# strftime fallback so we don't have to override on the apply side. +_PINNED_DATE = "26 Jul 2024" + +_MODEL_PAIRS = [ + # (canonical meta-llama id used by MODEL_RENDERER_MAP, unrestricted + # mirror used to actually load the tokenizer in tests) + ("meta-llama/Llama-3.2-1B-Instruct", "unsloth/Llama-3.2-1B-Instruct"), + ("meta-llama/Llama-3.2-3B-Instruct", "unsloth/Llama-3.2-3B-Instruct"), +] + + +@pytest.fixture(scope="module", params=_MODEL_PAIRS, ids=[m for m, _ in _MODEL_PAIRS]) +def llama_pair(request): + canonical, mirror = request.param + tok = load_tokenizer(mirror) + renderer = Llama3Renderer(tok, date_string=_PINNED_DATE) + return canonical, mirror, tok, renderer + + +# --------------------------------------------------------------------------- +# MODEL_RENDERER_MAP shape +# --------------------------------------------------------------------------- + + +def test_canonical_meta_llama_paths_route_to_llama_3(): + for canonical, _ in _MODEL_PAIRS: + assert MODEL_RENDERER_MAP.get(canonical) == "llama_3", ( + f"{canonical}: expected to route to 'llama_3'" + ) + + +def test_create_renderer_via_explicit_name(llama_pair): + """The 'llama_3' string resolves to Llama3Renderer in the registry.""" + _, _, tok, _ = llama_pair + r = create_renderer(tok, renderer="llama_3") + assert isinstance(r, Llama3Renderer) + + +# --------------------------------------------------------------------------- +# Constructor contract +# --------------------------------------------------------------------------- + + +def test_default_date_matches_chat_template_strftime_fallback(llama_pair): + """Default ``date_string`` is ``"26 Jul 2024"`` so output stays + deterministic without an explicit override.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer(tok) + assert r._date_string == _PINNED_DATE + + +def test_preserve_all_thinking_rejected(llama_pair): + _, _, tok, _ = llama_pair + with pytest.raises(NotImplementedError, match="reasoning_content"): + Llama3Renderer(tok, preserve_all_thinking=True) + + +def test_preserve_thinking_between_tool_calls_rejected(llama_pair): + _, _, tok, _ = llama_pair + with pytest.raises(NotImplementedError, match="reasoning_content"): + Llama3Renderer(tok, preserve_thinking_between_tool_calls=True) + + +# --------------------------------------------------------------------------- +# Byte parity vs apply_chat_template +# --------------------------------------------------------------------------- + + +def _expected(tok, messages, **kwargs): + kwargs.setdefault("add_generation_prompt", False) + kwargs.setdefault("date_string", _PINNED_DATE) + return list( + tok.apply_chat_template(messages, tokenize=True, return_dict=False, **kwargs) + ) + + +def test_parity_minimal_user(llama_pair): + _, _, tok, r = llama_pair + msgs = [{"role": "user", "content": "Hi."}] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_system_and_user(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_system_user_assistant(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + {"role": "assistant", "content": "Hello!"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_no_system_with_gen_prompt(llama_pair): + _, _, tok, r = llama_pair + msgs = [{"role": "user", "content": "Hi."}] + assert r.render_ids(msgs, add_generation_prompt=True) == _expected( + tok, msgs, add_generation_prompt=True + ) + + +def test_parity_multi_turn(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "D"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_trims_whitespace(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": " hello "}, + {"role": "assistant", "content": "\n\nworld\n"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_custom_date(llama_pair): + """``date_string`` constructor override changes both sides identically.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer(tok, date_string="01 Jan 2026") + msgs = [{"role": "user", "content": "Hi."}] + expected = list( + tok.apply_chat_template( + msgs, tokenize=True, return_dict=False, date_string="01 Jan 2026" + ) + ) + assert r.render_ids(msgs) == expected + + +def test_parity_tools_in_user_default(llama_pair): + _, _, tok, r = llama_pair + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + }, + } + ] + msgs = [ + {"role": "system", "content": "Be terse."}, + {"role": "user", "content": "Weather?"}, + ] + assert r.render_ids(msgs, tools=tools) == _expected(tok, msgs, tools=tools) + + +def test_parity_tools_in_system_mode(llama_pair): + """When constructed with ``tools_in_user_message=False``, the renderer + matches ``apply_chat_template(... tools_in_user_message=False)``.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer(tok, date_string=_PINNED_DATE, tools_in_user_message=False) + tools = [ + { + "type": "function", + "function": {"name": "get_weather", "parameters": {}}, + } + ] + msgs = [ + {"role": "system", "content": "Be terse."}, + {"role": "user", "content": "Weather?"}, + ] + expected = list( + tok.apply_chat_template( + msgs, + tokenize=True, + return_dict=False, + tools=tools, + tools_in_user_message=False, + date_string=_PINNED_DATE, + ) + ) + assert r.render_ids(msgs, tools=tools) == expected + + +def test_parity_tool_call_round_trip(llama_pair): + """Assistant tool_calls + tool response + final assistant — covers + the JSON tool-call body emission and the ``ipython`` response role.""" + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "Weather?"}, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"city": "NYC"}, + }, + } + ], + }, + {"role": "tool", "content": '{"temp": 72}'}, + {"role": "assistant", "content": "It's 72."}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_tool_response_dict_content(llama_pair): + """Tool response with mapping content goes through ``tojson`` in the + template; the renderer's ``_tool_response_str`` mirrors that.""" + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "x"}, + { + "role": "assistant", + "tool_calls": [{"function": {"name": "f", "arguments": {}}}], + }, + {"role": "tool", "content": {"k": "v", "n": 42}}, + {"role": "assistant", "content": "ok"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_render_raises_on_multiple_tool_calls(llama_pair): + """Llama-3 chat template explicitly raises on >1 tool call per turn — + the renderer mirrors that contract.""" + _, _, _, r = llama_pair + msgs = [ + {"role": "user", "content": "x"}, + { + "role": "assistant", + "tool_calls": [ + {"function": {"name": "f", "arguments": {}}}, + {"function": {"name": "g", "arguments": {}}}, + ], + }, + ] + with pytest.raises(ValueError, match="single tool call"): + r.render_ids(msgs) + + +# --------------------------------------------------------------------------- +# parse_response +# --------------------------------------------------------------------------- + + +def _tokens_for(tok, text: str) -> list[int]: + return tok.encode(text, add_special_tokens=False) + + +def test_parse_response_plain_content(llama_pair): + _, _, tok, r = llama_pair + ids = _tokens_for(tok, "Hello, world!") + [r._eot] + out = r.parse_response(ids) + assert isinstance(out, ParsedResponse) + assert out.content == "Hello, world!" + assert out.tool_calls is None + assert out.reasoning_content is None + + +def test_parse_response_tool_call(llama_pair): + _, _, tok, r = llama_pair + body = '{"name": "get_weather", "parameters": {"city": "NYC"}}' + ids = _tokens_for(tok, body) + [r._eot] + out = r.parse_response(ids) + assert out.content == "" + assert out.tool_calls == [ + {"function": {"name": "get_weather", "arguments": {"city": "NYC"}}} + ] + + +def test_parse_response_malformed_tool_call_falls_through_to_content(llama_pair): + """A body that LOOKS like a tool call but doesn't parse should land + in content rather than dropping silently.""" + _, _, tok, r = llama_pair + body = '{"name": "x", broken' + ids = _tokens_for(tok, body) + [r._eot] + out = r.parse_response(ids) + assert out.tool_calls is None + assert "{" in out.content + + +# --------------------------------------------------------------------------- +# Bridge contract +# --------------------------------------------------------------------------- + + +def _simulate_prior_turn(r): + prior = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + asst = [{"role": "assistant", "content": "Hello!"}] + + prev_prompt = r.render_ids(prior, add_generation_prompt=True) + full = r.render_ids(prior + asst, add_generation_prompt=False) + prev_completion = list(full[len(prev_prompt) :]) + + stop = set(r.get_stop_token_ids()) + last = -1 + for i in range(len(prev_completion) - 1, -1, -1): + if prev_completion[i] in stop: + last = i + break + if last >= 0: + prev_completion = prev_completion[: last + 1] + return prev_prompt, prev_completion + + +def test_bridge_extends_prev_verbatim_on_clean_stop(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + new_messages = [{"role": "user", "content": "What's 2+2?"}] + bridged = r.bridge_to_next_turn(prev_prompt, prev_completion, new_messages) + assert bridged is not None + prev = prev_prompt + prev_completion + assert bridged[: len(prev)] == prev + assert len(bridged) > len(prev) + + +def test_bridge_matches_fresh_render_on_clean_stop(llama_pair): + """The whole point of the bridge: it must produce the same tokens as + a fresh render of the full message list — except sampled tokens are + kept verbatim rather than re-rendered.""" + _, _, _, r = llama_pair + prior = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + asst = [{"role": "assistant", "content": "Hello!"}] + new_messages = [{"role": "user", "content": "What's 2+2?"}] + + prev_prompt, prev_completion = _simulate_prior_turn(r) + bridged = r.bridge_to_next_turn(prev_prompt, prev_completion, new_messages) + fresh = r.render_ids(prior + asst + new_messages, add_generation_prompt=True) + assert bridged == fresh + + +def test_bridge_rejects_assistant_in_extension(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + bridged = r.bridge_to_next_turn( + prev_prompt, + prev_completion, + [{"role": "assistant", "content": "forbidden"}], + ) + assert bridged is None + + +def test_bridge_synthesises_close_on_truncation(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + trunc = prev_completion[:-1] + if not trunc: + pytest.skip("simulated prior had no completion tokens to truncate") + bridged = r.bridge_to_next_turn( + prev_prompt, trunc, [{"role": "user", "content": "ping"}] + ) + assert bridged is not None + base = prev_prompt + trunc + assert bridged[: len(base)] == base + assert len(bridged) > len(base) From 395e53280494c36255953dc859404c5630938cfb Mon Sep 17 00:00:00 2001 From: hallerite Date: Thu, 4 Jun 2026 04:41:41 +0530 Subject: [PATCH 2/3] fix(parsing): parse_llama_3 emits list[ParsedToolCall], not OpenAI dicts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit parse_llama_3 violated the parser contract (documented in parsing.py's module docstring: "Every parser emits list[ParsedToolCall]"). It put an OpenAI-shaped {"function": {...}} dict in ParsedResponse.tool_calls and used None for the no-call case. Inference-client code that filters on ToolCallParseStatus.OK — or just iterates tool_calls — broke on every Llama-3 completion: AttributeError on .status for the call case, and TypeError iterating None for plain replies. Emit a ParsedToolCall(raw, name, arguments, token_span, status=OK) for a detected call, and fall back to the dataclass default (empty list) for plain content and for {...} bodies that don't parse / lack a name. Llama-3 has no tool-call delimiter to anchor a "malformed attempt" against, so a non-tool-call body stays content rather than producing a non-OK entry — preserving the prior fall-through-to-content behaviour. Tests updated to the ParsedToolCall shape and empty-list contract. Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/parsing.py | 24 +++++++++++++++--------- tests/test_llama_3.py | 19 +++++++++++++------ 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/renderers/parsing.py b/renderers/parsing.py index f33a992..be119f8 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -1299,18 +1299,24 @@ def parse_llama_3( parsed = json.loads(text) except json.JSONDecodeError: parsed = None - if isinstance(parsed, dict) and "name" in parsed: + if isinstance(parsed, dict) and parsed.get("name"): arguments = parsed.get("parameters", parsed.get("arguments", {})) - tool_call = { - "function": { - "name": parsed.get("name", ""), - "arguments": arguments, - } - } return ParsedResponse( content="", reasoning_content=None, - tool_calls=[tool_call], + tool_calls=[ + ParsedToolCall( + raw=text, + name=parsed["name"], + arguments=arguments, + token_span=(0, len(ids)), + status=ToolCallParseStatus.OK, + ) + ], ) - return ParsedResponse(content=text, reasoning_content=None, tool_calls=None) + # Not a tool-call shape (plain reply, or a ``{...}`` body that didn't + # parse / lacked a name). Llama-3 has no delimiter to anchor a + # "malformed attempt" against, so it falls through to content rather + # than producing a non-OK ParsedToolCall. + return ParsedResponse(content=text, reasoning_content=None) diff --git a/tests/test_llama_3.py b/tests/test_llama_3.py index 201b08e..82827b4 100644 --- a/tests/test_llama_3.py +++ b/tests/test_llama_3.py @@ -12,7 +12,12 @@ import pytest from renderers import Llama3Renderer, Llama3RendererConfig, create_renderer -from renderers.base import MODEL_RENDERER_MAP, ParsedResponse, load_tokenizer +from renderers.base import ( + MODEL_RENDERER_MAP, + ParsedResponse, + ToolCallParseStatus, + load_tokenizer, +) # Pinned date for byte-parity tests. Matches the chat template's # strftime fallback so we don't have to override on the apply side. @@ -285,7 +290,7 @@ def test_parse_response_plain_content(llama_pair): out = r.parse_response(ids) assert isinstance(out, ParsedResponse) assert out.content == "Hello, world!" - assert out.tool_calls is None + assert out.tool_calls == [] assert out.reasoning_content is None @@ -295,9 +300,11 @@ def test_parse_response_tool_call(llama_pair): ids = _tokens_for(tok, body) + [r._eot] out = r.parse_response(ids) assert out.content == "" - assert out.tool_calls == [ - {"function": {"name": "get_weather", "arguments": {"city": "NYC"}}} - ] + assert len(out.tool_calls) == 1 + tc = out.tool_calls[0] + assert tc.status == ToolCallParseStatus.OK + assert tc.name == "get_weather" + assert tc.arguments == {"city": "NYC"} def test_parse_response_malformed_tool_call_falls_through_to_content(llama_pair): @@ -307,7 +314,7 @@ def test_parse_response_malformed_tool_call_falls_through_to_content(llama_pair) body = '{"name": "x", broken' ids = _tokens_for(tok, body) + [r._eot] out = r.parse_response(ids) - assert out.tool_calls is None + assert out.tool_calls == [] assert "{" in out.content From 5a6152700eac71aa8ee780090d422bfaef71fa37 Mon Sep 17 00:00:00 2001 From: hallerite Date: Thu, 4 Jun 2026 05:40:46 +0530 Subject: [PATCH 3/3] fix(llama): align Llama3Renderer with the cross-renderer contract MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Llama-3 renderer predated several contract additions and diverged from every other renderer. Bring it fully in line: renderers/llama_3.py * render() now emits all six RenderedTokens fields. Previously it set only token_ids + message_indices, leaving sampled_mask, is_content, message_roles, and message_tool_names empty — below even DefaultRenderer's baseline. Thread is_sampled/is_content through the emit helpers (the qwen3/laguna emit_special/emit_text/ emit_text_segments trio); scaffold/body splits route through attribute_text_segments so byte-parity is preserved. * bridge_to_next_turn() now returns RenderedTokens (was list[int]), violating the Renderer protocol — with the contract attribution (prior portion -1/False, new portion indexed + content-masked, sampled_mask uniformly False). * parse_response() gains the `*, tools=None` kwarg from the protocol. * preserve_*_thinking flags are now no-ops instead of raising — Llama has no reasoning channel, the same never-preserves contract as Kimi-K2 / Qwen3-VL (no renderer raises on these). renderers/parsing.py * parse_llama_3 skips a leading assistant role-header before tool-call detection. Delimiter-based parsers tolerate that scaffold naturally; Llama's bare-JSON format needs it explicit. No-op on the sampled stream in production. tests/ * Wire Llama-3 into the shared matrices (conftest RENDERER_MODELS, test_bridge, test_roundtrip) via the ungated unsloth mirror, and add it to NO_OP_MODELS + NEVER_PRESERVES_MODELS. * Skip Llama for the generic HF-parity files (test_render_ids, test_build_helpers): its template fills the date via strftime_now, so apply_chat_template parity is non-deterministic — deterministic byte-parity (date pinned on both sides) stays in test_llama_3.py. * Skip the parallel-tool-call round-trip (template forbids >1 call). * Convert the preserve-thinking rejection tests to no-op assertions. Full suite: 1947 passed, 88 skipped, 1 xfailed. Co-Authored-By: Claude Opus 4.8 (1M context) --- renderers/llama_3.py | 274 +++++++++++++++++++++++--------- renderers/parsing.py | 21 ++- tests/conftest.py | 32 ++++ tests/test_bridge.py | 1 + tests/test_llama_3.py | 39 +++-- tests/test_preserve_thinking.py | 6 + tests/test_roundtrip.py | 7 + 7 files changed, 285 insertions(+), 95 deletions(-) diff --git a/renderers/llama_3.py b/renderers/llama_3.py index 6b2feb7..d15792a 100644 --- a/renderers/llama_3.py +++ b/renderers/llama_3.py @@ -48,6 +48,8 @@ ParsedResponse, RenderedTokens, ToolSpec, + attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, trim_to_turn_close, ) @@ -92,15 +94,14 @@ def __init__( tokenizer: PreTrainedTokenizer, config: Llama3RendererConfig | None = None, ): - config = config or Llama3RendererConfig() - if config.preserve_all_thinking or config.preserve_thinking_between_tool_calls: - raise NotImplementedError( - "Llama-3 doesn't ship a reasoning_content channel — the chat " - "template has no block to preserve or drop. " - "preserve_*_thinking flags are not applicable." - ) + # ``preserve_*_thinking`` are accepted but no-ops: Llama-3 ships no + # reasoning_content channel, so there's never any past-assistant + # thinking to retain or drop. The flags are stored on ``self.config`` + # for cross-renderer uniformity but never change the token stream — + # the same contract as Kimi-K2 / Qwen3-VL (see the never-preserves + # renderers in tests/test_preserve_thinking.py). self._tokenizer = tokenizer - self.config = config + self.config = config or Llama3RendererConfig() self._bos = self._token_id("<|begin_of_text|>") self._start_header = self._token_id("<|start_header_id|>") @@ -179,20 +180,43 @@ def render( tokens: list[int] = [] indices: list[int] = [] + sampled: list[bool] = [] + content_mask: list[bool] = [] - def emit_ids(ids: list[int], msg_idx: int) -> None: - tokens.extend(ids) - indices.extend([msg_idx] * len(ids)) - - def emit_special(token_id: int, msg_idx: int) -> None: + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: tokens.append(token_id) indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) - def emit_text(text: str, msg_idx: int) -> None: - emit_ids(self._encode(text), msg_idx) + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + ids = self._encode(text) + tokens.extend(ids) + indices.extend([msg_idx] * len(ids)) + sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + """Tokenize concatenated wrap + body as one BPE pass; per-token + ``is_content`` follows each token's source segment. Lets the + scaffold/body split stay attributed without splitting the + encode call (which could shift BPE merges at the boundary).""" + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) # ── 0. BOS ────────────────────────────────────────────────── - emit_special(self._bos, -1) + emit_special(self._bos, -1, is_sampled=False, is_content=False) # ── 1. System block (always emitted) ──────────────────────── first_is_system = messages[0].get("role") == "system" @@ -203,21 +227,27 @@ def emit_text(text: str, msg_idx: int) -> None: else "" ) - emit_special(self._start_header, sys_idx) - emit_text("system", sys_idx) - emit_special(self._end_header, sys_idx) - body = "\n\n" + emit_special(self._start_header, sys_idx, is_sampled=False, is_content=False) + emit_text("system", sys_idx, is_sampled=False, is_content=False) + emit_special(self._end_header, sys_idx, is_sampled=False, is_content=False) + # The Cutting Knowledge / Today Date preamble (and any tools-in-system + # block) is template scaffold; only the caller's system content is + # body. Route both through one BPE pass so the wrap/body boundary + # can't shift merges. + preamble = "\n\n" if tools is not None: - body += "Environment: ipython\n" - body += f"Cutting Knowledge Date: {_CUTTING_KNOWLEDGE_DATE}\n" - body += f"Today Date: {self.config.date_string}\n\n" + preamble += "Environment: ipython\n" + preamble += f"Cutting Knowledge Date: {_CUTTING_KNOWLEDGE_DATE}\n" + preamble += f"Today Date: {self.config.date_string}\n\n" if tools is not None and not self.config.tools_in_user_message: - body += _TOOLS_IN_SYSTEM_INTRO + preamble += _TOOLS_IN_SYSTEM_INTRO for t in tools: - body += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" - body += sys_text - emit_text(body, sys_idx) - emit_special(self._eot, sys_idx) + preamble += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" + sys_segments: list[tuple[str, bool]] = [(preamble, False)] + if sys_text: + sys_segments.append((sys_text, True)) + emit_text_segments(sys_segments, sys_idx, is_sampled=False) + emit_special(self._eot, sys_idx, is_sampled=False, is_content=False) # ── 2. Body messages ──────────────────────────────────────── body_messages = messages[1:] if first_is_system else messages @@ -239,15 +269,20 @@ def emit_text(text: str, msg_idx: int) -> None: f"message to be 'user'; got {first_user.get('role')!r}." ) user_idx = i + offset - emit_special(self._start_header, user_idx) - emit_text("user", user_idx) - emit_special(self._end_header, user_idx) - user_body = "\n\n" + _TOOLS_IN_USER_INTRO + emit_special( + self._start_header, user_idx, is_sampled=False, is_content=False + ) + emit_text("user", user_idx, is_sampled=False, is_content=False) + emit_special(self._end_header, user_idx, is_sampled=False, is_content=False) + user_preamble = "\n\n" + _TOOLS_IN_USER_INTRO for t in tools: - user_body += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" - user_body += self._content_str(first_user.get("content")).strip() - emit_text(user_body, user_idx) - emit_special(self._eot, user_idx) + user_preamble += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" + user_content = self._content_str(first_user.get("content")).strip() + user_segments: list[tuple[str, bool]] = [(user_preamble, False)] + if user_content: + user_segments.append((user_content, True)) + emit_text_segments(user_segments, user_idx, is_sampled=False) + emit_special(self._eot, user_idx, is_sampled=False, is_content=False) i += 1 # 2b. Remaining messages — plain user/assistant/tool/assistant-with-tool-calls. @@ -258,14 +293,22 @@ def emit_text(text: str, msg_idx: int) -> None: tool_calls = msg.get("tool_calls") if role in ("tool", "ipython"): - emit_special(self._start_header, msg_idx) - emit_text("ipython", msg_idx) - emit_special(self._end_header, msg_idx) - emit_text( - "\n\n" + self._tool_response_str(msg.get("content")), - msg_idx, + # Tool responses are conversation history the model never + # samples; the response body is caller content, the wrap is + # scaffold. + emit_special( + self._start_header, msg_idx, is_sampled=False, is_content=False + ) + emit_text("ipython", msg_idx, is_sampled=False, is_content=False) + emit_special( + self._end_header, msg_idx, is_sampled=False, is_content=False ) - emit_special(self._eot, msg_idx) + tool_body = self._tool_response_str(msg.get("content")) + tool_segments: list[tuple[str, bool]] = [("\n\n", False)] + if tool_body: + tool_segments.append((tool_body, True)) + emit_text_segments(tool_segments, msg_idx, is_sampled=False) + emit_special(self._eot, msg_idx, is_sampled=False, is_content=False) elif tool_calls: if len(tool_calls) != 1: raise ValueError( @@ -280,30 +323,70 @@ def emit_text(text: str, msg_idx: int) -> None: args_str = arguments else: args_str = json.dumps(arguments, ensure_ascii=False) - emit_special(self._start_header, msg_idx) - emit_text("assistant", msg_idx) - emit_special(self._end_header, msg_idx) + emit_special( + self._start_header, msg_idx, is_sampled=False, is_content=False + ) + emit_text("assistant", msg_idx, is_sampled=False, is_content=False) + emit_special( + self._end_header, msg_idx, is_sampled=False, is_content=False + ) + # The ``\n\n`` after the header is gen-prompt scaffold the + # model never samples; the JSON tool-call body and the + # closing ``<|eot_id|>`` are the model's sampled emission. + emit_text("\n\n", msg_idx, is_sampled=False, is_content=False) emit_text( - '\n\n{"name": "' + name + '", "parameters": ' + args_str + "}", + '{"name": "' + name + '", "parameters": ' + args_str + "}", msg_idx, + is_sampled=True, + is_content=True, + ) + emit_special(self._eot, msg_idx, is_sampled=True, is_content=True) + elif role == "assistant": + content = self._content_str(msg.get("content")).strip() + emit_special( + self._start_header, msg_idx, is_sampled=False, is_content=False + ) + emit_text("assistant", msg_idx, is_sampled=False, is_content=False) + emit_special( + self._end_header, msg_idx, is_sampled=False, is_content=False ) - emit_special(self._eot, msg_idx) + # ``\n\n`` separator is scaffold (it's the generation prompt); + # the body and the closing ``<|eot_id|>`` are model-sampled. + emit_text("\n\n", msg_idx, is_sampled=False, is_content=False) + if content: + emit_text(content, msg_idx, is_sampled=True, is_content=True) + emit_special(self._eot, msg_idx, is_sampled=True, is_content=True) else: + # user / non-leading system: caller content, never sampled. content = self._content_str(msg.get("content")).strip() - emit_special(self._start_header, msg_idx) - emit_text(role or "", msg_idx) - emit_special(self._end_header, msg_idx) - emit_text("\n\n" + content, msg_idx) - emit_special(self._eot, msg_idx) + emit_special( + self._start_header, msg_idx, is_sampled=False, is_content=False + ) + emit_text(role or "", msg_idx, is_sampled=False, is_content=False) + emit_special( + self._end_header, msg_idx, is_sampled=False, is_content=False + ) + segments: list[tuple[str, bool]] = [("\n\n", False)] + if content: + segments.append((content, True)) + emit_text_segments(segments, msg_idx, is_sampled=False) + emit_special(self._eot, msg_idx, is_sampled=False, is_content=False) # ── 3. Generation prompt ──────────────────────────────────── if add_generation_prompt: - emit_special(self._start_header, -1) - emit_text("assistant", -1) - emit_special(self._end_header, -1) - emit_text("\n\n", -1) - - return RenderedTokens(token_ids=tokens, message_indices=indices) + emit_special(self._start_header, -1, is_sampled=False, is_content=False) + emit_text("assistant", -1, is_sampled=False, is_content=False) + emit_special(self._end_header, -1, is_sampled=False, is_content=False) + emit_text("\n\n", -1, is_sampled=False, is_content=False) + + return RenderedTokens( + token_ids=tokens, + message_indices=indices, + sampled_mask=sampled, + is_content=content_mask, + message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), + ) def render_ids( self, @@ -318,7 +401,12 @@ def render_ids( add_generation_prompt=add_generation_prompt, ).token_ids - def parse_response(self, token_ids: list[int]) -> ParsedResponse: + def parse_response( + self, + token_ids: list[int], + *, + tools: list[ToolSpec] | None = None, + ) -> ParsedResponse: return parse_llama_3( self._tokenizer, token_ids, @@ -339,7 +427,7 @@ def bridge_to_next_turn( new_messages: list[Message], *, tools: list[ToolSpec] | None = None, - ) -> list[int] | None: + ) -> RenderedTokens | None: if ( not previous_prompt_ids or not new_messages @@ -357,32 +445,56 @@ def bridge_to_next_turn( return None ext: list[int] = [] - - def emit_special(token_id: int, _msg_idx: int = -1) -> None: + ext_indices: list[int] = [] + ext_content: list[bool] = [] + + # Every token the bridge emits is template scaffolding for the next + # prompt — none of it is model-sampled — so ``sampled_mask`` is + # uniformly ``False`` (applied over the whole sequence at return). + # ``is_content`` follows the same rules as :meth:`render` so a + # consumer can walk the trajectory and read each step's body mask. + def emit_special(token_id: int, msg_idx: int = -1) -> None: ext.append(token_id) - - def emit_text(text: str, _msg_idx: int = -1) -> None: - ext.extend(self._encode(text)) + ext_indices.append(msg_idx) + ext_content.append(False) + + def emit_text(text: str, msg_idx: int = -1) -> None: + ids = self._encode(text) + ext.extend(ids) + ext_indices.extend([msg_idx] * len(ids)) + ext_content.extend([False] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int = -1 + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_content.append(is_content) for i, msg in enumerate(new_messages): role = msg.get("role") - if role == "system": - emit_special(self._start_header, i) - emit_text("system", i) - emit_special(self._end_header, i) - emit_text("\n\n" + self._content_str(msg.get("content")).strip(), i) - emit_special(self._eot, i) - elif role == "user": + if role in ("system", "user"): + content = self._content_str(msg.get("content")).strip() emit_special(self._start_header, i) - emit_text("user", i) + emit_text(role, i) emit_special(self._end_header, i) - emit_text("\n\n" + self._content_str(msg.get("content")).strip(), i) + segs: list[tuple[str, bool]] = [("\n\n", False)] + if content: + segs.append((content, True)) + emit_text_segments(segs, i) emit_special(self._eot, i) elif role in ("tool", "ipython"): + tool_body = self._tool_response_str(msg.get("content")) emit_special(self._start_header, i) emit_text("ipython", i) emit_special(self._end_header, i) - emit_text("\n\n" + self._tool_response_str(msg.get("content")), i) + tool_segs: list[tuple[str, bool]] = [("\n\n", False)] + if tool_body: + tool_segs.append((tool_body, True)) + emit_text_segments(tool_segs, i) emit_special(self._eot, i) else: return None @@ -393,4 +505,12 @@ def emit_text(text: str, _msg_idx: int = -1) -> None: emit_special(self._end_header, -1) emit_text("\n\n", -1) - return previous_ids + ext + total_len = len(previous_ids) + len(ext) + return RenderedTokens( + token_ids=previous_ids + ext, + message_indices=[-1] * len(previous_ids) + ext_indices, + sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, + message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), + ) diff --git a/renderers/parsing.py b/renderers/parsing.py index be119f8..339b5d2 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -1290,9 +1290,26 @@ def parse_llama_3( + parses-as-dict-with-name-key check; anything else is treated as content. Llama-3 doesn't have a built-in reasoning channel, so ``reasoning_content`` is always ``None``. + + Unlike the delimiter-based formats (Qwen/GLM), the tool call has no + special token to anchor on, so a leading assistant role-header + (``<|start_header_id|>assistant<|end_header_id|>\\n\\n``) would defeat + the starts-with-``{`` check. Callers that slice a completion without + dropping the generation prompt include that scaffold; we skip past the + final ``<|end_header_id|>`` so the body is what we parse. The sampled + stream in production carries no header, making this a no-op there. """ ids = _strip_stop_tokens(token_ids, stop_ids) - text = _decode(tokenizer, ids).strip() + + # Skip a leading assistant role-header scaffold if present. + body_start = 0 + end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") + if isinstance(end_header_id, int): + eh_positions = _find_all(ids, end_header_id) + if eh_positions: + body_start = eh_positions[-1] + 1 + body_ids = ids[body_start:] + text = _decode(tokenizer, body_ids).strip() if text.startswith("{") and text.endswith("}"): try: @@ -1309,7 +1326,7 @@ def parse_llama_3( raw=text, name=parsed["name"], arguments=arguments, - token_span=(0, len(ids)), + token_span=(body_start, len(ids)), status=ToolCallParseStatus.OK, ) ], diff --git a/tests/conftest.py b/tests/conftest.py index c334430..3ef8350 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,11 @@ ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), ("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"), ("poolside/Laguna-XS.2", "auto"), + # Llama-3 loads via the unrestricted unsloth mirror (byte-identical + # chat template) so CI needs no Meta-gated HF token. Pinned to the + # explicit "llama-3" config because the mirror name isn't in + # MODEL_RENDERER_MAP (so "auto" would fall back to DefaultRenderer). + ("unsloth/Llama-3.2-1B-Instruct", "llama-3"), ("openai/gpt-oss-20b", "gpt-oss"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] @@ -103,3 +108,30 @@ def _skip_gpt_oss_for_hf_parity_tests(request): f"{model_name}: renderer matches openai-harmony / vLLM, not HF " "apply_chat_template — see test_gpt_oss_harmony_parity.py" ) + + +# Llama-3's chat template fills the "Today Date:" line via ``strftime_now``, +# so ``apply_chat_template`` with no explicit ``date_string`` bakes in the +# real wall-clock date — non-deterministic and not byte-stable against a +# renderer pinned to "26 Jul 2024". Generic HF-parity tests can't pass a +# kwarg, so they're skipped here; deterministic byte-parity (with the date +# passed on both sides) is covered in test_llama_3.py. +_LLAMA_HF_PARITY_TEST_FILES = { + "test_render_ids.py", + "test_build_helpers.py", +} + + +@pytest.fixture(autouse=True) +def _skip_llama_for_hf_parity_tests(request): + callspec = getattr(request.node, "callspec", None) + model_name = callspec.params.get("model_name") if callspec else None + if model_name != "unsloth/Llama-3.2-1B-Instruct": + return + test_file = os.path.basename(str(request.node.fspath)) + if test_file in _LLAMA_HF_PARITY_TEST_FILES: + pytest.skip( + f"{model_name}: template uses strftime_now for the date line, so " + "generic apply_chat_template parity is non-deterministic — " + "deterministic byte-parity is covered in test_llama_3.py" + ) diff --git a/tests/test_bridge.py b/tests/test_bridge.py index 81ff2e4..8b7bf77 100644 --- a/tests/test_bridge.py +++ b/tests/test_bridge.py @@ -33,6 +33,7 @@ ("moonshotai/Kimi-K2-Instruct", "auto"), ("moonshotai/Kimi-K2.5", "auto"), ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), + ("unsloth/Llama-3.2-1B-Instruct", "llama-3"), ("openai/gpt-oss-20b", "gpt-oss"), ] diff --git a/tests/test_llama_3.py b/tests/test_llama_3.py index 82827b4..018c2f3 100644 --- a/tests/test_llama_3.py +++ b/tests/test_llama_3.py @@ -71,18 +71,25 @@ def test_default_date_matches_chat_template_strftime_fallback(llama_pair): assert r.config.date_string == _PINNED_DATE -def test_preserve_all_thinking_rejected(llama_pair): +def test_preserve_thinking_flags_are_noops(llama_pair): + """Llama-3 has no reasoning channel, so the ``preserve_*_thinking`` + flags are accepted but never change the token stream — the same + never-preserves contract as Kimi-K2 / Qwen3-VL. (Cross-renderer + coverage lives in tests/test_preserve_thinking.py.)""" _, _, tok, _ = llama_pair - with pytest.raises(NotImplementedError, match="reasoning_content"): - Llama3Renderer(tok, Llama3RendererConfig(preserve_all_thinking=True)) - - -def test_preserve_thinking_between_tool_calls_rejected(llama_pair): - _, _, tok, _ = llama_pair - with pytest.raises(NotImplementedError, match="reasoning_content"): - Llama3Renderer( - tok, Llama3RendererConfig(preserve_thinking_between_tool_calls=True) - ) + msgs = [ + {"role": "user", "content": "Hi."}, + { + "role": "assistant", + "reasoning_content": "internal musings", + "content": "Hello!", + }, + ] + base = Llama3Renderer(tok).render_ids(msgs) + for flag in ("preserve_all_thinking", "preserve_thinking_between_tool_calls"): + r = Llama3Renderer(tok, Llama3RendererConfig(**{flag: True})) + assert r.config.__getattribute__(flag) is True + assert r.render_ids(msgs) == base, f"{flag} must be a no-op for Llama-3" # --------------------------------------------------------------------------- @@ -352,8 +359,8 @@ def test_bridge_extends_prev_verbatim_on_clean_stop(llama_pair): bridged = r.bridge_to_next_turn(prev_prompt, prev_completion, new_messages) assert bridged is not None prev = prev_prompt + prev_completion - assert bridged[: len(prev)] == prev - assert len(bridged) > len(prev) + assert bridged.token_ids[: len(prev)] == prev + assert len(bridged.token_ids) > len(prev) def test_bridge_matches_fresh_render_on_clean_stop(llama_pair): @@ -371,7 +378,7 @@ def test_bridge_matches_fresh_render_on_clean_stop(llama_pair): prev_prompt, prev_completion = _simulate_prior_turn(r) bridged = r.bridge_to_next_turn(prev_prompt, prev_completion, new_messages) fresh = r.render_ids(prior + asst + new_messages, add_generation_prompt=True) - assert bridged == fresh + assert bridged.token_ids == fresh def test_bridge_rejects_assistant_in_extension(llama_pair): @@ -396,5 +403,5 @@ def test_bridge_synthesises_close_on_truncation(llama_pair): ) assert bridged is not None base = prev_prompt + trunc - assert bridged[: len(base)] == base - assert len(bridged) > len(base) + assert bridged.token_ids[: len(base)] == base + assert len(bridged.token_ids) > len(base) diff --git a/tests/test_preserve_thinking.py b/tests/test_preserve_thinking.py index 1ef07f5..75b739b 100644 --- a/tests/test_preserve_thinking.py +++ b/tests/test_preserve_thinking.py @@ -51,6 +51,9 @@ def _make(tokenizer, renderer_name, **flags): "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct", "poolside/Laguna-XS.2", + # Llama-3 has no reasoning channel at all — preserve flags can't add + # or drop anything, so they're pure no-ops. + "unsloth/Llama-3.2-1B-Instruct", } @@ -319,6 +322,9 @@ def test_preserve_btc_on_live_cycle_matches_all( "Qwen/Qwen3-VL-4B-Instruct", "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct", + # Llama-3 ships no rendering path, so reasoning_content never + # surfaces in the output regardless of the preserve flags. + "unsloth/Llama-3.2-1B-Instruct", } diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 383bc14..326fa50 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -44,6 +44,7 @@ ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), ("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"), ("poolside/Laguna-XS.2", "auto"), + ("unsloth/Llama-3.2-1B-Instruct", "llama-3"), ("openai/gpt-oss-20b", "gpt-oss"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] @@ -207,6 +208,12 @@ def test_roundtrip_multiple_tool_calls( """Parsers that loop over ```` blocks can silently drop the second one; this test catches that.""" _maybe_skip_tool_calls(rt_renderer_name) + if rt_renderer_name == "llama-3": + pytest.skip( + "Llama-3's chat template forbids >1 tool call per assistant " + "message (the renderer raises, mirroring the template); the " + "single-call path is covered by test_roundtrip_single_tool_call." + ) msg = { "role": "assistant",