diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py
index 9e00d3d9..0d05f45a 100644
--- a/dflash/scripts/server.py
+++ b/dflash/scripts/server.py
@@ -41,6 +41,7 @@
compress_text_via_daemon, _drain_until_sentinel,
)
from prefix_cache import DaemonStdoutBus, PrefixCache
+from tool_memory import ToolMemory
class OpenAICompatError(Exception):
@@ -696,7 +697,22 @@ def _resolve_kv_k_type():
cap=prefix_cache_slots,
)
if prefill_cfg is not None and prefill_cache_slots > 0:
- prefix_cache.init_full_cache(prefill_cache_slots, budget_bytes=prefill_cache_bytes)
+ prefix_cache.init_full_cache(prefill_cache_slots)
+ tool_memory = ToolMemory(
+ max_entries=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_ENTRIES", "50000")),
+ max_bytes=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_BYTES", str(64 * 1024 * 1024))),
+ )
+
+ def _remember_tool_call_text(raw_text: str, tool_calls: list[dict] | None) -> None:
+ if not raw_text or not tool_calls:
+ return
+ call_ids = [
+ tc.get("id")
+ for tc in tool_calls
+ if isinstance(tc, dict) and isinstance(tc.get("id"), str) and tc.get("id")
+ ]
+ if call_ids:
+ tool_memory.remember(call_ids, raw_text)
@app.on_event("startup")
async def _startup():
@@ -801,13 +817,18 @@ def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], boo
msgs: list[dict] = []
for m in req.messages:
d: dict = {"role": m.role}
- if m.content is not None:
+ replay_raw_text = None
+ if m.role == "assistant" and m.tool_calls is not None:
+ replay_raw_text = tool_memory.lookup_message(m.tool_calls)
+ if replay_raw_text is not None:
+ d["content"] = replay_raw_text
+ elif m.content is not None:
d["content"] = _content_to_str(m.content)
if m.name is not None:
d["name"] = m.name
if m.tool_call_id is not None:
d["tool_call_id"] = m.tool_call_id
- if m.tool_calls is not None:
+ if m.tool_calls is not None and replay_raw_text is None:
d["tool_calls"] = []
for tc in m.tool_calls:
args = tc.function.arguments
@@ -1173,6 +1194,8 @@ def chunk(delta_obj, finish=None):
mode = "reasoning" if started_in_thinking else "content"
window = ""
tool_buffer = ""
+ accumulated_content = ""
+ accumulated_raw_text = ""
stops = normalize_stop(req.stop)
tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG))
stop_holdback = max((len(s) for s in stops), default=0)
@@ -1189,6 +1212,7 @@ def emit_delta(text, kind):
async for tok_id in _astream_tokens(r_pipe, gen_len, timing):
completion_tokens += 1
piece = tokenizer.decode([tok_id])
+ accumulated_raw_text += piece
window += piece
if stops and mode != "tool_buffer":
@@ -1197,6 +1221,8 @@ def emit_delta(text, kind):
window = window[:si]
stop_hit = True
kind = "reasoning_content" if mode == "reasoning" else "content"
+ if mode == "content":
+ accumulated_content += window
out = emit_delta(window, kind)
if out: yield out
window = ""
@@ -1236,6 +1262,7 @@ def emit_delta(text, kind):
hits.sort()
idx, which = hits[0]
pre = window[:idx]
+ accumulated_content += pre
out = emit_delta(pre, "content")
if out: yield out
if which == "think":
@@ -1250,6 +1277,7 @@ def emit_delta(text, kind):
continue
if len(window) > HOLDBACK:
safe = window[:-HOLDBACK]
+ accumulated_content += safe
out = emit_delta(safe, "content")
if out: yield out
window = window[-HOLDBACK:]
@@ -1277,6 +1305,7 @@ def emit_delta(text, kind):
out = emit_delta(window, "reasoning_content")
if out: yield out
elif mode == "content" and window:
+ accumulated_content += window
out = emit_delta(window, "content")
if out: yield out
elif mode == "tool_buffer":
@@ -1287,6 +1316,7 @@ def emit_delta(text, kind):
if mode == "tool_buffer":
cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=req.tools)
if tool_calls:
+ _remember_tool_call_text(accumulated_raw_text, tool_calls)
if cleaned_after:
out = emit_delta(cleaned_after, "content")
if out: yield out
@@ -1412,6 +1442,7 @@ def emit_delta(text, kind):
if req.chat_template_kwargs:
thinking_enabled = req.chat_template_kwargs.get("enable_thinking", True)
cleaned, tool_calls = parse_tool_calls(text, tools=req.tools)
+ _remember_tool_call_text(text, tool_calls)
cleaned, reasoning = parse_reasoning(
cleaned,
thinking_enabled=thinking_enabled,
@@ -1909,6 +1940,7 @@ async def _responses_non_stream(
if chat_req.chat_template_kwargs:
thinking_enabled = chat_req.chat_template_kwargs.get("enable_thinking", True)
cleaned, tool_calls = parse_tool_calls(text, tools=chat_req.tools)
+ _remember_tool_call_text(text, tool_calls)
cleaned, reasoning = parse_reasoning(
cleaned, thinking_enabled=thinking_enabled,
started_in_thinking=started_in_thinking)
@@ -2045,6 +2077,7 @@ async def sse() -> AsyncIterator[str]:
window = ""
tool_buffer = ""
accumulated_text = ""
+ accumulated_raw_text = ""
tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG))
HOLDBACK = tag_holdback
completion_tokens = 0
@@ -2054,6 +2087,7 @@ async def sse() -> AsyncIterator[str]:
async for tok_id in _astream_tokens(r_pipe, gen_len, timing):
completion_tokens += 1
piece = tokenizer.decode([tok_id])
+ accumulated_raw_text += piece
window += piece
while True:
@@ -2141,6 +2175,7 @@ async def sse() -> AsyncIterator[str]:
if mode == "tool_buffer" and tool_buffer:
cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=chat_req.tools)
if tool_calls:
+ _remember_tool_call_text(accumulated_raw_text, tool_calls)
if cleaned_after:
accumulated_text += cleaned_after
for tc in tool_calls:
diff --git a/dflash/scripts/test_server.py b/dflash/scripts/test_server.py
index 7bdfc569..564976eb 100644
--- a/dflash/scripts/test_server.py
+++ b/dflash/scripts/test_server.py
@@ -385,6 +385,51 @@ def test_chat_completions_non_streaming_with_tool_call(mock_os_read, mock_pipe,
assert tc[0]["function"]["name"] == "read_file"
+@patch("server.os.pipe")
+@patch("server.os.read")
+def test_chat_completions_replays_raw_tool_call_text(mock_os_read, mock_pipe,
+ mock_tokenizer, app):
+ mock_pipe.return_value = (1, 2)
+ raw_tool_text = (
+ "Before\n"
+ ""
+ "test.py"
+ "\n"
+ "After"
+ )
+ mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"]
+ mock_os_read.side_effect = [
+ struct.pack("" not in text
assert '"content":"8"' in text or '"content": "8"' in text
+@patch("server.os.pipe")
+@patch("server.os.read")
+def test_chat_completions_streaming_replays_exact_raw_text_with_reasoning(
+ mock_os_read, mock_pipe, mock_tokenizer, app):
+ mock_pipe.return_value = (1, 2)
+ raw_tool_turn = (
+ "private chain"
+ "visible"
+ ""
+ "x.py"
+ ""
+ )
+ mock_tokenizer.decode.side_effect = [
+ "private chain",
+ "",
+ "visible",
+ "x.py",
+ "followup",
+ ]
+ mock_os_read.side_effect = [
+ struct.pack("'
+ 'file.txt'
+ ''
+ )
+ mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"]
+ mock_os_read.side_effect = [
+ struct.pack("old")
+ mem.remember(["call_new"], "new")
+
+ assert mem.lookup_message([{"id": "call_old"}]) is None
+ assert mem.lookup_message([{"id": "call_new"}]) == "new"
+ assert len(mem.by_block) == 1
+
+
+def test_tool_memory_lookup_message_requires_same_raw_text():
+ mem = ToolMemory(max_entries=8, max_bytes=4096)
+ mem.remember(["call_a"], "a")
+ mem.remember(["call_b"], "b")
+
+ assert mem.lookup_message([{"id": "call_a"}, {"id": "call_b"}]) is None
diff --git a/dflash/scripts/tool_memory.py b/dflash/scripts/tool_memory.py
new file mode 100644
index 00000000..58bb4a83
--- /dev/null
+++ b/dflash/scripts/tool_memory.py
@@ -0,0 +1,124 @@
+from __future__ import annotations
+
+from collections import OrderedDict
+from dataclasses import dataclass
+from typing import Any, Iterable, Sequence
+
+
+@dataclass
+class _ToolMemoryBlock:
+ raw_text: str
+ size_bytes: int
+ refs: int = 0
+
+
+class ToolMemory:
+ """Exact assistant-text memory for tool-calling turns.
+
+ The server receives assistant tool calls back as structured JSON on later
+ turns. Re-rendering those objects can change key order, spacing, or wrapper
+ shape, which changes prompt tokenization and breaks prefix/KV reuse. This
+ store remembers the exact assistant text that originally produced a set of
+ tool-call IDs so replay can inject the original text verbatim.
+ """
+
+ def __init__(self, *, max_entries: int = 50_000, max_bytes: int = 64 * 1024 * 1024):
+ self.max_entries = max(0, int(max_entries))
+ self.max_bytes = max(0, int(max_bytes))
+ self.by_id: dict[str, _ToolMemoryBlock] = {}
+ self.by_block: dict[str, _ToolMemoryBlock] = {}
+ self._lru: OrderedDict[str, None] = OrderedDict()
+ self.total_bytes = 0
+
+ @property
+ def disabled(self) -> bool:
+ return self.max_entries == 0 or self.max_bytes == 0
+
+ def remember(self, call_ids: Iterable[str], raw_text: str) -> None:
+ if self.disabled or not raw_text:
+ return
+ unique_ids = []
+ seen: set[str] = set()
+ for call_id in call_ids:
+ if not isinstance(call_id, str) or not call_id or call_id in seen:
+ continue
+ seen.add(call_id)
+ unique_ids.append(call_id)
+ if not unique_ids:
+ return
+
+ block = self.by_block.get(raw_text)
+ if block is None:
+ block = _ToolMemoryBlock(
+ raw_text=raw_text,
+ size_bytes=len(raw_text.encode("utf-8")),
+ )
+ self.by_block[raw_text] = block
+ self.total_bytes += block.size_bytes
+
+ for call_id in unique_ids:
+ current = self.by_id.get(call_id)
+ if current is block:
+ self._touch(call_id)
+ continue
+ if current is not None:
+ self._drop_entry(call_id, current)
+ self.by_id[call_id] = block
+ block.refs += 1
+ self._touch(call_id)
+
+ self._prune()
+
+ def lookup_message(self, tool_calls: Sequence[Any]) -> str | None:
+ raw_text: str | None = None
+ touched: list[str] = []
+ for item in tool_calls:
+ call_id = self._extract_call_id(item)
+ if not call_id:
+ return None
+ block = self.by_id.get(call_id)
+ if block is None:
+ return None
+ touched.append(call_id)
+ if raw_text is None:
+ raw_text = block.raw_text
+ elif raw_text != block.raw_text:
+ return None
+ if raw_text is None:
+ return None
+ for call_id in touched:
+ self._touch(call_id)
+ return raw_text
+
+ def _touch(self, call_id: str) -> None:
+ self._lru[call_id] = None
+ self._lru.move_to_end(call_id)
+
+ def _prune(self) -> None:
+ while self.by_id and (
+ (self.max_entries > 0 and len(self.by_id) > self.max_entries)
+ or (self.max_bytes > 0 and self.total_bytes > self.max_bytes)
+ ):
+ oldest_id, _ = self._lru.popitem(last=False)
+ block = self.by_id.get(oldest_id)
+ if block is not None:
+ self._drop_entry(oldest_id, block)
+
+ def _drop_entry(self, call_id: str, block: _ToolMemoryBlock) -> None:
+ self.by_id.pop(call_id, None)
+ self._lru.pop(call_id, None)
+ if block.refs > 0:
+ block.refs -= 1
+ if block.refs == 0:
+ self.by_block.pop(block.raw_text, None)
+ self.total_bytes -= block.size_bytes
+ if self.total_bytes < 0:
+ self.total_bytes = 0
+
+ @staticmethod
+ def _extract_call_id(item: Any) -> str | None:
+ if isinstance(item, dict):
+ call_id = item.get("id")
+ return call_id if isinstance(call_id, str) and call_id else None
+ call_id = getattr(item, "id", None)
+ return call_id if isinstance(call_id, str) and call_id else None