Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions dflash/scripts/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
compress_text_via_daemon, _drain_until_sentinel,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Process-wide ToolMemory can replay raw assistant text from unrelated requests when tool-call IDs are reused, causing cross-request data mixing/information disclosure.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At dflash/scripts/server.py, line 676:

<comment>Process-wide ToolMemory can replay raw assistant text from unrelated requests when tool-call IDs are reused, causing cross-request data mixing/information disclosure.</comment>

<file context>
@@ -672,6 +673,21 @@ def _resolve_kv_k_type():
     )
     if prefill_cfg is not None and prefill_cache_slots > 0:
         prefix_cache.init_full_cache(prefill_cache_slots)
+    tool_memory = ToolMemory(
+        max_entries=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_ENTRIES", "50000")),
+        max_bytes=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_BYTES", str(64 * 1024 * 1024))),
</file context>

)
from prefix_cache import DaemonStdoutBus, PrefixCache
from tool_memory import ToolMemory


class OpenAICompatError(Exception):
Expand Down Expand Up @@ -672,6 +673,21 @@ def _resolve_kv_k_type():
)
if prefill_cfg is not None and prefill_cache_slots > 0:
prefix_cache.init_full_cache(prefill_cache_slots)
tool_memory = ToolMemory(
max_entries=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_ENTRIES", "50000")),
max_bytes=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_BYTES", str(64 * 1024 * 1024))),
)

def _remember_tool_call_text(raw_text: str, tool_calls: list[dict] | None) -> None:
if not raw_text or not tool_calls:
return
call_ids = [
tc.get("id")
for tc in tool_calls
if isinstance(tc, dict) and isinstance(tc.get("id"), str) and tc.get("id")
]
if call_ids:
tool_memory.remember(call_ids, raw_text)

@app.on_event("startup")
async def _startup():
Expand Down Expand Up @@ -771,13 +787,18 @@ def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], boo
msgs: list[dict] = []
for m in req.messages:
d: dict = {"role": m.role}
if m.content is not None:
replay_raw_text = None
if m.role == "assistant" and m.tool_calls is not None:
replay_raw_text = tool_memory.lookup_message(m.tool_calls)
if replay_raw_text is not None:
d["content"] = replay_raw_text
elif m.content is not None:
d["content"] = _content_to_str(m.content)
if m.name is not None:
d["name"] = m.name
if m.tool_call_id is not None:
d["tool_call_id"] = m.tool_call_id
if m.tool_calls is not None:
if m.tool_calls is not None and replay_raw_text is None:
d["tool_calls"] = []
for tc in m.tool_calls:
args = tc.function.arguments
Expand Down Expand Up @@ -1116,6 +1137,8 @@ def chunk(delta_obj, finish=None):
mode = "reasoning" if started_in_thinking else "content"
window = ""
tool_buffer = ""
accumulated_content = ""
accumulated_raw_text = ""
stops = normalize_stop(req.stop)
tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG))
stop_holdback = max((len(s) for s in stops), default=0)
Expand All @@ -1132,6 +1155,7 @@ def emit_delta(text, kind):
async for tok_id in _astream_tokens(r_pipe, gen_len, timing):
completion_tokens += 1
piece = tokenizer.decode([tok_id])
accumulated_raw_text += piece
window += piece

if stops and mode != "tool_buffer":
Expand All @@ -1140,6 +1164,8 @@ def emit_delta(text, kind):
window = window[:si]
stop_hit = True
kind = "reasoning_content" if mode == "reasoning" else "content"
if mode == "content":
accumulated_content += window
out = emit_delta(window, kind)
if out: yield out
window = ""
Expand Down Expand Up @@ -1176,6 +1202,7 @@ def emit_delta(text, kind):
hits.sort()
idx, which = hits[0]
pre = window[:idx]
accumulated_content += pre
out = emit_delta(pre, "content")
if out: yield out
if which == "think":
Expand All @@ -1188,6 +1215,7 @@ def emit_delta(text, kind):
continue
if len(window) > HOLDBACK:
safe = window[:-HOLDBACK]
accumulated_content += safe
out = emit_delta(safe, "content")
if out: yield out
window = window[-HOLDBACK:]
Expand Down Expand Up @@ -1215,6 +1243,7 @@ def emit_delta(text, kind):
out = emit_delta(window, "reasoning_content")
if out: yield out
elif mode == "content" and window:
accumulated_content += window
out = emit_delta(window, "content")
if out: yield out
elif mode == "tool_buffer":
Expand All @@ -1225,6 +1254,7 @@ def emit_delta(text, kind):
if mode == "tool_buffer":
cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=req.tools)
if tool_calls:
_remember_tool_call_text(accumulated_raw_text, tool_calls)
if cleaned_after:
out = emit_delta(cleaned_after, "content")
if out: yield out
Expand Down Expand Up @@ -1350,6 +1380,7 @@ def emit_delta(text, kind):
if req.chat_template_kwargs:
thinking_enabled = req.chat_template_kwargs.get("enable_thinking", True)
cleaned, tool_calls = parse_tool_calls(text, tools=req.tools)
_remember_tool_call_text(text, tool_calls)
cleaned, reasoning = parse_reasoning(
cleaned,
thinking_enabled=thinking_enabled,
Expand Down Expand Up @@ -1847,6 +1878,7 @@ async def _responses_non_stream(
if chat_req.chat_template_kwargs:
thinking_enabled = chat_req.chat_template_kwargs.get("enable_thinking", True)
cleaned, tool_calls = parse_tool_calls(text, tools=chat_req.tools)
_remember_tool_call_text(text, tool_calls)
cleaned, reasoning = parse_reasoning(
cleaned, thinking_enabled=thinking_enabled,
started_in_thinking=started_in_thinking)
Expand Down Expand Up @@ -1983,6 +2015,7 @@ async def sse() -> AsyncIterator[str]:
window = ""
tool_buffer = ""
accumulated_text = ""
accumulated_raw_text = ""
tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG))
HOLDBACK = tag_holdback
completion_tokens = 0
Expand All @@ -1992,6 +2025,7 @@ async def sse() -> AsyncIterator[str]:
async for tok_id in _astream_tokens(r_pipe, gen_len, timing):
completion_tokens += 1
piece = tokenizer.decode([tok_id])
accumulated_raw_text += piece
window += piece

while True:
Expand Down Expand Up @@ -2074,6 +2108,7 @@ async def sse() -> AsyncIterator[str]:
if mode == "tool_buffer" and tool_buffer:
cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=chat_req.tools)
if tool_calls:
_remember_tool_call_text(accumulated_raw_text, tool_calls)
if cleaned_after:
accumulated_text += cleaned_after
for tc in tool_calls:
Expand Down
153 changes: 152 additions & 1 deletion dflash/scripts/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,51 @@ def test_chat_completions_non_streaming_with_tool_call(mock_os_read, mock_pipe,
assert tc[0]["function"]["name"] == "read_file"


@patch("server.os.pipe")
@patch("server.os.read")
def test_chat_completions_replays_raw_tool_call_text(mock_os_read, mock_pipe,
mock_tokenizer, app):
mock_pipe.return_value = (1, 2)
raw_tool_text = (
"Before\n"
"<tool_call>"
"<function=read_file><parameter=path>test.py</parameter></function>"
"</tool_call>\n"
"After"
)
mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"]
mock_os_read.side_effect = [
struct.pack("<i", 10), struct.pack("<i", -1),
struct.pack("<i", 11), struct.pack("<i", -1),
]

client = TestClient(app)
first = client.post("/v1/chat/completions", json={
"model": MODEL_NAME,
"messages": [{"role": "user", "content": "read test.py"}],
"stream": False,
})
assert first.status_code == 200
assistant_msg = first.json()["choices"][0]["message"]

second = client.post("/v1/chat/completions", json={
"model": MODEL_NAME,
"messages": [
{"role": "user", "content": "read test.py"},
assistant_msg,
{"role": "tool", "tool_call_id": assistant_msg["tool_calls"][0]["id"], "content": "file body"},
{"role": "user", "content": "what next?"},
],
"stream": False,
})
assert second.status_code == 200

msgs = mock_tokenizer.apply_chat_template.call_args_list[-1][0][0]
assistant = next(m for m in msgs if m["role"] == "assistant")
assert assistant["content"] == raw_tool_text
assert "tool_calls" not in assistant


@patch("server.os.pipe")
@patch("server.os.read")
def test_zero_token_prompt_is_rejected_before_daemon(
Expand Down Expand Up @@ -430,6 +475,71 @@ def test_chat_completions_streaming(mock_os_read, mock_pipe, mock_tokenizer, app
assert all(c["object"] == "chat.completion.chunk" for c in chunks)


@patch("server.os.pipe")
@patch("server.os.read")
def test_chat_completions_streaming_replays_exact_raw_text_with_reasoning(
mock_os_read, mock_pipe, mock_tokenizer, app):
mock_pipe.return_value = (1, 2)
raw_tool_turn = (
"<think>private chain</think>"
"visible"
"<tool_call>"
"<function=read_file><parameter=path>x.py</parameter></function>"
"</tool_call>"
)
mock_tokenizer.decode.side_effect = [
"<think>private chain",
"</think>",
"visible",
"<tool_call><function=read_file><parameter=path>x.py</parameter></function></tool_call>",
"followup",
]
mock_os_read.side_effect = [
struct.pack("<i", 10), struct.pack("<i", 11),
struct.pack("<i", 12), struct.pack("<i", 13), struct.pack("<i", -1),
struct.pack("<i", 14), struct.pack("<i", -1),
]

client = TestClient(app)
first = client.post("/v1/chat/completions", json={
"model": MODEL_NAME,
"messages": [{"role": "user", "content": "read x.py"}],
"stream": True,
})
assert first.status_code == 200
chunks = [
json.loads(line[6:])
for line in first.text.strip().split("\n\n")
if line.startswith("data: ") and line != "data: [DONE]"
]
tool_delta = next(
c["choices"][0]["delta"]["tool_calls"][0]
for c in chunks
if c["choices"][0]["delta"].get("tool_calls")
)

second = client.post("/v1/chat/completions", json={
"model": MODEL_NAME,
"messages": [
{"role": "user", "content": "read x.py"},
{"role": "assistant", "tool_calls": [{
"id": tool_delta["id"],
"type": "function",
"function": tool_delta["function"],
}]},
{"role": "tool", "tool_call_id": tool_delta["id"], "content": "file body"},
{"role": "user", "content": "what next?"},
],
"stream": False,
})
assert second.status_code == 200

msgs = mock_tokenizer.apply_chat_template.call_args_list[-1][0][0]
assistant = next(m for m in msgs if m["role"] == "assistant")
assert assistant["content"] == raw_tool_turn
assert "tool_calls" not in assistant


# ─── POST /v1/responses ───────────────────────────────────────────

@patch("server.os.pipe")
Expand Down Expand Up @@ -619,7 +729,7 @@ def test_responses_object_tool_choice(mock_os_read, mock_pipe,
@patch("server.os.pipe")
@patch("server.os.read")
def test_responses_function_call_output(mock_os_read, mock_pipe,
mock_tokenizer, app):
mock_tokenizer, app):
"""Responses API maps function_call + function_call_output items."""
mock_pipe.return_value = (1, 2)
mock_os_read.side_effect = [struct.pack("<i", 10), struct.pack("<i", -1)]
Expand Down Expand Up @@ -649,6 +759,47 @@ def test_responses_function_call_output(mock_os_read, mock_pipe,
assert "tool" in roles


@patch("server.os.pipe")
@patch("server.os.read")
def test_responses_replay_raw_tool_call_text(mock_os_read, mock_pipe,
mock_tokenizer, app):
mock_pipe.return_value = (1, 2)
raw_tool_text = (
'<tool_call>'
'<function=read_file><parameter=path>file.txt</parameter></function>'
'</tool_call>'
)
mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"]
mock_os_read.side_effect = [
struct.pack("<i", 10), struct.pack("<i", -1),
struct.pack("<i", 11), struct.pack("<i", -1),
]

client = TestClient(app)
first = client.post("/v1/responses", json={
"model": MODEL_NAME,
"input": [{"type": "message", "role": "user", "content": "read file.txt"}],
})
assert first.status_code == 200
first_output = first.json()["output"][0]

second = client.post("/v1/responses", json={
"model": MODEL_NAME,
"input": [
{"type": "message", "role": "user", "content": "read file.txt"},
first_output,
{"type": "function_call_output", "call_id": first_output["call_id"], "output": "file body"},
{"type": "message", "role": "user", "content": "what next?"},
],
})
assert second.status_code == 200

msgs = mock_tokenizer.apply_chat_template.call_args_list[-1][0][0]
assistant = next(m for m in msgs if m["role"] == "assistant")
assert assistant["content"] == raw_tool_text
assert "tool_calls" not in assistant


@patch("server.os.pipe")
@patch("server.os.read")
def test_responses_developer_role_mapped_to_system(mock_os_read, mock_pipe,
Expand Down
29 changes: 29 additions & 0 deletions dflash/scripts/test_tool_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from tool_memory import ToolMemory


def test_tool_memory_remembers_shared_raw_text_for_multiple_ids():
mem = ToolMemory(max_entries=8, max_bytes=4096)
raw = '<tool_call><function=read_file><parameter=path>a.py</parameter></function></tool_call>'
mem.remember(["call_a", "call_b"], raw)

assert mem.lookup_message([{"id": "call_a"}, {"id": "call_b"}]) == raw
assert len(mem.by_id) == 2
assert len(mem.by_block) == 1


def test_tool_memory_eviction_drops_oldest_entry_and_unique_block():
mem = ToolMemory(max_entries=1, max_bytes=4096)
mem.remember(["call_old"], "<tool_call>old</tool_call>")
mem.remember(["call_new"], "<tool_call>new</tool_call>")

assert mem.lookup_message([{"id": "call_old"}]) is None
assert mem.lookup_message([{"id": "call_new"}]) == "<tool_call>new</tool_call>"
assert len(mem.by_block) == 1


def test_tool_memory_lookup_message_requires_same_raw_text():
mem = ToolMemory(max_entries=8, max_bytes=4096)
mem.remember(["call_a"], "<tool_call>a</tool_call>")
mem.remember(["call_b"], "<tool_call>b</tool_call>")

assert mem.lookup_message([{"id": "call_a"}, {"id": "call_b"}]) is None
Loading