From c4a97c705458958d1867e4f3384963389b491b33 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Mon, 11 May 2026 09:48:51 +0800 Subject: [PATCH 1/3] Preserve exact tool-call text during prompt replay The OpenAI-compatible dflash server was parsing assistant tool-call output into structured JSON and then rebuilding those turns through the chat template on later requests. That preserved the semantics of the tool call, but not the exact text the model originally emitted. Small formatting differences in the rebuilt assistant turn can change tokenization, which makes prefix and KV reuse less stable for tool-using conversations. Fix this by keeping a small Python-side tool-memory store in the server path and using it during prompt reconstruction. The server now remembers the original assistant text for generated tool-call turns, keyed by tool-call IDs. When a later request sends those same tool calls back as structured history, the prompt tokenizer looks up the IDs and re-injects the original assistant text verbatim instead of re-rendering canonicalized tool-call objects. If a lookup is missing or inconsistent, the existing structured rendering path still applies. This change stays intentionally simple and Python-first. It adds a focused ToolMemory helper plus regression coverage for chat-completions and responses round trips, without introducing a native radix-tree backend before profiling shows the server-side store is a real bottleneck. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/scripts/server.py | 35 +++++++- dflash/scripts/test_server.py | 88 +++++++++++++++++++- dflash/scripts/test_tool_memory.py | 29 +++++++ dflash/scripts/tool_memory.py | 124 +++++++++++++++++++++++++++++ 4 files changed, 273 insertions(+), 3 deletions(-) create mode 100644 dflash/scripts/test_tool_memory.py create mode 100644 dflash/scripts/tool_memory.py diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py index 9e00d3d9..d88dab78 100644 --- a/dflash/scripts/server.py +++ b/dflash/scripts/server.py @@ -41,6 +41,7 @@ compress_text_via_daemon, _drain_until_sentinel, ) from prefix_cache import DaemonStdoutBus, PrefixCache +from tool_memory import ToolMemory class OpenAICompatError(Exception): @@ -697,6 +698,21 @@ def _resolve_kv_k_type(): ) if prefill_cfg is not None and prefill_cache_slots > 0: prefix_cache.init_full_cache(prefill_cache_slots, budget_bytes=prefill_cache_bytes) + tool_memory = ToolMemory( + max_entries=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_ENTRIES", "50000")), + max_bytes=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_BYTES", str(64 * 1024 * 1024))), + ) + + def _remember_tool_call_text(raw_text: str, tool_calls: list[dict] | None) -> None: + if not raw_text or not tool_calls: + return + call_ids = [ + tc.get("id") + for tc in tool_calls + if isinstance(tc, dict) and isinstance(tc.get("id"), str) and tc.get("id") + ] + if call_ids: + tool_memory.remember(call_ids, raw_text) @app.on_event("startup") async def _startup(): @@ -801,13 +817,18 @@ def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], boo msgs: list[dict] = [] for m in req.messages: d: dict = {"role": m.role} - if m.content is not None: + replay_raw_text = None + if m.role == "assistant" and m.tool_calls is not None: + replay_raw_text = tool_memory.lookup_message(m.tool_calls) + if replay_raw_text is not None: + d["content"] = replay_raw_text + elif m.content is not None: d["content"] = _content_to_str(m.content) if m.name is not None: d["name"] = m.name if m.tool_call_id is not None: d["tool_call_id"] = m.tool_call_id - if m.tool_calls is not None: + if m.tool_calls is not None and replay_raw_text is None: d["tool_calls"] = [] for tc in m.tool_calls: args = tc.function.arguments @@ -1173,6 +1194,7 @@ def chunk(delta_obj, finish=None): mode = "reasoning" if started_in_thinking else "content" window = "" tool_buffer = "" + accumulated_content = "" stops = normalize_stop(req.stop) tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG)) stop_holdback = max((len(s) for s in stops), default=0) @@ -1197,6 +1219,8 @@ def emit_delta(text, kind): window = window[:si] stop_hit = True kind = "reasoning_content" if mode == "reasoning" else "content" + if mode == "content": + accumulated_content += window out = emit_delta(window, kind) if out: yield out window = "" @@ -1236,6 +1260,7 @@ def emit_delta(text, kind): hits.sort() idx, which = hits[0] pre = window[:idx] + accumulated_content += pre out = emit_delta(pre, "content") if out: yield out if which == "think": @@ -1250,6 +1275,7 @@ def emit_delta(text, kind): continue if len(window) > HOLDBACK: safe = window[:-HOLDBACK] + accumulated_content += safe out = emit_delta(safe, "content") if out: yield out window = window[-HOLDBACK:] @@ -1277,6 +1303,7 @@ def emit_delta(text, kind): out = emit_delta(window, "reasoning_content") if out: yield out elif mode == "content" and window: + accumulated_content += window out = emit_delta(window, "content") if out: yield out elif mode == "tool_buffer": @@ -1287,6 +1314,7 @@ def emit_delta(text, kind): if mode == "tool_buffer": cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=req.tools) if tool_calls: + _remember_tool_call_text(accumulated_content + tool_buffer, tool_calls) if cleaned_after: out = emit_delta(cleaned_after, "content") if out: yield out @@ -1412,6 +1440,7 @@ def emit_delta(text, kind): if req.chat_template_kwargs: thinking_enabled = req.chat_template_kwargs.get("enable_thinking", True) cleaned, tool_calls = parse_tool_calls(text, tools=req.tools) + _remember_tool_call_text(text, tool_calls) cleaned, reasoning = parse_reasoning( cleaned, thinking_enabled=thinking_enabled, @@ -1909,6 +1938,7 @@ async def _responses_non_stream( if chat_req.chat_template_kwargs: thinking_enabled = chat_req.chat_template_kwargs.get("enable_thinking", True) cleaned, tool_calls = parse_tool_calls(text, tools=chat_req.tools) + _remember_tool_call_text(text, tool_calls) cleaned, reasoning = parse_reasoning( cleaned, thinking_enabled=thinking_enabled, started_in_thinking=started_in_thinking) @@ -2141,6 +2171,7 @@ async def sse() -> AsyncIterator[str]: if mode == "tool_buffer" and tool_buffer: cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=chat_req.tools) if tool_calls: + _remember_tool_call_text(accumulated_text + tool_buffer, tool_calls) if cleaned_after: accumulated_text += cleaned_after for tc in tool_calls: diff --git a/dflash/scripts/test_server.py b/dflash/scripts/test_server.py index 7bdfc569..02c4df7a 100644 --- a/dflash/scripts/test_server.py +++ b/dflash/scripts/test_server.py @@ -385,6 +385,51 @@ def test_chat_completions_non_streaming_with_tool_call(mock_os_read, mock_pipe, assert tc[0]["function"]["name"] == "read_file" +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_completions_replays_raw_tool_call_text(mock_os_read, mock_pipe, + mock_tokenizer, app): + mock_pipe.return_value = (1, 2) + raw_tool_text = ( + "Before\n" + "" + "test.py" + "\n" + "After" + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("' + 'file.txt' + '' + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("old") + mem.remember(["call_new"], "new") + + assert mem.lookup_message([{"id": "call_old"}]) is None + assert mem.lookup_message([{"id": "call_new"}]) == "new" + assert len(mem.by_block) == 1 + + +def test_tool_memory_lookup_message_requires_same_raw_text(): + mem = ToolMemory(max_entries=8, max_bytes=4096) + mem.remember(["call_a"], "a") + mem.remember(["call_b"], "b") + + assert mem.lookup_message([{"id": "call_a"}, {"id": "call_b"}]) is None diff --git a/dflash/scripts/tool_memory.py b/dflash/scripts/tool_memory.py new file mode 100644 index 00000000..58bb4a83 --- /dev/null +++ b/dflash/scripts/tool_memory.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Iterable, Sequence + + +@dataclass +class _ToolMemoryBlock: + raw_text: str + size_bytes: int + refs: int = 0 + + +class ToolMemory: + """Exact assistant-text memory for tool-calling turns. + + The server receives assistant tool calls back as structured JSON on later + turns. Re-rendering those objects can change key order, spacing, or wrapper + shape, which changes prompt tokenization and breaks prefix/KV reuse. This + store remembers the exact assistant text that originally produced a set of + tool-call IDs so replay can inject the original text verbatim. + """ + + def __init__(self, *, max_entries: int = 50_000, max_bytes: int = 64 * 1024 * 1024): + self.max_entries = max(0, int(max_entries)) + self.max_bytes = max(0, int(max_bytes)) + self.by_id: dict[str, _ToolMemoryBlock] = {} + self.by_block: dict[str, _ToolMemoryBlock] = {} + self._lru: OrderedDict[str, None] = OrderedDict() + self.total_bytes = 0 + + @property + def disabled(self) -> bool: + return self.max_entries == 0 or self.max_bytes == 0 + + def remember(self, call_ids: Iterable[str], raw_text: str) -> None: + if self.disabled or not raw_text: + return + unique_ids = [] + seen: set[str] = set() + for call_id in call_ids: + if not isinstance(call_id, str) or not call_id or call_id in seen: + continue + seen.add(call_id) + unique_ids.append(call_id) + if not unique_ids: + return + + block = self.by_block.get(raw_text) + if block is None: + block = _ToolMemoryBlock( + raw_text=raw_text, + size_bytes=len(raw_text.encode("utf-8")), + ) + self.by_block[raw_text] = block + self.total_bytes += block.size_bytes + + for call_id in unique_ids: + current = self.by_id.get(call_id) + if current is block: + self._touch(call_id) + continue + if current is not None: + self._drop_entry(call_id, current) + self.by_id[call_id] = block + block.refs += 1 + self._touch(call_id) + + self._prune() + + def lookup_message(self, tool_calls: Sequence[Any]) -> str | None: + raw_text: str | None = None + touched: list[str] = [] + for item in tool_calls: + call_id = self._extract_call_id(item) + if not call_id: + return None + block = self.by_id.get(call_id) + if block is None: + return None + touched.append(call_id) + if raw_text is None: + raw_text = block.raw_text + elif raw_text != block.raw_text: + return None + if raw_text is None: + return None + for call_id in touched: + self._touch(call_id) + return raw_text + + def _touch(self, call_id: str) -> None: + self._lru[call_id] = None + self._lru.move_to_end(call_id) + + def _prune(self) -> None: + while self.by_id and ( + (self.max_entries > 0 and len(self.by_id) > self.max_entries) + or (self.max_bytes > 0 and self.total_bytes > self.max_bytes) + ): + oldest_id, _ = self._lru.popitem(last=False) + block = self.by_id.get(oldest_id) + if block is not None: + self._drop_entry(oldest_id, block) + + def _drop_entry(self, call_id: str, block: _ToolMemoryBlock) -> None: + self.by_id.pop(call_id, None) + self._lru.pop(call_id, None) + if block.refs > 0: + block.refs -= 1 + if block.refs == 0: + self.by_block.pop(block.raw_text, None) + self.total_bytes -= block.size_bytes + if self.total_bytes < 0: + self.total_bytes = 0 + + @staticmethod + def _extract_call_id(item: Any) -> str | None: + if isinstance(item, dict): + call_id = item.get("id") + return call_id if isinstance(call_id, str) and call_id else None + call_id = getattr(item, "id", None) + return call_id if isinstance(call_id, str) and call_id else None From c38c3ba95e1924d327912d085d99ab2e50d14a0d Mon Sep 17 00:00:00 2001 From: Howard Su Date: Mon, 11 May 2026 23:14:18 +0800 Subject: [PATCH 2/3] Improve tool-call streaming and validation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/DEVELOPER.md | 2 +- dflash/README.md | 2 +- dflash/scripts/server.py | 1604 +++++++++++++++++++++++++-------- dflash/scripts/test_server.py | 1390 +++++++++++++++++++++++++--- pflash/README.md | 2 +- 5 files changed, 2542 insertions(+), 458 deletions(-) diff --git a/dflash/DEVELOPER.md b/dflash/DEVELOPER.md index 8f948899..1c9460a6 100644 --- a/dflash/DEVELOPER.md +++ b/dflash/DEVELOPER.md @@ -227,7 +227,7 @@ dflash/ │ └── draft/model.safetensors ├── scripts/ │ ├── server.py # Main OpenAI/Codex server -│ ├── server_tools.py # Legacy fork with tool calling (deprecated) +│ ├── server_tools.py # Legacy fork kept for reference; server.py is the tool/Codex path │ ├── prefix_cache.py # LRU prefix cache │ ├── _prefill_hook.py # Speculative prefill compression │ ├── run.py # CLI text generation diff --git a/dflash/README.md b/dflash/README.md index 63c6122f..b73fde56 100644 --- a/dflash/README.md +++ b/dflash/README.md @@ -119,7 +119,7 @@ allows capacity checks where the draft and a target layer range share one GPU before serving integration. `--target-split-dflash` runs the same split target placement through a chain DFlash decode loop and reports acceptance length. -**Python flags on `scripts/run.py`, `scripts/server.py`, `scripts/server_tools.py`:** +**Python flags on `scripts/run.py` and `scripts/server.py` (`scripts/server_tools.py` is legacy):** ```bash python3 scripts/run.py --ctk q8_0 --ctv q4_0 --prompt "hello" python3 scripts/run.py --cache-type-k q8_0 --cache-type-v q4_0 --prompt "hello" diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py index d88dab78..0232fae1 100644 --- a/dflash/scripts/server.py +++ b/dflash/scripts/server.py @@ -24,6 +24,7 @@ import tempfile import time import uuid +from dataclasses import dataclass, field from pathlib import Path from typing import Any, AsyncIterator @@ -180,6 +181,8 @@ def _tokenizer_id_from_gguf(gguf_path: Path) -> str: r"", re.DOTALL) TOOL_CODE_RE = re.compile(r"(.*?)", re.DOTALL) TOOL_OPEN_TAG = "" +FUNCTION_OPEN_TAG = " dict: @@ -573,6 +576,846 @@ class ResponsesCreateRequest(BaseModel): previous_response_id: str | None = None +@dataclass(frozen=True) +class ToolPolicy: + prompt_tools: list[Any] | None + parse_tools: list[Any] | None + choice_kind: str + choice_name: str | None = None + render_tools: bool = False + parse_tool_calls: bool = False + bypass_compression: bool = False + + +@dataclass +class _PartialToolCallSnapshot: + name: str + arguments: str + complete: bool = False + + +@dataclass +class _TrackedToolCallState: + index: int + id: str + name: str | None = None + emitted_arguments: str = "" + announced: bool = False + done: bool = False + + +@dataclass +class _SharedStreamState: + mode: str + holdback: int + allow_tools: bool + window: str = "" + tool_buffer: str = "" + raw_text: str = "" + visible_text: str = "" + stop_hit: bool = False + tool_states: list[_TrackedToolCallState] = field(default_factory=list) + final_tool_calls: list[dict] = field(default_factory=list) + + +def _trim_stream_raw_suffix(raw_text: str, suffix_len: int, keep_len: int) -> str: + trim_len = max(0, suffix_len - keep_len) + if trim_len <= 0: + return raw_text + return raw_text[:-trim_len] if trim_len < len(raw_text) else "" + + +def _tool_name(tool: Any) -> str | None: + fn = None + if hasattr(tool, "function"): + fn = tool.function + elif isinstance(tool, dict): + fn = tool.get("function") + if hasattr(fn, "model_dump"): + fn = fn.model_dump() + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str) and name: + return name + return None + + +def _normalize_tool_choice(tool_choice: Any) -> tuple[str, str | None]: + if tool_choice is None: + return "auto", None + if isinstance(tool_choice, str): + choice = tool_choice.strip().lower() + if choice in {"auto", "none", "required"}: + return choice, None + if isinstance(tool_choice, dict): + choice_type = tool_choice.get("type") + if isinstance(choice_type, str): + choice_type = choice_type.lower() + if choice_type in {"auto", "none", "required"}: + return choice_type, None + name = tool_choice.get("name") + if not isinstance(name, str): + fn = tool_choice.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if (choice_type == "function" or "function" in tool_choice or "name" in tool_choice) and isinstance(name, str) and name: + return "function", name + raise OpenAICompatError( + "Unsupported tool_choice value", + param="tool_choice", + ) + + +def _resolve_tool_policy(tools: list[Any] | None, tool_choice: Any) -> ToolPolicy: + requested_tools = list(tools or []) + has_tools = bool(requested_tools) + choice_kind, choice_name = _normalize_tool_choice(tool_choice) + + if choice_kind == "function": + if not has_tools: + raise OpenAICompatError( + "tool_choice=function requires tools", + param="tool_choice", + ) + selected = next((t for t in requested_tools if _tool_name(t) == choice_name), None) + if selected is None: + raise OpenAICompatError( + f"tool_choice function '{choice_name}' was not provided in tools", + param="tool_choice", + ) + return ToolPolicy( + prompt_tools=[selected], + parse_tools=requested_tools, + choice_kind=choice_kind, + choice_name=choice_name, + render_tools=True, + parse_tool_calls=True, + bypass_compression=True, + ) + + if choice_kind == "required" and not has_tools: + raise OpenAICompatError( + "tool_choice='required' requires tools", + param="tool_choice", + ) + + if choice_kind == "none": + return ToolPolicy( + prompt_tools=None, + parse_tools=None, + choice_kind=choice_kind, + render_tools=False, + parse_tool_calls=False, + bypass_compression=has_tools, + ) + + if has_tools: + return ToolPolicy( + prompt_tools=requested_tools, + parse_tools=requested_tools, + choice_kind=choice_kind, + render_tools=True, + parse_tool_calls=True, + bypass_compression=True, + ) + + return ToolPolicy( + prompt_tools=None, + parse_tools=None, + choice_kind=choice_kind, + render_tools=False, + parse_tool_calls=True, + bypass_compression=False, + ) + + +def _parse_generated_tool_calls(text: str, tool_policy: ToolPolicy) -> tuple[str, list[dict]]: + if not tool_policy.parse_tool_calls: + return text, [] + return parse_tool_calls(text, tools=tool_policy.parse_tools) + + +def _tool_choice_violation(tool_policy: ToolPolicy, tool_calls: list[dict]) -> OpenAICompatError | None: + if tool_policy.choice_kind in {"auto", "none"}: + return None + if tool_policy.choice_kind == "required": + if tool_calls: + return None + return OpenAICompatError( + "tool_choice='required' requires the model to emit a tool call", + param="tool_choice", + ) + if tool_policy.choice_kind == "function": + disallowed = [ + tc.get("function", {}).get("name") + for tc in tool_calls + if tc.get("function", {}).get("name") != tool_policy.choice_name + ] + disallowed = [name for name in disallowed if isinstance(name, str) and name] + if disallowed: + leaked = ", ".join(sorted(set(disallowed))) + return OpenAICompatError( + f"tool_choice function '{tool_policy.choice_name}' does not allow other tool calls ({leaked})", + param="tool_choice", + ) + if any(tc.get("function", {}).get("name") == tool_policy.choice_name for tc in tool_calls): + return None + return OpenAICompatError( + f"tool_choice function '{tool_policy.choice_name}' was not satisfied by model output", + param="tool_choice", + ) + return None + + +def _tool_param_type(tools, function_name: str, param_name: str) -> str: + params = _find_tool_properties(tools, function_name) + cfg = params.get(param_name, {}) if isinstance(params, dict) else {} + if isinstance(cfg, dict): + if isinstance(cfg.get("type"), str): + return cfg["type"].strip().lower() + if "anyOf" in cfg: + return "object" + return "string" + + +def _json_string_fragment(value: str) -> str: + return json.dumps(value, ensure_ascii=False)[1:-1] + + +def _find_toolish_json_start(text: str, start: int = 0) -> int: + """Return the earliest `{` that could begin a supported JSON tool call.""" + keys = ("name", "function", "arguments") + cursor = start + while cursor < len(text): + brace = text.find("{", cursor) + if brace == -1: + return -1 + key_cursor = brace + 1 + while key_cursor < len(text) and text[key_cursor].isspace(): + key_cursor += 1 + if key_cursor >= len(text): + return brace + if text[key_cursor] != '"': + cursor = brace + 1 + continue + key_start = key_cursor + 1 + key_end = key_start + while key_end < len(text) and text[key_end] not in {'"', '\\'}: + key_end += 1 + if key_end >= len(text): + fragment = text[key_start:] + if any(k.startswith(fragment) for k in keys): + return brace + cursor = brace + 1 + continue + if text[key_end] == "\\": + cursor = brace + 1 + continue + key = text[key_start:key_end] + if not any(k.startswith(key) for k in keys): + cursor = brace + 1 + continue + colon_cursor = key_end + 1 + while colon_cursor < len(text) and text[colon_cursor].isspace(): + colon_cursor += 1 + if colon_cursor >= len(text) or text[colon_cursor] == ":": + return brace + cursor = brace + 1 + return -1 + + +def _json_tool_call_spans(text: str) -> list[tuple[int, int]]: + spans: list[tuple[int, int]] = [] + decoder = json.JSONDecoder() + cursor = 0 + while cursor < len(text): + start = text.find("{", cursor) + if start == -1: + break + try: + obj, consumed = decoder.raw_decode(text[start:]) + except json.JSONDecodeError: + cursor = start + 1 + continue + if _parse_json_tool_call(obj) is not None: + spans.append((start, start + consumed)) + cursor = start + max(consumed, 1) + return spans + + +def _build_partial_xml_arguments(text: str, cursor: int, function_name: str, tools) -> tuple[str, bool, int]: + param_config = _find_tool_properties(tools, function_name) + fragments: list[str] = [] + complete = False + + while cursor < len(text): + while cursor < len(text) and text[cursor].isspace(): + cursor += 1 + if text.startswith("", cursor): + cursor += len("") + complete = True + break + if text.startswith("", cursor): + cursor += len("") + continue + if not text.startswith("", cursor) + if name_end == -1: + break + param_name = text[cursor:name_end].strip() + cursor = name_end + 1 + + end_tag = text.find("", cursor) + next_param = text.find("", cursor) + candidates = [(pos, kind) for pos, kind in ( + (end_tag, "parameter"), + (next_param, "next_param"), + (close_fn, "function"), + ) if pos != -1] + if candidates: + value_end, end_kind = min(candidates, key=lambda item: item[0]) + else: + value_end, end_kind = len(text), "eof" + raw_value = text[cursor:value_end] + if raw_value.startswith("\n"): + raw_value = raw_value[1:] + if end_kind == "parameter" and raw_value.endswith("\n"): + raw_value = raw_value[:-1] + + if fragments: + fragments.append(",") + else: + fragments.append("{") + fragments.append(json.dumps(param_name, ensure_ascii=False)) + fragments.append(":") + + if _tool_param_type(tools, function_name, param_name) == "string": + fragments.append('"') + fragments.append(_json_string_fragment(raw_value)) + if end_kind == "parameter": + fragments.append('"') + cursor = value_end + len("") + continue + break + + if end_kind == "parameter": + value_obj = _convert_param_value(raw_value, param_name, param_config, function_name) + fragments.append(json.dumps(value_obj, ensure_ascii=False, separators=(",", ":"))) + cursor = value_end + len("") + continue + + del fragments[prefix_len:] + break + + if complete: + if not fragments: + return "{}", True, cursor + fragments.append("}") + return "".join(fragments), complete, cursor + + +def _partial_tool_call_snapshots(text: str, tools=None) -> list[_PartialToolCallSnapshot]: + snapshots: list[_PartialToolCallSnapshot] = [] + decoder = json.JSONDecoder() + cursor = 0 + while cursor < len(text): + hits = [(i, kind) for i, kind in ( + (text.find(TOOL_OPEN_TAG, cursor), "tool"), + (text.find(FUNCTION_OPEN_TAG, cursor), "function"), + (text.find(TOOL_CODE_OPEN_TAG, cursor), "tool_code"), + (_find_toolish_json_start(text, cursor), "json"), + ) if i != -1] + if not hits: + break + start, kind = min(hits, key=lambda item: item[0]) + cursor = start + if kind == "tool_code": + close = text.find("", cursor + len(TOOL_CODE_OPEN_TAG)) + if close == -1: + break + try: + obj = json.loads(text[cursor + len(TOOL_CODE_OPEN_TAG):close].strip()) + except json.JSONDecodeError: + cursor = close + len("") + continue + parsed = _parse_json_tool_call(obj) + if parsed is not None and _tool_allowed(tools, parsed[0]): + snapshots.append(_PartialToolCallSnapshot( + name=parsed[0], + arguments=json.dumps(parsed[1], ensure_ascii=False, separators=(",", ":")), + complete=True, + )) + cursor = close + len("") + continue + if kind == "json": + try: + obj, consumed = decoder.raw_decode(text[cursor:]) + except json.JSONDecodeError: + break + parsed = _parse_json_tool_call(obj) + if parsed is not None and _tool_allowed(tools, parsed[0]): + snapshots.append(_PartialToolCallSnapshot( + name=parsed[0], + arguments=json.dumps(parsed[1], ensure_ascii=False, separators=(",", ":")), + complete=True, + )) + cursor += max(consumed, 1) + continue + if text.startswith(TOOL_OPEN_TAG, cursor): + cursor += len(TOOL_OPEN_TAG) + while cursor < len(text) and text[cursor].isspace(): + cursor += 1 + if not text.startswith("", fn_cursor) + signature_end = text.find("(", fn_cursor) + if signature_end != -1 and (header_end == -1 or signature_end < header_end): + close = text.find("", signature_end) + if close == -1: + break + _, calls = parse_tool_calls(text[start:close + len("")], tools=tools) + if calls: + snapshots.append(_PartialToolCallSnapshot( + name=calls[0]["function"]["name"], + arguments=calls[0]["function"]["arguments"], + complete=True, + )) + cursor = close + len("") + continue + if header_end == -1: + break + + function_name = text[fn_cursor:header_end].strip() + if not function_name: + cursor = start + 1 + continue + if not _tool_allowed(tools, function_name): + close = text.find("", header_end + 1) + cursor = close + len("") if close != -1 else header_end + 1 + continue + + arguments, complete, body_end = _build_partial_xml_arguments( + text, header_end + 1, function_name, tools) + snapshots.append(_PartialToolCallSnapshot( + name=function_name, + arguments=arguments, + complete=complete, + )) + cursor = max(body_end, header_end + 1) + if not complete: + break + tool_close = text.find("", cursor) + if tool_close != -1: + cursor = max(cursor, tool_close + len("")) + return snapshots + + +def _complete_tool_buffer_spans(text: str) -> list[tuple[int, int]]: + spans: list[tuple[int, int]] = [] + for match in TOOL_CALL_COMPLETE_RE.finditer(text): + spans.append((match.start(), match.end())) + for pattern in (BARE_FUNCTION_XML_RE, FUNCTION_SIGNATURE_RE): + for match in pattern.finditer(text): + if any(lo <= match.start() < hi for lo, hi in spans): + continue + spans.append((match.start(), match.end())) + for match in TOOL_CODE_RE.finditer(text): + try: + obj = json.loads(match.group(1).strip()) + except json.JSONDecodeError: + continue + if _parse_json_tool_call(obj) is None: + continue + if any(lo <= match.start() < hi for lo, hi in spans): + continue + spans.append((match.start(), match.end())) + for start, end in _json_tool_call_spans(text): + if any(lo <= start < hi for lo, hi in spans): + continue + spans.append((start, end)) + spans.sort() + return spans + + +def _first_tool_buffer_stop_match(text: str, stops: list[str]) -> int: + spans = _complete_tool_buffer_spans(text) + if not spans: + return -1 + + def stop_in_gap(start: int, end: int) -> int: + json_open = _find_toolish_json_start(text[start:end]) + if json_open != -1: + json_open += start + next_open = min( + (pos for pos in ( + text.find(TOOL_OPEN_TAG, start, end), + text.find(FUNCTION_OPEN_TAG, start, end), + text.find(TOOL_CODE_OPEN_TAG, start, end), + json_open, + ) if pos != -1), + default=-1, + ) + if next_open != -1: + end = next_open + if start >= end: + return -1 + gap_index = first_stop_match(text[start:end], stops) + return start + gap_index if gap_index != -1 else -1 + + cursor = 0 + for start, end in spans: + stop_index = stop_in_gap(cursor, start) + if stop_index != -1: + return stop_index + cursor = max(cursor, end) + return stop_in_gap(cursor, len(text)) + + +def _sync_partial_tool_stream(tool_buffer: str, tools, states: list[_TrackedToolCallState]) -> list[dict]: + events: list[dict] = [] + for index, snapshot in enumerate(_partial_tool_call_snapshots(tool_buffer, tools=tools)): + while len(states) <= index: + states.append(_TrackedToolCallState( + index=len(states), + id="call_" + uuid.uuid4().hex[:24], + )) + state = states[index] + if snapshot.name and state.name is None: + state.name = snapshot.name + if state.name and not state.announced: + state.announced = True + events.append({ + "kind": "tool_start", + "index": index, + "id": state.id, + "name": state.name, + }) + if snapshot.arguments: + fragment = snapshot.arguments + if snapshot.arguments.startswith(state.emitted_arguments): + fragment = snapshot.arguments[len(state.emitted_arguments):] + if fragment: + state.emitted_arguments = snapshot.arguments + events.append({ + "kind": "tool_args", + "index": index, + "id": state.id, + "name": state.name, + "arguments": fragment, + }) + if snapshot.complete and not state.done: + state.done = True + events.append({ + "kind": "tool_done", + "index": index, + "id": state.id, + "name": state.name, + "arguments": snapshot.arguments or "{}", + }) + return events + + +def _reconcile_final_tool_events(tool_calls: list[dict], states: list[_TrackedToolCallState]) -> list[dict]: + events: list[dict] = [] + for index, tc in enumerate(tool_calls): + while len(states) <= index: + states.append(_TrackedToolCallState( + index=len(states), + id="call_" + uuid.uuid4().hex[:24], + )) + state = states[index] + state.name = tc["function"]["name"] + tc["id"] = state.id + if not state.announced: + state.announced = True + events.append({ + "kind": "tool_start", + "index": index, + "id": state.id, + "name": state.name, + }) + final_args = tc["function"]["arguments"] + fragment = final_args + if final_args.startswith(state.emitted_arguments): + fragment = final_args[len(state.emitted_arguments):] + if fragment: + state.emitted_arguments = final_args + events.append({ + "kind": "tool_args", + "index": index, + "id": state.id, + "name": state.name, + "arguments": fragment, + }) + if not state.done: + state.done = True + events.append({ + "kind": "tool_done", + "index": index, + "id": state.id, + "name": state.name, + "arguments": final_args, + }) + return events + + +def _finalize_tool_buffer_stream( + state: _SharedStreamState, + tool_policy: ToolPolicy, +) -> tuple[list[dict], list[dict]]: + events: list[dict] = [] + tool_calls: list[dict] = [] + cleaned_after, tool_calls = _parse_generated_tool_calls(state.tool_buffer, tool_policy) + if tool_calls: + events.extend(_reconcile_final_tool_events(tool_calls, state.tool_states)) + if cleaned_after: + state.visible_text += cleaned_after + events.append({"kind": "content", "text": cleaned_after}) + elif state.tool_buffer: + state.visible_text += state.tool_buffer + events.append({"kind": "content", "text": state.tool_buffer}) + state.final_tool_calls = tool_calls + state.tool_buffer = "" + state.mode = "content" + return events, tool_calls + + +def _make_shared_stream_state(started_in_thinking: bool, stops: list[str], tool_policy: ToolPolicy) -> _SharedStreamState: + tag_holdback = max( + len(THINK_OPEN_TAG), + len(THINK_CLOSE_TAG), + len(TOOL_OPEN_TAG) if tool_policy.parse_tool_calls else 0, + len(FUNCTION_OPEN_TAG) if tool_policy.parse_tool_calls else 0, + len(TOOL_CODE_OPEN_TAG) if tool_policy.parse_tool_calls else 0, + ) + stop_holdback = max((len(s) for s in stops), default=0) + return _SharedStreamState( + mode="reasoning" if started_in_thinking else "content", + holdback=max(tag_holdback, stop_holdback), + allow_tools=tool_policy.parse_tool_calls, + ) + + +def _feed_shared_stream_piece( + state: _SharedStreamState, + piece: str, + *, + stops: list[str], + tool_policy: ToolPolicy, +) -> list[dict]: + events: list[dict] = [] + state.raw_text += piece + state.window += piece + + if stops: + if state.mode == "tool_buffer": + combined = state.tool_buffer + state.window + stop_index = _first_tool_buffer_stop_match(combined, stops) + if stop_index != -1: + state.raw_text = _trim_stream_raw_suffix( + state.raw_text, len(combined), stop_index) + state.tool_buffer = combined[:stop_index] + state.window = "" + state.stop_hit = True + else: + window_len = len(state.window) + stop_index = first_stop_match(state.window, stops) + if stop_index != -1: + state.raw_text = _trim_stream_raw_suffix( + state.raw_text, window_len, stop_index) + state.window = state.window[:stop_index] + state.stop_hit = True + + while True: + if state.mode == "tool_buffer": + state.tool_buffer += state.window + state.window = "" + events.extend(_sync_partial_tool_stream( + state.tool_buffer, tool_policy.parse_tools, state.tool_states)) + if state.stop_hit: + finalized_events, _ = _finalize_tool_buffer_stream(state, tool_policy) + events.extend(finalized_events) + break + + if state.mode == "reasoning": + idx = state.window.find(THINK_CLOSE_TAG) + if idx != -1: + pre = state.window[:idx] + if pre: + events.append({"kind": "reasoning", "text": pre}) + state.window = state.window[idx + len(THINK_CLOSE_TAG):] + state.mode = "content" + continue + if len(state.window) > state.holdback or (state.stop_hit and state.window): + safe = state.window if state.stop_hit else state.window[:-state.holdback] + if safe: + events.append({"kind": "reasoning", "text": safe}) + state.window = "" if state.stop_hit else state.window[-state.holdback:] + break + + think_idx = state.window.find(THINK_OPEN_TAG) + tool_idx = state.window.find(TOOL_OPEN_TAG) if state.allow_tools else -1 + bare_fn_idx = state.window.find(FUNCTION_OPEN_TAG) if state.allow_tools else -1 + tool_code_idx = state.window.find(TOOL_CODE_OPEN_TAG) if state.allow_tools else -1 + json_tool_idx = _find_toolish_json_start(state.window) if state.allow_tools else -1 + hits = [(i, t) for i, t in ( + (think_idx, "think"), + (tool_idx, "tool"), + (bare_fn_idx, "tool"), + (tool_code_idx, "tool"), + (json_tool_idx, "tool"), + ) if i != -1] + if hits: + hits.sort() + idx, which = hits[0] + pre = state.window[:idx] + if pre: + state.visible_text += pre + events.append({"kind": "content", "text": pre}) + if which == "think": + state.window = state.window[idx + len(THINK_OPEN_TAG):] + state.mode = "reasoning" + else: + state.tool_buffer += state.window[idx:] + state.window = "" + state.mode = "tool_buffer" + events.extend(_sync_partial_tool_stream( + state.tool_buffer, tool_policy.parse_tools, state.tool_states)) + if state.stop_hit: + finalized_events, _ = _finalize_tool_buffer_stream(state, tool_policy) + events.extend(finalized_events) + continue + if len(state.window) > state.holdback or (state.stop_hit and state.window): + safe = state.window if state.stop_hit else state.window[:-state.holdback] + if safe: + state.visible_text += safe + events.append({"kind": "content", "text": safe}) + state.window = "" if state.stop_hit else state.window[-state.holdback:] + break + return events + + +def _flush_shared_stream( + state: _SharedStreamState, + *, + tool_policy: ToolPolicy, +) -> tuple[list[dict], list[dict]]: + events: list[dict] = [] + tool_calls: list[dict] = [] + + if state.mode == "reasoning" and state.window: + events.append({"kind": "reasoning", "text": state.window}) + elif state.mode == "content" and state.window: + state.visible_text += state.window + events.append({"kind": "content", "text": state.window}) + elif state.mode == "tool_buffer": + state.tool_buffer += state.window + events.extend(_sync_partial_tool_stream( + state.tool_buffer, tool_policy.parse_tools, state.tool_states)) + state.window = "" + + if state.mode == "tool_buffer": + finalized_events, tool_calls = _finalize_tool_buffer_stream(state, tool_policy) + events.extend(finalized_events) + + return events, tool_calls + + +def _parse_responses_non_stream_output( + text: str, + tool_policy: ToolPolicy, + msg_item_id: str, + *, + started_in_thinking: bool, +) -> tuple[str, list[dict], list[tuple[str, str]]]: + state = _make_shared_stream_state(started_in_thinking, [], tool_policy) + output_order: list[tuple[str, str]] = [] + message_announced = False + + def record_output_order(events: list[dict]) -> None: + nonlocal message_announced + for event in events: + if event["kind"] == "content": + if not message_announced: + output_order.append(("message", msg_item_id)) + message_announced = True + continue + if event["kind"] == "tool_start": + output_order.append(("function_call", event["id"])) + + stream_events = _feed_shared_stream_piece( + state, text, stops=[], tool_policy=tool_policy) + record_output_order(stream_events) + flush_events, tool_calls = _flush_shared_stream(state, tool_policy=tool_policy) + record_output_order(flush_events) + return state.visible_text.strip(), tool_calls, output_order + + +def _responses_message_item(msg_item_id: str, text: str, *, status: str = "completed") -> dict: + return { + "type": "message", + "id": msg_item_id, + "status": status, + "role": "assistant", + "content": [{"type": "output_text", "text": text, "annotations": []}], + } + + +def _responses_function_call_item(tc: dict, *, status: str = "completed") -> dict: + return { + "type": "function_call", + "id": tc["id"], + "status": status, + "call_id": tc["id"], + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + } + + +def _build_responses_output( + cleaned_text: str, + tool_calls: list[dict], + msg_item_id: str, + *, + output_order: list[tuple[str, str]] | None = None, +) -> list[dict]: + items: dict[tuple[str, str], dict] = {} + if cleaned_text or not tool_calls: + items[("message", msg_item_id)] = _responses_message_item(msg_item_id, cleaned_text) + for tc in tool_calls: + items[("function_call", tc["id"])] = _responses_function_call_item(tc) + + if not output_order: + output: list[dict] = [] + if ("message", msg_item_id) in items: + output.append(items.pop(("message", msg_item_id))) + for tc in tool_calls: + item = items.pop(("function_call", tc["id"]), None) + if item is not None: + output.append(item) + return output + + output = [] + seen: set[tuple[str, str]] = set() + for key in output_order: + item = items.get(key) + if item is not None and key not in seen: + output.append(item) + seen.add(key) + for key, item in items.items(): + if key not in seen: + output.append(item) + return output + + def _samp_suffix(req) -> str: # Render ` samp=temp,top_p,top_k,rep_pen[,seed]` tail when the request asks for # non-greedy decoding. Empty string keeps the daemon protocol greedy-compatible. @@ -697,7 +1540,7 @@ def _resolve_kv_k_type(): cap=prefix_cache_slots, ) if prefill_cfg is not None and prefill_cache_slots > 0: - prefix_cache.init_full_cache(prefill_cache_slots, budget_bytes=prefill_cache_bytes) + prefix_cache.init_full_cache(prefill_cache_slots) tool_memory = ToolMemory( max_entries=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_ENTRIES", "50000")), max_bytes=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_BYTES", str(64 * 1024 * 1024))), @@ -812,7 +1655,7 @@ def _render_messages(msgs_list: list[dict], param="messages") return _ids_to_bin(ids), ids, prompt - def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], bool]: + def _tokenize_prompt(req: ChatRequest, tool_policy: ToolPolicy) -> tuple[Path, list[int], list[dict], bool]: """Returns (bin, ids, raw_msgs, started_in_thinking).""" msgs: list[dict] = [] for m in req.messages: @@ -844,16 +1687,23 @@ def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], boo msgs.append(d) tools_arg = None - if req.tools: - tools_arg = [t.model_dump() for t in req.tools] + if tool_policy.render_tools and tool_policy.prompt_tools: + tools_arg = [ + t.model_dump() if hasattr(t, "model_dump") else t + for t in tool_policy.prompt_tools + ] path, ids, _prompt = _render_messages(msgs, req.chat_template_kwargs, tools_arg) started_in_thinking = bool(re.search(r"\s*$", _prompt)) return path, ids, msgs, started_in_thinking def _maybe_compress(msgs: list[dict], prompt_bin: Path, prompt_ids: list[int], - template_kwargs: dict | None = None + template_kwargs: dict | None = None, + *, + bypass: bool = False, ) -> tuple[Path, list[int]]: + if bypass: + return prompt_bin, prompt_ids if not prefill_cfg or not prefill_cfg.enabled: return prompt_bin, prompt_ids if not prefill_cfg.should_compress(len(prompt_ids)): @@ -921,6 +1771,15 @@ async def _collect_tokens_sync(r, n_gen, timing=None) -> list[int]: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, lambda: list(_token_stream(r, n_gen, timing))) + async def _adrain_until_sentinel(r, timing=None) -> None: + if timing is not None and timing.get("daemon_done"): + return + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _drain_until_sentinel, r) + if timing is not None: + timing["daemon_done"] = True + timing["t_last_tok"] = time.monotonic() + async def _astream_tokens(r, n_gen, timing=None): generated = 0 hit_stop = False @@ -965,16 +1824,19 @@ def _write_cmd(cmd_line: str, timing=None): sz = bin_path.stat().st_size if sz == 0: log.warning("prompt .bin is 0 bytes: %s", bin_path) - if lazy_draft: - log.debug("lazy-draft: unpark draft before generate") - t = time.monotonic() - daemon_proc.stdin.write(b"unpark draft\n") + try: + if lazy_draft: + log.debug("lazy-draft: unpark draft before generate") + t = time.monotonic() + daemon_proc.stdin.write(b"unpark draft\n") + daemon_proc.stdin.flush() + _drain_until_sentinel(r_pipe) + if timing is not None: + timing["unpark"] = time.monotonic() - t + daemon_proc.stdin.write(cmd_line.encode("utf-8")) daemon_proc.stdin.flush() - _drain_until_sentinel(r_pipe) - if timing is not None: - timing["unpark"] = time.monotonic() - t - daemon_proc.stdin.write(cmd_line.encode("utf-8")) - daemon_proc.stdin.flush() + except (BrokenPipeError, OSError, ValueError) as exc: + raise RuntimeError(f"failed to send command to dflash daemon: {exc}") from exc if timing is not None: timing["t_cmd_sent"] = time.monotonic() @@ -1069,8 +1931,15 @@ def _build_cmd_line(req, cur_bin, cur_ids, gen_len, prefix_cache, cmd_line += f" snap={snap_prep[1]}:{snap_prep[0]}" return cmd_line + _samp_suffix(req) + "\n", snap_prep + def _abort_reserved_snap(full_snap_prep, snap_prep): + if full_snap_prep is not None: + fslot, _ = full_snap_prep + prefix_cache.abort_full_snap(fslot) + elif snap_prep: + prefix_cache.abort_inline_snap(snap_prep[0]) + def _confirm_or_abort_snap(n_tokens: int, full_snap_prep, snap_prep, - prompt_ids, cur_bin, cur_ids): + prompt_ids, cur_bin, cur_ids): """Confirm prefix-cache snapshots only when the daemon actually generated tokens. When the daemon returns 0 tokens (e.g. empty prompt / file read failure), confirming would register a snapshot @@ -1084,11 +1953,7 @@ def _confirm_or_abort_snap(n_tokens: int, full_snap_prep, snap_prep, prefix_cache.confirm_inline_snap(*snap_prep, cur_ids) else: # Abort: release the reservation without registering. - if full_snap_prep is not None: - fslot, _ = full_snap_prep - prefix_cache.abort_full_snap(fslot) - elif snap_prep: - prefix_cache.abort_inline_snap(snap_prep[0]) + _abort_reserved_snap(full_snap_prep, snap_prep) log.warning("0 output tokens — aborted snapshot reservation") def _gen_len_for(prompt_len: int, max_tokens: int) -> int: @@ -1098,7 +1963,8 @@ def _gen_len_for(prompt_len: int, max_tokens: int) -> int: @app.post("/v1/chat/completions") async def chat_completions(req: ChatRequest): - prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(req) + tool_policy = _resolve_tool_policy(req.tools, req.tool_choice) + prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(req, tool_policy) completion_id = "chatcmpl-" + uuid.uuid4().hex[:24] created = int(time.time()) prompt_len = len(prompt_ids) @@ -1117,18 +1983,19 @@ async def chat_completions(req: ChatRequest): if req.stream: async def sse() -> AsyncIterator[str]: - nonlocal started_in_thinking + nonlocal started_in_thinking, prompt_len async with daemon_lock: timing = {} full_snap_prep_ref = [None] snap_prep = None + cur_bin = prompt_bin + cur_ids = prompt_ids - full_hit = prefix_cache.lookup_full(prompt_ids) + full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids) if full_hit is not None: slot, cached_cur_bin, cached_cur_ids_len = full_hit cur_bin = Path(cached_cur_bin) prompt_len = cached_cur_ids_len - started_in_thinking = False # cached: no think prefill gen_len = _gen_len_for(prompt_len, req.max_tokens) if gen_len <= 0: try: prompt_bin.unlink() @@ -1145,7 +2012,8 @@ async def sse() -> AsyncIterator[str]: t_compress = time.monotonic() cur_bin, cur_ids = await asyncio.to_thread( _maybe_compress, raw_msgs, prompt_bin, prompt_ids, - req.chat_template_kwargs) + req.chat_template_kwargs, + bypass=tool_policy.bypass_compression) timing["compress"] = time.monotonic() - t_compress prompt_len = len(cur_ids) gen_len = _gen_len_for(prompt_len, req.max_tokens) @@ -1168,6 +2036,7 @@ async def sse() -> AsyncIterator[str]: try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) yield f"data: {json.dumps({'error': str(e)})}\n\n" yield "data: [DONE]\n\n" return @@ -1180,8 +2049,6 @@ async def sse() -> AsyncIterator[str]: "finish_reason": None}], } yield f"data: {json.dumps(head)}\n\n" - window, mode = "", ("reasoning" if started_in_thinking else "content") - include_usage = bool(req.stream_options and req.stream_options.get("include_usage")) def chunk(delta_obj, finish=None): @@ -1190,99 +2057,91 @@ def chunk(delta_obj, finish=None): "choices": [{"index": 0, "delta": delta_obj, "finish_reason": finish}]} - # State machine: mode ∈ {'reasoning', 'content', 'tool_buffer'} - mode = "reasoning" if started_in_thinking else "content" - window = "" - tool_buffer = "" - accumulated_content = "" + def tool_chunk(index: int, *, call_id: str | None = None, + name: str | None = None, + arguments: str | None = None) -> str: + tc: dict[str, Any] = {"index": index} + if call_id is not None: + tc["id"] = call_id + tc["type"] = "function" + fn: dict[str, Any] = {} + if name is not None: + fn["name"] = name + if arguments is not None: + fn["arguments"] = arguments + if fn: + tc["function"] = fn + return f"data: {json.dumps(chunk({'tool_calls': [tc]}))}\n\n" + stops = normalize_stop(req.stop) - tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG)) - stop_holdback = max((len(s) for s in stops), default=0) - HOLDBACK = max(tag_holdback, stop_holdback) + stream_state = _make_shared_stream_state(started_in_thinking, stops, tool_policy) completion_tokens = 0 - stop_hit = False - - def emit_delta(text, kind): - if not text: - return None - return f"data: {json.dumps(chunk({kind: text}))}\n\n" + finish_reason = "stop" + tool_calls: list[dict] = [] try: async for tok_id in _astream_tokens(r_pipe, gen_len, timing): completion_tokens += 1 piece = tokenizer.decode([tok_id]) - window += piece - - if stops and mode != "tool_buffer": - si = first_stop_match(window, stops) - if si != -1: - window = window[:si] - stop_hit = True - kind = "reasoning_content" if mode == "reasoning" else "content" - if mode == "content": - accumulated_content += window - out = emit_delta(window, kind) - if out: yield out - window = "" - break - - while True: - if mode == "tool_buffer": - tool_buffer += window - window = "" - break - - if mode == "reasoning": - idx = window.find(THINK_CLOSE_TAG) - if idx != -1: - pre = window[:idx] - out = emit_delta(pre, "reasoning_content") - if out: yield out - window = window[idx + len(THINK_CLOSE_TAG):] - mode = "content" - continue - if len(window) > HOLDBACK: - safe = window[:-HOLDBACK] - out = emit_delta(safe, "reasoning_content") - if out: yield out - window = window[-HOLDBACK:] - break - - else: # mode == "content" - think_idx = window.find(THINK_OPEN_TAG) - think_close_idx = window.find(THINK_CLOSE_TAG) - tool_idx = window.find(TOOL_OPEN_TAG) - hits = [(i, t) for i, t in - ((think_idx, "think"), - (think_close_idx, "think_close"), - (tool_idx, "tool")) if i != -1] - if hits: - hits.sort() - idx, which = hits[0] - pre = window[:idx] - accumulated_content += pre - out = emit_delta(pre, "content") - if out: yield out - if which == "think": - window = window[idx + len(THINK_OPEN_TAG):] - mode = "reasoning" - elif which == "think_close": - window = window[idx + len(THINK_CLOSE_TAG):] - else: - tool_buffer = window[idx:] - window = "" - mode = "tool_buffer" - continue - if len(window) > HOLDBACK: - safe = window[:-HOLDBACK] - accumulated_content += safe - out = emit_delta(safe, "content") - if out: yield out - window = window[-HOLDBACK:] - break - - if stop_hit: - finish_reason = "stop" + for event in _feed_shared_stream_piece( + stream_state, piece, stops=stops, tool_policy=tool_policy): + if event["kind"] == "content": + yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n" + elif event["kind"] == "reasoning": + yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n" + elif event["kind"] == "tool_start": + yield tool_chunk( + event["index"], + call_id=event["id"], + name=event["name"], + ) + elif event["kind"] == "tool_args": + yield tool_chunk( + event["index"], + arguments=event["arguments"], + ) + if stream_state.stop_hit: + break + + if stream_state.stop_hit: + if not timing.get("daemon_done"): + await _adrain_until_sentinel(r_pipe, timing) + stream_events, tool_calls = _flush_shared_stream( + stream_state, tool_policy=tool_policy) + if not tool_calls and stream_state.final_tool_calls: + tool_calls = list(stream_state.final_tool_calls) + for event in stream_events: + if event["kind"] == "content": + yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n" + elif event["kind"] == "reasoning": + yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n" + elif event["kind"] == "tool_start": + yield tool_chunk( + event["index"], + call_id=event["id"], + name=event["name"], + ) + elif event["kind"] == "tool_args": + yield tool_chunk( + event["index"], + arguments=event["arguments"], + ) + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + err = {"error": { + "message": tool_choice_error.message, + "type": tool_choice_error.error_type, + }} + if tool_choice_error.param is not None: + err["error"]["param"] = tool_choice_error.param + yield f"data: {json.dumps(err)}\n\n" + yield "data: [DONE]\n\n" + return + if tool_calls: + _remember_tool_call_text(stream_state.raw_text, tool_calls) + finish_reason = "tool_calls" + else: + finish_reason = "stop" yield f"data: {json.dumps(chunk({}, finish=finish_reason))}\n\n" if include_usage: usage_chunk = {"id": completion_id, "object": "chat.completion.chunk", @@ -1292,43 +2151,44 @@ def emit_delta(text, kind): "total_tokens": prompt_len + completion_tokens}} yield f"data: {json.dumps(usage_chunk)}\n\n" yield "data: [DONE]\n\n" - if timing.get("daemon_done") and full_hit is None: - try: cur_bin.unlink() - except Exception: pass - _park_draft_if_lazy(timing) return - # Flush remaining - if mode == "reasoning" and window: - out = emit_delta(window, "reasoning_content") - if out: yield out - elif mode == "content" and window: - accumulated_content += window - out = emit_delta(window, "content") - if out: yield out - elif mode == "tool_buffer": - tool_buffer += window - window = "" - - finish_reason = "stop" - if mode == "tool_buffer": - cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=req.tools) - if tool_calls: - _remember_tool_call_text(accumulated_content + tool_buffer, tool_calls) - if cleaned_after: - out = emit_delta(cleaned_after, "content") - if out: yield out - tc_delta_list = [{ - "index": i, "id": tc["id"], "type": "function", - "function": {"name": tc["function"]["name"], - "arguments": tc["function"]["arguments"]}, - } for i, tc in enumerate(tool_calls)] - yield f"data: {json.dumps(chunk({'tool_calls': tc_delta_list}))}\n\n" - finish_reason = "tool_calls" - else: - out = emit_delta(tool_buffer, "content") - if out: yield out + stream_events, tool_calls = _flush_shared_stream( + stream_state, tool_policy=tool_policy) + for event in stream_events: + if event["kind"] == "content": + yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n" + elif event["kind"] == "reasoning": + yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n" + elif event["kind"] == "tool_start": + yield tool_chunk( + event["index"], + call_id=event["id"], + name=event["name"], + ) + elif event["kind"] == "tool_args": + yield tool_chunk( + event["index"], + arguments=event["arguments"], + ) + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + err = {"error": { + "message": tool_choice_error.message, + "type": tool_choice_error.error_type, + }} + if tool_choice_error.param is not None: + err["error"]["param"] = tool_choice_error.param + yield f"data: {json.dumps(err)}\n\n" + yield "data: [DONE]\n\n" + return + if tool_calls: + _remember_tool_call_text(stream_state.raw_text, tool_calls) + finish_reason = "tool_calls" finally: + _confirm_or_abort_snap( + completion_tokens, full_snap_prep_ref[0], snap_prep, + prompt_ids, cur_bin, cur_ids) if timing.get("daemon_done"): if full_hit is None: try: cur_bin.unlink() @@ -1340,11 +2200,7 @@ def emit_delta(text, kind): log.warning( "stream ended before daemon sentinel; " "retaining prompt .bin for in-flight daemon read") - - _confirm_or_abort_snap( - completion_tokens, full_snap_prep_ref[0], snap_prep, - prompt_ids, cur_bin, cur_ids) - _park_draft_if_lazy(timing) + _park_draft_if_lazy(timing) yield f"data: {json.dumps(chunk({}, finish=finish_reason))}\n\n" if include_usage: @@ -1375,7 +2231,7 @@ def emit_delta(text, kind): full_snap_prep_ref = [None] snap_prep = None - full_hit = prefix_cache.lookup_full(prompt_ids) + full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids) if full_hit is not None: slot, cached_cur_bin, cached_cur_ids_len = full_hit cur_bin = Path(cached_cur_bin) @@ -1393,7 +2249,8 @@ def emit_delta(text, kind): t_compress = time.monotonic() cur_bin, cur_ids = await asyncio.to_thread( _maybe_compress, raw_msgs, prompt_bin, prompt_ids, - req.chat_template_kwargs) + req.chat_template_kwargs, + bypass=tool_policy.bypass_compression) timing["compress"] = time.monotonic() - t_compress prompt_len = len(cur_ids) gen_len = _gen_len_for(prompt_len, req.max_tokens) @@ -1411,6 +2268,7 @@ def emit_delta(text, kind): try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) return JSONResponse({"detail": str(e)}, status_code=503) # FIX 6: use run_in_executor instead of list() blocking event loop @@ -1439,7 +2297,10 @@ def emit_delta(text, kind): thinking_enabled = True if req.chat_template_kwargs: thinking_enabled = req.chat_template_kwargs.get("enable_thinking", True) - cleaned, tool_calls = parse_tool_calls(text, tools=req.tools) + cleaned, tool_calls = _parse_generated_tool_calls(text, tool_policy) + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + raise tool_choice_error _remember_tool_call_text(text, tool_calls) cleaned, reasoning = parse_reasoning( cleaned, @@ -1559,6 +2420,7 @@ async def sse() -> AsyncIterator[str]: try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) yield f"event: error\ndata: {json.dumps({'type':'error','error':{'type':'server_error','message':str(e)}})}\n\n" return @@ -1672,6 +2534,7 @@ async def sse() -> AsyncIterator[str]: try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) return JSONResponse({"type": "error", "error": {"type": "server_error", "message": str(e)}}, status_code=503) @@ -1717,12 +2580,34 @@ def _map_responses_input(req: ResponsesCreateRequest ) -> tuple[list[ChatMessage], list[ToolDef] | None]: """Map Responses API input → ChatMessage list + ToolDef list.""" messages: list[ChatMessage] = [] + pending_assistant_text: list[str] = [] + pending_assistant_calls: list[ToolCall] = [] # Collect all system-level content (instructions + developer messages) # and merge into a single system message at position 0, since # Qwen's chat template requires the system message at the beginning. system_parts: list[str] = [] + def _flush_pending_assistant() -> None: + if not pending_assistant_text and not pending_assistant_calls: + return + messages.append(ChatMessage( + role="assistant", + content="".join(pending_assistant_text) or None, + tool_calls=list(pending_assistant_calls) or None, + )) + pending_assistant_text.clear() + pending_assistant_calls.clear() + + def _responses_item_text(content: Any) -> str: + if isinstance(content, list): + text_parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") in ("output_text", "text", "input_text"): + text_parts.append(part.get("text", "")) + return "".join(text_parts) + return content if isinstance(content, str) else "" + if req.instructions: system_parts.append(req.instructions) @@ -1738,18 +2623,14 @@ def _map_responses_input(req: ResponsesCreateRequest if item_type == "message": role = item.get("role", "user") - content = item.get("content", "") - if isinstance(content, list): - # Extract text from content parts - text_parts = [] - for part in content: - if isinstance(part, dict): - if part.get("type") in ("output_text", "text", "input_text"): - text_parts.append(part.get("text", "")) - content = "".join(text_parts) + content = _responses_item_text(item.get("content", "")) if role == "developer" or role == "system": + _flush_pending_assistant() system_parts.append(content) + elif role == "assistant": + pending_assistant_text.append(content) else: + _flush_pending_assistant() messages.append(ChatMessage(role=role, content=content)) elif item_type == "function_call": @@ -1761,10 +2642,10 @@ def _map_responses_input(req: ResponsesCreateRequest arguments=item.get("arguments", "{}"), ), ) - messages.append(ChatMessage( - role="assistant", content=None, tool_calls=[tc])) + pending_assistant_calls.append(tc) elif item_type == "function_call_output": + _flush_pending_assistant() output = item.get("output", "") if not isinstance(output, str): output = json.dumps(output) @@ -1776,6 +2657,8 @@ def _map_responses_input(req: ResponsesCreateRequest # Ignore reasoning, local_shell_call, etc. — we just # need the message/function_call/output items for the model. + _flush_pending_assistant() + # Prepend merged system message if system_parts: messages.insert(0, ChatMessage( @@ -1821,13 +2704,14 @@ async def responses_create(req: ResponsesCreateRequest): tool_choice=req.tool_choice, chat_template_kwargs={"enable_thinking": enable_thinking}, ) + tool_policy = _resolve_tool_policy(chat_req.tools, chat_req.tool_choice) response_id = "resp_" + uuid.uuid4().hex[:24] msg_item_id = "msg_" + uuid.uuid4().hex[:24] created_at = int(time.time()) # Tokenize - prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(chat_req) + prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(chat_req, tool_policy) prompt_len = len(prompt_ids) # Summarise roles for the log line @@ -1846,17 +2730,17 @@ async def responses_create(req: ResponsesCreateRequest): if req.stream: return await _responses_stream( - chat_req, prompt_bin, prompt_ids, raw_msgs, + chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy, started_in_thinking, response_id, msg_item_id, created_at, prompt_len, time.monotonic()) else: return await _responses_non_stream( - chat_req, prompt_bin, prompt_ids, raw_msgs, + chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy, started_in_thinking, response_id, msg_item_id, created_at, prompt_len, time.monotonic()) async def _responses_non_stream( - chat_req, prompt_bin, prompt_ids, raw_msgs, + chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy, started_in_thinking, response_id, msg_item_id, created_at, prompt_len, t0): """Non-streaming Responses API handler.""" @@ -1864,14 +2748,13 @@ async def _responses_non_stream( timing = {} full_snap_prep_ref = [None] snap_prep = None + cur_ids = None - full_hit = prefix_cache.lookup_full(prompt_ids) + full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids) if full_hit is not None: slot, cached_cur_bin, cached_cur_ids_len = full_hit cur_bin = Path(cached_cur_bin) - cur_ids = None prompt_len = cached_cur_ids_len - started_in_thinking = False # cached: no think prefill gen_len = _gen_len_for(prompt_len, chat_req.max_tokens) if gen_len <= 0: log.warning( @@ -1890,7 +2773,8 @@ async def _responses_non_stream( t_compress = time.monotonic() cur_bin, cur_ids = await asyncio.to_thread( _maybe_compress, raw_msgs, prompt_bin, prompt_ids, - chat_req.chat_template_kwargs) + chat_req.chat_template_kwargs, + bypass=tool_policy.bypass_compression) timing["compress"] = time.monotonic() - t_compress prompt_len = len(cur_ids) gen_len = _gen_len_for(prompt_len, chat_req.max_tokens) @@ -1914,6 +2798,7 @@ async def _responses_non_stream( try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) return JSONResponse({ "type": "error", "error": {"type": "server_error", "message": str(e)} @@ -1937,32 +2822,19 @@ async def _responses_non_stream( thinking_enabled = True if chat_req.chat_template_kwargs: thinking_enabled = chat_req.chat_template_kwargs.get("enable_thinking", True) - cleaned, tool_calls = parse_tool_calls(text, tools=chat_req.tools) + cleaned, tool_calls, output_order = _parse_responses_non_stream_output( + text, tool_policy, msg_item_id, + started_in_thinking=started_in_thinking) + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + raise tool_choice_error _remember_tool_call_text(text, tool_calls) cleaned, reasoning = parse_reasoning( cleaned, thinking_enabled=thinking_enabled, started_in_thinking=started_in_thinking) - # Build output items - output: list[dict] = [] - if tool_calls: - for tc in tool_calls: - output.append({ - "type": "function_call", - "id": tc["id"], - "status": "completed", - "call_id": tc["id"], - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - }) - else: - output.append({ - "type": "message", - "id": msg_item_id, - "status": "completed", - "role": "assistant", - "content": [{"type": "output_text", "text": cleaned, "annotations": []}], - }) + output = _build_responses_output( + cleaned, tool_calls, msg_item_id, output_order=output_order) out_types = [o.get("type") for o in output] elapsed = time.monotonic() - t0 @@ -1989,7 +2861,7 @@ async def _responses_non_stream( }) async def _responses_stream( - chat_req, prompt_bin, prompt_ids, raw_msgs, + chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy, started_in_thinking, response_id, msg_item_id, created_at, prompt_len, t0): """Streaming Responses API handler — emits Responses SSE events.""" @@ -2001,13 +2873,13 @@ async def sse() -> AsyncIterator[str]: timing = {} full_snap_prep_ref = [None] snap_prep = None + cur_ids = None - full_hit = prefix_cache.lookup_full(prompt_ids) + full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids) if full_hit is not None: slot, cached_cur_bin, cached_cur_ids_len = full_hit cur_bin = Path(cached_cur_bin) prompt_len = cached_cur_ids_len - started_in_thinking = False gen_len = _gen_len_for(prompt_len, chat_req.max_tokens) if gen_len <= 0: log.warning( @@ -2025,7 +2897,8 @@ async def sse() -> AsyncIterator[str]: t_compress = time.monotonic() cur_bin, cur_ids = await asyncio.to_thread( _maybe_compress, raw_msgs, prompt_bin, prompt_ids, - chat_req.chat_template_kwargs) + chat_req.chat_template_kwargs, + bypass=tool_policy.bypass_compression) timing["compress"] = time.monotonic() - t_compress prompt_len = len(cur_ids) gen_len = _gen_len_for(prompt_len, chat_req.max_tokens) @@ -2048,106 +2921,157 @@ async def sse() -> AsyncIterator[str]: try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) yield _resp_sse("error", { "error": {"type": "server_error", "message": str(e)}}) return - # Lifecycle: response.created yield _resp_sse("response.created", { "response": _resp_shell(response_id, chat_req.model, created_at, "in_progress")}) - # Announce output item - yield _resp_sse("response.output_item.added", { - "output_index": 0, - "item": {"type": "message", "id": msg_item_id, - "status": "in_progress", "role": "assistant", - "content": []}}) - - # Announce content part - yield _resp_sse("response.content_part.added", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, - "part": {"type": "output_text", "text": "", "annotations": []}}) - - # Stream tokens with state machine - mode = "reasoning" if started_in_thinking else "content" - window = "" - tool_buffer = "" - accumulated_text = "" - tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG)) - HOLDBACK = tag_holdback + stream_state = _make_shared_stream_state(started_in_thinking, [], tool_policy) completion_tokens = 0 - tool_call_active = False + tool_calls: list[dict] = [] + next_output_index = 0 + message_output_index: int | None = None + message_announced = False + message_done = False + tool_output_indices: dict[str, int] = {} + output_item_order: list[tuple[str, str]] = [] + + def ensure_message_started() -> list[str]: + nonlocal next_output_index, message_output_index, message_announced + if message_announced: + return [] + message_output_index = next_output_index + next_output_index += 1 + message_announced = True + output_item_order.append(("message", msg_item_id)) + return [ + _resp_sse("response.output_item.added", { + "output_index": message_output_index, + "item": { + "type": "message", + "id": msg_item_id, + "status": "in_progress", + "role": "assistant", + "content": [], + }, + }), + _resp_sse("response.content_part.added", { + "item_id": msg_item_id, + "output_index": message_output_index, + "content_index": 0, + "part": {"type": "output_text", "text": "", "annotations": []}, + }), + ] + + def ensure_tool_started(event: dict) -> list[str]: + nonlocal next_output_index + output_index = tool_output_indices.get(event["id"]) + if output_index is not None: + return [] + output_index = next_output_index + next_output_index += 1 + tool_output_indices[event["id"]] = output_index + output_item_order.append(("function_call", event["id"])) + return [_resp_sse("response.output_item.added", { + "output_index": output_index, + "item": { + "type": "function_call", + "id": event["id"], + "status": "in_progress", + "call_id": event["id"], + "name": event["name"], + "arguments": "", + }, + })] + + def finalize_message_item() -> list[str]: + nonlocal message_done + if not message_announced or message_done or message_output_index is None: + return [] + message_done = True + return [ + _resp_sse("response.output_text.done", { + "item_id": msg_item_id, + "output_index": message_output_index, + "content_index": 0, + "text": stream_state.visible_text, + }), + _resp_sse("response.content_part.done", { + "item_id": msg_item_id, + "output_index": message_output_index, + "content_index": 0, + "part": {"type": "output_text", "text": stream_state.visible_text, + "annotations": []}, + }), + _resp_sse("response.output_item.done", { + "output_index": message_output_index, + "item": _responses_message_item(msg_item_id, stream_state.visible_text), + }), + ] + + def finalize_tool_items(final_tool_calls: list[dict]) -> list[str]: + messages: list[str] = [] + tool_map = {tc["id"]: tc for tc in final_tool_calls} + for kind, item_id in output_item_order: + if kind != "function_call": + continue + tc = tool_map.get(item_id) + output_index = tool_output_indices.get(item_id) + if tc is None or output_index is None: + continue + messages.append(_resp_sse("response.function_call_arguments.done", { + "item_id": item_id, + "output_index": output_index, + "arguments": tc["function"]["arguments"], + "name": tc["function"]["name"], + })) + messages.append(_resp_sse("response.output_item.done", { + "output_index": output_index, + "item": _responses_function_call_item(tc), + })) + return messages + + def emit_stream_event(event: dict) -> list[str]: + messages: list[str] = [] + if event["kind"] == "content": + messages.extend(ensure_message_started()) + assert message_output_index is not None + messages.append(_resp_sse("response.output_text.delta", { + "item_id": msg_item_id, + "output_index": message_output_index, + "content_index": 0, + "delta": event["text"], + })) + return messages + if event["kind"] == "tool_start": + messages.extend(ensure_tool_started(event)) + return messages + if event["kind"] == "tool_args": + if event["id"] not in tool_output_indices: + messages.extend(ensure_tool_started(event)) + messages.append(_resp_sse("response.function_call_arguments.delta", { + "item_id": event["id"], + "output_index": tool_output_indices[event["id"]], + "delta": event["arguments"], + })) + return messages + if event["kind"] == "tool_done": + if event["id"] not in tool_output_indices: + messages.extend(ensure_tool_started(event)) + return messages try: async for tok_id in _astream_tokens(r_pipe, gen_len, timing): completion_tokens += 1 piece = tokenizer.decode([tok_id]) - window += piece - - while True: - if mode == "tool_buffer": - tool_buffer += window - window = "" - break - - if mode == "reasoning": - idx = window.find(THINK_CLOSE_TAG) - if idx != -1: - window = window[idx + len(THINK_CLOSE_TAG):] - mode = "content" - continue - if len(window) > HOLDBACK: - window = window[-HOLDBACK:] - break - - else: # content - think_idx = window.find(THINK_OPEN_TAG) - think_close_idx = window.find(THINK_CLOSE_TAG) - tool_idx = window.find(TOOL_OPEN_TAG) - hits = [(i, t) for i, t in - ((think_idx, "think"), - (think_close_idx, "think_close"), - (tool_idx, "tool")) if i != -1] - if hits: - hits.sort() - idx, which = hits[0] - pre = window[:idx] - if pre: - accumulated_text += pre - yield _resp_sse("response.output_text.delta", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "delta": pre}) - if which == "think": - window = window[idx + len(THINK_OPEN_TAG):] - mode = "reasoning" - elif which == "think_close": - window = window[idx + len(THINK_CLOSE_TAG):] - else: - tool_buffer = window[idx:] - window = "" - mode = "tool_buffer" - continue - if len(window) > HOLDBACK: - safe = window[:-HOLDBACK] - accumulated_text += safe - yield _resp_sse("response.output_text.delta", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "delta": safe}) - window = window[-HOLDBACK:] - break - - # Flush remaining window - if mode == "content" and window: - accumulated_text += window - yield _resp_sse("response.output_text.delta", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "delta": window}) - elif mode == "tool_buffer": - tool_buffer += window - window = "" - + for event in _feed_shared_stream_piece( + stream_state, piece, stops=[], tool_policy=tool_policy): + for message in emit_stream_event(event): + yield message finally: if timing.get("daemon_done"): if full_hit is None: @@ -2166,71 +3090,45 @@ async def sse() -> AsyncIterator[str]: prompt_ids, cur_bin, cur_ids) _park_draft_if_lazy(timing) - # Build final output items - final_output: list[dict] = [] - if mode == "tool_buffer" and tool_buffer: - cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=chat_req.tools) - if tool_calls: - _remember_tool_call_text(accumulated_text + tool_buffer, tool_calls) - if cleaned_after: - accumulated_text += cleaned_after - for tc in tool_calls: - tool_call_active = True - tc_item_id = tc["id"] - # Emit function_call_arguments.delta for each tool call - yield _resp_sse("response.function_call_arguments.delta", { - "item_id": tc_item_id, "output_index": 0, - "delta": tc["function"]["arguments"]}) - yield _resp_sse("response.function_call_arguments.done", { - "item_id": tc_item_id, "output_index": 0, - "arguments": tc["function"]["arguments"], - "name": tc["function"]["name"]}) - final_output.append({ - "type": "function_call", - "id": tc_item_id, - "status": "completed", - "call_id": tc_item_id, - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - }) - else: - accumulated_text += tool_buffer - yield _resp_sse("response.output_text.delta", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "delta": tool_buffer}) - - # Finalize text output - yield _resp_sse("response.output_text.done", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "text": accumulated_text}) - yield _resp_sse("response.content_part.done", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, - "part": {"type": "output_text", "text": accumulated_text, - "annotations": []}}) - - if not tool_call_active: - final_output.append({ - "type": "message", - "id": msg_item_id, - "status": "completed", - "role": "assistant", - "content": [{"type": "output_text", "text": accumulated_text, - "annotations": []}], - }) + stream_events, tool_calls = _flush_shared_stream( + stream_state, tool_policy=tool_policy) + for event in stream_events: + for message in emit_stream_event(event): + yield message + + if tool_calls: + _remember_tool_call_text(stream_state.raw_text, tool_calls) + + if not message_announced and not tool_calls: + for message in ensure_message_started(): + yield message + for message in finalize_message_item(): + yield message + + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + yield _resp_sse("response.failed", { + "response": _resp_shell(response_id, chat_req.model, created_at, + "failed")}) + err = {"error": { + "type": tool_choice_error.error_type, + "message": tool_choice_error.message, + }} + if tool_choice_error.param is not None: + err["error"]["param"] = tool_choice_error.param + yield _resp_sse("error", err) + return - yield _resp_sse("response.output_item.done", { - "output_index": 0, - "item": final_output[0] if final_output else { - "type": "message", "id": msg_item_id, - "status": "completed", "role": "assistant", - "content": []}}) + for message in finalize_tool_items(tool_calls): + yield message + final_output = _build_responses_output( + stream_state.visible_text, tool_calls, msg_item_id, + output_order=output_item_order) - # response.completed shell = _resp_shell(response_id, chat_req.model, created_at, "completed") shell["output"] = final_output - shell["output_text"] = accumulated_text + shell["output_text"] = stream_state.visible_text shell["usage"] = { "input_tokens": prompt_len, "output_tokens": completion_tokens, @@ -2242,7 +3140,7 @@ async def sse() -> AsyncIterator[str]: log.info( "responses DONE %s in=%d out=%d %.1fs %.1f tok/s output=%s text_len=%d %s", response_id, prompt_len, completion_tokens, - elapsed, tok_s, out_types, len(accumulated_text), + elapsed, tok_s, out_types, len(stream_state.visible_text), _timing_summary(timing, completion_tokens), ) yield _resp_sse("response.completed", {"response": shell}) diff --git a/dflash/scripts/test_server.py b/dflash/scripts/test_server.py index 02c4df7a..bd4a37e2 100644 --- a/dflash/scripts/test_server.py +++ b/dflash/scripts/test_server.py @@ -8,6 +8,7 @@ import pytest from fastapi.testclient import TestClient +from _prefill_hook import PrefillConfig from server import ( build_app, MODEL_NAME, parse_tool_calls, parse_reasoning, @@ -43,11 +44,122 @@ def app(mock_tokenizer): return a +@pytest.fixture +def app_with_prefill(mock_tokenizer): + """Build a FastAPI app with compression enabled.""" + drafter_tokenizer = MagicMock() + drafter_tokenizer.return_value = {"input_ids": [1, 2, 3]} + drafter_tokenizer.decode.return_value = "compressed prompt" + with patch("server.subprocess.Popen") as mock_popen: + mock_popen.return_value.poll.return_value = None # daemon alive + a = build_app( + target=Path("target.gguf"), + draft=Path("draft.safetensors"), + bin_path=Path("test_dflash"), + budget=22, + max_ctx=131072, + tokenizer=mock_tokenizer, + stop_ids={2}, + prefill_cache_slots=0, + prefill_cfg=PrefillConfig( + mode="always", + threshold=1, + keep_ratio=0.5, + drafter_gguf=Path("drafter.gguf"), + drafter_tokenizer_id="mock-tokenizer", + ), + drafter_tokenizer=drafter_tokenizer, + ) + return a + + @pytest.fixture def client(app): return TestClient(app) +def _build_app_with_process(mock_tokenizer, process, *, enable_prefill: bool = False): + kwargs = {} + if enable_prefill: + drafter_tokenizer = MagicMock() + drafter_tokenizer.return_value = {"input_ids": [1, 2, 3]} + drafter_tokenizer.decode.return_value = "compressed prompt" + kwargs.update( + prefill_cache_slots=0, + prefill_cfg=PrefillConfig( + mode="always", + threshold=1, + keep_ratio=0.5, + drafter_gguf=Path("drafter.gguf"), + drafter_tokenizer_id="mock-tokenizer", + ), + drafter_tokenizer=drafter_tokenizer, + ) + with patch("server.subprocess.Popen") as mock_popen: + mock_popen.return_value = process + return build_app( + target=Path("target.gguf"), + draft=Path("draft.safetensors"), + bin_path=Path("test_dflash"), + budget=22, + max_ctx=131072, + tokenizer=mock_tokenizer, + stop_ids={2}, + **kwargs, + ) + + +def _chat_sse_chunks(text: str) -> list[dict]: + return [ + json.loads(line[6:]) + for line in text.strip().split("\n\n") + if line.startswith("data: ") and line != "data: [DONE]" + ] + + +def _responses_sse_events(text: str) -> list[tuple[str, dict]]: + events: list[tuple[str, dict]] = [] + for block in text.strip().split("\n\n"): + if not block.strip(): + continue + event_line = next((line for line in block.splitlines() if line.startswith("event: ")), None) + data_line = next((line for line in block.splitlines() if line.startswith("data: ")), None) + if event_line and data_line: + events.append((event_line[7:], json.loads(data_line[6:]))) + return events + + +def _chat_stream_assistant_message(chunks: list[dict]) -> dict: + content_parts: list[str] = [] + tool_calls: dict[int, dict] = {} + for chunk in chunks: + for choice in chunk.get("choices", []): + delta = choice.get("delta", {}) + text = delta.get("content") + if isinstance(text, str): + content_parts.append(text) + for tc_delta in delta.get("tool_calls", []): + index = tc_delta["index"] + state = tool_calls.setdefault(index, { + "id": tc_delta.get("id"), + "type": tc_delta.get("type", "function"), + "function": {"name": None, "arguments": ""}, + }) + if tc_delta.get("id"): + state["id"] = tc_delta["id"] + fn_delta = tc_delta.get("function", {}) + if fn_delta.get("name"): + state["function"]["name"] = fn_delta["name"] + if fn_delta.get("arguments"): + state["function"]["arguments"] += fn_delta["arguments"] + msg = {"role": "assistant"} + content = "".join(content_parts) + msg["content"] = content or None + if tool_calls: + msg["tool_calls"] = [tool_calls[i] for i in sorted(tool_calls)] + return msg + + # ─── parse_reasoning ─────────────────────────────────────────────── class TestParseReasoning: @@ -430,6 +542,34 @@ def test_chat_completions_replays_raw_tool_call_text(mock_os_read, mock_pipe, assert "tool_calls" not in assistant +@patch("server.compress_text_via_daemon") +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_tool_requests_skip_compression(mock_os_read, mock_pipe, mock_compress, + mock_tokenizer, app_with_prefill): + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("", "", "8"] + mock_lookup_full.return_value = (7, "cached_prompt.bin", 1) + mock_tokenizer.apply_chat_template.return_value = "prompt\n" + + def decode_side_effect(ids, *args, **kwargs): + token_id = ids[0] if isinstance(ids, list) else ids + return { + 10: "hidden reasoning", + 11: "", + 12: "visible answer", + }[token_id] + + mock_tokenizer.decode.side_effect = decode_side_effect mock_os_read.side_effect = [ struct.pack("" not in text - assert '"content":"8"' in text or '"content": "8"' in text - + chunks = _chat_sse_chunks(response.text) + reasoning = "".join( + choice.get("delta", {}).get("reasoning_content", "") + for chunk in chunks + for choice in chunk.get("choices", []) + ) + content = "".join( + choice.get("delta", {}).get("content", "") + for chunk in chunks + for choice in chunk.get("choices", []) + ) + assert reasoning == "hidden reasoning" + assert content == "visible answer" + assert "" not in content + assert "hidden reasoning" not in content -# ─── POST /v1/responses ─────────────────────────────────────────── @patch("server.os.pipe") @patch("server.os.read") -def test_responses_non_streaming(mock_os_read, mock_pipe, mock_tokenizer, app): - """POST /v1/responses non-streaming returns ResponsesObject.""" +@pytest.mark.parametrize(("decoded_chunks", "leaked_fragments"), [ + ( + [ + "test", + ".py", + ], + ["", "test", + ".py", + ], + ["{"name":"read_file","arguments":{"path":"test', + '.py"}}', + ], + ["", '{"name":"read_file"'], + ), +]) +def test_chat_completions_streaming_tool_call_deltas(mock_os_read, mock_pipe, + mock_tokenizer, app, + decoded_chunks, leaked_fragments): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack(" 0 - assert data["usage"]["output_tokens"] > 0 - assert data["usage"]["total_tokens"] == data["usage"]["input_tokens"] + data["usage"]["output_tokens"] + chunks = _chat_sse_chunks(response.text) + tool_deltas = [ + chunk["choices"][0]["delta"]["tool_calls"][0] + for chunk in chunks + if chunk.get("choices") + and chunk["choices"][0].get("delta", {}).get("tool_calls") + ] + assert len(tool_deltas) >= 2 + assert tool_deltas[0]["id"].startswith("call_") + assert tool_deltas[0]["function"]["name"] == "read_file" + assert "".join( + delta.get("function", {}).get("arguments", "") + for delta in tool_deltas + ) == '{"path":"test.py"}' + assert not any( + any(fragment in choice.get("delta", {}).get("content", "") for fragment in leaked_fragments) + for chunk in chunks + for choice in chunk.get("choices", []) + ) + finish_chunk = next( + chunk for chunk in reversed(chunks) + if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None + ) + assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls" @patch("server.os.pipe") @patch("server.os.read") -def test_responses_non_streaming_string_input(mock_os_read, mock_pipe, - mock_tokenizer, app): - """Responses API accepts a plain string as input.""" +def test_chat_streaming_stop_sequence_preserves_tool_deltas(mock_os_read, mock_pipe, + mock_tokenizer, app): mock_pipe.return_value = (1, 2) + stop_marker = "" + mock_tokenizer.decode.return_value = ( + "" + "test.py" + "" + f"{stop_marker}" + ) mock_os_read.side_effect = [struct.pack("" in choice.get("delta", {}).get("content", "") + for chunk in chunks + for choice in chunk.get("choices", []) + ) + assistant_msg = _chat_stream_assistant_message(chunks) + assert assistant_msg["content"] is None + assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file" + assert assistant_msg["tool_calls"][0]["function"]["arguments"] == '{"path":"test.py"}' + finish_chunk = next( + chunk for chunk in reversed(chunks) + if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None + ) + assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls" @patch("server.os.pipe") @patch("server.os.read") -def test_responses_non_streaming_started_in_thinking(mock_os_read, mock_pipe, - mock_tokenizer, app): - """When prompt ends with , reasoning without tags is not misclassified as content.""" +def test_chat_streaming_stop_sequence_after_bare_function_close_tag( + mock_os_read, mock_pipe, mock_tokenizer, app): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("\n - mock_tokenizer.apply_chat_template.return_value = "prompt\n" - # Model output has no tags — it's a continuation of the prefilled block - mock_tokenizer.decode.return_value = "internal reasoning\nactual answer" + stop_marker = "" + mock_tokenizer.decode.side_effect = [ + "test", + ".py", + f"{stop_marker}ignored", + ] + mock_os_read.side_effect = [ + struct.pack("', + ], + ['{"name":"read_file"', ""], + ), + ( + [ + '{"name":"read_file","arguments":{"path":"test', + '.py"}}', + ], + ["", '{"name":"read_file"', ""], + ), +]) +def test_chat_streaming_stop_sequence_after_json_tool_call( + mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks, + leaked_fragments): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("", + "stream": True, }) assert response.status_code == 200 - # Verify apply_chat_template was called with system message - calls = mock_tokenizer.apply_chat_template.call_args_list - last_call = calls[-1] - msgs = last_call[0][0] # first positional arg - assert msgs[0]["role"] == "system" - assert "coding assistant" in msgs[0]["content"] + chunks = _chat_sse_chunks(response.text) + assistant_msg = _chat_stream_assistant_message(chunks) + assert assistant_msg["content"] is None + assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file" + assert assistant_msg["tool_calls"][0]["function"]["arguments"] == '{"path":"test.py"}' + assert not any( + any(fragment in choice.get("delta", {}).get("content", "") for fragment in leaked_fragments) + for chunk in chunks + for choice in chunk.get("choices", []) + ) + finish_chunk = next( + chunk for chunk in reversed(chunks) + if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None + ) + assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls" @patch("server.os.pipe") @patch("server.os.read") -def test_responses_streaming(mock_os_read, mock_pipe, mock_tokenizer, app): - """POST /v1/responses streaming emits proper SSE lifecycle events.""" +def test_chat_streaming_replays_raw_tool_call_text(mock_os_read, mock_pipe, + mock_tokenizer, app): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("" + "test.py" + "\\n" + "After" + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("", "", "8"] + stop_marker = "" + trimmed_raw_tool_text = ( + "Before\n" + "" + "test.py" + "\n" + "After" + ) + mock_tokenizer.decode.side_effect = [ + trimmed_raw_tool_text + stop_marker + "ignored", + "followup", + ] mock_os_read.side_effect = [ - struct.pack("" not in text - assert '"delta":"8"' in text or '"delta": "8"' in text + second = client.post("/v1/chat/completions", json={ + "model": MODEL_NAME, + "messages": [ + {"role": "user", "content": "read test.py"}, + assistant_msg, + {"role": "tool", "tool_call_id": assistant_msg["tool_calls"][0]["id"], "content": "file body"}, + {"role": "user", "content": "what next?"}, + ], + "stream": False, + }) + assert second.status_code == 200 + + msgs = mock_tokenizer.apply_chat_template.call_args_list[-1][0][0] + assistant = next(m for m in msgs if m["role"] == "assistant") + assert assistant["content"] == trimmed_raw_tool_text + assert stop_marker not in assistant["content"] + assert "ignored" not in assistant["content"] + assert "tool_calls" not in assistant @patch("server.os.pipe") @patch("server.os.read") -def test_responses_with_tools(mock_os_read, mock_pipe, mock_tokenizer, app): - """POST /v1/responses with function tools maps correctly.""" +@pytest.mark.parametrize("decoded_chunks", [ + [ + "test", + ".py", + ], + [ + "test", + ".py", + ], +]) +def test_chat_streaming_stop_hit_drains_daemon_before_next_stream( + mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("", + "stream": True, + }) + assert first.status_code == 200 + assistant_msg = _chat_stream_assistant_message(_chat_sse_chunks(first.text)) + assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file" + + second = client.post("/v1/chat/completions", json={ + "model": MODEL_NAME, + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }) + assert second.status_code == 200 + second_msg = _chat_stream_assistant_message(_chat_sse_chunks(second.text)) + assert second_msg["content"] == "fresh reply" + assert "tool_calls" not in second_msg + + +@patch("server.PrefixCache.abort_full_snap") +@patch("server.PrefixCache.confirm_full_snap") +@patch("server.PrefixCache.prepare_full_snap", return_value=(7, 0)) +@patch("server.compress_text_via_daemon", return_value="compressed prompt") +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_streaming_stop_hit_confirms_reserved_full_snapshot( + mock_os_read, mock_pipe, _mock_compress, _mock_prepare_full_snap, + mock_confirm_full_snap, mock_abort_full_snap, mock_tokenizer, + app_with_prefill): + mock_pipe.return_value = (1, 2) + + def decode_side_effect(ids, *args, **kwargs): + token_id = ids[0] if isinstance(ids, list) else ids + return { + 10: "hello", + 11: "", + 12: "stale", + }[token_id] + + mock_tokenizer.decode.side_effect = decode_side_effect + mock_os_read.side_effect = [ + struct.pack("", + "stream": True, + }) + + assert response.status_code == 200 + assistant_msg = _chat_stream_assistant_message(_chat_sse_chunks(response.text)) + assert assistant_msg["content"] == "hello" + mock_confirm_full_snap.assert_called_once() + slot, prompt_ids, cur_bin, cur_ids_len = mock_confirm_full_snap.call_args.args + assert slot == 7 + assert prompt_ids == [1] + assert isinstance(cur_bin, Path) + assert cur_ids_len == 1 + mock_abort_full_snap.assert_not_called() + + +@patch("server.PrefixCache.abort_full_snap") +@patch("server.PrefixCache.prepare_full_snap", return_value=(7, 0)) +@patch("server.compress_text_via_daemon", return_value="compressed prompt") +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_streaming_write_failure_aborts_reserved_full_snapshot( + mock_os_read, mock_pipe, _mock_compress, _mock_prepare_full_snap, + mock_abort_full_snap, mock_tokenizer): + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [] + dead_proc = MagicMock() + dead_proc.poll.return_value = 1 + local_app = _build_app_with_process( + mock_tokenizer, dead_proc, enable_prefill=True) + + client = TestClient(local_app) + response = client.post("/v1/chat/completions", json={ + "model": MODEL_NAME, + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }) + + assert response.status_code == 200 + assert "dflash daemon has exited unexpectedly" in response.text + mock_abort_full_snap.assert_called_once_with(7) + + +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_tool_choice_required_rejects_plain_text(mock_os_read, mock_pipe, + mock_tokenizer, app): + mock_pipe.return_value = (1, 2) + mock_tokenizer.decode.return_value = "plain text" + mock_os_read.side_effect = [struct.pack("file.txt' + '' + ) + mock_os_read.side_effect = [struct.pack(" 0 + assert data["usage"]["output_tokens"] > 0 + assert data["usage"]["total_tokens"] == data["usage"]["input_tokens"] + data["usage"]["output_tokens"] + + +@patch("server.os.pipe") +@patch("server.os.read") +def test_responses_non_streaming_string_input(mock_os_read, mock_pipe, + mock_tokenizer, app): + """Responses API accepts a plain string as input.""" + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("" + "file.txt" + "" + "After tool" + ) + mock_os_read.side_effect = [struct.pack(", reasoning without tags is not misclassified as content.""" + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("\n + mock_tokenizer.apply_chat_template.return_value = "prompt\n" + # Model output has no tags — it's a continuation of the prefilled block + mock_tokenizer.decode.return_value = "internal reasoning\nactual answer" + + client = TestClient(app) + response = client.post("/v1/responses", json={ + "model": MODEL_NAME, + "input": [{"type": "message", "role": "user", "content": "hello"}], + }) + + assert response.status_code == 200 + data = response.json() + assert data["object"] == "response" + # The "actual answer" part should be the output text, not the reasoning + assert "actual answer" in data["output_text"] + # The reasoning should NOT leak into the output text + assert "internal reasoning" not in data["output_text"] + + +@patch("server.os.pipe") +@patch("server.os.read") +def test_responses_with_instructions(mock_os_read, mock_pipe, + mock_tokenizer, app): + """Instructions are mapped to system message.""" + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("", + 12: "visible answer", + }[token_id] + + mock_tokenizer.decode.side_effect = decode_side_effect + mock_os_read.side_effect = [ + struct.pack("" in delta or "hidden reasoning" in delta for delta in deltas) + completed = next(data for event, data in events if event == "response.completed") + assert completed["response"]["output_text"] == "visible answer" + + +@patch("server.os.pipe") +@patch("server.os.read") +@pytest.mark.parametrize("decoded_chunks", [ + [ + "test", + ".py", + ], + [ + "test", + ".py", + ], + [ + '{"name":"read_file","arguments":{"path":"test', + '.py"}}', + ], + [ + '{"name":"read_file","arguments":{"path":"test', + '.py"}}', + ], +]) +def test_responses_streaming_function_call_lifecycle(mock_os_read, mock_pipe, + mock_tokenizer, app, + decoded_chunks): + mock_pipe.return_value = (1, 2) + mock_tokenizer.decode.side_effect = decoded_chunks + mock_os_read.side_effect = [ + struct.pack("file.txt', + ], + ["", "{"name":"write_file","arguments":{"path":"file.txt"}}', + ], + ["", '{"name":"write_file"'], + ), +]) +def test_responses_streaming_tool_choice_failure_suppresses_terminal_function_events( + mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks, + leaked_fragments): + mock_pipe.return_value = (1, 2) + mock_tokenizer.decode.side_effect = decoded_chunks + mock_os_read.side_effect = [struct.pack("file.txt", + "After tool", + ] + mock_os_read.side_effect = [ + struct.pack("", "", "8"] + mock_os_read.side_effect = [ + struct.pack("" not in text + assert '"delta":"8"' in text or '"delta": "8"' in text + + +@patch("server.os.pipe") +@patch("server.os.read") +def test_responses_with_tools(mock_os_read, mock_pipe, mock_tokenizer, app): + """POST /v1/responses with function tools maps correctly.""" + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("' + 'file.txt' + '' + ) + mock_tokenizer.decode.return_value = raw_tool_text + mock_os_read.side_effect = [struct.pack("' + 'file.txt' + '' + ) + mock_os_read.side_effect = [struct.pack("' + 'other.txt' + '' + '' + 'file.txt' + '' + ) + mock_os_read.side_effect = [struct.pack("' + 'file.txt' + '' + ) mock_os_read.side_effect = [struct.pack("' + 'file.txt' + '' + 'After tool' + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack(" Date: Tue, 12 May 2026 09:16:50 +0800 Subject: [PATCH 3/3] Handle stray in shared stream parser During rebase conflict resolution, shared streaming path retained stray closing think tags in content mode for SSE responses. Treat think-close markers as control tags in content mode so they are stripped consistently. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dflash/scripts/server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py index 0232fae1..d83db8eb 100644 --- a/dflash/scripts/server.py +++ b/dflash/scripts/server.py @@ -1262,12 +1262,14 @@ def _feed_shared_stream_piece( break think_idx = state.window.find(THINK_OPEN_TAG) + think_close_idx = state.window.find(THINK_CLOSE_TAG) tool_idx = state.window.find(TOOL_OPEN_TAG) if state.allow_tools else -1 bare_fn_idx = state.window.find(FUNCTION_OPEN_TAG) if state.allow_tools else -1 tool_code_idx = state.window.find(TOOL_CODE_OPEN_TAG) if state.allow_tools else -1 json_tool_idx = _find_toolish_json_start(state.window) if state.allow_tools else -1 hits = [(i, t) for i, t in ( (think_idx, "think"), + (think_close_idx, "think_close"), (tool_idx, "tool"), (bare_fn_idx, "tool"), (tool_code_idx, "tool"), @@ -1283,6 +1285,8 @@ def _feed_shared_stream_piece( if which == "think": state.window = state.window[idx + len(THINK_OPEN_TAG):] state.mode = "reasoning" + elif which == "think_close": + state.window = state.window[idx + len(THINK_CLOSE_TAG):] else: state.tool_buffer += state.window[idx:] state.window = ""