diff --git a/renderers/__init__.py b/renderers/__init__.py index e7cd1c4..0670769 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -52,6 +52,7 @@ KimiK25RendererConfig, KimiK2RendererConfig, LagunaXS2RendererConfig, + Llama3RendererConfig, MiniMaxM2RendererConfig, Nemotron3RendererConfig, Qwen35RendererConfig, @@ -82,6 +83,7 @@ "KimiK25Renderer": "renderers.kimi_k25", "KimiK2Renderer": "renderers.kimi_k2", "LagunaXS2Renderer": "renderers.laguna_xs2", + "Llama3Renderer": "renderers.llama_3", "MiniMaxM2Renderer": "renderers.minimax_m2", "Nemotron3Renderer": "renderers.nemotron3", "Qwen35Renderer": "renderers.qwen35", @@ -130,6 +132,8 @@ def __dir__() -> list[str]: "KimiK2RendererConfig", "LagunaXS2Renderer", "LagunaXS2RendererConfig", + "Llama3Renderer", + "Llama3RendererConfig", "MULTIMODAL_MODELS", "Message", "MiniMaxM2Renderer", diff --git a/renderers/base.py b/renderers/base.py index 242adae..e6af3fc 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -1045,6 +1045,14 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "nemotron-3", "nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16": "nemotron-3", "nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-FP8": "nemotron-3", + # Llama 3.2 (Instruct). Tested against the gated meta-llama repos and + # the unrestricted unsloth/... mirror, which ships a byte-identical + # chat template. ``Llama3Renderer`` defaults ``date_string`` to + # "26 Jul 2024" — matching the chat template's strftime fallback — + # so the renderer is reproducible. Pass ``date_string=...`` at + # construction to pin a different date. + "meta-llama/Llama-3.2-1B-Instruct": "llama-3", + "meta-llama/Llama-3.2-3B-Instruct": "llama-3", # Poolside Laguna. "poolside/Laguna-XS.2": "laguna-xs.2", # GPT-OSS. @@ -1334,6 +1342,7 @@ def _populate_registry(): from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer from renderers.laguna_xs2 import LagunaXS2Renderer + from renderers.llama_3 import Llama3Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -1356,6 +1365,7 @@ def _populate_registry(): "kimi-k2": KimiK2Renderer, "kimi-k2.5": KimiK25Renderer, "laguna-xs.2": LagunaXS2Renderer, + "llama-3": Llama3Renderer, "nemotron-3": Nemotron3Renderer, "gpt-oss": GptOssRenderer, } diff --git a/renderers/configs.py b/renderers/configs.py index 2c18a17..8262078 100644 --- a/renderers/configs.py +++ b/renderers/configs.py @@ -318,6 +318,31 @@ class LagunaXS2RendererConfig(BaseRendererConfig): chat template's ``render_assistant_messages_raw`` gate.""" +class Llama3RendererConfig(BaseRendererConfig): + """Llama-3.x Instruct renderer config. + + Llama-3 ships no reasoning channel, so the base ``preserve_*_thinking`` + flags don't apply: ``Llama3Renderer`` raises ``NotImplementedError`` + if either is set (matching ``DefaultRenderer``'s contract for the + same case). Both fields below mirror real ``apply_chat_template`` + kwargs. + """ + + name: Literal["llama-3"] = "llama-3" + + date_string: str = "26 Jul 2024" + """``Today Date`` value injected into the system preamble. Pinned to + the chat template's ``strftime`` fallback by default so output stays + deterministic; override per instance for production runs that want + today's date. Mirrors the chat template's ``date_string`` kwarg.""" + + tools_in_user_message: bool = True + """When ``True`` (default), tool descriptions + JSON signatures inject + into the first user message; ``False`` routes them into the system + block instead. Mirrors the chat template's ``tools_in_user_message`` + kwarg.""" + + class MiniMaxM2RendererConfig(BaseRendererConfig): """MiniMax M2 / M2.5 renderer config.""" @@ -410,6 +435,7 @@ class DeepSeekV3RendererConfig(BaseRendererConfig): KimiK2RendererConfig, KimiK25RendererConfig, LagunaXS2RendererConfig, + Llama3RendererConfig, MiniMaxM2RendererConfig, Nemotron3RendererConfig, DeepSeekV3RendererConfig, @@ -444,6 +470,7 @@ class DeepSeekV3RendererConfig(BaseRendererConfig): "kimi-k2": KimiK2RendererConfig, "kimi-k2.5": KimiK25RendererConfig, "laguna-xs.2": LagunaXS2RendererConfig, + "llama-3": Llama3RendererConfig, "minimax-m2": MiniMaxM2RendererConfig, "nemotron-3": Nemotron3RendererConfig, "deepseek-v3": DeepSeekV3RendererConfig, @@ -486,6 +513,7 @@ def config_from_name(name: str) -> BaseRendererConfig | None: "KimiK25RendererConfig", "KimiK2RendererConfig", "LagunaXS2RendererConfig", + "Llama3RendererConfig", "MiniMaxM2RendererConfig", "Nemotron3RendererConfig", "Qwen35RendererConfig", diff --git a/renderers/llama_3.py b/renderers/llama_3.py new file mode 100644 index 0000000..d15792a --- /dev/null +++ b/renderers/llama_3.py @@ -0,0 +1,516 @@ +"""Llama-3 Renderer — hard-coded Python mirroring Meta's Llama-3 chat template. + +Initial scope: Llama-3.2-1B-Instruct and Llama-3.2-3B-Instruct (and the +unrestricted ``unsloth/Llama-3.2-{1B,3B}-Instruct`` mirror, which ships a +byte-identical chat template). Other Llama-3.x sizes ship slightly +different templates and are NOT covered by this renderer until parity is +verified. + +Notable differences from the Qwen / GLM family renderers: + +* No ```` / reasoning channel — Llama-3 doesn't ship a + reasoning-content concept, so ``preserve_*_thinking`` flags don't + apply. +* ``<|begin_of_text|>`` (BOS) is emitted at the very start of every + render. The chat template never omits it. +* The system block is emitted **unconditionally** with a fixed + ``Cutting Knowledge Date: December 2023\\nToday Date: \\n\\n`` + preamble — even when no system message is supplied. Empty system + message → block ends with ``\\n\\n<|eot_id|>``. +* Tools default to "first-user-message" mode (matching the chat + template's default ``tools_in_user_message=True``): tool descriptions + + JSON signatures are injected into the first user message rather + than the system block. Set ``Llama3RendererConfig.tools_in_user_message + = False`` to flip to system-block mode. +* ``Llama3RendererConfig.date_string`` is pinned at ``"26 Jul 2024"`` by + default to match the chat template's ``strftime`` fallback (and keep + output deterministic). Override per instance for production runs that + want today's date. +* Tool calls: a single ``{"name": "...", "parameters": ...}`` JSON blob + inside the assistant body. The chat template explicitly raises if + ``message.tool_calls | length != 1``; this renderer matches that. +* Tool responses: rendered with role ``ipython`` regardless of whether + the source message used ``role: "tool"`` or ``role: "ipython"``. The + chat template runs ``content | tojson`` on any mapping/iterable + content — and Jinja considers strings iterable, so plain string + contents get JSON-quoted. We mirror that exactly. +""" + +from __future__ import annotations + +import json +from typing import Any + +from transformers.tokenization_utils import PreTrainedTokenizer + +from renderers.base import ( + Message, + ParsedResponse, + RenderedTokens, + ToolSpec, + attribute_text_segments, + extract_message_tool_names, + reject_assistant_in_extension, + trim_to_turn_close, +) +from renderers.configs import Llama3RendererConfig +from renderers.parsing import parse_llama_3 + +# --------------------------------------------------------------------------- +# Constants — must match the Jinja chat template's literal strings exactly. +# --------------------------------------------------------------------------- + +_CUTTING_KNOWLEDGE_DATE = "December 2023" + +# Tools-in-system intro: emitted into the system block when tools is set +# AND tools_in_user_message=False. Note: the chat template puts these +# three string literals back-to-back with NO newline between the second +# and third, so there's no space before "Do not use variables.". +_TOOLS_IN_SYSTEM_INTRO = ( + "You have access to the following functions. To call a function, " + "please respond with JSON for a function call." + 'Respond in the format {"name": function name, "parameters": ' + "dictionary of argument name and its value}." + "Do not use variables.\n\n" +) + +# Tools-in-user intro: emitted into the first user message when tools is +# set AND tools_in_user_message=True (the default). +_TOOLS_IN_USER_INTRO = ( + "Given the following functions, please respond with a JSON for a " + "function call with its proper arguments that best answers the given " + "prompt.\n\n" + 'Respond in the format {"name": function name, "parameters": ' + "dictionary of argument name and its value}." + "Do not use variables.\n\n" +) + + +class Llama3Renderer: + """Deterministic message → token renderer for Llama-3.x Instruct models.""" + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + config: Llama3RendererConfig | None = None, + ): + # ``preserve_*_thinking`` are accepted but no-ops: Llama-3 ships no + # reasoning_content channel, so there's never any past-assistant + # thinking to retain or drop. The flags are stored on ``self.config`` + # for cross-renderer uniformity but never change the token stream — + # the same contract as Kimi-K2 / Qwen3-VL (see the never-preserves + # renderers in tests/test_preserve_thinking.py). + self._tokenizer = tokenizer + self.config = config or Llama3RendererConfig() + + self._bos = self._token_id("<|begin_of_text|>") + self._start_header = self._token_id("<|start_header_id|>") + self._end_header = self._token_id("<|end_header_id|>") + self._eot = self._token_id("<|eot_id|>") + self._end_of_text = self._token_id("<|end_of_text|>") + # ``<|eom_id|>`` shows up in some Llama-3 tool-calling traces (the + # "ipython" / python-tag flow) but the standard 3.2 chat template + # closes turns with ``<|eot_id|>``. We still treat eom as a stop + # token so models that emit it terminate cleanly. + self._eom = self._token_id("<|eom_id|>") + + def _token_id(self, token: str) -> int: + tid = self._tokenizer.convert_tokens_to_ids(token) + assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, ( + f"Special token {token!r} not found in tokenizer vocabulary" + ) + return tid + + def _encode(self, text: str) -> list[int]: + if not text: + return [] + return self._tokenizer.encode(text, add_special_tokens=False) + + @staticmethod + def _content_str(content: Any) -> str: + """Render content to a plain string. Handles ``str``, list-of-text-parts, + and ``None``. Matches the chat template's ``message.content | trim`` + callers, which expect a string in.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict) and "text" in item: + parts.append(item["text"]) + else: + raise ValueError(f"Unexpected content item: {item}") + return "".join(parts) + raise TypeError(f"Unexpected content type: {type(content)}") + + @staticmethod + def _tool_response_str(content: Any) -> str: + """Mirror the chat template's tool-response branch: + ``{% if message.content is mapping or message.content is iterable %} + {{ message.content | tojson }} {% else %} {{ message.content }}``. + + In Jinja, **strings are iterable** — so plain-string tool contents + also go through ``tojson`` (i.e. ``json.dumps``), wrapping them in + quotes and escaping. Non-iterable scalars (numbers, bools, None) + fall through to literal stringification. + """ + if content is None: + return "" + if isinstance(content, (dict, list, str)): + return json.dumps(content, ensure_ascii=False) + return str(content) + + # ------------------------------------------------------------------ + # render + # ------------------------------------------------------------------ + + def render( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> RenderedTokens: + if not messages: + raise ValueError("No messages provided.") + + tokens: list[int] = [] + indices: list[int] = [] + sampled: list[bool] = [] + content_mask: list[bool] = [] + + def emit_special( + token_id: int, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + tokens.append(token_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) + + def emit_text( + text: str, msg_idx: int, *, is_sampled: bool, is_content: bool + ) -> None: + ids = self._encode(text) + tokens.extend(ids) + indices.extend([msg_idx] * len(ids)) + sampled.extend([is_sampled] * len(ids)) + content_mask.extend([is_content] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int, *, is_sampled: bool + ) -> None: + """Tokenize concatenated wrap + body as one BPE pass; per-token + ``is_content`` follows each token's source segment. Lets the + scaffold/body split stay attributed without splitting the + encode call (which could shift BPE merges at the boundary).""" + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + tokens.append(tok_id) + indices.append(msg_idx) + sampled.append(is_sampled) + content_mask.append(is_content) + + # ── 0. BOS ────────────────────────────────────────────────── + emit_special(self._bos, -1, is_sampled=False, is_content=False) + + # ── 1. System block (always emitted) ──────────────────────── + first_is_system = messages[0].get("role") == "system" + sys_idx = 0 if first_is_system else -1 + sys_text = ( + self._content_str(messages[0].get("content")).strip() + if first_is_system + else "" + ) + + emit_special(self._start_header, sys_idx, is_sampled=False, is_content=False) + emit_text("system", sys_idx, is_sampled=False, is_content=False) + emit_special(self._end_header, sys_idx, is_sampled=False, is_content=False) + # The Cutting Knowledge / Today Date preamble (and any tools-in-system + # block) is template scaffold; only the caller's system content is + # body. Route both through one BPE pass so the wrap/body boundary + # can't shift merges. + preamble = "\n\n" + if tools is not None: + preamble += "Environment: ipython\n" + preamble += f"Cutting Knowledge Date: {_CUTTING_KNOWLEDGE_DATE}\n" + preamble += f"Today Date: {self.config.date_string}\n\n" + if tools is not None and not self.config.tools_in_user_message: + preamble += _TOOLS_IN_SYSTEM_INTRO + for t in tools: + preamble += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" + sys_segments: list[tuple[str, bool]] = [(preamble, False)] + if sys_text: + sys_segments.append((sys_text, True)) + emit_text_segments(sys_segments, sys_idx, is_sampled=False) + emit_special(self._eot, sys_idx, is_sampled=False, is_content=False) + + # ── 2. Body messages ──────────────────────────────────────── + body_messages = messages[1:] if first_is_system else messages + offset = 1 if first_is_system else 0 + + i = 0 + # 2a. tools_in_user_message mode pulls the first user message + # into a special block with the tools description prepended. + if tools is not None and self.config.tools_in_user_message: + if i >= len(body_messages): + raise ValueError( + "Cannot place tools in the first user message — no user " + "message was provided." + ) + first_user = body_messages[i] + if first_user.get("role") != "user": + raise ValueError( + "tools_in_user_message=True requires the first non-system " + f"message to be 'user'; got {first_user.get('role')!r}." + ) + user_idx = i + offset + emit_special( + self._start_header, user_idx, is_sampled=False, is_content=False + ) + emit_text("user", user_idx, is_sampled=False, is_content=False) + emit_special(self._end_header, user_idx, is_sampled=False, is_content=False) + user_preamble = "\n\n" + _TOOLS_IN_USER_INTRO + for t in tools: + user_preamble += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" + user_content = self._content_str(first_user.get("content")).strip() + user_segments: list[tuple[str, bool]] = [(user_preamble, False)] + if user_content: + user_segments.append((user_content, True)) + emit_text_segments(user_segments, user_idx, is_sampled=False) + emit_special(self._eot, user_idx, is_sampled=False, is_content=False) + i += 1 + + # 2b. Remaining messages — plain user/assistant/tool/assistant-with-tool-calls. + for j in range(i, len(body_messages)): + msg = body_messages[j] + msg_idx = j + offset + role = msg.get("role") + tool_calls = msg.get("tool_calls") + + if role in ("tool", "ipython"): + # Tool responses are conversation history the model never + # samples; the response body is caller content, the wrap is + # scaffold. + emit_special( + self._start_header, msg_idx, is_sampled=False, is_content=False + ) + emit_text("ipython", msg_idx, is_sampled=False, is_content=False) + emit_special( + self._end_header, msg_idx, is_sampled=False, is_content=False + ) + tool_body = self._tool_response_str(msg.get("content")) + tool_segments: list[tuple[str, bool]] = [("\n\n", False)] + if tool_body: + tool_segments.append((tool_body, True)) + emit_text_segments(tool_segments, msg_idx, is_sampled=False) + emit_special(self._eot, msg_idx, is_sampled=False, is_content=False) + elif tool_calls: + if len(tool_calls) != 1: + raise ValueError( + "Llama-3 chat template only supports a single tool call " + "per assistant message." + ) + tc = tool_calls[0] + func = tc.get("function") or tc + name = func.get("name", "") + arguments = func.get("arguments", {}) + if isinstance(arguments, str): + args_str = arguments + else: + args_str = json.dumps(arguments, ensure_ascii=False) + emit_special( + self._start_header, msg_idx, is_sampled=False, is_content=False + ) + emit_text("assistant", msg_idx, is_sampled=False, is_content=False) + emit_special( + self._end_header, msg_idx, is_sampled=False, is_content=False + ) + # The ``\n\n`` after the header is gen-prompt scaffold the + # model never samples; the JSON tool-call body and the + # closing ``<|eot_id|>`` are the model's sampled emission. + emit_text("\n\n", msg_idx, is_sampled=False, is_content=False) + emit_text( + '{"name": "' + name + '", "parameters": ' + args_str + "}", + msg_idx, + is_sampled=True, + is_content=True, + ) + emit_special(self._eot, msg_idx, is_sampled=True, is_content=True) + elif role == "assistant": + content = self._content_str(msg.get("content")).strip() + emit_special( + self._start_header, msg_idx, is_sampled=False, is_content=False + ) + emit_text("assistant", msg_idx, is_sampled=False, is_content=False) + emit_special( + self._end_header, msg_idx, is_sampled=False, is_content=False + ) + # ``\n\n`` separator is scaffold (it's the generation prompt); + # the body and the closing ``<|eot_id|>`` are model-sampled. + emit_text("\n\n", msg_idx, is_sampled=False, is_content=False) + if content: + emit_text(content, msg_idx, is_sampled=True, is_content=True) + emit_special(self._eot, msg_idx, is_sampled=True, is_content=True) + else: + # user / non-leading system: caller content, never sampled. + content = self._content_str(msg.get("content")).strip() + emit_special( + self._start_header, msg_idx, is_sampled=False, is_content=False + ) + emit_text(role or "", msg_idx, is_sampled=False, is_content=False) + emit_special( + self._end_header, msg_idx, is_sampled=False, is_content=False + ) + segments: list[tuple[str, bool]] = [("\n\n", False)] + if content: + segments.append((content, True)) + emit_text_segments(segments, msg_idx, is_sampled=False) + emit_special(self._eot, msg_idx, is_sampled=False, is_content=False) + + # ── 3. Generation prompt ──────────────────────────────────── + if add_generation_prompt: + emit_special(self._start_header, -1, is_sampled=False, is_content=False) + emit_text("assistant", -1, is_sampled=False, is_content=False) + emit_special(self._end_header, -1, is_sampled=False, is_content=False) + emit_text("\n\n", -1, is_sampled=False, is_content=False) + + return RenderedTokens( + token_ids=tokens, + message_indices=indices, + sampled_mask=sampled, + is_content=content_mask, + message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), + ) + + def render_ids( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> list[int]: + return self.render( + messages, + tools=tools, + add_generation_prompt=add_generation_prompt, + ).token_ids + + def parse_response( + self, + token_ids: list[int], + *, + tools: list[ToolSpec] | None = None, + ) -> ParsedResponse: + return parse_llama_3( + self._tokenizer, + token_ids, + stop_ids={self._eot, self._end_of_text, self._eom}, + ) + + def get_stop_token_ids(self) -> list[int]: + return [self._eot, self._end_of_text, self._eom] + + # ------------------------------------------------------------------ + # bridge_to_next_turn + # ------------------------------------------------------------------ + + def bridge_to_next_turn( + self, + previous_prompt_ids: list[int], + previous_completion_ids: list[int], + new_messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + ) -> RenderedTokens | None: + if ( + not previous_prompt_ids + or not new_messages + or reject_assistant_in_extension(new_messages) + ): + return None + + previous_ids = trim_to_turn_close( + previous_prompt_ids, + previous_completion_ids, + {self._eot, self._end_of_text, self._eom}, + synthesize_close=self._eot, + ) + if previous_ids is None: + return None + + ext: list[int] = [] + ext_indices: list[int] = [] + ext_content: list[bool] = [] + + # Every token the bridge emits is template scaffolding for the next + # prompt — none of it is model-sampled — so ``sampled_mask`` is + # uniformly ``False`` (applied over the whole sequence at return). + # ``is_content`` follows the same rules as :meth:`render` so a + # consumer can walk the trajectory and read each step's body mask. + def emit_special(token_id: int, msg_idx: int = -1) -> None: + ext.append(token_id) + ext_indices.append(msg_idx) + ext_content.append(False) + + def emit_text(text: str, msg_idx: int = -1) -> None: + ids = self._encode(text) + ext.extend(ids) + ext_indices.extend([msg_idx] * len(ids)) + ext_content.extend([False] * len(ids)) + + def emit_text_segments( + segments: list[tuple[str, bool]], msg_idx: int = -1 + ) -> None: + for tok_id, is_content in attribute_text_segments( + self._tokenizer, segments + ): + ext.append(tok_id) + ext_indices.append(msg_idx) + ext_content.append(is_content) + + for i, msg in enumerate(new_messages): + role = msg.get("role") + if role in ("system", "user"): + content = self._content_str(msg.get("content")).strip() + emit_special(self._start_header, i) + emit_text(role, i) + emit_special(self._end_header, i) + segs: list[tuple[str, bool]] = [("\n\n", False)] + if content: + segs.append((content, True)) + emit_text_segments(segs, i) + emit_special(self._eot, i) + elif role in ("tool", "ipython"): + tool_body = self._tool_response_str(msg.get("content")) + emit_special(self._start_header, i) + emit_text("ipython", i) + emit_special(self._end_header, i) + tool_segs: list[tuple[str, bool]] = [("\n\n", False)] + if tool_body: + tool_segs.append((tool_body, True)) + emit_text_segments(tool_segs, i) + emit_special(self._eot, i) + else: + return None + + # Generation prompt — matches the gen-prompt branch of ``render()``. + emit_special(self._start_header, -1) + emit_text("assistant", -1) + emit_special(self._end_header, -1) + emit_text("\n\n", -1) + + total_len = len(previous_ids) + len(ext) + return RenderedTokens( + token_ids=previous_ids + ext, + message_indices=[-1] * len(previous_ids) + ext_indices, + sampled_mask=[False] * total_len, + is_content=[False] * len(previous_ids) + ext_content, + message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), + ) diff --git a/renderers/parsing.py b/renderers/parsing.py index ee26e89..339b5d2 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -1270,3 +1270,70 @@ def _gptoss_extract_after_token( return None after = _decode(tokenizer, header_ids[pos + 1 :]).strip() return after.split()[0] if after else None + + +# ── Llama-3: single JSON tool call {"name": "...", "parameters": {...}} ─ + + +def parse_llama_3( + tokenizer, + token_ids: list[int], + *, + stop_ids: set[int], +) -> ParsedResponse: + """Parse Llama-3 completion tokens. + + The Llama-3 chat template emits tool calls as a single JSON blob in + the assistant body — ``{"name": "...", "parameters": {...}}`` — with + no surrounding XML tags or special tokens. Plain replies are just + text. We detect the tool-call shape with a strict starts-with-``{`` + + parses-as-dict-with-name-key check; anything else is treated as + content. Llama-3 doesn't have a built-in reasoning channel, so + ``reasoning_content`` is always ``None``. + + Unlike the delimiter-based formats (Qwen/GLM), the tool call has no + special token to anchor on, so a leading assistant role-header + (``<|start_header_id|>assistant<|end_header_id|>\\n\\n``) would defeat + the starts-with-``{`` check. Callers that slice a completion without + dropping the generation prompt include that scaffold; we skip past the + final ``<|end_header_id|>`` so the body is what we parse. The sampled + stream in production carries no header, making this a no-op there. + """ + ids = _strip_stop_tokens(token_ids, stop_ids) + + # Skip a leading assistant role-header scaffold if present. + body_start = 0 + end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") + if isinstance(end_header_id, int): + eh_positions = _find_all(ids, end_header_id) + if eh_positions: + body_start = eh_positions[-1] + 1 + body_ids = ids[body_start:] + text = _decode(tokenizer, body_ids).strip() + + if text.startswith("{") and text.endswith("}"): + try: + parsed = json.loads(text) + except json.JSONDecodeError: + parsed = None + if isinstance(parsed, dict) and parsed.get("name"): + arguments = parsed.get("parameters", parsed.get("arguments", {})) + return ParsedResponse( + content="", + reasoning_content=None, + tool_calls=[ + ParsedToolCall( + raw=text, + name=parsed["name"], + arguments=arguments, + token_span=(body_start, len(ids)), + status=ToolCallParseStatus.OK, + ) + ], + ) + + # Not a tool-call shape (plain reply, or a ``{...}`` body that didn't + # parse / lacked a name). Llama-3 has no delimiter to anchor a + # "malformed attempt" against, so it falls through to content rather + # than producing a non-OK ParsedToolCall. + return ParsedResponse(content=text, reasoning_content=None) diff --git a/tests/conftest.py b/tests/conftest.py index 4266487..2d360d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,11 @@ # Ultra resolves the Ultra template variant via name (auto → ultra=True). ("nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16", "auto"), ("poolside/Laguna-XS.2", "auto"), + # Llama-3 loads via the unrestricted unsloth mirror (byte-identical + # chat template) so CI needs no Meta-gated HF token. Pinned to the + # explicit "llama-3" config because the mirror name isn't in + # MODEL_RENDERER_MAP (so "auto" would fall back to DefaultRenderer). + ("unsloth/Llama-3.2-1B-Instruct", "llama-3"), ("openai/gpt-oss-20b", "gpt-oss"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] @@ -105,3 +110,30 @@ def _skip_gpt_oss_for_hf_parity_tests(request): f"{model_name}: renderer matches openai-harmony / vLLM, not HF " "apply_chat_template — see test_gpt_oss_harmony_parity.py" ) + + +# Llama-3's chat template fills the "Today Date:" line via ``strftime_now``, +# so ``apply_chat_template`` with no explicit ``date_string`` bakes in the +# real wall-clock date — non-deterministic and not byte-stable against a +# renderer pinned to "26 Jul 2024". Generic HF-parity tests can't pass a +# kwarg, so they're skipped here; deterministic byte-parity (with the date +# passed on both sides) is covered in test_llama_3.py. +_LLAMA_HF_PARITY_TEST_FILES = { + "test_render_ids.py", + "test_build_helpers.py", +} + + +@pytest.fixture(autouse=True) +def _skip_llama_for_hf_parity_tests(request): + callspec = getattr(request.node, "callspec", None) + model_name = callspec.params.get("model_name") if callspec else None + if model_name != "unsloth/Llama-3.2-1B-Instruct": + return + test_file = os.path.basename(str(request.node.fspath)) + if test_file in _LLAMA_HF_PARITY_TEST_FILES: + pytest.skip( + f"{model_name}: template uses strftime_now for the date line, so " + "generic apply_chat_template parity is non-deterministic — " + "deterministic byte-parity is covered in test_llama_3.py" + ) diff --git a/tests/test_bridge.py b/tests/test_bridge.py index 81ff2e4..8b7bf77 100644 --- a/tests/test_bridge.py +++ b/tests/test_bridge.py @@ -33,6 +33,7 @@ ("moonshotai/Kimi-K2-Instruct", "auto"), ("moonshotai/Kimi-K2.5", "auto"), ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), + ("unsloth/Llama-3.2-1B-Instruct", "llama-3"), ("openai/gpt-oss-20b", "gpt-oss"), ] diff --git a/tests/test_llama_3.py b/tests/test_llama_3.py new file mode 100644 index 0000000..018c2f3 --- /dev/null +++ b/tests/test_llama_3.py @@ -0,0 +1,407 @@ +"""Llama-3 renderer coverage. + +Covers ``Llama3Renderer`` and the ``meta-llama/Llama-3.2-{1B,3B}-Instruct`` +entries in ``MODEL_RENDERER_MAP``. Tokenizers are loaded via the +unrestricted ``unsloth/Llama-3.2-{1B,3B}-Instruct`` mirrors (verified +byte-identical chat templates) so CI doesn't need an HF token with Meta +license access. +""" + +from __future__ import annotations + +import pytest + +from renderers import Llama3Renderer, Llama3RendererConfig, create_renderer +from renderers.base import ( + MODEL_RENDERER_MAP, + ParsedResponse, + ToolCallParseStatus, + load_tokenizer, +) + +# Pinned date for byte-parity tests. Matches the chat template's +# strftime fallback so we don't have to override on the apply side. +_PINNED_DATE = "26 Jul 2024" + +_MODEL_PAIRS = [ + # (canonical meta-llama id used by MODEL_RENDERER_MAP, unrestricted + # mirror used to actually load the tokenizer in tests) + ("meta-llama/Llama-3.2-1B-Instruct", "unsloth/Llama-3.2-1B-Instruct"), + ("meta-llama/Llama-3.2-3B-Instruct", "unsloth/Llama-3.2-3B-Instruct"), +] + + +@pytest.fixture(scope="module", params=_MODEL_PAIRS, ids=[m for m, _ in _MODEL_PAIRS]) +def llama_pair(request): + canonical, mirror = request.param + tok = load_tokenizer(mirror) + renderer = Llama3Renderer(tok, Llama3RendererConfig(date_string=_PINNED_DATE)) + return canonical, mirror, tok, renderer + + +# --------------------------------------------------------------------------- +# MODEL_RENDERER_MAP shape +# --------------------------------------------------------------------------- + + +def test_canonical_meta_llama_paths_route_to_llama_3(): + for canonical, _ in _MODEL_PAIRS: + assert MODEL_RENDERER_MAP.get(canonical) == "llama-3", ( + f"{canonical}: expected to route to 'llama-3'" + ) + + +def test_create_renderer_via_explicit_config(llama_pair): + """``Llama3RendererConfig`` resolves to Llama3Renderer in the registry.""" + _, _, tok, _ = llama_pair + r = create_renderer(tok, Llama3RendererConfig()) + assert isinstance(r, Llama3Renderer) + + +# --------------------------------------------------------------------------- +# Constructor contract +# --------------------------------------------------------------------------- + + +def test_default_date_matches_chat_template_strftime_fallback(llama_pair): + """Default ``date_string`` is ``"26 Jul 2024"`` so output stays + deterministic without an explicit override.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer(tok) + assert r.config.date_string == _PINNED_DATE + + +def test_preserve_thinking_flags_are_noops(llama_pair): + """Llama-3 has no reasoning channel, so the ``preserve_*_thinking`` + flags are accepted but never change the token stream — the same + never-preserves contract as Kimi-K2 / Qwen3-VL. (Cross-renderer + coverage lives in tests/test_preserve_thinking.py.)""" + _, _, tok, _ = llama_pair + msgs = [ + {"role": "user", "content": "Hi."}, + { + "role": "assistant", + "reasoning_content": "internal musings", + "content": "Hello!", + }, + ] + base = Llama3Renderer(tok).render_ids(msgs) + for flag in ("preserve_all_thinking", "preserve_thinking_between_tool_calls"): + r = Llama3Renderer(tok, Llama3RendererConfig(**{flag: True})) + assert r.config.__getattribute__(flag) is True + assert r.render_ids(msgs) == base, f"{flag} must be a no-op for Llama-3" + + +# --------------------------------------------------------------------------- +# Byte parity vs apply_chat_template +# --------------------------------------------------------------------------- + + +def _expected(tok, messages, **kwargs): + kwargs.setdefault("add_generation_prompt", False) + kwargs.setdefault("date_string", _PINNED_DATE) + return list( + tok.apply_chat_template(messages, tokenize=True, return_dict=False, **kwargs) + ) + + +def test_parity_minimal_user(llama_pair): + _, _, tok, r = llama_pair + msgs = [{"role": "user", "content": "Hi."}] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_system_and_user(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_system_user_assistant(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + {"role": "assistant", "content": "Hello!"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_no_system_with_gen_prompt(llama_pair): + _, _, tok, r = llama_pair + msgs = [{"role": "user", "content": "Hi."}] + assert r.render_ids(msgs, add_generation_prompt=True) == _expected( + tok, msgs, add_generation_prompt=True + ) + + +def test_parity_multi_turn(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "D"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_trims_whitespace(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": " hello "}, + {"role": "assistant", "content": "\n\nworld\n"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_custom_date(llama_pair): + """``date_string`` constructor override changes both sides identically.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer(tok, Llama3RendererConfig(date_string="01 Jan 2026")) + msgs = [{"role": "user", "content": "Hi."}] + expected = list( + tok.apply_chat_template( + msgs, tokenize=True, return_dict=False, date_string="01 Jan 2026" + ) + ) + assert r.render_ids(msgs) == expected + + +def test_parity_tools_in_user_default(llama_pair): + _, _, tok, r = llama_pair + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + }, + } + ] + msgs = [ + {"role": "system", "content": "Be terse."}, + {"role": "user", "content": "Weather?"}, + ] + assert r.render_ids(msgs, tools=tools) == _expected(tok, msgs, tools=tools) + + +def test_parity_tools_in_system_mode(llama_pair): + """When constructed with ``tools_in_user_message=False``, the renderer + matches ``apply_chat_template(... tools_in_user_message=False)``.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer( + tok, + Llama3RendererConfig(date_string=_PINNED_DATE, tools_in_user_message=False), + ) + tools = [ + { + "type": "function", + "function": {"name": "get_weather", "parameters": {}}, + } + ] + msgs = [ + {"role": "system", "content": "Be terse."}, + {"role": "user", "content": "Weather?"}, + ] + expected = list( + tok.apply_chat_template( + msgs, + tokenize=True, + return_dict=False, + tools=tools, + tools_in_user_message=False, + date_string=_PINNED_DATE, + ) + ) + assert r.render_ids(msgs, tools=tools) == expected + + +def test_parity_tool_call_round_trip(llama_pair): + """Assistant tool_calls + tool response + final assistant — covers + the JSON tool-call body emission and the ``ipython`` response role.""" + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "Weather?"}, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"city": "NYC"}, + }, + } + ], + }, + {"role": "tool", "content": '{"temp": 72}'}, + {"role": "assistant", "content": "It's 72."}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_tool_response_dict_content(llama_pair): + """Tool response with mapping content goes through ``tojson`` in the + template; the renderer's ``_tool_response_str`` mirrors that.""" + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "x"}, + { + "role": "assistant", + "tool_calls": [{"function": {"name": "f", "arguments": {}}}], + }, + {"role": "tool", "content": {"k": "v", "n": 42}}, + {"role": "assistant", "content": "ok"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_render_raises_on_multiple_tool_calls(llama_pair): + """Llama-3 chat template explicitly raises on >1 tool call per turn — + the renderer mirrors that contract.""" + _, _, _, r = llama_pair + msgs = [ + {"role": "user", "content": "x"}, + { + "role": "assistant", + "tool_calls": [ + {"function": {"name": "f", "arguments": {}}}, + {"function": {"name": "g", "arguments": {}}}, + ], + }, + ] + with pytest.raises(ValueError, match="single tool call"): + r.render_ids(msgs) + + +# --------------------------------------------------------------------------- +# parse_response +# --------------------------------------------------------------------------- + + +def _tokens_for(tok, text: str) -> list[int]: + return tok.encode(text, add_special_tokens=False) + + +def test_parse_response_plain_content(llama_pair): + _, _, tok, r = llama_pair + ids = _tokens_for(tok, "Hello, world!") + [r._eot] + out = r.parse_response(ids) + assert isinstance(out, ParsedResponse) + assert out.content == "Hello, world!" + assert out.tool_calls == [] + assert out.reasoning_content is None + + +def test_parse_response_tool_call(llama_pair): + _, _, tok, r = llama_pair + body = '{"name": "get_weather", "parameters": {"city": "NYC"}}' + ids = _tokens_for(tok, body) + [r._eot] + out = r.parse_response(ids) + assert out.content == "" + assert len(out.tool_calls) == 1 + tc = out.tool_calls[0] + assert tc.status == ToolCallParseStatus.OK + assert tc.name == "get_weather" + assert tc.arguments == {"city": "NYC"} + + +def test_parse_response_malformed_tool_call_falls_through_to_content(llama_pair): + """A body that LOOKS like a tool call but doesn't parse should land + in content rather than dropping silently.""" + _, _, tok, r = llama_pair + body = '{"name": "x", broken' + ids = _tokens_for(tok, body) + [r._eot] + out = r.parse_response(ids) + assert out.tool_calls == [] + assert "{" in out.content + + +# --------------------------------------------------------------------------- +# Bridge contract +# --------------------------------------------------------------------------- + + +def _simulate_prior_turn(r): + prior = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + asst = [{"role": "assistant", "content": "Hello!"}] + + prev_prompt = r.render_ids(prior, add_generation_prompt=True) + full = r.render_ids(prior + asst, add_generation_prompt=False) + prev_completion = list(full[len(prev_prompt) :]) + + stop = set(r.get_stop_token_ids()) + last = -1 + for i in range(len(prev_completion) - 1, -1, -1): + if prev_completion[i] in stop: + last = i + break + if last >= 0: + prev_completion = prev_completion[: last + 1] + return prev_prompt, prev_completion + + +def test_bridge_extends_prev_verbatim_on_clean_stop(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + new_messages = [{"role": "user", "content": "What's 2+2?"}] + bridged = r.bridge_to_next_turn(prev_prompt, prev_completion, new_messages) + assert bridged is not None + prev = prev_prompt + prev_completion + assert bridged.token_ids[: len(prev)] == prev + assert len(bridged.token_ids) > len(prev) + + +def test_bridge_matches_fresh_render_on_clean_stop(llama_pair): + """The whole point of the bridge: it must produce the same tokens as + a fresh render of the full message list — except sampled tokens are + kept verbatim rather than re-rendered.""" + _, _, _, r = llama_pair + prior = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + asst = [{"role": "assistant", "content": "Hello!"}] + new_messages = [{"role": "user", "content": "What's 2+2?"}] + + prev_prompt, prev_completion = _simulate_prior_turn(r) + bridged = r.bridge_to_next_turn(prev_prompt, prev_completion, new_messages) + fresh = r.render_ids(prior + asst + new_messages, add_generation_prompt=True) + assert bridged.token_ids == fresh + + +def test_bridge_rejects_assistant_in_extension(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + bridged = r.bridge_to_next_turn( + prev_prompt, + prev_completion, + [{"role": "assistant", "content": "forbidden"}], + ) + assert bridged is None + + +def test_bridge_synthesises_close_on_truncation(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + trunc = prev_completion[:-1] + if not trunc: + pytest.skip("simulated prior had no completion tokens to truncate") + bridged = r.bridge_to_next_turn( + prev_prompt, trunc, [{"role": "user", "content": "ping"}] + ) + assert bridged is not None + base = prev_prompt + trunc + assert bridged.token_ids[: len(base)] == base + assert len(bridged.token_ids) > len(base) diff --git a/tests/test_preserve_thinking.py b/tests/test_preserve_thinking.py index 1ef07f5..75b739b 100644 --- a/tests/test_preserve_thinking.py +++ b/tests/test_preserve_thinking.py @@ -51,6 +51,9 @@ def _make(tokenizer, renderer_name, **flags): "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct", "poolside/Laguna-XS.2", + # Llama-3 has no reasoning channel at all — preserve flags can't add + # or drop anything, so they're pure no-ops. + "unsloth/Llama-3.2-1B-Instruct", } @@ -319,6 +322,9 @@ def test_preserve_btc_on_live_cycle_matches_all( "Qwen/Qwen3-VL-4B-Instruct", "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct", + # Llama-3 ships no rendering path, so reasoning_content never + # surfaces in the output regardless of the preserve flags. + "unsloth/Llama-3.2-1B-Instruct", } diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 7d7ee36..236dca7 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -47,6 +47,7 @@ # (no separating newline) — the Ultra-specific glue stresses the round-trip. ("nvidia/NVIDIA-Nemotron-3-Ultra-550B-A55B-BF16", "auto"), ("poolside/Laguna-XS.2", "auto"), + ("unsloth/Llama-3.2-1B-Instruct", "llama-3"), ("openai/gpt-oss-20b", "gpt-oss"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] @@ -210,6 +211,12 @@ def test_roundtrip_multiple_tool_calls( """Parsers that loop over ```` blocks can silently drop the second one; this test catches that.""" _maybe_skip_tool_calls(rt_renderer_name) + if rt_renderer_name == "llama-3": + pytest.skip( + "Llama-3's chat template forbids >1 tool call per assistant " + "message (the renderer raises, mirroring the template); the " + "single-call path is covered by test_roundtrip_single_tool_call." + ) msg = { "role": "assistant",