diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py index 9e00d3d9..0d05f45a 100644 --- a/dflash/scripts/server.py +++ b/dflash/scripts/server.py @@ -41,6 +41,7 @@ compress_text_via_daemon, _drain_until_sentinel, ) from prefix_cache import DaemonStdoutBus, PrefixCache +from tool_memory import ToolMemory class OpenAICompatError(Exception): @@ -696,7 +697,22 @@ def _resolve_kv_k_type(): cap=prefix_cache_slots, ) if prefill_cfg is not None and prefill_cache_slots > 0: - prefix_cache.init_full_cache(prefill_cache_slots, budget_bytes=prefill_cache_bytes) + prefix_cache.init_full_cache(prefill_cache_slots) + tool_memory = ToolMemory( + max_entries=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_ENTRIES", "50000")), + max_bytes=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_BYTES", str(64 * 1024 * 1024))), + ) + + def _remember_tool_call_text(raw_text: str, tool_calls: list[dict] | None) -> None: + if not raw_text or not tool_calls: + return + call_ids = [ + tc.get("id") + for tc in tool_calls + if isinstance(tc, dict) and isinstance(tc.get("id"), str) and tc.get("id") + ] + if call_ids: + tool_memory.remember(call_ids, raw_text) @app.on_event("startup") async def _startup(): @@ -801,13 +817,18 @@ def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], boo msgs: list[dict] = [] for m in req.messages: d: dict = {"role": m.role} - if m.content is not None: + replay_raw_text = None + if m.role == "assistant" and m.tool_calls is not None: + replay_raw_text = tool_memory.lookup_message(m.tool_calls) + if replay_raw_text is not None: + d["content"] = replay_raw_text + elif m.content is not None: d["content"] = _content_to_str(m.content) if m.name is not None: d["name"] = m.name if m.tool_call_id is not None: d["tool_call_id"] = m.tool_call_id - if m.tool_calls is not None: + if m.tool_calls is not None and replay_raw_text is None: d["tool_calls"] = [] for tc in m.tool_calls: args = tc.function.arguments @@ -1173,6 +1194,8 @@ def chunk(delta_obj, finish=None): mode = "reasoning" if started_in_thinking else "content" window = "" tool_buffer = "" + accumulated_content = "" + accumulated_raw_text = "" stops = normalize_stop(req.stop) tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG)) stop_holdback = max((len(s) for s in stops), default=0) @@ -1189,6 +1212,7 @@ def emit_delta(text, kind): async for tok_id in _astream_tokens(r_pipe, gen_len, timing): completion_tokens += 1 piece = tokenizer.decode([tok_id]) + accumulated_raw_text += piece window += piece if stops and mode != "tool_buffer": @@ -1197,6 +1221,8 @@ def emit_delta(text, kind): window = window[:si] stop_hit = True kind = "reasoning_content" if mode == "reasoning" else "content" + if mode == "content": + accumulated_content += window out = emit_delta(window, kind) if out: yield out window = "" @@ -1236,6 +1262,7 @@ def emit_delta(text, kind): hits.sort() idx, which = hits[0] pre = window[:idx] + accumulated_content += pre out = emit_delta(pre, "content") if out: yield out if which == "think": @@ -1250,6 +1277,7 @@ def emit_delta(text, kind): continue if len(window) > HOLDBACK: safe = window[:-HOLDBACK] + accumulated_content += safe out = emit_delta(safe, "content") if out: yield out window = window[-HOLDBACK:] @@ -1277,6 +1305,7 @@ def emit_delta(text, kind): out = emit_delta(window, "reasoning_content") if out: yield out elif mode == "content" and window: + accumulated_content += window out = emit_delta(window, "content") if out: yield out elif mode == "tool_buffer": @@ -1287,6 +1316,7 @@ def emit_delta(text, kind): if mode == "tool_buffer": cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=req.tools) if tool_calls: + _remember_tool_call_text(accumulated_raw_text, tool_calls) if cleaned_after: out = emit_delta(cleaned_after, "content") if out: yield out @@ -1412,6 +1442,7 @@ def emit_delta(text, kind): if req.chat_template_kwargs: thinking_enabled = req.chat_template_kwargs.get("enable_thinking", True) cleaned, tool_calls = parse_tool_calls(text, tools=req.tools) + _remember_tool_call_text(text, tool_calls) cleaned, reasoning = parse_reasoning( cleaned, thinking_enabled=thinking_enabled, @@ -1909,6 +1940,7 @@ async def _responses_non_stream( if chat_req.chat_template_kwargs: thinking_enabled = chat_req.chat_template_kwargs.get("enable_thinking", True) cleaned, tool_calls = parse_tool_calls(text, tools=chat_req.tools) + _remember_tool_call_text(text, tool_calls) cleaned, reasoning = parse_reasoning( cleaned, thinking_enabled=thinking_enabled, started_in_thinking=started_in_thinking) @@ -2045,6 +2077,7 @@ async def sse() -> AsyncIterator[str]: window = "" tool_buffer = "" accumulated_text = "" + accumulated_raw_text = "" tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG)) HOLDBACK = tag_holdback completion_tokens = 0 @@ -2054,6 +2087,7 @@ async def sse() -> AsyncIterator[str]: async for tok_id in _astream_tokens(r_pipe, gen_len, timing): completion_tokens += 1 piece = tokenizer.decode([tok_id]) + accumulated_raw_text += piece window += piece while True: @@ -2141,6 +2175,7 @@ async def sse() -> AsyncIterator[str]: if mode == "tool_buffer" and tool_buffer: cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=chat_req.tools) if tool_calls: + _remember_tool_call_text(accumulated_raw_text, tool_calls) if cleaned_after: accumulated_text += cleaned_after for tc in tool_calls: diff --git a/dflash/scripts/test_server.py b/dflash/scripts/test_server.py index 7bdfc569..564976eb 100644 --- a/dflash/scripts/test_server.py +++ b/dflash/scripts/test_server.py @@ -385,6 +385,51 @@ def test_chat_completions_non_streaming_with_tool_call(mock_os_read, mock_pipe, assert tc[0]["function"]["name"] == "read_file" +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_completions_replays_raw_tool_call_text(mock_os_read, mock_pipe, + mock_tokenizer, app): + mock_pipe.return_value = (1, 2) + raw_tool_text = ( + "Before\n" + "" + "test.py" + "\n" + "After" + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("" not in text assert '"content":"8"' in text or '"content": "8"' in text +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_completions_streaming_replays_exact_raw_text_with_reasoning( + mock_os_read, mock_pipe, mock_tokenizer, app): + mock_pipe.return_value = (1, 2) + raw_tool_turn = ( + "private chain" + "visible" + "" + "x.py" + "" + ) + mock_tokenizer.decode.side_effect = [ + "private chain", + "", + "visible", + "x.py", + "followup", + ] + mock_os_read.side_effect = [ + struct.pack("' + 'file.txt' + '' + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("old") + mem.remember(["call_new"], "new") + + assert mem.lookup_message([{"id": "call_old"}]) is None + assert mem.lookup_message([{"id": "call_new"}]) == "new" + assert len(mem.by_block) == 1 + + +def test_tool_memory_lookup_message_requires_same_raw_text(): + mem = ToolMemory(max_entries=8, max_bytes=4096) + mem.remember(["call_a"], "a") + mem.remember(["call_b"], "b") + + assert mem.lookup_message([{"id": "call_a"}, {"id": "call_b"}]) is None diff --git a/dflash/scripts/tool_memory.py b/dflash/scripts/tool_memory.py new file mode 100644 index 00000000..58bb4a83 --- /dev/null +++ b/dflash/scripts/tool_memory.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Iterable, Sequence + + +@dataclass +class _ToolMemoryBlock: + raw_text: str + size_bytes: int + refs: int = 0 + + +class ToolMemory: + """Exact assistant-text memory for tool-calling turns. + + The server receives assistant tool calls back as structured JSON on later + turns. Re-rendering those objects can change key order, spacing, or wrapper + shape, which changes prompt tokenization and breaks prefix/KV reuse. This + store remembers the exact assistant text that originally produced a set of + tool-call IDs so replay can inject the original text verbatim. + """ + + def __init__(self, *, max_entries: int = 50_000, max_bytes: int = 64 * 1024 * 1024): + self.max_entries = max(0, int(max_entries)) + self.max_bytes = max(0, int(max_bytes)) + self.by_id: dict[str, _ToolMemoryBlock] = {} + self.by_block: dict[str, _ToolMemoryBlock] = {} + self._lru: OrderedDict[str, None] = OrderedDict() + self.total_bytes = 0 + + @property + def disabled(self) -> bool: + return self.max_entries == 0 or self.max_bytes == 0 + + def remember(self, call_ids: Iterable[str], raw_text: str) -> None: + if self.disabled or not raw_text: + return + unique_ids = [] + seen: set[str] = set() + for call_id in call_ids: + if not isinstance(call_id, str) or not call_id or call_id in seen: + continue + seen.add(call_id) + unique_ids.append(call_id) + if not unique_ids: + return + + block = self.by_block.get(raw_text) + if block is None: + block = _ToolMemoryBlock( + raw_text=raw_text, + size_bytes=len(raw_text.encode("utf-8")), + ) + self.by_block[raw_text] = block + self.total_bytes += block.size_bytes + + for call_id in unique_ids: + current = self.by_id.get(call_id) + if current is block: + self._touch(call_id) + continue + if current is not None: + self._drop_entry(call_id, current) + self.by_id[call_id] = block + block.refs += 1 + self._touch(call_id) + + self._prune() + + def lookup_message(self, tool_calls: Sequence[Any]) -> str | None: + raw_text: str | None = None + touched: list[str] = [] + for item in tool_calls: + call_id = self._extract_call_id(item) + if not call_id: + return None + block = self.by_id.get(call_id) + if block is None: + return None + touched.append(call_id) + if raw_text is None: + raw_text = block.raw_text + elif raw_text != block.raw_text: + return None + if raw_text is None: + return None + for call_id in touched: + self._touch(call_id) + return raw_text + + def _touch(self, call_id: str) -> None: + self._lru[call_id] = None + self._lru.move_to_end(call_id) + + def _prune(self) -> None: + while self.by_id and ( + (self.max_entries > 0 and len(self.by_id) > self.max_entries) + or (self.max_bytes > 0 and self.total_bytes > self.max_bytes) + ): + oldest_id, _ = self._lru.popitem(last=False) + block = self.by_id.get(oldest_id) + if block is not None: + self._drop_entry(oldest_id, block) + + def _drop_entry(self, call_id: str, block: _ToolMemoryBlock) -> None: + self.by_id.pop(call_id, None) + self._lru.pop(call_id, None) + if block.refs > 0: + block.refs -= 1 + if block.refs == 0: + self.by_block.pop(block.raw_text, None) + self.total_bytes -= block.size_bytes + if self.total_bytes < 0: + self.total_bytes = 0 + + @staticmethod + def _extract_call_id(item: Any) -> str | None: + if isinstance(item, dict): + call_id = item.get("id") + return call_id if isinstance(call_id, str) and call_id else None + call_id = getattr(item, "id", None) + return call_id if isinstance(call_id, str) and call_id else None