diff --git a/dflash/DEVELOPER.md b/dflash/DEVELOPER.md
index 8f948899..1c9460a6 100644
--- a/dflash/DEVELOPER.md
+++ b/dflash/DEVELOPER.md
@@ -227,7 +227,7 @@ dflash/
│ └── draft/model.safetensors
├── scripts/
│ ├── server.py # Main OpenAI/Codex server
-│ ├── server_tools.py # Legacy fork with tool calling (deprecated)
+│ ├── server_tools.py # Legacy fork kept for reference; server.py is the tool/Codex path
│ ├── prefix_cache.py # LRU prefix cache
│ ├── _prefill_hook.py # Speculative prefill compression
│ ├── run.py # CLI text generation
diff --git a/dflash/README.md b/dflash/README.md
index 63c6122f..b73fde56 100644
--- a/dflash/README.md
+++ b/dflash/README.md
@@ -119,7 +119,7 @@ allows capacity checks where the draft and a target layer range share one GPU
before serving integration. `--target-split-dflash` runs the same split target
placement through a chain DFlash decode loop and reports acceptance length.
-**Python flags on `scripts/run.py`, `scripts/server.py`, `scripts/server_tools.py`:**
+**Python flags on `scripts/run.py` and `scripts/server.py` (`scripts/server_tools.py` is legacy):**
```bash
python3 scripts/run.py --ctk q8_0 --ctv q4_0 --prompt "hello"
python3 scripts/run.py --cache-type-k q8_0 --cache-type-v q4_0 --prompt "hello"
diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py
index 9e00d3d9..d83db8eb 100644
--- a/dflash/scripts/server.py
+++ b/dflash/scripts/server.py
@@ -24,6 +24,7 @@
import tempfile
import time
import uuid
+from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, AsyncIterator
@@ -41,6 +42,7 @@
compress_text_via_daemon, _drain_until_sentinel,
)
from prefix_cache import DaemonStdoutBus, PrefixCache
+from tool_memory import ToolMemory
class OpenAICompatError(Exception):
@@ -179,6 +181,8 @@ def _tokenizer_id_from_gguf(gguf_path: Path) -> str:
r"", re.DOTALL)
TOOL_CODE_RE = re.compile(r"(.*?)", re.DOTALL)
TOOL_OPEN_TAG = ""
+FUNCTION_OPEN_TAG = ""
THINK_OPEN_TAG = ""
THINK_CLOSE_TAG = ""
@@ -344,13 +348,13 @@ def add_call(function_name: str, args: dict, start: int, end: int):
if not _tool_allowed(tools, function_name):
return
tool_calls.append({
- "id": "call_" + uuid.uuid4().hex[:24],
- "type": "function",
- "function": {
- "name": function_name,
- "arguments": json.dumps(args, ensure_ascii=False),
- },
- })
+ "id": "call_" + uuid.uuid4().hex[:24],
+ "type": "function",
+ "function": {
+ "name": function_name,
+ "arguments": json.dumps(args, ensure_ascii=False, separators=(",", ":")),
+ },
+ })
removals.append((start, end))
def parse_xml_function(function_name: str, params_region: str) -> dict:
@@ -572,6 +576,850 @@ class ResponsesCreateRequest(BaseModel):
previous_response_id: str | None = None
+@dataclass(frozen=True)
+class ToolPolicy:
+ prompt_tools: list[Any] | None
+ parse_tools: list[Any] | None
+ choice_kind: str
+ choice_name: str | None = None
+ render_tools: bool = False
+ parse_tool_calls: bool = False
+ bypass_compression: bool = False
+
+
+@dataclass
+class _PartialToolCallSnapshot:
+ name: str
+ arguments: str
+ complete: bool = False
+
+
+@dataclass
+class _TrackedToolCallState:
+ index: int
+ id: str
+ name: str | None = None
+ emitted_arguments: str = ""
+ announced: bool = False
+ done: bool = False
+
+
+@dataclass
+class _SharedStreamState:
+ mode: str
+ holdback: int
+ allow_tools: bool
+ window: str = ""
+ tool_buffer: str = ""
+ raw_text: str = ""
+ visible_text: str = ""
+ stop_hit: bool = False
+ tool_states: list[_TrackedToolCallState] = field(default_factory=list)
+ final_tool_calls: list[dict] = field(default_factory=list)
+
+
+def _trim_stream_raw_suffix(raw_text: str, suffix_len: int, keep_len: int) -> str:
+ trim_len = max(0, suffix_len - keep_len)
+ if trim_len <= 0:
+ return raw_text
+ return raw_text[:-trim_len] if trim_len < len(raw_text) else ""
+
+
+def _tool_name(tool: Any) -> str | None:
+ fn = None
+ if hasattr(tool, "function"):
+ fn = tool.function
+ elif isinstance(tool, dict):
+ fn = tool.get("function")
+ if hasattr(fn, "model_dump"):
+ fn = fn.model_dump()
+ if isinstance(fn, dict):
+ name = fn.get("name")
+ if isinstance(name, str) and name:
+ return name
+ return None
+
+
+def _normalize_tool_choice(tool_choice: Any) -> tuple[str, str | None]:
+ if tool_choice is None:
+ return "auto", None
+ if isinstance(tool_choice, str):
+ choice = tool_choice.strip().lower()
+ if choice in {"auto", "none", "required"}:
+ return choice, None
+ if isinstance(tool_choice, dict):
+ choice_type = tool_choice.get("type")
+ if isinstance(choice_type, str):
+ choice_type = choice_type.lower()
+ if choice_type in {"auto", "none", "required"}:
+ return choice_type, None
+ name = tool_choice.get("name")
+ if not isinstance(name, str):
+ fn = tool_choice.get("function")
+ if isinstance(fn, dict):
+ name = fn.get("name")
+ if (choice_type == "function" or "function" in tool_choice or "name" in tool_choice) and isinstance(name, str) and name:
+ return "function", name
+ raise OpenAICompatError(
+ "Unsupported tool_choice value",
+ param="tool_choice",
+ )
+
+
+def _resolve_tool_policy(tools: list[Any] | None, tool_choice: Any) -> ToolPolicy:
+ requested_tools = list(tools or [])
+ has_tools = bool(requested_tools)
+ choice_kind, choice_name = _normalize_tool_choice(tool_choice)
+
+ if choice_kind == "function":
+ if not has_tools:
+ raise OpenAICompatError(
+ "tool_choice=function requires tools",
+ param="tool_choice",
+ )
+ selected = next((t for t in requested_tools if _tool_name(t) == choice_name), None)
+ if selected is None:
+ raise OpenAICompatError(
+ f"tool_choice function '{choice_name}' was not provided in tools",
+ param="tool_choice",
+ )
+ return ToolPolicy(
+ prompt_tools=[selected],
+ parse_tools=requested_tools,
+ choice_kind=choice_kind,
+ choice_name=choice_name,
+ render_tools=True,
+ parse_tool_calls=True,
+ bypass_compression=True,
+ )
+
+ if choice_kind == "required" and not has_tools:
+ raise OpenAICompatError(
+ "tool_choice='required' requires tools",
+ param="tool_choice",
+ )
+
+ if choice_kind == "none":
+ return ToolPolicy(
+ prompt_tools=None,
+ parse_tools=None,
+ choice_kind=choice_kind,
+ render_tools=False,
+ parse_tool_calls=False,
+ bypass_compression=has_tools,
+ )
+
+ if has_tools:
+ return ToolPolicy(
+ prompt_tools=requested_tools,
+ parse_tools=requested_tools,
+ choice_kind=choice_kind,
+ render_tools=True,
+ parse_tool_calls=True,
+ bypass_compression=True,
+ )
+
+ return ToolPolicy(
+ prompt_tools=None,
+ parse_tools=None,
+ choice_kind=choice_kind,
+ render_tools=False,
+ parse_tool_calls=True,
+ bypass_compression=False,
+ )
+
+
+def _parse_generated_tool_calls(text: str, tool_policy: ToolPolicy) -> tuple[str, list[dict]]:
+ if not tool_policy.parse_tool_calls:
+ return text, []
+ return parse_tool_calls(text, tools=tool_policy.parse_tools)
+
+
+def _tool_choice_violation(tool_policy: ToolPolicy, tool_calls: list[dict]) -> OpenAICompatError | None:
+ if tool_policy.choice_kind in {"auto", "none"}:
+ return None
+ if tool_policy.choice_kind == "required":
+ if tool_calls:
+ return None
+ return OpenAICompatError(
+ "tool_choice='required' requires the model to emit a tool call",
+ param="tool_choice",
+ )
+ if tool_policy.choice_kind == "function":
+ disallowed = [
+ tc.get("function", {}).get("name")
+ for tc in tool_calls
+ if tc.get("function", {}).get("name") != tool_policy.choice_name
+ ]
+ disallowed = [name for name in disallowed if isinstance(name, str) and name]
+ if disallowed:
+ leaked = ", ".join(sorted(set(disallowed)))
+ return OpenAICompatError(
+ f"tool_choice function '{tool_policy.choice_name}' does not allow other tool calls ({leaked})",
+ param="tool_choice",
+ )
+ if any(tc.get("function", {}).get("name") == tool_policy.choice_name for tc in tool_calls):
+ return None
+ return OpenAICompatError(
+ f"tool_choice function '{tool_policy.choice_name}' was not satisfied by model output",
+ param="tool_choice",
+ )
+ return None
+
+
+def _tool_param_type(tools, function_name: str, param_name: str) -> str:
+ params = _find_tool_properties(tools, function_name)
+ cfg = params.get(param_name, {}) if isinstance(params, dict) else {}
+ if isinstance(cfg, dict):
+ if isinstance(cfg.get("type"), str):
+ return cfg["type"].strip().lower()
+ if "anyOf" in cfg:
+ return "object"
+ return "string"
+
+
+def _json_string_fragment(value: str) -> str:
+ return json.dumps(value, ensure_ascii=False)[1:-1]
+
+
+def _find_toolish_json_start(text: str, start: int = 0) -> int:
+ """Return the earliest `{` that could begin a supported JSON tool call."""
+ keys = ("name", "function", "arguments")
+ cursor = start
+ while cursor < len(text):
+ brace = text.find("{", cursor)
+ if brace == -1:
+ return -1
+ key_cursor = brace + 1
+ while key_cursor < len(text) and text[key_cursor].isspace():
+ key_cursor += 1
+ if key_cursor >= len(text):
+ return brace
+ if text[key_cursor] != '"':
+ cursor = brace + 1
+ continue
+ key_start = key_cursor + 1
+ key_end = key_start
+ while key_end < len(text) and text[key_end] not in {'"', '\\'}:
+ key_end += 1
+ if key_end >= len(text):
+ fragment = text[key_start:]
+ if any(k.startswith(fragment) for k in keys):
+ return brace
+ cursor = brace + 1
+ continue
+ if text[key_end] == "\\":
+ cursor = brace + 1
+ continue
+ key = text[key_start:key_end]
+ if not any(k.startswith(key) for k in keys):
+ cursor = brace + 1
+ continue
+ colon_cursor = key_end + 1
+ while colon_cursor < len(text) and text[colon_cursor].isspace():
+ colon_cursor += 1
+ if colon_cursor >= len(text) or text[colon_cursor] == ":":
+ return brace
+ cursor = brace + 1
+ return -1
+
+
+def _json_tool_call_spans(text: str) -> list[tuple[int, int]]:
+ spans: list[tuple[int, int]] = []
+ decoder = json.JSONDecoder()
+ cursor = 0
+ while cursor < len(text):
+ start = text.find("{", cursor)
+ if start == -1:
+ break
+ try:
+ obj, consumed = decoder.raw_decode(text[start:])
+ except json.JSONDecodeError:
+ cursor = start + 1
+ continue
+ if _parse_json_tool_call(obj) is not None:
+ spans.append((start, start + consumed))
+ cursor = start + max(consumed, 1)
+ return spans
+
+
+def _build_partial_xml_arguments(text: str, cursor: int, function_name: str, tools) -> tuple[str, bool, int]:
+ param_config = _find_tool_properties(tools, function_name)
+ fragments: list[str] = []
+ complete = False
+
+ while cursor < len(text):
+ while cursor < len(text) and text[cursor].isspace():
+ cursor += 1
+ if text.startswith("", cursor):
+ cursor += len("")
+ complete = True
+ break
+ if text.startswith("", cursor):
+ cursor += len("")
+ continue
+ if not text.startswith("", cursor)
+ if name_end == -1:
+ break
+ param_name = text[cursor:name_end].strip()
+ cursor = name_end + 1
+
+ end_tag = text.find("", cursor)
+ next_param = text.find("", cursor)
+ candidates = [(pos, kind) for pos, kind in (
+ (end_tag, "parameter"),
+ (next_param, "next_param"),
+ (close_fn, "function"),
+ ) if pos != -1]
+ if candidates:
+ value_end, end_kind = min(candidates, key=lambda item: item[0])
+ else:
+ value_end, end_kind = len(text), "eof"
+ raw_value = text[cursor:value_end]
+ if raw_value.startswith("\n"):
+ raw_value = raw_value[1:]
+ if end_kind == "parameter" and raw_value.endswith("\n"):
+ raw_value = raw_value[:-1]
+
+ if fragments:
+ fragments.append(",")
+ else:
+ fragments.append("{")
+ fragments.append(json.dumps(param_name, ensure_ascii=False))
+ fragments.append(":")
+
+ if _tool_param_type(tools, function_name, param_name) == "string":
+ fragments.append('"')
+ fragments.append(_json_string_fragment(raw_value))
+ if end_kind == "parameter":
+ fragments.append('"')
+ cursor = value_end + len("")
+ continue
+ break
+
+ if end_kind == "parameter":
+ value_obj = _convert_param_value(raw_value, param_name, param_config, function_name)
+ fragments.append(json.dumps(value_obj, ensure_ascii=False, separators=(",", ":")))
+ cursor = value_end + len("")
+ continue
+
+ del fragments[prefix_len:]
+ break
+
+ if complete:
+ if not fragments:
+ return "{}", True, cursor
+ fragments.append("}")
+ return "".join(fragments), complete, cursor
+
+
+def _partial_tool_call_snapshots(text: str, tools=None) -> list[_PartialToolCallSnapshot]:
+ snapshots: list[_PartialToolCallSnapshot] = []
+ decoder = json.JSONDecoder()
+ cursor = 0
+ while cursor < len(text):
+ hits = [(i, kind) for i, kind in (
+ (text.find(TOOL_OPEN_TAG, cursor), "tool"),
+ (text.find(FUNCTION_OPEN_TAG, cursor), "function"),
+ (text.find(TOOL_CODE_OPEN_TAG, cursor), "tool_code"),
+ (_find_toolish_json_start(text, cursor), "json"),
+ ) if i != -1]
+ if not hits:
+ break
+ start, kind = min(hits, key=lambda item: item[0])
+ cursor = start
+ if kind == "tool_code":
+ close = text.find("", cursor + len(TOOL_CODE_OPEN_TAG))
+ if close == -1:
+ break
+ try:
+ obj = json.loads(text[cursor + len(TOOL_CODE_OPEN_TAG):close].strip())
+ except json.JSONDecodeError:
+ cursor = close + len("")
+ continue
+ parsed = _parse_json_tool_call(obj)
+ if parsed is not None and _tool_allowed(tools, parsed[0]):
+ snapshots.append(_PartialToolCallSnapshot(
+ name=parsed[0],
+ arguments=json.dumps(parsed[1], ensure_ascii=False, separators=(",", ":")),
+ complete=True,
+ ))
+ cursor = close + len("")
+ continue
+ if kind == "json":
+ try:
+ obj, consumed = decoder.raw_decode(text[cursor:])
+ except json.JSONDecodeError:
+ break
+ parsed = _parse_json_tool_call(obj)
+ if parsed is not None and _tool_allowed(tools, parsed[0]):
+ snapshots.append(_PartialToolCallSnapshot(
+ name=parsed[0],
+ arguments=json.dumps(parsed[1], ensure_ascii=False, separators=(",", ":")),
+ complete=True,
+ ))
+ cursor += max(consumed, 1)
+ continue
+ if text.startswith(TOOL_OPEN_TAG, cursor):
+ cursor += len(TOOL_OPEN_TAG)
+ while cursor < len(text) and text[cursor].isspace():
+ cursor += 1
+ if not text.startswith("", fn_cursor)
+ signature_end = text.find("(", fn_cursor)
+ if signature_end != -1 and (header_end == -1 or signature_end < header_end):
+ close = text.find("", signature_end)
+ if close == -1:
+ break
+ _, calls = parse_tool_calls(text[start:close + len("")], tools=tools)
+ if calls:
+ snapshots.append(_PartialToolCallSnapshot(
+ name=calls[0]["function"]["name"],
+ arguments=calls[0]["function"]["arguments"],
+ complete=True,
+ ))
+ cursor = close + len("")
+ continue
+ if header_end == -1:
+ break
+
+ function_name = text[fn_cursor:header_end].strip()
+ if not function_name:
+ cursor = start + 1
+ continue
+ if not _tool_allowed(tools, function_name):
+ close = text.find("", header_end + 1)
+ cursor = close + len("") if close != -1 else header_end + 1
+ continue
+
+ arguments, complete, body_end = _build_partial_xml_arguments(
+ text, header_end + 1, function_name, tools)
+ snapshots.append(_PartialToolCallSnapshot(
+ name=function_name,
+ arguments=arguments,
+ complete=complete,
+ ))
+ cursor = max(body_end, header_end + 1)
+ if not complete:
+ break
+ tool_close = text.find("", cursor)
+ if tool_close != -1:
+ cursor = max(cursor, tool_close + len(""))
+ return snapshots
+
+
+def _complete_tool_buffer_spans(text: str) -> list[tuple[int, int]]:
+ spans: list[tuple[int, int]] = []
+ for match in TOOL_CALL_COMPLETE_RE.finditer(text):
+ spans.append((match.start(), match.end()))
+ for pattern in (BARE_FUNCTION_XML_RE, FUNCTION_SIGNATURE_RE):
+ for match in pattern.finditer(text):
+ if any(lo <= match.start() < hi for lo, hi in spans):
+ continue
+ spans.append((match.start(), match.end()))
+ for match in TOOL_CODE_RE.finditer(text):
+ try:
+ obj = json.loads(match.group(1).strip())
+ except json.JSONDecodeError:
+ continue
+ if _parse_json_tool_call(obj) is None:
+ continue
+ if any(lo <= match.start() < hi for lo, hi in spans):
+ continue
+ spans.append((match.start(), match.end()))
+ for start, end in _json_tool_call_spans(text):
+ if any(lo <= start < hi for lo, hi in spans):
+ continue
+ spans.append((start, end))
+ spans.sort()
+ return spans
+
+
+def _first_tool_buffer_stop_match(text: str, stops: list[str]) -> int:
+ spans = _complete_tool_buffer_spans(text)
+ if not spans:
+ return -1
+
+ def stop_in_gap(start: int, end: int) -> int:
+ json_open = _find_toolish_json_start(text[start:end])
+ if json_open != -1:
+ json_open += start
+ next_open = min(
+ (pos for pos in (
+ text.find(TOOL_OPEN_TAG, start, end),
+ text.find(FUNCTION_OPEN_TAG, start, end),
+ text.find(TOOL_CODE_OPEN_TAG, start, end),
+ json_open,
+ ) if pos != -1),
+ default=-1,
+ )
+ if next_open != -1:
+ end = next_open
+ if start >= end:
+ return -1
+ gap_index = first_stop_match(text[start:end], stops)
+ return start + gap_index if gap_index != -1 else -1
+
+ cursor = 0
+ for start, end in spans:
+ stop_index = stop_in_gap(cursor, start)
+ if stop_index != -1:
+ return stop_index
+ cursor = max(cursor, end)
+ return stop_in_gap(cursor, len(text))
+
+
+def _sync_partial_tool_stream(tool_buffer: str, tools, states: list[_TrackedToolCallState]) -> list[dict]:
+ events: list[dict] = []
+ for index, snapshot in enumerate(_partial_tool_call_snapshots(tool_buffer, tools=tools)):
+ while len(states) <= index:
+ states.append(_TrackedToolCallState(
+ index=len(states),
+ id="call_" + uuid.uuid4().hex[:24],
+ ))
+ state = states[index]
+ if snapshot.name and state.name is None:
+ state.name = snapshot.name
+ if state.name and not state.announced:
+ state.announced = True
+ events.append({
+ "kind": "tool_start",
+ "index": index,
+ "id": state.id,
+ "name": state.name,
+ })
+ if snapshot.arguments:
+ fragment = snapshot.arguments
+ if snapshot.arguments.startswith(state.emitted_arguments):
+ fragment = snapshot.arguments[len(state.emitted_arguments):]
+ if fragment:
+ state.emitted_arguments = snapshot.arguments
+ events.append({
+ "kind": "tool_args",
+ "index": index,
+ "id": state.id,
+ "name": state.name,
+ "arguments": fragment,
+ })
+ if snapshot.complete and not state.done:
+ state.done = True
+ events.append({
+ "kind": "tool_done",
+ "index": index,
+ "id": state.id,
+ "name": state.name,
+ "arguments": snapshot.arguments or "{}",
+ })
+ return events
+
+
+def _reconcile_final_tool_events(tool_calls: list[dict], states: list[_TrackedToolCallState]) -> list[dict]:
+ events: list[dict] = []
+ for index, tc in enumerate(tool_calls):
+ while len(states) <= index:
+ states.append(_TrackedToolCallState(
+ index=len(states),
+ id="call_" + uuid.uuid4().hex[:24],
+ ))
+ state = states[index]
+ state.name = tc["function"]["name"]
+ tc["id"] = state.id
+ if not state.announced:
+ state.announced = True
+ events.append({
+ "kind": "tool_start",
+ "index": index,
+ "id": state.id,
+ "name": state.name,
+ })
+ final_args = tc["function"]["arguments"]
+ fragment = final_args
+ if final_args.startswith(state.emitted_arguments):
+ fragment = final_args[len(state.emitted_arguments):]
+ if fragment:
+ state.emitted_arguments = final_args
+ events.append({
+ "kind": "tool_args",
+ "index": index,
+ "id": state.id,
+ "name": state.name,
+ "arguments": fragment,
+ })
+ if not state.done:
+ state.done = True
+ events.append({
+ "kind": "tool_done",
+ "index": index,
+ "id": state.id,
+ "name": state.name,
+ "arguments": final_args,
+ })
+ return events
+
+
+def _finalize_tool_buffer_stream(
+ state: _SharedStreamState,
+ tool_policy: ToolPolicy,
+) -> tuple[list[dict], list[dict]]:
+ events: list[dict] = []
+ tool_calls: list[dict] = []
+ cleaned_after, tool_calls = _parse_generated_tool_calls(state.tool_buffer, tool_policy)
+ if tool_calls:
+ events.extend(_reconcile_final_tool_events(tool_calls, state.tool_states))
+ if cleaned_after:
+ state.visible_text += cleaned_after
+ events.append({"kind": "content", "text": cleaned_after})
+ elif state.tool_buffer:
+ state.visible_text += state.tool_buffer
+ events.append({"kind": "content", "text": state.tool_buffer})
+ state.final_tool_calls = tool_calls
+ state.tool_buffer = ""
+ state.mode = "content"
+ return events, tool_calls
+
+
+def _make_shared_stream_state(started_in_thinking: bool, stops: list[str], tool_policy: ToolPolicy) -> _SharedStreamState:
+ tag_holdback = max(
+ len(THINK_OPEN_TAG),
+ len(THINK_CLOSE_TAG),
+ len(TOOL_OPEN_TAG) if tool_policy.parse_tool_calls else 0,
+ len(FUNCTION_OPEN_TAG) if tool_policy.parse_tool_calls else 0,
+ len(TOOL_CODE_OPEN_TAG) if tool_policy.parse_tool_calls else 0,
+ )
+ stop_holdback = max((len(s) for s in stops), default=0)
+ return _SharedStreamState(
+ mode="reasoning" if started_in_thinking else "content",
+ holdback=max(tag_holdback, stop_holdback),
+ allow_tools=tool_policy.parse_tool_calls,
+ )
+
+
+def _feed_shared_stream_piece(
+ state: _SharedStreamState,
+ piece: str,
+ *,
+ stops: list[str],
+ tool_policy: ToolPolicy,
+) -> list[dict]:
+ events: list[dict] = []
+ state.raw_text += piece
+ state.window += piece
+
+ if stops:
+ if state.mode == "tool_buffer":
+ combined = state.tool_buffer + state.window
+ stop_index = _first_tool_buffer_stop_match(combined, stops)
+ if stop_index != -1:
+ state.raw_text = _trim_stream_raw_suffix(
+ state.raw_text, len(combined), stop_index)
+ state.tool_buffer = combined[:stop_index]
+ state.window = ""
+ state.stop_hit = True
+ else:
+ window_len = len(state.window)
+ stop_index = first_stop_match(state.window, stops)
+ if stop_index != -1:
+ state.raw_text = _trim_stream_raw_suffix(
+ state.raw_text, window_len, stop_index)
+ state.window = state.window[:stop_index]
+ state.stop_hit = True
+
+ while True:
+ if state.mode == "tool_buffer":
+ state.tool_buffer += state.window
+ state.window = ""
+ events.extend(_sync_partial_tool_stream(
+ state.tool_buffer, tool_policy.parse_tools, state.tool_states))
+ if state.stop_hit:
+ finalized_events, _ = _finalize_tool_buffer_stream(state, tool_policy)
+ events.extend(finalized_events)
+ break
+
+ if state.mode == "reasoning":
+ idx = state.window.find(THINK_CLOSE_TAG)
+ if idx != -1:
+ pre = state.window[:idx]
+ if pre:
+ events.append({"kind": "reasoning", "text": pre})
+ state.window = state.window[idx + len(THINK_CLOSE_TAG):]
+ state.mode = "content"
+ continue
+ if len(state.window) > state.holdback or (state.stop_hit and state.window):
+ safe = state.window if state.stop_hit else state.window[:-state.holdback]
+ if safe:
+ events.append({"kind": "reasoning", "text": safe})
+ state.window = "" if state.stop_hit else state.window[-state.holdback:]
+ break
+
+ think_idx = state.window.find(THINK_OPEN_TAG)
+ think_close_idx = state.window.find(THINK_CLOSE_TAG)
+ tool_idx = state.window.find(TOOL_OPEN_TAG) if state.allow_tools else -1
+ bare_fn_idx = state.window.find(FUNCTION_OPEN_TAG) if state.allow_tools else -1
+ tool_code_idx = state.window.find(TOOL_CODE_OPEN_TAG) if state.allow_tools else -1
+ json_tool_idx = _find_toolish_json_start(state.window) if state.allow_tools else -1
+ hits = [(i, t) for i, t in (
+ (think_idx, "think"),
+ (think_close_idx, "think_close"),
+ (tool_idx, "tool"),
+ (bare_fn_idx, "tool"),
+ (tool_code_idx, "tool"),
+ (json_tool_idx, "tool"),
+ ) if i != -1]
+ if hits:
+ hits.sort()
+ idx, which = hits[0]
+ pre = state.window[:idx]
+ if pre:
+ state.visible_text += pre
+ events.append({"kind": "content", "text": pre})
+ if which == "think":
+ state.window = state.window[idx + len(THINK_OPEN_TAG):]
+ state.mode = "reasoning"
+ elif which == "think_close":
+ state.window = state.window[idx + len(THINK_CLOSE_TAG):]
+ else:
+ state.tool_buffer += state.window[idx:]
+ state.window = ""
+ state.mode = "tool_buffer"
+ events.extend(_sync_partial_tool_stream(
+ state.tool_buffer, tool_policy.parse_tools, state.tool_states))
+ if state.stop_hit:
+ finalized_events, _ = _finalize_tool_buffer_stream(state, tool_policy)
+ events.extend(finalized_events)
+ continue
+ if len(state.window) > state.holdback or (state.stop_hit and state.window):
+ safe = state.window if state.stop_hit else state.window[:-state.holdback]
+ if safe:
+ state.visible_text += safe
+ events.append({"kind": "content", "text": safe})
+ state.window = "" if state.stop_hit else state.window[-state.holdback:]
+ break
+ return events
+
+
+def _flush_shared_stream(
+ state: _SharedStreamState,
+ *,
+ tool_policy: ToolPolicy,
+) -> tuple[list[dict], list[dict]]:
+ events: list[dict] = []
+ tool_calls: list[dict] = []
+
+ if state.mode == "reasoning" and state.window:
+ events.append({"kind": "reasoning", "text": state.window})
+ elif state.mode == "content" and state.window:
+ state.visible_text += state.window
+ events.append({"kind": "content", "text": state.window})
+ elif state.mode == "tool_buffer":
+ state.tool_buffer += state.window
+ events.extend(_sync_partial_tool_stream(
+ state.tool_buffer, tool_policy.parse_tools, state.tool_states))
+ state.window = ""
+
+ if state.mode == "tool_buffer":
+ finalized_events, tool_calls = _finalize_tool_buffer_stream(state, tool_policy)
+ events.extend(finalized_events)
+
+ return events, tool_calls
+
+
+def _parse_responses_non_stream_output(
+ text: str,
+ tool_policy: ToolPolicy,
+ msg_item_id: str,
+ *,
+ started_in_thinking: bool,
+) -> tuple[str, list[dict], list[tuple[str, str]]]:
+ state = _make_shared_stream_state(started_in_thinking, [], tool_policy)
+ output_order: list[tuple[str, str]] = []
+ message_announced = False
+
+ def record_output_order(events: list[dict]) -> None:
+ nonlocal message_announced
+ for event in events:
+ if event["kind"] == "content":
+ if not message_announced:
+ output_order.append(("message", msg_item_id))
+ message_announced = True
+ continue
+ if event["kind"] == "tool_start":
+ output_order.append(("function_call", event["id"]))
+
+ stream_events = _feed_shared_stream_piece(
+ state, text, stops=[], tool_policy=tool_policy)
+ record_output_order(stream_events)
+ flush_events, tool_calls = _flush_shared_stream(state, tool_policy=tool_policy)
+ record_output_order(flush_events)
+ return state.visible_text.strip(), tool_calls, output_order
+
+
+def _responses_message_item(msg_item_id: str, text: str, *, status: str = "completed") -> dict:
+ return {
+ "type": "message",
+ "id": msg_item_id,
+ "status": status,
+ "role": "assistant",
+ "content": [{"type": "output_text", "text": text, "annotations": []}],
+ }
+
+
+def _responses_function_call_item(tc: dict, *, status: str = "completed") -> dict:
+ return {
+ "type": "function_call",
+ "id": tc["id"],
+ "status": status,
+ "call_id": tc["id"],
+ "name": tc["function"]["name"],
+ "arguments": tc["function"]["arguments"],
+ }
+
+
+def _build_responses_output(
+ cleaned_text: str,
+ tool_calls: list[dict],
+ msg_item_id: str,
+ *,
+ output_order: list[tuple[str, str]] | None = None,
+) -> list[dict]:
+ items: dict[tuple[str, str], dict] = {}
+ if cleaned_text or not tool_calls:
+ items[("message", msg_item_id)] = _responses_message_item(msg_item_id, cleaned_text)
+ for tc in tool_calls:
+ items[("function_call", tc["id"])] = _responses_function_call_item(tc)
+
+ if not output_order:
+ output: list[dict] = []
+ if ("message", msg_item_id) in items:
+ output.append(items.pop(("message", msg_item_id)))
+ for tc in tool_calls:
+ item = items.pop(("function_call", tc["id"]), None)
+ if item is not None:
+ output.append(item)
+ return output
+
+ output = []
+ seen: set[tuple[str, str]] = set()
+ for key in output_order:
+ item = items.get(key)
+ if item is not None and key not in seen:
+ output.append(item)
+ seen.add(key)
+ for key, item in items.items():
+ if key not in seen:
+ output.append(item)
+ return output
+
+
def _samp_suffix(req) -> str:
# Render ` samp=temp,top_p,top_k,rep_pen[,seed]` tail when the request asks for
# non-greedy decoding. Empty string keeps the daemon protocol greedy-compatible.
@@ -696,7 +1544,22 @@ def _resolve_kv_k_type():
cap=prefix_cache_slots,
)
if prefill_cfg is not None and prefill_cache_slots > 0:
- prefix_cache.init_full_cache(prefill_cache_slots, budget_bytes=prefill_cache_bytes)
+ prefix_cache.init_full_cache(prefill_cache_slots)
+ tool_memory = ToolMemory(
+ max_entries=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_ENTRIES", "50000")),
+ max_bytes=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_BYTES", str(64 * 1024 * 1024))),
+ )
+
+ def _remember_tool_call_text(raw_text: str, tool_calls: list[dict] | None) -> None:
+ if not raw_text or not tool_calls:
+ return
+ call_ids = [
+ tc.get("id")
+ for tc in tool_calls
+ if isinstance(tc, dict) and isinstance(tc.get("id"), str) and tc.get("id")
+ ]
+ if call_ids:
+ tool_memory.remember(call_ids, raw_text)
@app.on_event("startup")
async def _startup():
@@ -796,18 +1659,23 @@ def _render_messages(msgs_list: list[dict],
param="messages")
return _ids_to_bin(ids), ids, prompt
- def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], bool]:
+ def _tokenize_prompt(req: ChatRequest, tool_policy: ToolPolicy) -> tuple[Path, list[int], list[dict], bool]:
"""Returns (bin, ids, raw_msgs, started_in_thinking)."""
msgs: list[dict] = []
for m in req.messages:
d: dict = {"role": m.role}
- if m.content is not None:
+ replay_raw_text = None
+ if m.role == "assistant" and m.tool_calls is not None:
+ replay_raw_text = tool_memory.lookup_message(m.tool_calls)
+ if replay_raw_text is not None:
+ d["content"] = replay_raw_text
+ elif m.content is not None:
d["content"] = _content_to_str(m.content)
if m.name is not None:
d["name"] = m.name
if m.tool_call_id is not None:
d["tool_call_id"] = m.tool_call_id
- if m.tool_calls is not None:
+ if m.tool_calls is not None and replay_raw_text is None:
d["tool_calls"] = []
for tc in m.tool_calls:
args = tc.function.arguments
@@ -823,16 +1691,23 @@ def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], boo
msgs.append(d)
tools_arg = None
- if req.tools:
- tools_arg = [t.model_dump() for t in req.tools]
+ if tool_policy.render_tools and tool_policy.prompt_tools:
+ tools_arg = [
+ t.model_dump() if hasattr(t, "model_dump") else t
+ for t in tool_policy.prompt_tools
+ ]
path, ids, _prompt = _render_messages(msgs, req.chat_template_kwargs, tools_arg)
started_in_thinking = bool(re.search(r"\s*$", _prompt))
return path, ids, msgs, started_in_thinking
def _maybe_compress(msgs: list[dict], prompt_bin: Path, prompt_ids: list[int],
- template_kwargs: dict | None = None
+ template_kwargs: dict | None = None,
+ *,
+ bypass: bool = False,
) -> tuple[Path, list[int]]:
+ if bypass:
+ return prompt_bin, prompt_ids
if not prefill_cfg or not prefill_cfg.enabled:
return prompt_bin, prompt_ids
if not prefill_cfg.should_compress(len(prompt_ids)):
@@ -900,6 +1775,15 @@ async def _collect_tokens_sync(r, n_gen, timing=None) -> list[int]:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, lambda: list(_token_stream(r, n_gen, timing)))
+ async def _adrain_until_sentinel(r, timing=None) -> None:
+ if timing is not None and timing.get("daemon_done"):
+ return
+ loop = asyncio.get_running_loop()
+ await loop.run_in_executor(None, _drain_until_sentinel, r)
+ if timing is not None:
+ timing["daemon_done"] = True
+ timing["t_last_tok"] = time.monotonic()
+
async def _astream_tokens(r, n_gen, timing=None):
generated = 0
hit_stop = False
@@ -944,16 +1828,19 @@ def _write_cmd(cmd_line: str, timing=None):
sz = bin_path.stat().st_size
if sz == 0:
log.warning("prompt .bin is 0 bytes: %s", bin_path)
- if lazy_draft:
- log.debug("lazy-draft: unpark draft before generate")
- t = time.monotonic()
- daemon_proc.stdin.write(b"unpark draft\n")
+ try:
+ if lazy_draft:
+ log.debug("lazy-draft: unpark draft before generate")
+ t = time.monotonic()
+ daemon_proc.stdin.write(b"unpark draft\n")
+ daemon_proc.stdin.flush()
+ _drain_until_sentinel(r_pipe)
+ if timing is not None:
+ timing["unpark"] = time.monotonic() - t
+ daemon_proc.stdin.write(cmd_line.encode("utf-8"))
daemon_proc.stdin.flush()
- _drain_until_sentinel(r_pipe)
- if timing is not None:
- timing["unpark"] = time.monotonic() - t
- daemon_proc.stdin.write(cmd_line.encode("utf-8"))
- daemon_proc.stdin.flush()
+ except (BrokenPipeError, OSError, ValueError) as exc:
+ raise RuntimeError(f"failed to send command to dflash daemon: {exc}") from exc
if timing is not None:
timing["t_cmd_sent"] = time.monotonic()
@@ -1048,8 +1935,15 @@ def _build_cmd_line(req, cur_bin, cur_ids, gen_len, prefix_cache,
cmd_line += f" snap={snap_prep[1]}:{snap_prep[0]}"
return cmd_line + _samp_suffix(req) + "\n", snap_prep
+ def _abort_reserved_snap(full_snap_prep, snap_prep):
+ if full_snap_prep is not None:
+ fslot, _ = full_snap_prep
+ prefix_cache.abort_full_snap(fslot)
+ elif snap_prep:
+ prefix_cache.abort_inline_snap(snap_prep[0])
+
def _confirm_or_abort_snap(n_tokens: int, full_snap_prep, snap_prep,
- prompt_ids, cur_bin, cur_ids):
+ prompt_ids, cur_bin, cur_ids):
"""Confirm prefix-cache snapshots only when the daemon actually
generated tokens. When the daemon returns 0 tokens (e.g. empty
prompt / file read failure), confirming would register a snapshot
@@ -1063,11 +1957,7 @@ def _confirm_or_abort_snap(n_tokens: int, full_snap_prep, snap_prep,
prefix_cache.confirm_inline_snap(*snap_prep, cur_ids)
else:
# Abort: release the reservation without registering.
- if full_snap_prep is not None:
- fslot, _ = full_snap_prep
- prefix_cache.abort_full_snap(fslot)
- elif snap_prep:
- prefix_cache.abort_inline_snap(snap_prep[0])
+ _abort_reserved_snap(full_snap_prep, snap_prep)
log.warning("0 output tokens — aborted snapshot reservation")
def _gen_len_for(prompt_len: int, max_tokens: int) -> int:
@@ -1077,7 +1967,8 @@ def _gen_len_for(prompt_len: int, max_tokens: int) -> int:
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
- prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(req)
+ tool_policy = _resolve_tool_policy(req.tools, req.tool_choice)
+ prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(req, tool_policy)
completion_id = "chatcmpl-" + uuid.uuid4().hex[:24]
created = int(time.time())
prompt_len = len(prompt_ids)
@@ -1096,18 +1987,19 @@ async def chat_completions(req: ChatRequest):
if req.stream:
async def sse() -> AsyncIterator[str]:
- nonlocal started_in_thinking
+ nonlocal started_in_thinking, prompt_len
async with daemon_lock:
timing = {}
full_snap_prep_ref = [None]
snap_prep = None
+ cur_bin = prompt_bin
+ cur_ids = prompt_ids
- full_hit = prefix_cache.lookup_full(prompt_ids)
+ full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids)
if full_hit is not None:
slot, cached_cur_bin, cached_cur_ids_len = full_hit
cur_bin = Path(cached_cur_bin)
prompt_len = cached_cur_ids_len
- started_in_thinking = False # cached: no think prefill
gen_len = _gen_len_for(prompt_len, req.max_tokens)
if gen_len <= 0:
try: prompt_bin.unlink()
@@ -1124,7 +2016,8 @@ async def sse() -> AsyncIterator[str]:
t_compress = time.monotonic()
cur_bin, cur_ids = await asyncio.to_thread(
_maybe_compress, raw_msgs, prompt_bin, prompt_ids,
- req.chat_template_kwargs)
+ req.chat_template_kwargs,
+ bypass=tool_policy.bypass_compression)
timing["compress"] = time.monotonic() - t_compress
prompt_len = len(cur_ids)
gen_len = _gen_len_for(prompt_len, req.max_tokens)
@@ -1147,6 +2040,7 @@ async def sse() -> AsyncIterator[str]:
try:
_write_cmd(cmd_line, timing)
except RuntimeError as e:
+ _abort_reserved_snap(full_snap_prep_ref[0], snap_prep)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
yield "data: [DONE]\n\n"
return
@@ -1159,8 +2053,6 @@ async def sse() -> AsyncIterator[str]:
"finish_reason": None}],
}
yield f"data: {json.dumps(head)}\n\n"
- window, mode = "", ("reasoning" if started_in_thinking else "content")
-
include_usage = bool(req.stream_options and req.stream_options.get("include_usage"))
def chunk(delta_obj, finish=None):
@@ -1169,94 +2061,91 @@ def chunk(delta_obj, finish=None):
"choices": [{"index": 0, "delta": delta_obj,
"finish_reason": finish}]}
- # State machine: mode ∈ {'reasoning', 'content', 'tool_buffer'}
- mode = "reasoning" if started_in_thinking else "content"
- window = ""
- tool_buffer = ""
+ def tool_chunk(index: int, *, call_id: str | None = None,
+ name: str | None = None,
+ arguments: str | None = None) -> str:
+ tc: dict[str, Any] = {"index": index}
+ if call_id is not None:
+ tc["id"] = call_id
+ tc["type"] = "function"
+ fn: dict[str, Any] = {}
+ if name is not None:
+ fn["name"] = name
+ if arguments is not None:
+ fn["arguments"] = arguments
+ if fn:
+ tc["function"] = fn
+ return f"data: {json.dumps(chunk({'tool_calls': [tc]}))}\n\n"
+
stops = normalize_stop(req.stop)
- tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG))
- stop_holdback = max((len(s) for s in stops), default=0)
- HOLDBACK = max(tag_holdback, stop_holdback)
+ stream_state = _make_shared_stream_state(started_in_thinking, stops, tool_policy)
completion_tokens = 0
- stop_hit = False
-
- def emit_delta(text, kind):
- if not text:
- return None
- return f"data: {json.dumps(chunk({kind: text}))}\n\n"
+ finish_reason = "stop"
+ tool_calls: list[dict] = []
try:
async for tok_id in _astream_tokens(r_pipe, gen_len, timing):
completion_tokens += 1
piece = tokenizer.decode([tok_id])
- window += piece
-
- if stops and mode != "tool_buffer":
- si = first_stop_match(window, stops)
- if si != -1:
- window = window[:si]
- stop_hit = True
- kind = "reasoning_content" if mode == "reasoning" else "content"
- out = emit_delta(window, kind)
- if out: yield out
- window = ""
- break
-
- while True:
- if mode == "tool_buffer":
- tool_buffer += window
- window = ""
- break
-
- if mode == "reasoning":
- idx = window.find(THINK_CLOSE_TAG)
- if idx != -1:
- pre = window[:idx]
- out = emit_delta(pre, "reasoning_content")
- if out: yield out
- window = window[idx + len(THINK_CLOSE_TAG):]
- mode = "content"
- continue
- if len(window) > HOLDBACK:
- safe = window[:-HOLDBACK]
- out = emit_delta(safe, "reasoning_content")
- if out: yield out
- window = window[-HOLDBACK:]
- break
-
- else: # mode == "content"
- think_idx = window.find(THINK_OPEN_TAG)
- think_close_idx = window.find(THINK_CLOSE_TAG)
- tool_idx = window.find(TOOL_OPEN_TAG)
- hits = [(i, t) for i, t in
- ((think_idx, "think"),
- (think_close_idx, "think_close"),
- (tool_idx, "tool")) if i != -1]
- if hits:
- hits.sort()
- idx, which = hits[0]
- pre = window[:idx]
- out = emit_delta(pre, "content")
- if out: yield out
- if which == "think":
- window = window[idx + len(THINK_OPEN_TAG):]
- mode = "reasoning"
- elif which == "think_close":
- window = window[idx + len(THINK_CLOSE_TAG):]
- else:
- tool_buffer = window[idx:]
- window = ""
- mode = "tool_buffer"
- continue
- if len(window) > HOLDBACK:
- safe = window[:-HOLDBACK]
- out = emit_delta(safe, "content")
- if out: yield out
- window = window[-HOLDBACK:]
- break
-
- if stop_hit:
- finish_reason = "stop"
+ for event in _feed_shared_stream_piece(
+ stream_state, piece, stops=stops, tool_policy=tool_policy):
+ if event["kind"] == "content":
+ yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n"
+ elif event["kind"] == "reasoning":
+ yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n"
+ elif event["kind"] == "tool_start":
+ yield tool_chunk(
+ event["index"],
+ call_id=event["id"],
+ name=event["name"],
+ )
+ elif event["kind"] == "tool_args":
+ yield tool_chunk(
+ event["index"],
+ arguments=event["arguments"],
+ )
+ if stream_state.stop_hit:
+ break
+
+ if stream_state.stop_hit:
+ if not timing.get("daemon_done"):
+ await _adrain_until_sentinel(r_pipe, timing)
+ stream_events, tool_calls = _flush_shared_stream(
+ stream_state, tool_policy=tool_policy)
+ if not tool_calls and stream_state.final_tool_calls:
+ tool_calls = list(stream_state.final_tool_calls)
+ for event in stream_events:
+ if event["kind"] == "content":
+ yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n"
+ elif event["kind"] == "reasoning":
+ yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n"
+ elif event["kind"] == "tool_start":
+ yield tool_chunk(
+ event["index"],
+ call_id=event["id"],
+ name=event["name"],
+ )
+ elif event["kind"] == "tool_args":
+ yield tool_chunk(
+ event["index"],
+ arguments=event["arguments"],
+ )
+ tool_choice_error = _tool_choice_violation(tool_policy, tool_calls)
+ if tool_choice_error is not None:
+ err = {"error": {
+ "message": tool_choice_error.message,
+ "type": tool_choice_error.error_type,
+ }}
+ if tool_choice_error.param is not None:
+ err["error"]["param"] = tool_choice_error.param
+ yield f"data: {json.dumps(err)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+ if tool_calls:
+ _remember_tool_call_text(stream_state.raw_text, tool_calls)
+ finish_reason = "tool_calls"
+ else:
+ finish_reason = "stop"
yield f"data: {json.dumps(chunk({}, finish=finish_reason))}\n\n"
if include_usage:
usage_chunk = {"id": completion_id, "object": "chat.completion.chunk",
@@ -1266,41 +2155,44 @@ def emit_delta(text, kind):
"total_tokens": prompt_len + completion_tokens}}
yield f"data: {json.dumps(usage_chunk)}\n\n"
yield "data: [DONE]\n\n"
- if timing.get("daemon_done") and full_hit is None:
- try: cur_bin.unlink()
- except Exception: pass
- _park_draft_if_lazy(timing)
return
- # Flush remaining
- if mode == "reasoning" and window:
- out = emit_delta(window, "reasoning_content")
- if out: yield out
- elif mode == "content" and window:
- out = emit_delta(window, "content")
- if out: yield out
- elif mode == "tool_buffer":
- tool_buffer += window
- window = ""
-
- finish_reason = "stop"
- if mode == "tool_buffer":
- cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=req.tools)
- if tool_calls:
- if cleaned_after:
- out = emit_delta(cleaned_after, "content")
- if out: yield out
- tc_delta_list = [{
- "index": i, "id": tc["id"], "type": "function",
- "function": {"name": tc["function"]["name"],
- "arguments": tc["function"]["arguments"]},
- } for i, tc in enumerate(tool_calls)]
- yield f"data: {json.dumps(chunk({'tool_calls': tc_delta_list}))}\n\n"
- finish_reason = "tool_calls"
- else:
- out = emit_delta(tool_buffer, "content")
- if out: yield out
+ stream_events, tool_calls = _flush_shared_stream(
+ stream_state, tool_policy=tool_policy)
+ for event in stream_events:
+ if event["kind"] == "content":
+ yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n"
+ elif event["kind"] == "reasoning":
+ yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n"
+ elif event["kind"] == "tool_start":
+ yield tool_chunk(
+ event["index"],
+ call_id=event["id"],
+ name=event["name"],
+ )
+ elif event["kind"] == "tool_args":
+ yield tool_chunk(
+ event["index"],
+ arguments=event["arguments"],
+ )
+ tool_choice_error = _tool_choice_violation(tool_policy, tool_calls)
+ if tool_choice_error is not None:
+ err = {"error": {
+ "message": tool_choice_error.message,
+ "type": tool_choice_error.error_type,
+ }}
+ if tool_choice_error.param is not None:
+ err["error"]["param"] = tool_choice_error.param
+ yield f"data: {json.dumps(err)}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+ if tool_calls:
+ _remember_tool_call_text(stream_state.raw_text, tool_calls)
+ finish_reason = "tool_calls"
finally:
+ _confirm_or_abort_snap(
+ completion_tokens, full_snap_prep_ref[0], snap_prep,
+ prompt_ids, cur_bin, cur_ids)
if timing.get("daemon_done"):
if full_hit is None:
try: cur_bin.unlink()
@@ -1312,11 +2204,7 @@ def emit_delta(text, kind):
log.warning(
"stream ended before daemon sentinel; "
"retaining prompt .bin for in-flight daemon read")
-
- _confirm_or_abort_snap(
- completion_tokens, full_snap_prep_ref[0], snap_prep,
- prompt_ids, cur_bin, cur_ids)
- _park_draft_if_lazy(timing)
+ _park_draft_if_lazy(timing)
yield f"data: {json.dumps(chunk({}, finish=finish_reason))}\n\n"
if include_usage:
@@ -1347,7 +2235,7 @@ def emit_delta(text, kind):
full_snap_prep_ref = [None]
snap_prep = None
- full_hit = prefix_cache.lookup_full(prompt_ids)
+ full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids)
if full_hit is not None:
slot, cached_cur_bin, cached_cur_ids_len = full_hit
cur_bin = Path(cached_cur_bin)
@@ -1365,7 +2253,8 @@ def emit_delta(text, kind):
t_compress = time.monotonic()
cur_bin, cur_ids = await asyncio.to_thread(
_maybe_compress, raw_msgs, prompt_bin, prompt_ids,
- req.chat_template_kwargs)
+ req.chat_template_kwargs,
+ bypass=tool_policy.bypass_compression)
timing["compress"] = time.monotonic() - t_compress
prompt_len = len(cur_ids)
gen_len = _gen_len_for(prompt_len, req.max_tokens)
@@ -1383,6 +2272,7 @@ def emit_delta(text, kind):
try:
_write_cmd(cmd_line, timing)
except RuntimeError as e:
+ _abort_reserved_snap(full_snap_prep_ref[0], snap_prep)
return JSONResponse({"detail": str(e)}, status_code=503)
# FIX 6: use run_in_executor instead of list() blocking event loop
@@ -1411,7 +2301,11 @@ def emit_delta(text, kind):
thinking_enabled = True
if req.chat_template_kwargs:
thinking_enabled = req.chat_template_kwargs.get("enable_thinking", True)
- cleaned, tool_calls = parse_tool_calls(text, tools=req.tools)
+ cleaned, tool_calls = _parse_generated_tool_calls(text, tool_policy)
+ tool_choice_error = _tool_choice_violation(tool_policy, tool_calls)
+ if tool_choice_error is not None:
+ raise tool_choice_error
+ _remember_tool_call_text(text, tool_calls)
cleaned, reasoning = parse_reasoning(
cleaned,
thinking_enabled=thinking_enabled,
@@ -1530,6 +2424,7 @@ async def sse() -> AsyncIterator[str]:
try:
_write_cmd(cmd_line, timing)
except RuntimeError as e:
+ _abort_reserved_snap(full_snap_prep_ref[0], snap_prep)
yield f"event: error\ndata: {json.dumps({'type':'error','error':{'type':'server_error','message':str(e)}})}\n\n"
return
@@ -1643,6 +2538,7 @@ async def sse() -> AsyncIterator[str]:
try:
_write_cmd(cmd_line, timing)
except RuntimeError as e:
+ _abort_reserved_snap(full_snap_prep_ref[0], snap_prep)
return JSONResponse({"type": "error", "error": {"type": "server_error",
"message": str(e)}}, status_code=503)
@@ -1688,12 +2584,34 @@ def _map_responses_input(req: ResponsesCreateRequest
) -> tuple[list[ChatMessage], list[ToolDef] | None]:
"""Map Responses API input → ChatMessage list + ToolDef list."""
messages: list[ChatMessage] = []
+ pending_assistant_text: list[str] = []
+ pending_assistant_calls: list[ToolCall] = []
# Collect all system-level content (instructions + developer messages)
# and merge into a single system message at position 0, since
# Qwen's chat template requires the system message at the beginning.
system_parts: list[str] = []
+ def _flush_pending_assistant() -> None:
+ if not pending_assistant_text and not pending_assistant_calls:
+ return
+ messages.append(ChatMessage(
+ role="assistant",
+ content="".join(pending_assistant_text) or None,
+ tool_calls=list(pending_assistant_calls) or None,
+ ))
+ pending_assistant_text.clear()
+ pending_assistant_calls.clear()
+
+ def _responses_item_text(content: Any) -> str:
+ if isinstance(content, list):
+ text_parts = []
+ for part in content:
+ if isinstance(part, dict) and part.get("type") in ("output_text", "text", "input_text"):
+ text_parts.append(part.get("text", ""))
+ return "".join(text_parts)
+ return content if isinstance(content, str) else ""
+
if req.instructions:
system_parts.append(req.instructions)
@@ -1709,18 +2627,14 @@ def _map_responses_input(req: ResponsesCreateRequest
if item_type == "message":
role = item.get("role", "user")
- content = item.get("content", "")
- if isinstance(content, list):
- # Extract text from content parts
- text_parts = []
- for part in content:
- if isinstance(part, dict):
- if part.get("type") in ("output_text", "text", "input_text"):
- text_parts.append(part.get("text", ""))
- content = "".join(text_parts)
+ content = _responses_item_text(item.get("content", ""))
if role == "developer" or role == "system":
+ _flush_pending_assistant()
system_parts.append(content)
+ elif role == "assistant":
+ pending_assistant_text.append(content)
else:
+ _flush_pending_assistant()
messages.append(ChatMessage(role=role, content=content))
elif item_type == "function_call":
@@ -1732,10 +2646,10 @@ def _map_responses_input(req: ResponsesCreateRequest
arguments=item.get("arguments", "{}"),
),
)
- messages.append(ChatMessage(
- role="assistant", content=None, tool_calls=[tc]))
+ pending_assistant_calls.append(tc)
elif item_type == "function_call_output":
+ _flush_pending_assistant()
output = item.get("output", "")
if not isinstance(output, str):
output = json.dumps(output)
@@ -1747,6 +2661,8 @@ def _map_responses_input(req: ResponsesCreateRequest
# Ignore reasoning, local_shell_call, etc. — we just
# need the message/function_call/output items for the model.
+ _flush_pending_assistant()
+
# Prepend merged system message
if system_parts:
messages.insert(0, ChatMessage(
@@ -1792,13 +2708,14 @@ async def responses_create(req: ResponsesCreateRequest):
tool_choice=req.tool_choice,
chat_template_kwargs={"enable_thinking": enable_thinking},
)
+ tool_policy = _resolve_tool_policy(chat_req.tools, chat_req.tool_choice)
response_id = "resp_" + uuid.uuid4().hex[:24]
msg_item_id = "msg_" + uuid.uuid4().hex[:24]
created_at = int(time.time())
# Tokenize
- prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(chat_req)
+ prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(chat_req, tool_policy)
prompt_len = len(prompt_ids)
# Summarise roles for the log line
@@ -1817,17 +2734,17 @@ async def responses_create(req: ResponsesCreateRequest):
if req.stream:
return await _responses_stream(
- chat_req, prompt_bin, prompt_ids, raw_msgs,
+ chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy,
started_in_thinking, response_id, msg_item_id,
created_at, prompt_len, time.monotonic())
else:
return await _responses_non_stream(
- chat_req, prompt_bin, prompt_ids, raw_msgs,
+ chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy,
started_in_thinking, response_id, msg_item_id,
created_at, prompt_len, time.monotonic())
async def _responses_non_stream(
- chat_req, prompt_bin, prompt_ids, raw_msgs,
+ chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy,
started_in_thinking, response_id, msg_item_id,
created_at, prompt_len, t0):
"""Non-streaming Responses API handler."""
@@ -1835,14 +2752,13 @@ async def _responses_non_stream(
timing = {}
full_snap_prep_ref = [None]
snap_prep = None
+ cur_ids = None
- full_hit = prefix_cache.lookup_full(prompt_ids)
+ full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids)
if full_hit is not None:
slot, cached_cur_bin, cached_cur_ids_len = full_hit
cur_bin = Path(cached_cur_bin)
- cur_ids = None
prompt_len = cached_cur_ids_len
- started_in_thinking = False # cached: no think prefill
gen_len = _gen_len_for(prompt_len, chat_req.max_tokens)
if gen_len <= 0:
log.warning(
@@ -1861,7 +2777,8 @@ async def _responses_non_stream(
t_compress = time.monotonic()
cur_bin, cur_ids = await asyncio.to_thread(
_maybe_compress, raw_msgs, prompt_bin, prompt_ids,
- chat_req.chat_template_kwargs)
+ chat_req.chat_template_kwargs,
+ bypass=tool_policy.bypass_compression)
timing["compress"] = time.monotonic() - t_compress
prompt_len = len(cur_ids)
gen_len = _gen_len_for(prompt_len, chat_req.max_tokens)
@@ -1885,6 +2802,7 @@ async def _responses_non_stream(
try:
_write_cmd(cmd_line, timing)
except RuntimeError as e:
+ _abort_reserved_snap(full_snap_prep_ref[0], snap_prep)
return JSONResponse({
"type": "error",
"error": {"type": "server_error", "message": str(e)}
@@ -1908,31 +2826,19 @@ async def _responses_non_stream(
thinking_enabled = True
if chat_req.chat_template_kwargs:
thinking_enabled = chat_req.chat_template_kwargs.get("enable_thinking", True)
- cleaned, tool_calls = parse_tool_calls(text, tools=chat_req.tools)
+ cleaned, tool_calls, output_order = _parse_responses_non_stream_output(
+ text, tool_policy, msg_item_id,
+ started_in_thinking=started_in_thinking)
+ tool_choice_error = _tool_choice_violation(tool_policy, tool_calls)
+ if tool_choice_error is not None:
+ raise tool_choice_error
+ _remember_tool_call_text(text, tool_calls)
cleaned, reasoning = parse_reasoning(
cleaned, thinking_enabled=thinking_enabled,
started_in_thinking=started_in_thinking)
- # Build output items
- output: list[dict] = []
- if tool_calls:
- for tc in tool_calls:
- output.append({
- "type": "function_call",
- "id": tc["id"],
- "status": "completed",
- "call_id": tc["id"],
- "name": tc["function"]["name"],
- "arguments": tc["function"]["arguments"],
- })
- else:
- output.append({
- "type": "message",
- "id": msg_item_id,
- "status": "completed",
- "role": "assistant",
- "content": [{"type": "output_text", "text": cleaned, "annotations": []}],
- })
+ output = _build_responses_output(
+ cleaned, tool_calls, msg_item_id, output_order=output_order)
out_types = [o.get("type") for o in output]
elapsed = time.monotonic() - t0
@@ -1959,7 +2865,7 @@ async def _responses_non_stream(
})
async def _responses_stream(
- chat_req, prompt_bin, prompt_ids, raw_msgs,
+ chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy,
started_in_thinking, response_id, msg_item_id,
created_at, prompt_len, t0):
"""Streaming Responses API handler — emits Responses SSE events."""
@@ -1971,13 +2877,13 @@ async def sse() -> AsyncIterator[str]:
timing = {}
full_snap_prep_ref = [None]
snap_prep = None
+ cur_ids = None
- full_hit = prefix_cache.lookup_full(prompt_ids)
+ full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids)
if full_hit is not None:
slot, cached_cur_bin, cached_cur_ids_len = full_hit
cur_bin = Path(cached_cur_bin)
prompt_len = cached_cur_ids_len
- started_in_thinking = False
gen_len = _gen_len_for(prompt_len, chat_req.max_tokens)
if gen_len <= 0:
log.warning(
@@ -1995,7 +2901,8 @@ async def sse() -> AsyncIterator[str]:
t_compress = time.monotonic()
cur_bin, cur_ids = await asyncio.to_thread(
_maybe_compress, raw_msgs, prompt_bin, prompt_ids,
- chat_req.chat_template_kwargs)
+ chat_req.chat_template_kwargs,
+ bypass=tool_policy.bypass_compression)
timing["compress"] = time.monotonic() - t_compress
prompt_len = len(cur_ids)
gen_len = _gen_len_for(prompt_len, chat_req.max_tokens)
@@ -2018,106 +2925,157 @@ async def sse() -> AsyncIterator[str]:
try:
_write_cmd(cmd_line, timing)
except RuntimeError as e:
+ _abort_reserved_snap(full_snap_prep_ref[0], snap_prep)
yield _resp_sse("error", {
"error": {"type": "server_error", "message": str(e)}})
return
- # Lifecycle: response.created
yield _resp_sse("response.created", {
"response": _resp_shell(response_id, chat_req.model, created_at,
"in_progress")})
- # Announce output item
- yield _resp_sse("response.output_item.added", {
- "output_index": 0,
- "item": {"type": "message", "id": msg_item_id,
- "status": "in_progress", "role": "assistant",
- "content": []}})
-
- # Announce content part
- yield _resp_sse("response.content_part.added", {
- "item_id": msg_item_id, "output_index": 0,
- "content_index": 0,
- "part": {"type": "output_text", "text": "", "annotations": []}})
-
- # Stream tokens with state machine
- mode = "reasoning" if started_in_thinking else "content"
- window = ""
- tool_buffer = ""
- accumulated_text = ""
- tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG))
- HOLDBACK = tag_holdback
+ stream_state = _make_shared_stream_state(started_in_thinking, [], tool_policy)
completion_tokens = 0
- tool_call_active = False
+ tool_calls: list[dict] = []
+ next_output_index = 0
+ message_output_index: int | None = None
+ message_announced = False
+ message_done = False
+ tool_output_indices: dict[str, int] = {}
+ output_item_order: list[tuple[str, str]] = []
+
+ def ensure_message_started() -> list[str]:
+ nonlocal next_output_index, message_output_index, message_announced
+ if message_announced:
+ return []
+ message_output_index = next_output_index
+ next_output_index += 1
+ message_announced = True
+ output_item_order.append(("message", msg_item_id))
+ return [
+ _resp_sse("response.output_item.added", {
+ "output_index": message_output_index,
+ "item": {
+ "type": "message",
+ "id": msg_item_id,
+ "status": "in_progress",
+ "role": "assistant",
+ "content": [],
+ },
+ }),
+ _resp_sse("response.content_part.added", {
+ "item_id": msg_item_id,
+ "output_index": message_output_index,
+ "content_index": 0,
+ "part": {"type": "output_text", "text": "", "annotations": []},
+ }),
+ ]
+
+ def ensure_tool_started(event: dict) -> list[str]:
+ nonlocal next_output_index
+ output_index = tool_output_indices.get(event["id"])
+ if output_index is not None:
+ return []
+ output_index = next_output_index
+ next_output_index += 1
+ tool_output_indices[event["id"]] = output_index
+ output_item_order.append(("function_call", event["id"]))
+ return [_resp_sse("response.output_item.added", {
+ "output_index": output_index,
+ "item": {
+ "type": "function_call",
+ "id": event["id"],
+ "status": "in_progress",
+ "call_id": event["id"],
+ "name": event["name"],
+ "arguments": "",
+ },
+ })]
+
+ def finalize_message_item() -> list[str]:
+ nonlocal message_done
+ if not message_announced or message_done or message_output_index is None:
+ return []
+ message_done = True
+ return [
+ _resp_sse("response.output_text.done", {
+ "item_id": msg_item_id,
+ "output_index": message_output_index,
+ "content_index": 0,
+ "text": stream_state.visible_text,
+ }),
+ _resp_sse("response.content_part.done", {
+ "item_id": msg_item_id,
+ "output_index": message_output_index,
+ "content_index": 0,
+ "part": {"type": "output_text", "text": stream_state.visible_text,
+ "annotations": []},
+ }),
+ _resp_sse("response.output_item.done", {
+ "output_index": message_output_index,
+ "item": _responses_message_item(msg_item_id, stream_state.visible_text),
+ }),
+ ]
+
+ def finalize_tool_items(final_tool_calls: list[dict]) -> list[str]:
+ messages: list[str] = []
+ tool_map = {tc["id"]: tc for tc in final_tool_calls}
+ for kind, item_id in output_item_order:
+ if kind != "function_call":
+ continue
+ tc = tool_map.get(item_id)
+ output_index = tool_output_indices.get(item_id)
+ if tc is None or output_index is None:
+ continue
+ messages.append(_resp_sse("response.function_call_arguments.done", {
+ "item_id": item_id,
+ "output_index": output_index,
+ "arguments": tc["function"]["arguments"],
+ "name": tc["function"]["name"],
+ }))
+ messages.append(_resp_sse("response.output_item.done", {
+ "output_index": output_index,
+ "item": _responses_function_call_item(tc),
+ }))
+ return messages
+
+ def emit_stream_event(event: dict) -> list[str]:
+ messages: list[str] = []
+ if event["kind"] == "content":
+ messages.extend(ensure_message_started())
+ assert message_output_index is not None
+ messages.append(_resp_sse("response.output_text.delta", {
+ "item_id": msg_item_id,
+ "output_index": message_output_index,
+ "content_index": 0,
+ "delta": event["text"],
+ }))
+ return messages
+ if event["kind"] == "tool_start":
+ messages.extend(ensure_tool_started(event))
+ return messages
+ if event["kind"] == "tool_args":
+ if event["id"] not in tool_output_indices:
+ messages.extend(ensure_tool_started(event))
+ messages.append(_resp_sse("response.function_call_arguments.delta", {
+ "item_id": event["id"],
+ "output_index": tool_output_indices[event["id"]],
+ "delta": event["arguments"],
+ }))
+ return messages
+ if event["kind"] == "tool_done":
+ if event["id"] not in tool_output_indices:
+ messages.extend(ensure_tool_started(event))
+ return messages
try:
async for tok_id in _astream_tokens(r_pipe, gen_len, timing):
completion_tokens += 1
piece = tokenizer.decode([tok_id])
- window += piece
-
- while True:
- if mode == "tool_buffer":
- tool_buffer += window
- window = ""
- break
-
- if mode == "reasoning":
- idx = window.find(THINK_CLOSE_TAG)
- if idx != -1:
- window = window[idx + len(THINK_CLOSE_TAG):]
- mode = "content"
- continue
- if len(window) > HOLDBACK:
- window = window[-HOLDBACK:]
- break
-
- else: # content
- think_idx = window.find(THINK_OPEN_TAG)
- think_close_idx = window.find(THINK_CLOSE_TAG)
- tool_idx = window.find(TOOL_OPEN_TAG)
- hits = [(i, t) for i, t in
- ((think_idx, "think"),
- (think_close_idx, "think_close"),
- (tool_idx, "tool")) if i != -1]
- if hits:
- hits.sort()
- idx, which = hits[0]
- pre = window[:idx]
- if pre:
- accumulated_text += pre
- yield _resp_sse("response.output_text.delta", {
- "item_id": msg_item_id, "output_index": 0,
- "content_index": 0, "delta": pre})
- if which == "think":
- window = window[idx + len(THINK_OPEN_TAG):]
- mode = "reasoning"
- elif which == "think_close":
- window = window[idx + len(THINK_CLOSE_TAG):]
- else:
- tool_buffer = window[idx:]
- window = ""
- mode = "tool_buffer"
- continue
- if len(window) > HOLDBACK:
- safe = window[:-HOLDBACK]
- accumulated_text += safe
- yield _resp_sse("response.output_text.delta", {
- "item_id": msg_item_id, "output_index": 0,
- "content_index": 0, "delta": safe})
- window = window[-HOLDBACK:]
- break
-
- # Flush remaining window
- if mode == "content" and window:
- accumulated_text += window
- yield _resp_sse("response.output_text.delta", {
- "item_id": msg_item_id, "output_index": 0,
- "content_index": 0, "delta": window})
- elif mode == "tool_buffer":
- tool_buffer += window
- window = ""
-
+ for event in _feed_shared_stream_piece(
+ stream_state, piece, stops=[], tool_policy=tool_policy):
+ for message in emit_stream_event(event):
+ yield message
finally:
if timing.get("daemon_done"):
if full_hit is None:
@@ -2136,70 +3094,45 @@ async def sse() -> AsyncIterator[str]:
prompt_ids, cur_bin, cur_ids)
_park_draft_if_lazy(timing)
- # Build final output items
- final_output: list[dict] = []
- if mode == "tool_buffer" and tool_buffer:
- cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=chat_req.tools)
- if tool_calls:
- if cleaned_after:
- accumulated_text += cleaned_after
- for tc in tool_calls:
- tool_call_active = True
- tc_item_id = tc["id"]
- # Emit function_call_arguments.delta for each tool call
- yield _resp_sse("response.function_call_arguments.delta", {
- "item_id": tc_item_id, "output_index": 0,
- "delta": tc["function"]["arguments"]})
- yield _resp_sse("response.function_call_arguments.done", {
- "item_id": tc_item_id, "output_index": 0,
- "arguments": tc["function"]["arguments"],
- "name": tc["function"]["name"]})
- final_output.append({
- "type": "function_call",
- "id": tc_item_id,
- "status": "completed",
- "call_id": tc_item_id,
- "name": tc["function"]["name"],
- "arguments": tc["function"]["arguments"],
- })
- else:
- accumulated_text += tool_buffer
- yield _resp_sse("response.output_text.delta", {
- "item_id": msg_item_id, "output_index": 0,
- "content_index": 0, "delta": tool_buffer})
-
- # Finalize text output
- yield _resp_sse("response.output_text.done", {
- "item_id": msg_item_id, "output_index": 0,
- "content_index": 0, "text": accumulated_text})
- yield _resp_sse("response.content_part.done", {
- "item_id": msg_item_id, "output_index": 0,
- "content_index": 0,
- "part": {"type": "output_text", "text": accumulated_text,
- "annotations": []}})
-
- if not tool_call_active:
- final_output.append({
- "type": "message",
- "id": msg_item_id,
- "status": "completed",
- "role": "assistant",
- "content": [{"type": "output_text", "text": accumulated_text,
- "annotations": []}],
- })
+ stream_events, tool_calls = _flush_shared_stream(
+ stream_state, tool_policy=tool_policy)
+ for event in stream_events:
+ for message in emit_stream_event(event):
+ yield message
+
+ if tool_calls:
+ _remember_tool_call_text(stream_state.raw_text, tool_calls)
+
+ if not message_announced and not tool_calls:
+ for message in ensure_message_started():
+ yield message
+ for message in finalize_message_item():
+ yield message
+
+ tool_choice_error = _tool_choice_violation(tool_policy, tool_calls)
+ if tool_choice_error is not None:
+ yield _resp_sse("response.failed", {
+ "response": _resp_shell(response_id, chat_req.model, created_at,
+ "failed")})
+ err = {"error": {
+ "type": tool_choice_error.error_type,
+ "message": tool_choice_error.message,
+ }}
+ if tool_choice_error.param is not None:
+ err["error"]["param"] = tool_choice_error.param
+ yield _resp_sse("error", err)
+ return
- yield _resp_sse("response.output_item.done", {
- "output_index": 0,
- "item": final_output[0] if final_output else {
- "type": "message", "id": msg_item_id,
- "status": "completed", "role": "assistant",
- "content": []}})
+ for message in finalize_tool_items(tool_calls):
+ yield message
+ final_output = _build_responses_output(
+ stream_state.visible_text, tool_calls, msg_item_id,
+ output_order=output_item_order)
- # response.completed
shell = _resp_shell(response_id, chat_req.model, created_at,
"completed")
shell["output"] = final_output
- shell["output_text"] = accumulated_text
+ shell["output_text"] = stream_state.visible_text
shell["usage"] = {
"input_tokens": prompt_len,
"output_tokens": completion_tokens,
@@ -2211,7 +3144,7 @@ async def sse() -> AsyncIterator[str]:
log.info(
"responses DONE %s in=%d out=%d %.1fs %.1f tok/s output=%s text_len=%d %s",
response_id, prompt_len, completion_tokens,
- elapsed, tok_s, out_types, len(accumulated_text),
+ elapsed, tok_s, out_types, len(stream_state.visible_text),
_timing_summary(timing, completion_tokens),
)
yield _resp_sse("response.completed", {"response": shell})
diff --git a/dflash/scripts/test_server.py b/dflash/scripts/test_server.py
index 7bdfc569..bd4a37e2 100644
--- a/dflash/scripts/test_server.py
+++ b/dflash/scripts/test_server.py
@@ -8,6 +8,7 @@
import pytest
from fastapi.testclient import TestClient
+from _prefill_hook import PrefillConfig
from server import (
build_app, MODEL_NAME,
parse_tool_calls, parse_reasoning,
@@ -43,11 +44,122 @@ def app(mock_tokenizer):
return a
+@pytest.fixture
+def app_with_prefill(mock_tokenizer):
+ """Build a FastAPI app with compression enabled."""
+ drafter_tokenizer = MagicMock()
+ drafter_tokenizer.return_value = {"input_ids": [1, 2, 3]}
+ drafter_tokenizer.decode.return_value = "compressed prompt"
+ with patch("server.subprocess.Popen") as mock_popen:
+ mock_popen.return_value.poll.return_value = None # daemon alive
+ a = build_app(
+ target=Path("target.gguf"),
+ draft=Path("draft.safetensors"),
+ bin_path=Path("test_dflash"),
+ budget=22,
+ max_ctx=131072,
+ tokenizer=mock_tokenizer,
+ stop_ids={2},
+ prefill_cache_slots=0,
+ prefill_cfg=PrefillConfig(
+ mode="always",
+ threshold=1,
+ keep_ratio=0.5,
+ drafter_gguf=Path("drafter.gguf"),
+ drafter_tokenizer_id="mock-tokenizer",
+ ),
+ drafter_tokenizer=drafter_tokenizer,
+ )
+ return a
+
+
@pytest.fixture
def client(app):
return TestClient(app)
+def _build_app_with_process(mock_tokenizer, process, *, enable_prefill: bool = False):
+ kwargs = {}
+ if enable_prefill:
+ drafter_tokenizer = MagicMock()
+ drafter_tokenizer.return_value = {"input_ids": [1, 2, 3]}
+ drafter_tokenizer.decode.return_value = "compressed prompt"
+ kwargs.update(
+ prefill_cache_slots=0,
+ prefill_cfg=PrefillConfig(
+ mode="always",
+ threshold=1,
+ keep_ratio=0.5,
+ drafter_gguf=Path("drafter.gguf"),
+ drafter_tokenizer_id="mock-tokenizer",
+ ),
+ drafter_tokenizer=drafter_tokenizer,
+ )
+ with patch("server.subprocess.Popen") as mock_popen:
+ mock_popen.return_value = process
+ return build_app(
+ target=Path("target.gguf"),
+ draft=Path("draft.safetensors"),
+ bin_path=Path("test_dflash"),
+ budget=22,
+ max_ctx=131072,
+ tokenizer=mock_tokenizer,
+ stop_ids={2},
+ **kwargs,
+ )
+
+
+def _chat_sse_chunks(text: str) -> list[dict]:
+ return [
+ json.loads(line[6:])
+ for line in text.strip().split("\n\n")
+ if line.startswith("data: ") and line != "data: [DONE]"
+ ]
+
+
+def _responses_sse_events(text: str) -> list[tuple[str, dict]]:
+ events: list[tuple[str, dict]] = []
+ for block in text.strip().split("\n\n"):
+ if not block.strip():
+ continue
+ event_line = next((line for line in block.splitlines() if line.startswith("event: ")), None)
+ data_line = next((line for line in block.splitlines() if line.startswith("data: ")), None)
+ if event_line and data_line:
+ events.append((event_line[7:], json.loads(data_line[6:])))
+ return events
+
+
+def _chat_stream_assistant_message(chunks: list[dict]) -> dict:
+ content_parts: list[str] = []
+ tool_calls: dict[int, dict] = {}
+ for chunk in chunks:
+ for choice in chunk.get("choices", []):
+ delta = choice.get("delta", {})
+ text = delta.get("content")
+ if isinstance(text, str):
+ content_parts.append(text)
+ for tc_delta in delta.get("tool_calls", []):
+ index = tc_delta["index"]
+ state = tool_calls.setdefault(index, {
+ "id": tc_delta.get("id"),
+ "type": tc_delta.get("type", "function"),
+ "function": {"name": None, "arguments": ""},
+ })
+ if tc_delta.get("id"):
+ state["id"] = tc_delta["id"]
+ fn_delta = tc_delta.get("function", {})
+ if fn_delta.get("name"):
+ state["function"]["name"] = fn_delta["name"]
+ if fn_delta.get("arguments"):
+ state["function"]["arguments"] += fn_delta["arguments"]
+ msg = {"role": "assistant"}
+ content = "".join(content_parts)
+ msg["content"] = content or None
+ if tool_calls:
+ msg["tool_calls"] = [tool_calls[i] for i in sorted(tool_calls)]
+ return msg
+
+
# ─── parse_reasoning ───────────────────────────────────────────────
class TestParseReasoning:
@@ -385,6 +497,79 @@ def test_chat_completions_non_streaming_with_tool_call(mock_os_read, mock_pipe,
assert tc[0]["function"]["name"] == "read_file"
+@patch("server.os.pipe")
+@patch("server.os.read")
+def test_chat_completions_replays_raw_tool_call_text(mock_os_read, mock_pipe,
+ mock_tokenizer, app):
+ mock_pipe.return_value = (1, 2)
+ raw_tool_text = (
+ "Before\n"
+ ""
+ "test.py"
+ "\n"
+ "After"
+ )
+ mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"]
+ mock_os_read.side_effect = [
+ struct.pack("", "", "8"]
+ mock_lookup_full.return_value = (7, "cached_prompt.bin", 1)
+ mock_tokenizer.apply_chat_template.return_value = "prompt\n"
+
+ def decode_side_effect(ids, *args, **kwargs):
+ token_id = ids[0] if isinstance(ids, list) else ids
+ return {
+ 10: "hidden reasoning",
+ 11: "",
+ 12: "visible answer",
+ }[token_id]
+
+ mock_tokenizer.decode.side_effect = decode_side_effect
mock_os_read.side_effect = [
struct.pack("" not in text
- assert '"content":"8"' in text or '"content": "8"' in text
-
+ chunks = _chat_sse_chunks(response.text)
+ reasoning = "".join(
+ choice.get("delta", {}).get("reasoning_content", "")
+ for chunk in chunks
+ for choice in chunk.get("choices", [])
+ )
+ content = "".join(
+ choice.get("delta", {}).get("content", "")
+ for chunk in chunks
+ for choice in chunk.get("choices", [])
+ )
+ assert reasoning == "hidden reasoning"
+ assert content == "visible answer"
+ assert "" not in content
+ assert "hidden reasoning" not in content
-# ─── POST /v1/responses ───────────────────────────────────────────
@patch("server.os.pipe")
@patch("server.os.read")
-def test_responses_non_streaming(mock_os_read, mock_pipe, mock_tokenizer, app):
- """POST /v1/responses non-streaming returns ResponsesObject."""
+@pytest.mark.parametrize(("decoded_chunks", "leaked_fragments"), [
+ (
+ [
+ "test",
+ ".py",
+ ],
+ ["", "test",
+ ".py",
+ ],
+ ["{"name":"read_file","arguments":{"path":"test',
+ '.py"}}',
+ ],
+ ["", '{"name":"read_file"'],
+ ),
+])
+def test_chat_completions_streaming_tool_call_deltas(mock_os_read, mock_pipe,
+ mock_tokenizer, app,
+ decoded_chunks, leaked_fragments):
mock_pipe.return_value = (1, 2)
- mock_os_read.side_effect = [struct.pack(" 0
- assert data["usage"]["output_tokens"] > 0
- assert data["usage"]["total_tokens"] == data["usage"]["input_tokens"] + data["usage"]["output_tokens"]
+ chunks = _chat_sse_chunks(response.text)
+ tool_deltas = [
+ chunk["choices"][0]["delta"]["tool_calls"][0]
+ for chunk in chunks
+ if chunk.get("choices")
+ and chunk["choices"][0].get("delta", {}).get("tool_calls")
+ ]
+ assert len(tool_deltas) >= 2
+ assert tool_deltas[0]["id"].startswith("call_")
+ assert tool_deltas[0]["function"]["name"] == "read_file"
+ assert "".join(
+ delta.get("function", {}).get("arguments", "")
+ for delta in tool_deltas
+ ) == '{"path":"test.py"}'
+ assert not any(
+ any(fragment in choice.get("delta", {}).get("content", "") for fragment in leaked_fragments)
+ for chunk in chunks
+ for choice in chunk.get("choices", [])
+ )
+ finish_chunk = next(
+ chunk for chunk in reversed(chunks)
+ if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None
+ )
+ assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls"
@patch("server.os.pipe")
@patch("server.os.read")
-def test_responses_non_streaming_string_input(mock_os_read, mock_pipe,
- mock_tokenizer, app):
- """Responses API accepts a plain string as input."""
+def test_chat_streaming_stop_sequence_preserves_tool_deltas(mock_os_read, mock_pipe,
+ mock_tokenizer, app):
mock_pipe.return_value = (1, 2)
+ stop_marker = ""
+ mock_tokenizer.decode.return_value = (
+ ""
+ "test.py"
+ ""
+ f"{stop_marker}"
+ )
mock_os_read.side_effect = [struct.pack("" in choice.get("delta", {}).get("content", "")
+ for chunk in chunks
+ for choice in chunk.get("choices", [])
+ )
+ assistant_msg = _chat_stream_assistant_message(chunks)
+ assert assistant_msg["content"] is None
+ assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file"
+ assert assistant_msg["tool_calls"][0]["function"]["arguments"] == '{"path":"test.py"}'
+ finish_chunk = next(
+ chunk for chunk in reversed(chunks)
+ if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None
+ )
+ assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls"
@patch("server.os.pipe")
@patch("server.os.read")
-def test_responses_non_streaming_started_in_thinking(mock_os_read, mock_pipe,
- mock_tokenizer, app):
- """When prompt ends with , reasoning without tags is not misclassified as content."""
+def test_chat_streaming_stop_sequence_after_bare_function_close_tag(
+ mock_os_read, mock_pipe, mock_tokenizer, app):
mock_pipe.return_value = (1, 2)
- mock_os_read.side_effect = [struct.pack("\n
- mock_tokenizer.apply_chat_template.return_value = "prompt\n"
- # Model output has no tags — it's a continuation of the prefilled block
- mock_tokenizer.decode.return_value = "internal reasoning\nactual answer"
+ stop_marker = ""
+ mock_tokenizer.decode.side_effect = [
+ "test",
+ ".py",
+ f"{stop_marker}ignored",
+ ]
+ mock_os_read.side_effect = [
+ struct.pack("',
+ ],
+ ['{"name":"read_file"', ""],
+ ),
+ (
+ [
+ '{"name":"read_file","arguments":{"path":"test',
+ '.py"}}',
+ ],
+ ["", '{"name":"read_file"', ""],
+ ),
+])
+def test_chat_streaming_stop_sequence_after_json_tool_call(
+ mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks,
+ leaked_fragments):
mock_pipe.return_value = (1, 2)
- mock_os_read.side_effect = [struct.pack("",
+ "stream": True,
})
assert response.status_code == 200
- # Verify apply_chat_template was called with system message
- calls = mock_tokenizer.apply_chat_template.call_args_list
- last_call = calls[-1]
- msgs = last_call[0][0] # first positional arg
- assert msgs[0]["role"] == "system"
- assert "coding assistant" in msgs[0]["content"]
+ chunks = _chat_sse_chunks(response.text)
+ assistant_msg = _chat_stream_assistant_message(chunks)
+ assert assistant_msg["content"] is None
+ assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file"
+ assert assistant_msg["tool_calls"][0]["function"]["arguments"] == '{"path":"test.py"}'
+ assert not any(
+ any(fragment in choice.get("delta", {}).get("content", "") for fragment in leaked_fragments)
+ for chunk in chunks
+ for choice in chunk.get("choices", [])
+ )
+ finish_chunk = next(
+ chunk for chunk in reversed(chunks)
+ if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None
+ )
+ assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls"
@patch("server.os.pipe")
@patch("server.os.read")
-def test_responses_streaming(mock_os_read, mock_pipe, mock_tokenizer, app):
- """POST /v1/responses streaming emits proper SSE lifecycle events."""
+def test_chat_streaming_replays_raw_tool_call_text(mock_os_read, mock_pipe,
+ mock_tokenizer, app):
mock_pipe.return_value = (1, 2)
- mock_os_read.side_effect = [struct.pack(""
+ "test.py"
+ "\\n"
+ "After"
+ )
+ mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"]
+ mock_os_read.side_effect = [
+ struct.pack("", "", "8"]
+ stop_marker = ""
+ trimmed_raw_tool_text = (
+ "Before\n"
+ ""
+ "test.py"
+ "\n"
+ "After"
+ )
+ mock_tokenizer.decode.side_effect = [
+ trimmed_raw_tool_text + stop_marker + "ignored",
+ "followup",
+ ]
mock_os_read.side_effect = [
- struct.pack("" not in text
- assert '"delta":"8"' in text or '"delta": "8"' in text
+ second = client.post("/v1/chat/completions", json={
+ "model": MODEL_NAME,
+ "messages": [
+ {"role": "user", "content": "read test.py"},
+ assistant_msg,
+ {"role": "tool", "tool_call_id": assistant_msg["tool_calls"][0]["id"], "content": "file body"},
+ {"role": "user", "content": "what next?"},
+ ],
+ "stream": False,
+ })
+ assert second.status_code == 200
+
+ msgs = mock_tokenizer.apply_chat_template.call_args_list[-1][0][0]
+ assistant = next(m for m in msgs if m["role"] == "assistant")
+ assert assistant["content"] == trimmed_raw_tool_text
+ assert stop_marker not in assistant["content"]
+ assert "ignored" not in assistant["content"]
+ assert "tool_calls" not in assistant
@patch("server.os.pipe")
@patch("server.os.read")
-def test_responses_with_tools(mock_os_read, mock_pipe, mock_tokenizer, app):
- """POST /v1/responses with function tools maps correctly."""
+@pytest.mark.parametrize("decoded_chunks", [
+ [
+ "test",
+ ".py",
+ ],
+ [
+ "test",
+ ".py",
+ ],
+])
+def test_chat_streaming_stop_hit_drains_daemon_before_next_stream(
+ mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks):
mock_pipe.return_value = (1, 2)
- mock_os_read.side_effect = [struct.pack("",
+ "stream": True,
})
+ assert first.status_code == 200
+ assistant_msg = _chat_stream_assistant_message(_chat_sse_chunks(first.text))
+ assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file"
- assert response.status_code == 200
- data = response.json()
- assert data["object"] == "response"
- assert data["status"] == "completed"
+ second = client.post("/v1/chat/completions", json={
+ "model": MODEL_NAME,
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": True,
+ })
+ assert second.status_code == 200
+ second_msg = _chat_stream_assistant_message(_chat_sse_chunks(second.text))
+ assert second_msg["content"] == "fresh reply"
+ assert "tool_calls" not in second_msg
+@patch("server.PrefixCache.abort_full_snap")
+@patch("server.PrefixCache.confirm_full_snap")
+@patch("server.PrefixCache.prepare_full_snap", return_value=(7, 0))
+@patch("server.compress_text_via_daemon", return_value="compressed prompt")
@patch("server.os.pipe")
@patch("server.os.read")
-def test_responses_object_tool_choice(mock_os_read, mock_pipe,
- mock_tokenizer, app):
- """POST /v1/responses with object-style tool_choice must not 422."""
+def test_chat_streaming_stop_hit_confirms_reserved_full_snapshot(
+ mock_os_read, mock_pipe, _mock_compress, _mock_prepare_full_snap,
+ mock_confirm_full_snap, mock_abort_full_snap, mock_tokenizer,
+ app_with_prefill):
mock_pipe.return_value = (1, 2)
- mock_os_read.side_effect = [struct.pack("",
+ 12: "stale",
+ }[token_id]
+
+ mock_tokenizer.decode.side_effect = decode_side_effect
+ mock_os_read.side_effect = [
+ struct.pack("",
+ "stream": True,
+ })
+
+ assert response.status_code == 200
+ assistant_msg = _chat_stream_assistant_message(_chat_sse_chunks(response.text))
+ assert assistant_msg["content"] == "hello"
+ mock_confirm_full_snap.assert_called_once()
+ slot, prompt_ids, cur_bin, cur_ids_len = mock_confirm_full_snap.call_args.args
+ assert slot == 7
+ assert prompt_ids == [1]
+ assert isinstance(cur_bin, Path)
+ assert cur_ids_len == 1
+ mock_abort_full_snap.assert_not_called()
+
+
+@patch("server.PrefixCache.abort_full_snap")
+@patch("server.PrefixCache.prepare_full_snap", return_value=(7, 0))
+@patch("server.compress_text_via_daemon", return_value="compressed prompt")
+@patch("server.os.pipe")
+@patch("server.os.read")
+def test_chat_streaming_write_failure_aborts_reserved_full_snapshot(
+ mock_os_read, mock_pipe, _mock_compress, _mock_prepare_full_snap,
+ mock_abort_full_snap, mock_tokenizer):
+ mock_pipe.return_value = (1, 2)
+ mock_os_read.side_effect = []
+ dead_proc = MagicMock()
+ dead_proc.poll.return_value = 1
+ local_app = _build_app_with_process(
+ mock_tokenizer, dead_proc, enable_prefill=True)
+
+ client = TestClient(local_app)
+ response = client.post("/v1/chat/completions", json={
+ "model": MODEL_NAME,
+ "messages": [{"role": "user", "content": "hello"}],
+ "stream": True,
+ })
+
+ assert response.status_code == 200
+ assert "dflash daemon has exited unexpectedly" in response.text
+ mock_abort_full_snap.assert_called_once_with(7)
+
+
+@patch("server.os.pipe")
+@patch("server.os.read")
+def test_chat_tool_choice_required_rejects_plain_text(mock_os_read, mock_pipe,
+ mock_tokenizer, app):
+ mock_pipe.return_value = (1, 2)
+ mock_tokenizer.decode.return_value = "plain text"
+ mock_os_read.side_effect = [struct.pack("file.txt'
+ ''
+ )
+ mock_os_read.side_effect = [struct.pack(" 0
+ assert data["usage"]["output_tokens"] > 0
+ assert data["usage"]["total_tokens"] == data["usage"]["input_tokens"] + data["usage"]["output_tokens"]
+
+
+@patch("server.os.pipe")
+@patch("server.os.read")
+def test_responses_non_streaming_string_input(mock_os_read, mock_pipe,
+ mock_tokenizer, app):
+ """Responses API accepts a plain string as input."""
+ mock_pipe.return_value = (1, 2)
+ mock_os_read.side_effect = [struct.pack(""
+ "file.txt"
+ ""
+ "After tool"
+ )
+ mock_os_read.side_effect = [struct.pack(", reasoning without tags is not misclassified as content."""
+ mock_pipe.return_value = (1, 2)
+ mock_os_read.side_effect = [struct.pack("\n
+ mock_tokenizer.apply_chat_template.return_value = "prompt\n"
+ # Model output has no tags — it's a continuation of the prefilled block
+ mock_tokenizer.decode.return_value = "internal reasoning\nactual answer"
+
+ client = TestClient(app)
+ response = client.post("/v1/responses", json={
+ "model": MODEL_NAME,
+ "input": [{"type": "message", "role": "user", "content": "hello"}],
+ })
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["object"] == "response"
+ # The "actual answer" part should be the output text, not the reasoning
+ assert "actual answer" in data["output_text"]
+ # The reasoning should NOT leak into the output text
+ assert "internal reasoning" not in data["output_text"]
+
+
+@patch("server.os.pipe")
+@patch("server.os.read")
+def test_responses_with_instructions(mock_os_read, mock_pipe,
+ mock_tokenizer, app):
+ """Instructions are mapped to system message."""
+ mock_pipe.return_value = (1, 2)
+ mock_os_read.side_effect = [struct.pack("",
+ 12: "visible answer",
+ }[token_id]
+
+ mock_tokenizer.decode.side_effect = decode_side_effect
+ mock_os_read.side_effect = [
+ struct.pack("" in delta or "hidden reasoning" in delta for delta in deltas)
+ completed = next(data for event, data in events if event == "response.completed")
+ assert completed["response"]["output_text"] == "visible answer"
+
+
+@patch("server.os.pipe")
+@patch("server.os.read")
+@pytest.mark.parametrize("decoded_chunks", [
+ [
+ "test",
+ ".py",
+ ],
+ [
+ "test",
+ ".py",
+ ],
+ [
+ '{"name":"read_file","arguments":{"path":"test',
+ '.py"}}',
+ ],
+ [
+ '{"name":"read_file","arguments":{"path":"test',
+ '.py"}}',
+ ],
+])
+def test_responses_streaming_function_call_lifecycle(mock_os_read, mock_pipe,
+ mock_tokenizer, app,
+ decoded_chunks):
+ mock_pipe.return_value = (1, 2)
+ mock_tokenizer.decode.side_effect = decoded_chunks
+ mock_os_read.side_effect = [
+ struct.pack("file.txt',
+ ],
+ ["", "{"name":"write_file","arguments":{"path":"file.txt"}}',
+ ],
+ ["", '{"name":"write_file"'],
+ ),
+])
+def test_responses_streaming_tool_choice_failure_suppresses_terminal_function_events(
+ mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks,
+ leaked_fragments):
+ mock_pipe.return_value = (1, 2)
+ mock_tokenizer.decode.side_effect = decoded_chunks
+ mock_os_read.side_effect = [struct.pack("file.txt",
+ "After tool",
+ ]
+ mock_os_read.side_effect = [
+ struct.pack("", "", "8"]
+ mock_os_read.side_effect = [
+ struct.pack("" not in text
+ assert '"delta":"8"' in text or '"delta": "8"' in text
+
+
+@patch("server.os.pipe")
+@patch("server.os.read")
+def test_responses_with_tools(mock_os_read, mock_pipe, mock_tokenizer, app):
+ """POST /v1/responses with function tools maps correctly."""
+ mock_pipe.return_value = (1, 2)
+ mock_os_read.side_effect = [struct.pack("'
+ 'file.txt'
+ ''
+ )
+ mock_tokenizer.decode.return_value = raw_tool_text
+ mock_os_read.side_effect = [struct.pack("'
+ 'file.txt'
+ ''
+ )
+ mock_os_read.side_effect = [struct.pack("'
+ 'other.txt'
+ ''
+ ''
+ 'file.txt'
+ ''
+ )
+ mock_os_read.side_effect = [struct.pack("'
+ 'file.txt'
+ ''
+ )
+ mock_os_read.side_effect = [struct.pack("'
+ 'file.txt'
+ ''
+ )
+ mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"]
+ mock_os_read.side_effect = [
+ struct.pack("'
+ 'file.txt'
+ ''
+ 'After tool'
+ )
+ mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"]
+ mock_os_read.side_effect = [
+ struct.pack("old")
+ mem.remember(["call_new"], "new")
+
+ assert mem.lookup_message([{"id": "call_old"}]) is None
+ assert mem.lookup_message([{"id": "call_new"}]) == "new"
+ assert len(mem.by_block) == 1
+
+
+def test_tool_memory_lookup_message_requires_same_raw_text():
+ mem = ToolMemory(max_entries=8, max_bytes=4096)
+ mem.remember(["call_a"], "a")
+ mem.remember(["call_b"], "b")
+
+ assert mem.lookup_message([{"id": "call_a"}, {"id": "call_b"}]) is None
diff --git a/dflash/scripts/tool_memory.py b/dflash/scripts/tool_memory.py
new file mode 100644
index 00000000..58bb4a83
--- /dev/null
+++ b/dflash/scripts/tool_memory.py
@@ -0,0 +1,124 @@
+from __future__ import annotations
+
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Any, Iterable, Sequence
+
+
+@dataclass
+class _ToolMemoryBlock:
+ raw_text: str
+ size_bytes: int
+ refs: int = 0
+
+
+class ToolMemory:
+ """Exact assistant-text memory for tool-calling turns.
+
+ The server receives assistant tool calls back as structured JSON on later
+ turns. Re-rendering those objects can change key order, spacing, or wrapper
+ shape, which changes prompt tokenization and breaks prefix/KV reuse. This
+ store remembers the exact assistant text that originally produced a set of
+ tool-call IDs so replay can inject the original text verbatim.
+ """
+
+ def __init__(self, *, max_entries: int = 50_000, max_bytes: int = 64 * 1024 * 1024):
+ self.max_entries = max(0, int(max_entries))
+ self.max_bytes = max(0, int(max_bytes))
+ self.by_id: dict[str, _ToolMemoryBlock] = {}
+ self.by_block: dict[str, _ToolMemoryBlock] = {}
+ self._lru: OrderedDict[str, None] = OrderedDict()
+ self.total_bytes = 0
+
+ @property
+ def disabled(self) -> bool:
+ return self.max_entries == 0 or self.max_bytes == 0
+
+ def remember(self, call_ids: Iterable[str], raw_text: str) -> None:
+ if self.disabled or not raw_text:
+ return
+ unique_ids = []
+ seen: set[str] = set()
+ for call_id in call_ids:
+ if not isinstance(call_id, str) or not call_id or call_id in seen:
+ continue
+ seen.add(call_id)
+ unique_ids.append(call_id)
+ if not unique_ids:
+ return
+
+ block = self.by_block.get(raw_text)
+ if block is None:
+ block = _ToolMemoryBlock(
+ raw_text=raw_text,
+ size_bytes=len(raw_text.encode("utf-8")),
+ )
+ self.by_block[raw_text] = block
+ self.total_bytes += block.size_bytes
+
+ for call_id in unique_ids:
+ current = self.by_id.get(call_id)
+ if current is block:
+ self._touch(call_id)
+ continue
+ if current is not None:
+ self._drop_entry(call_id, current)
+ self.by_id[call_id] = block
+ block.refs += 1
+ self._touch(call_id)
+
+ self._prune()
+
+ def lookup_message(self, tool_calls: Sequence[Any]) -> str | None:
+ raw_text: str | None = None
+ touched: list[str] = []
+ for item in tool_calls:
+ call_id = self._extract_call_id(item)
+ if not call_id:
+ return None
+ block = self.by_id.get(call_id)
+ if block is None:
+ return None
+ touched.append(call_id)
+ if raw_text is None:
+ raw_text = block.raw_text
+ elif raw_text != block.raw_text:
+ return None
+ if raw_text is None:
+ return None
+ for call_id in touched:
+ self._touch(call_id)
+ return raw_text
+
+ def _touch(self, call_id: str) -> None:
+ self._lru[call_id] = None
+ self._lru.move_to_end(call_id)
+
+ def _prune(self) -> None:
+ while self.by_id and (
+ (self.max_entries > 0 and len(self.by_id) > self.max_entries)
+ or (self.max_bytes > 0 and self.total_bytes > self.max_bytes)
+ ):
+ oldest_id, _ = self._lru.popitem(last=False)
+ block = self.by_id.get(oldest_id)
+ if block is not None:
+ self._drop_entry(oldest_id, block)
+
+ def _drop_entry(self, call_id: str, block: _ToolMemoryBlock) -> None:
+ self.by_id.pop(call_id, None)
+ self._lru.pop(call_id, None)
+ if block.refs > 0:
+ block.refs -= 1
+ if block.refs == 0:
+ self.by_block.pop(block.raw_text, None)
+ self.total_bytes -= block.size_bytes
+ if self.total_bytes < 0:
+ self.total_bytes = 0
+
+ @staticmethod
+ def _extract_call_id(item: Any) -> str | None:
+ if isinstance(item, dict):
+ call_id = item.get("id")
+ return call_id if isinstance(call_id, str) and call_id else None
+ call_id = getattr(item, "id", None)
+ return call_id if isinstance(call_id, str) and call_id else None
diff --git a/pflash/README.md b/pflash/README.md
index 8f6a473f..350704bd 100644
--- a/pflash/README.md
+++ b/pflash/README.md
@@ -95,7 +95,7 @@ python tests/bench_niah_cpp.py \
## OpenAI server flags
-For an OpenAI-compatible server with transparent compression on long prompts, run [`dflash/scripts/server.py`](../dflash/scripts/server.py) (or `server_tools.py` for tool-calling) with these flags:
+For an OpenAI-compatible server with transparent compression on long prompts, run [`dflash/scripts/server.py`](../dflash/scripts/server.py) with these flags:
| Flag | Choices / type | Default | Effect |
|---|---|:---:|---|