diff --git a/dflash/DEVELOPER.md b/dflash/DEVELOPER.md index 8f948899..1c9460a6 100644 --- a/dflash/DEVELOPER.md +++ b/dflash/DEVELOPER.md @@ -227,7 +227,7 @@ dflash/ │ └── draft/model.safetensors ├── scripts/ │ ├── server.py # Main OpenAI/Codex server -│ ├── server_tools.py # Legacy fork with tool calling (deprecated) +│ ├── server_tools.py # Legacy fork kept for reference; server.py is the tool/Codex path │ ├── prefix_cache.py # LRU prefix cache │ ├── _prefill_hook.py # Speculative prefill compression │ ├── run.py # CLI text generation diff --git a/dflash/README.md b/dflash/README.md index 63c6122f..b73fde56 100644 --- a/dflash/README.md +++ b/dflash/README.md @@ -119,7 +119,7 @@ allows capacity checks where the draft and a target layer range share one GPU before serving integration. `--target-split-dflash` runs the same split target placement through a chain DFlash decode loop and reports acceptance length. -**Python flags on `scripts/run.py`, `scripts/server.py`, `scripts/server_tools.py`:** +**Python flags on `scripts/run.py` and `scripts/server.py` (`scripts/server_tools.py` is legacy):** ```bash python3 scripts/run.py --ctk q8_0 --ctv q4_0 --prompt "hello" python3 scripts/run.py --cache-type-k q8_0 --cache-type-v q4_0 --prompt "hello" diff --git a/dflash/scripts/server.py b/dflash/scripts/server.py index 9e00d3d9..d83db8eb 100644 --- a/dflash/scripts/server.py +++ b/dflash/scripts/server.py @@ -24,6 +24,7 @@ import tempfile import time import uuid +from dataclasses import dataclass, field from pathlib import Path from typing import Any, AsyncIterator @@ -41,6 +42,7 @@ compress_text_via_daemon, _drain_until_sentinel, ) from prefix_cache import DaemonStdoutBus, PrefixCache +from tool_memory import ToolMemory class OpenAICompatError(Exception): @@ -179,6 +181,8 @@ def _tokenizer_id_from_gguf(gguf_path: Path) -> str: r"", re.DOTALL) TOOL_CODE_RE = re.compile(r"(.*?)", re.DOTALL) TOOL_OPEN_TAG = "" +FUNCTION_OPEN_TAG = " dict: @@ -572,6 +576,850 @@ class ResponsesCreateRequest(BaseModel): previous_response_id: str | None = None +@dataclass(frozen=True) +class ToolPolicy: + prompt_tools: list[Any] | None + parse_tools: list[Any] | None + choice_kind: str + choice_name: str | None = None + render_tools: bool = False + parse_tool_calls: bool = False + bypass_compression: bool = False + + +@dataclass +class _PartialToolCallSnapshot: + name: str + arguments: str + complete: bool = False + + +@dataclass +class _TrackedToolCallState: + index: int + id: str + name: str | None = None + emitted_arguments: str = "" + announced: bool = False + done: bool = False + + +@dataclass +class _SharedStreamState: + mode: str + holdback: int + allow_tools: bool + window: str = "" + tool_buffer: str = "" + raw_text: str = "" + visible_text: str = "" + stop_hit: bool = False + tool_states: list[_TrackedToolCallState] = field(default_factory=list) + final_tool_calls: list[dict] = field(default_factory=list) + + +def _trim_stream_raw_suffix(raw_text: str, suffix_len: int, keep_len: int) -> str: + trim_len = max(0, suffix_len - keep_len) + if trim_len <= 0: + return raw_text + return raw_text[:-trim_len] if trim_len < len(raw_text) else "" + + +def _tool_name(tool: Any) -> str | None: + fn = None + if hasattr(tool, "function"): + fn = tool.function + elif isinstance(tool, dict): + fn = tool.get("function") + if hasattr(fn, "model_dump"): + fn = fn.model_dump() + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str) and name: + return name + return None + + +def _normalize_tool_choice(tool_choice: Any) -> tuple[str, str | None]: + if tool_choice is None: + return "auto", None + if isinstance(tool_choice, str): + choice = tool_choice.strip().lower() + if choice in {"auto", "none", "required"}: + return choice, None + if isinstance(tool_choice, dict): + choice_type = tool_choice.get("type") + if isinstance(choice_type, str): + choice_type = choice_type.lower() + if choice_type in {"auto", "none", "required"}: + return choice_type, None + name = tool_choice.get("name") + if not isinstance(name, str): + fn = tool_choice.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if (choice_type == "function" or "function" in tool_choice or "name" in tool_choice) and isinstance(name, str) and name: + return "function", name + raise OpenAICompatError( + "Unsupported tool_choice value", + param="tool_choice", + ) + + +def _resolve_tool_policy(tools: list[Any] | None, tool_choice: Any) -> ToolPolicy: + requested_tools = list(tools or []) + has_tools = bool(requested_tools) + choice_kind, choice_name = _normalize_tool_choice(tool_choice) + + if choice_kind == "function": + if not has_tools: + raise OpenAICompatError( + "tool_choice=function requires tools", + param="tool_choice", + ) + selected = next((t for t in requested_tools if _tool_name(t) == choice_name), None) + if selected is None: + raise OpenAICompatError( + f"tool_choice function '{choice_name}' was not provided in tools", + param="tool_choice", + ) + return ToolPolicy( + prompt_tools=[selected], + parse_tools=requested_tools, + choice_kind=choice_kind, + choice_name=choice_name, + render_tools=True, + parse_tool_calls=True, + bypass_compression=True, + ) + + if choice_kind == "required" and not has_tools: + raise OpenAICompatError( + "tool_choice='required' requires tools", + param="tool_choice", + ) + + if choice_kind == "none": + return ToolPolicy( + prompt_tools=None, + parse_tools=None, + choice_kind=choice_kind, + render_tools=False, + parse_tool_calls=False, + bypass_compression=has_tools, + ) + + if has_tools: + return ToolPolicy( + prompt_tools=requested_tools, + parse_tools=requested_tools, + choice_kind=choice_kind, + render_tools=True, + parse_tool_calls=True, + bypass_compression=True, + ) + + return ToolPolicy( + prompt_tools=None, + parse_tools=None, + choice_kind=choice_kind, + render_tools=False, + parse_tool_calls=True, + bypass_compression=False, + ) + + +def _parse_generated_tool_calls(text: str, tool_policy: ToolPolicy) -> tuple[str, list[dict]]: + if not tool_policy.parse_tool_calls: + return text, [] + return parse_tool_calls(text, tools=tool_policy.parse_tools) + + +def _tool_choice_violation(tool_policy: ToolPolicy, tool_calls: list[dict]) -> OpenAICompatError | None: + if tool_policy.choice_kind in {"auto", "none"}: + return None + if tool_policy.choice_kind == "required": + if tool_calls: + return None + return OpenAICompatError( + "tool_choice='required' requires the model to emit a tool call", + param="tool_choice", + ) + if tool_policy.choice_kind == "function": + disallowed = [ + tc.get("function", {}).get("name") + for tc in tool_calls + if tc.get("function", {}).get("name") != tool_policy.choice_name + ] + disallowed = [name for name in disallowed if isinstance(name, str) and name] + if disallowed: + leaked = ", ".join(sorted(set(disallowed))) + return OpenAICompatError( + f"tool_choice function '{tool_policy.choice_name}' does not allow other tool calls ({leaked})", + param="tool_choice", + ) + if any(tc.get("function", {}).get("name") == tool_policy.choice_name for tc in tool_calls): + return None + return OpenAICompatError( + f"tool_choice function '{tool_policy.choice_name}' was not satisfied by model output", + param="tool_choice", + ) + return None + + +def _tool_param_type(tools, function_name: str, param_name: str) -> str: + params = _find_tool_properties(tools, function_name) + cfg = params.get(param_name, {}) if isinstance(params, dict) else {} + if isinstance(cfg, dict): + if isinstance(cfg.get("type"), str): + return cfg["type"].strip().lower() + if "anyOf" in cfg: + return "object" + return "string" + + +def _json_string_fragment(value: str) -> str: + return json.dumps(value, ensure_ascii=False)[1:-1] + + +def _find_toolish_json_start(text: str, start: int = 0) -> int: + """Return the earliest `{` that could begin a supported JSON tool call.""" + keys = ("name", "function", "arguments") + cursor = start + while cursor < len(text): + brace = text.find("{", cursor) + if brace == -1: + return -1 + key_cursor = brace + 1 + while key_cursor < len(text) and text[key_cursor].isspace(): + key_cursor += 1 + if key_cursor >= len(text): + return brace + if text[key_cursor] != '"': + cursor = brace + 1 + continue + key_start = key_cursor + 1 + key_end = key_start + while key_end < len(text) and text[key_end] not in {'"', '\\'}: + key_end += 1 + if key_end >= len(text): + fragment = text[key_start:] + if any(k.startswith(fragment) for k in keys): + return brace + cursor = brace + 1 + continue + if text[key_end] == "\\": + cursor = brace + 1 + continue + key = text[key_start:key_end] + if not any(k.startswith(key) for k in keys): + cursor = brace + 1 + continue + colon_cursor = key_end + 1 + while colon_cursor < len(text) and text[colon_cursor].isspace(): + colon_cursor += 1 + if colon_cursor >= len(text) or text[colon_cursor] == ":": + return brace + cursor = brace + 1 + return -1 + + +def _json_tool_call_spans(text: str) -> list[tuple[int, int]]: + spans: list[tuple[int, int]] = [] + decoder = json.JSONDecoder() + cursor = 0 + while cursor < len(text): + start = text.find("{", cursor) + if start == -1: + break + try: + obj, consumed = decoder.raw_decode(text[start:]) + except json.JSONDecodeError: + cursor = start + 1 + continue + if _parse_json_tool_call(obj) is not None: + spans.append((start, start + consumed)) + cursor = start + max(consumed, 1) + return spans + + +def _build_partial_xml_arguments(text: str, cursor: int, function_name: str, tools) -> tuple[str, bool, int]: + param_config = _find_tool_properties(tools, function_name) + fragments: list[str] = [] + complete = False + + while cursor < len(text): + while cursor < len(text) and text[cursor].isspace(): + cursor += 1 + if text.startswith("", cursor): + cursor += len("") + complete = True + break + if text.startswith("", cursor): + cursor += len("") + continue + if not text.startswith("", cursor) + if name_end == -1: + break + param_name = text[cursor:name_end].strip() + cursor = name_end + 1 + + end_tag = text.find("", cursor) + next_param = text.find("", cursor) + candidates = [(pos, kind) for pos, kind in ( + (end_tag, "parameter"), + (next_param, "next_param"), + (close_fn, "function"), + ) if pos != -1] + if candidates: + value_end, end_kind = min(candidates, key=lambda item: item[0]) + else: + value_end, end_kind = len(text), "eof" + raw_value = text[cursor:value_end] + if raw_value.startswith("\n"): + raw_value = raw_value[1:] + if end_kind == "parameter" and raw_value.endswith("\n"): + raw_value = raw_value[:-1] + + if fragments: + fragments.append(",") + else: + fragments.append("{") + fragments.append(json.dumps(param_name, ensure_ascii=False)) + fragments.append(":") + + if _tool_param_type(tools, function_name, param_name) == "string": + fragments.append('"') + fragments.append(_json_string_fragment(raw_value)) + if end_kind == "parameter": + fragments.append('"') + cursor = value_end + len("") + continue + break + + if end_kind == "parameter": + value_obj = _convert_param_value(raw_value, param_name, param_config, function_name) + fragments.append(json.dumps(value_obj, ensure_ascii=False, separators=(",", ":"))) + cursor = value_end + len("") + continue + + del fragments[prefix_len:] + break + + if complete: + if not fragments: + return "{}", True, cursor + fragments.append("}") + return "".join(fragments), complete, cursor + + +def _partial_tool_call_snapshots(text: str, tools=None) -> list[_PartialToolCallSnapshot]: + snapshots: list[_PartialToolCallSnapshot] = [] + decoder = json.JSONDecoder() + cursor = 0 + while cursor < len(text): + hits = [(i, kind) for i, kind in ( + (text.find(TOOL_OPEN_TAG, cursor), "tool"), + (text.find(FUNCTION_OPEN_TAG, cursor), "function"), + (text.find(TOOL_CODE_OPEN_TAG, cursor), "tool_code"), + (_find_toolish_json_start(text, cursor), "json"), + ) if i != -1] + if not hits: + break + start, kind = min(hits, key=lambda item: item[0]) + cursor = start + if kind == "tool_code": + close = text.find("", cursor + len(TOOL_CODE_OPEN_TAG)) + if close == -1: + break + try: + obj = json.loads(text[cursor + len(TOOL_CODE_OPEN_TAG):close].strip()) + except json.JSONDecodeError: + cursor = close + len("") + continue + parsed = _parse_json_tool_call(obj) + if parsed is not None and _tool_allowed(tools, parsed[0]): + snapshots.append(_PartialToolCallSnapshot( + name=parsed[0], + arguments=json.dumps(parsed[1], ensure_ascii=False, separators=(",", ":")), + complete=True, + )) + cursor = close + len("") + continue + if kind == "json": + try: + obj, consumed = decoder.raw_decode(text[cursor:]) + except json.JSONDecodeError: + break + parsed = _parse_json_tool_call(obj) + if parsed is not None and _tool_allowed(tools, parsed[0]): + snapshots.append(_PartialToolCallSnapshot( + name=parsed[0], + arguments=json.dumps(parsed[1], ensure_ascii=False, separators=(",", ":")), + complete=True, + )) + cursor += max(consumed, 1) + continue + if text.startswith(TOOL_OPEN_TAG, cursor): + cursor += len(TOOL_OPEN_TAG) + while cursor < len(text) and text[cursor].isspace(): + cursor += 1 + if not text.startswith("", fn_cursor) + signature_end = text.find("(", fn_cursor) + if signature_end != -1 and (header_end == -1 or signature_end < header_end): + close = text.find("", signature_end) + if close == -1: + break + _, calls = parse_tool_calls(text[start:close + len("")], tools=tools) + if calls: + snapshots.append(_PartialToolCallSnapshot( + name=calls[0]["function"]["name"], + arguments=calls[0]["function"]["arguments"], + complete=True, + )) + cursor = close + len("") + continue + if header_end == -1: + break + + function_name = text[fn_cursor:header_end].strip() + if not function_name: + cursor = start + 1 + continue + if not _tool_allowed(tools, function_name): + close = text.find("", header_end + 1) + cursor = close + len("") if close != -1 else header_end + 1 + continue + + arguments, complete, body_end = _build_partial_xml_arguments( + text, header_end + 1, function_name, tools) + snapshots.append(_PartialToolCallSnapshot( + name=function_name, + arguments=arguments, + complete=complete, + )) + cursor = max(body_end, header_end + 1) + if not complete: + break + tool_close = text.find("", cursor) + if tool_close != -1: + cursor = max(cursor, tool_close + len("")) + return snapshots + + +def _complete_tool_buffer_spans(text: str) -> list[tuple[int, int]]: + spans: list[tuple[int, int]] = [] + for match in TOOL_CALL_COMPLETE_RE.finditer(text): + spans.append((match.start(), match.end())) + for pattern in (BARE_FUNCTION_XML_RE, FUNCTION_SIGNATURE_RE): + for match in pattern.finditer(text): + if any(lo <= match.start() < hi for lo, hi in spans): + continue + spans.append((match.start(), match.end())) + for match in TOOL_CODE_RE.finditer(text): + try: + obj = json.loads(match.group(1).strip()) + except json.JSONDecodeError: + continue + if _parse_json_tool_call(obj) is None: + continue + if any(lo <= match.start() < hi for lo, hi in spans): + continue + spans.append((match.start(), match.end())) + for start, end in _json_tool_call_spans(text): + if any(lo <= start < hi for lo, hi in spans): + continue + spans.append((start, end)) + spans.sort() + return spans + + +def _first_tool_buffer_stop_match(text: str, stops: list[str]) -> int: + spans = _complete_tool_buffer_spans(text) + if not spans: + return -1 + + def stop_in_gap(start: int, end: int) -> int: + json_open = _find_toolish_json_start(text[start:end]) + if json_open != -1: + json_open += start + next_open = min( + (pos for pos in ( + text.find(TOOL_OPEN_TAG, start, end), + text.find(FUNCTION_OPEN_TAG, start, end), + text.find(TOOL_CODE_OPEN_TAG, start, end), + json_open, + ) if pos != -1), + default=-1, + ) + if next_open != -1: + end = next_open + if start >= end: + return -1 + gap_index = first_stop_match(text[start:end], stops) + return start + gap_index if gap_index != -1 else -1 + + cursor = 0 + for start, end in spans: + stop_index = stop_in_gap(cursor, start) + if stop_index != -1: + return stop_index + cursor = max(cursor, end) + return stop_in_gap(cursor, len(text)) + + +def _sync_partial_tool_stream(tool_buffer: str, tools, states: list[_TrackedToolCallState]) -> list[dict]: + events: list[dict] = [] + for index, snapshot in enumerate(_partial_tool_call_snapshots(tool_buffer, tools=tools)): + while len(states) <= index: + states.append(_TrackedToolCallState( + index=len(states), + id="call_" + uuid.uuid4().hex[:24], + )) + state = states[index] + if snapshot.name and state.name is None: + state.name = snapshot.name + if state.name and not state.announced: + state.announced = True + events.append({ + "kind": "tool_start", + "index": index, + "id": state.id, + "name": state.name, + }) + if snapshot.arguments: + fragment = snapshot.arguments + if snapshot.arguments.startswith(state.emitted_arguments): + fragment = snapshot.arguments[len(state.emitted_arguments):] + if fragment: + state.emitted_arguments = snapshot.arguments + events.append({ + "kind": "tool_args", + "index": index, + "id": state.id, + "name": state.name, + "arguments": fragment, + }) + if snapshot.complete and not state.done: + state.done = True + events.append({ + "kind": "tool_done", + "index": index, + "id": state.id, + "name": state.name, + "arguments": snapshot.arguments or "{}", + }) + return events + + +def _reconcile_final_tool_events(tool_calls: list[dict], states: list[_TrackedToolCallState]) -> list[dict]: + events: list[dict] = [] + for index, tc in enumerate(tool_calls): + while len(states) <= index: + states.append(_TrackedToolCallState( + index=len(states), + id="call_" + uuid.uuid4().hex[:24], + )) + state = states[index] + state.name = tc["function"]["name"] + tc["id"] = state.id + if not state.announced: + state.announced = True + events.append({ + "kind": "tool_start", + "index": index, + "id": state.id, + "name": state.name, + }) + final_args = tc["function"]["arguments"] + fragment = final_args + if final_args.startswith(state.emitted_arguments): + fragment = final_args[len(state.emitted_arguments):] + if fragment: + state.emitted_arguments = final_args + events.append({ + "kind": "tool_args", + "index": index, + "id": state.id, + "name": state.name, + "arguments": fragment, + }) + if not state.done: + state.done = True + events.append({ + "kind": "tool_done", + "index": index, + "id": state.id, + "name": state.name, + "arguments": final_args, + }) + return events + + +def _finalize_tool_buffer_stream( + state: _SharedStreamState, + tool_policy: ToolPolicy, +) -> tuple[list[dict], list[dict]]: + events: list[dict] = [] + tool_calls: list[dict] = [] + cleaned_after, tool_calls = _parse_generated_tool_calls(state.tool_buffer, tool_policy) + if tool_calls: + events.extend(_reconcile_final_tool_events(tool_calls, state.tool_states)) + if cleaned_after: + state.visible_text += cleaned_after + events.append({"kind": "content", "text": cleaned_after}) + elif state.tool_buffer: + state.visible_text += state.tool_buffer + events.append({"kind": "content", "text": state.tool_buffer}) + state.final_tool_calls = tool_calls + state.tool_buffer = "" + state.mode = "content" + return events, tool_calls + + +def _make_shared_stream_state(started_in_thinking: bool, stops: list[str], tool_policy: ToolPolicy) -> _SharedStreamState: + tag_holdback = max( + len(THINK_OPEN_TAG), + len(THINK_CLOSE_TAG), + len(TOOL_OPEN_TAG) if tool_policy.parse_tool_calls else 0, + len(FUNCTION_OPEN_TAG) if tool_policy.parse_tool_calls else 0, + len(TOOL_CODE_OPEN_TAG) if tool_policy.parse_tool_calls else 0, + ) + stop_holdback = max((len(s) for s in stops), default=0) + return _SharedStreamState( + mode="reasoning" if started_in_thinking else "content", + holdback=max(tag_holdback, stop_holdback), + allow_tools=tool_policy.parse_tool_calls, + ) + + +def _feed_shared_stream_piece( + state: _SharedStreamState, + piece: str, + *, + stops: list[str], + tool_policy: ToolPolicy, +) -> list[dict]: + events: list[dict] = [] + state.raw_text += piece + state.window += piece + + if stops: + if state.mode == "tool_buffer": + combined = state.tool_buffer + state.window + stop_index = _first_tool_buffer_stop_match(combined, stops) + if stop_index != -1: + state.raw_text = _trim_stream_raw_suffix( + state.raw_text, len(combined), stop_index) + state.tool_buffer = combined[:stop_index] + state.window = "" + state.stop_hit = True + else: + window_len = len(state.window) + stop_index = first_stop_match(state.window, stops) + if stop_index != -1: + state.raw_text = _trim_stream_raw_suffix( + state.raw_text, window_len, stop_index) + state.window = state.window[:stop_index] + state.stop_hit = True + + while True: + if state.mode == "tool_buffer": + state.tool_buffer += state.window + state.window = "" + events.extend(_sync_partial_tool_stream( + state.tool_buffer, tool_policy.parse_tools, state.tool_states)) + if state.stop_hit: + finalized_events, _ = _finalize_tool_buffer_stream(state, tool_policy) + events.extend(finalized_events) + break + + if state.mode == "reasoning": + idx = state.window.find(THINK_CLOSE_TAG) + if idx != -1: + pre = state.window[:idx] + if pre: + events.append({"kind": "reasoning", "text": pre}) + state.window = state.window[idx + len(THINK_CLOSE_TAG):] + state.mode = "content" + continue + if len(state.window) > state.holdback or (state.stop_hit and state.window): + safe = state.window if state.stop_hit else state.window[:-state.holdback] + if safe: + events.append({"kind": "reasoning", "text": safe}) + state.window = "" if state.stop_hit else state.window[-state.holdback:] + break + + think_idx = state.window.find(THINK_OPEN_TAG) + think_close_idx = state.window.find(THINK_CLOSE_TAG) + tool_idx = state.window.find(TOOL_OPEN_TAG) if state.allow_tools else -1 + bare_fn_idx = state.window.find(FUNCTION_OPEN_TAG) if state.allow_tools else -1 + tool_code_idx = state.window.find(TOOL_CODE_OPEN_TAG) if state.allow_tools else -1 + json_tool_idx = _find_toolish_json_start(state.window) if state.allow_tools else -1 + hits = [(i, t) for i, t in ( + (think_idx, "think"), + (think_close_idx, "think_close"), + (tool_idx, "tool"), + (bare_fn_idx, "tool"), + (tool_code_idx, "tool"), + (json_tool_idx, "tool"), + ) if i != -1] + if hits: + hits.sort() + idx, which = hits[0] + pre = state.window[:idx] + if pre: + state.visible_text += pre + events.append({"kind": "content", "text": pre}) + if which == "think": + state.window = state.window[idx + len(THINK_OPEN_TAG):] + state.mode = "reasoning" + elif which == "think_close": + state.window = state.window[idx + len(THINK_CLOSE_TAG):] + else: + state.tool_buffer += state.window[idx:] + state.window = "" + state.mode = "tool_buffer" + events.extend(_sync_partial_tool_stream( + state.tool_buffer, tool_policy.parse_tools, state.tool_states)) + if state.stop_hit: + finalized_events, _ = _finalize_tool_buffer_stream(state, tool_policy) + events.extend(finalized_events) + continue + if len(state.window) > state.holdback or (state.stop_hit and state.window): + safe = state.window if state.stop_hit else state.window[:-state.holdback] + if safe: + state.visible_text += safe + events.append({"kind": "content", "text": safe}) + state.window = "" if state.stop_hit else state.window[-state.holdback:] + break + return events + + +def _flush_shared_stream( + state: _SharedStreamState, + *, + tool_policy: ToolPolicy, +) -> tuple[list[dict], list[dict]]: + events: list[dict] = [] + tool_calls: list[dict] = [] + + if state.mode == "reasoning" and state.window: + events.append({"kind": "reasoning", "text": state.window}) + elif state.mode == "content" and state.window: + state.visible_text += state.window + events.append({"kind": "content", "text": state.window}) + elif state.mode == "tool_buffer": + state.tool_buffer += state.window + events.extend(_sync_partial_tool_stream( + state.tool_buffer, tool_policy.parse_tools, state.tool_states)) + state.window = "" + + if state.mode == "tool_buffer": + finalized_events, tool_calls = _finalize_tool_buffer_stream(state, tool_policy) + events.extend(finalized_events) + + return events, tool_calls + + +def _parse_responses_non_stream_output( + text: str, + tool_policy: ToolPolicy, + msg_item_id: str, + *, + started_in_thinking: bool, +) -> tuple[str, list[dict], list[tuple[str, str]]]: + state = _make_shared_stream_state(started_in_thinking, [], tool_policy) + output_order: list[tuple[str, str]] = [] + message_announced = False + + def record_output_order(events: list[dict]) -> None: + nonlocal message_announced + for event in events: + if event["kind"] == "content": + if not message_announced: + output_order.append(("message", msg_item_id)) + message_announced = True + continue + if event["kind"] == "tool_start": + output_order.append(("function_call", event["id"])) + + stream_events = _feed_shared_stream_piece( + state, text, stops=[], tool_policy=tool_policy) + record_output_order(stream_events) + flush_events, tool_calls = _flush_shared_stream(state, tool_policy=tool_policy) + record_output_order(flush_events) + return state.visible_text.strip(), tool_calls, output_order + + +def _responses_message_item(msg_item_id: str, text: str, *, status: str = "completed") -> dict: + return { + "type": "message", + "id": msg_item_id, + "status": status, + "role": "assistant", + "content": [{"type": "output_text", "text": text, "annotations": []}], + } + + +def _responses_function_call_item(tc: dict, *, status: str = "completed") -> dict: + return { + "type": "function_call", + "id": tc["id"], + "status": status, + "call_id": tc["id"], + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"], + } + + +def _build_responses_output( + cleaned_text: str, + tool_calls: list[dict], + msg_item_id: str, + *, + output_order: list[tuple[str, str]] | None = None, +) -> list[dict]: + items: dict[tuple[str, str], dict] = {} + if cleaned_text or not tool_calls: + items[("message", msg_item_id)] = _responses_message_item(msg_item_id, cleaned_text) + for tc in tool_calls: + items[("function_call", tc["id"])] = _responses_function_call_item(tc) + + if not output_order: + output: list[dict] = [] + if ("message", msg_item_id) in items: + output.append(items.pop(("message", msg_item_id))) + for tc in tool_calls: + item = items.pop(("function_call", tc["id"]), None) + if item is not None: + output.append(item) + return output + + output = [] + seen: set[tuple[str, str]] = set() + for key in output_order: + item = items.get(key) + if item is not None and key not in seen: + output.append(item) + seen.add(key) + for key, item in items.items(): + if key not in seen: + output.append(item) + return output + + def _samp_suffix(req) -> str: # Render ` samp=temp,top_p,top_k,rep_pen[,seed]` tail when the request asks for # non-greedy decoding. Empty string keeps the daemon protocol greedy-compatible. @@ -696,7 +1544,22 @@ def _resolve_kv_k_type(): cap=prefix_cache_slots, ) if prefill_cfg is not None and prefill_cache_slots > 0: - prefix_cache.init_full_cache(prefill_cache_slots, budget_bytes=prefill_cache_bytes) + prefix_cache.init_full_cache(prefill_cache_slots) + tool_memory = ToolMemory( + max_entries=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_ENTRIES", "50000")), + max_bytes=int(os.environ.get("DFLASH_TOOL_MEMORY_MAX_BYTES", str(64 * 1024 * 1024))), + ) + + def _remember_tool_call_text(raw_text: str, tool_calls: list[dict] | None) -> None: + if not raw_text or not tool_calls: + return + call_ids = [ + tc.get("id") + for tc in tool_calls + if isinstance(tc, dict) and isinstance(tc.get("id"), str) and tc.get("id") + ] + if call_ids: + tool_memory.remember(call_ids, raw_text) @app.on_event("startup") async def _startup(): @@ -796,18 +1659,23 @@ def _render_messages(msgs_list: list[dict], param="messages") return _ids_to_bin(ids), ids, prompt - def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], bool]: + def _tokenize_prompt(req: ChatRequest, tool_policy: ToolPolicy) -> tuple[Path, list[int], list[dict], bool]: """Returns (bin, ids, raw_msgs, started_in_thinking).""" msgs: list[dict] = [] for m in req.messages: d: dict = {"role": m.role} - if m.content is not None: + replay_raw_text = None + if m.role == "assistant" and m.tool_calls is not None: + replay_raw_text = tool_memory.lookup_message(m.tool_calls) + if replay_raw_text is not None: + d["content"] = replay_raw_text + elif m.content is not None: d["content"] = _content_to_str(m.content) if m.name is not None: d["name"] = m.name if m.tool_call_id is not None: d["tool_call_id"] = m.tool_call_id - if m.tool_calls is not None: + if m.tool_calls is not None and replay_raw_text is None: d["tool_calls"] = [] for tc in m.tool_calls: args = tc.function.arguments @@ -823,16 +1691,23 @@ def _tokenize_prompt(req: ChatRequest) -> tuple[Path, list[int], list[dict], boo msgs.append(d) tools_arg = None - if req.tools: - tools_arg = [t.model_dump() for t in req.tools] + if tool_policy.render_tools and tool_policy.prompt_tools: + tools_arg = [ + t.model_dump() if hasattr(t, "model_dump") else t + for t in tool_policy.prompt_tools + ] path, ids, _prompt = _render_messages(msgs, req.chat_template_kwargs, tools_arg) started_in_thinking = bool(re.search(r"\s*$", _prompt)) return path, ids, msgs, started_in_thinking def _maybe_compress(msgs: list[dict], prompt_bin: Path, prompt_ids: list[int], - template_kwargs: dict | None = None + template_kwargs: dict | None = None, + *, + bypass: bool = False, ) -> tuple[Path, list[int]]: + if bypass: + return prompt_bin, prompt_ids if not prefill_cfg or not prefill_cfg.enabled: return prompt_bin, prompt_ids if not prefill_cfg.should_compress(len(prompt_ids)): @@ -900,6 +1775,15 @@ async def _collect_tokens_sync(r, n_gen, timing=None) -> list[int]: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, lambda: list(_token_stream(r, n_gen, timing))) + async def _adrain_until_sentinel(r, timing=None) -> None: + if timing is not None and timing.get("daemon_done"): + return + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, _drain_until_sentinel, r) + if timing is not None: + timing["daemon_done"] = True + timing["t_last_tok"] = time.monotonic() + async def _astream_tokens(r, n_gen, timing=None): generated = 0 hit_stop = False @@ -944,16 +1828,19 @@ def _write_cmd(cmd_line: str, timing=None): sz = bin_path.stat().st_size if sz == 0: log.warning("prompt .bin is 0 bytes: %s", bin_path) - if lazy_draft: - log.debug("lazy-draft: unpark draft before generate") - t = time.monotonic() - daemon_proc.stdin.write(b"unpark draft\n") + try: + if lazy_draft: + log.debug("lazy-draft: unpark draft before generate") + t = time.monotonic() + daemon_proc.stdin.write(b"unpark draft\n") + daemon_proc.stdin.flush() + _drain_until_sentinel(r_pipe) + if timing is not None: + timing["unpark"] = time.monotonic() - t + daemon_proc.stdin.write(cmd_line.encode("utf-8")) daemon_proc.stdin.flush() - _drain_until_sentinel(r_pipe) - if timing is not None: - timing["unpark"] = time.monotonic() - t - daemon_proc.stdin.write(cmd_line.encode("utf-8")) - daemon_proc.stdin.flush() + except (BrokenPipeError, OSError, ValueError) as exc: + raise RuntimeError(f"failed to send command to dflash daemon: {exc}") from exc if timing is not None: timing["t_cmd_sent"] = time.monotonic() @@ -1048,8 +1935,15 @@ def _build_cmd_line(req, cur_bin, cur_ids, gen_len, prefix_cache, cmd_line += f" snap={snap_prep[1]}:{snap_prep[0]}" return cmd_line + _samp_suffix(req) + "\n", snap_prep + def _abort_reserved_snap(full_snap_prep, snap_prep): + if full_snap_prep is not None: + fslot, _ = full_snap_prep + prefix_cache.abort_full_snap(fslot) + elif snap_prep: + prefix_cache.abort_inline_snap(snap_prep[0]) + def _confirm_or_abort_snap(n_tokens: int, full_snap_prep, snap_prep, - prompt_ids, cur_bin, cur_ids): + prompt_ids, cur_bin, cur_ids): """Confirm prefix-cache snapshots only when the daemon actually generated tokens. When the daemon returns 0 tokens (e.g. empty prompt / file read failure), confirming would register a snapshot @@ -1063,11 +1957,7 @@ def _confirm_or_abort_snap(n_tokens: int, full_snap_prep, snap_prep, prefix_cache.confirm_inline_snap(*snap_prep, cur_ids) else: # Abort: release the reservation without registering. - if full_snap_prep is not None: - fslot, _ = full_snap_prep - prefix_cache.abort_full_snap(fslot) - elif snap_prep: - prefix_cache.abort_inline_snap(snap_prep[0]) + _abort_reserved_snap(full_snap_prep, snap_prep) log.warning("0 output tokens — aborted snapshot reservation") def _gen_len_for(prompt_len: int, max_tokens: int) -> int: @@ -1077,7 +1967,8 @@ def _gen_len_for(prompt_len: int, max_tokens: int) -> int: @app.post("/v1/chat/completions") async def chat_completions(req: ChatRequest): - prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(req) + tool_policy = _resolve_tool_policy(req.tools, req.tool_choice) + prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(req, tool_policy) completion_id = "chatcmpl-" + uuid.uuid4().hex[:24] created = int(time.time()) prompt_len = len(prompt_ids) @@ -1096,18 +1987,19 @@ async def chat_completions(req: ChatRequest): if req.stream: async def sse() -> AsyncIterator[str]: - nonlocal started_in_thinking + nonlocal started_in_thinking, prompt_len async with daemon_lock: timing = {} full_snap_prep_ref = [None] snap_prep = None + cur_bin = prompt_bin + cur_ids = prompt_ids - full_hit = prefix_cache.lookup_full(prompt_ids) + full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids) if full_hit is not None: slot, cached_cur_bin, cached_cur_ids_len = full_hit cur_bin = Path(cached_cur_bin) prompt_len = cached_cur_ids_len - started_in_thinking = False # cached: no think prefill gen_len = _gen_len_for(prompt_len, req.max_tokens) if gen_len <= 0: try: prompt_bin.unlink() @@ -1124,7 +2016,8 @@ async def sse() -> AsyncIterator[str]: t_compress = time.monotonic() cur_bin, cur_ids = await asyncio.to_thread( _maybe_compress, raw_msgs, prompt_bin, prompt_ids, - req.chat_template_kwargs) + req.chat_template_kwargs, + bypass=tool_policy.bypass_compression) timing["compress"] = time.monotonic() - t_compress prompt_len = len(cur_ids) gen_len = _gen_len_for(prompt_len, req.max_tokens) @@ -1147,6 +2040,7 @@ async def sse() -> AsyncIterator[str]: try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) yield f"data: {json.dumps({'error': str(e)})}\n\n" yield "data: [DONE]\n\n" return @@ -1159,8 +2053,6 @@ async def sse() -> AsyncIterator[str]: "finish_reason": None}], } yield f"data: {json.dumps(head)}\n\n" - window, mode = "", ("reasoning" if started_in_thinking else "content") - include_usage = bool(req.stream_options and req.stream_options.get("include_usage")) def chunk(delta_obj, finish=None): @@ -1169,94 +2061,91 @@ def chunk(delta_obj, finish=None): "choices": [{"index": 0, "delta": delta_obj, "finish_reason": finish}]} - # State machine: mode ∈ {'reasoning', 'content', 'tool_buffer'} - mode = "reasoning" if started_in_thinking else "content" - window = "" - tool_buffer = "" + def tool_chunk(index: int, *, call_id: str | None = None, + name: str | None = None, + arguments: str | None = None) -> str: + tc: dict[str, Any] = {"index": index} + if call_id is not None: + tc["id"] = call_id + tc["type"] = "function" + fn: dict[str, Any] = {} + if name is not None: + fn["name"] = name + if arguments is not None: + fn["arguments"] = arguments + if fn: + tc["function"] = fn + return f"data: {json.dumps(chunk({'tool_calls': [tc]}))}\n\n" + stops = normalize_stop(req.stop) - tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG)) - stop_holdback = max((len(s) for s in stops), default=0) - HOLDBACK = max(tag_holdback, stop_holdback) + stream_state = _make_shared_stream_state(started_in_thinking, stops, tool_policy) completion_tokens = 0 - stop_hit = False - - def emit_delta(text, kind): - if not text: - return None - return f"data: {json.dumps(chunk({kind: text}))}\n\n" + finish_reason = "stop" + tool_calls: list[dict] = [] try: async for tok_id in _astream_tokens(r_pipe, gen_len, timing): completion_tokens += 1 piece = tokenizer.decode([tok_id]) - window += piece - - if stops and mode != "tool_buffer": - si = first_stop_match(window, stops) - if si != -1: - window = window[:si] - stop_hit = True - kind = "reasoning_content" if mode == "reasoning" else "content" - out = emit_delta(window, kind) - if out: yield out - window = "" - break - - while True: - if mode == "tool_buffer": - tool_buffer += window - window = "" - break - - if mode == "reasoning": - idx = window.find(THINK_CLOSE_TAG) - if idx != -1: - pre = window[:idx] - out = emit_delta(pre, "reasoning_content") - if out: yield out - window = window[idx + len(THINK_CLOSE_TAG):] - mode = "content" - continue - if len(window) > HOLDBACK: - safe = window[:-HOLDBACK] - out = emit_delta(safe, "reasoning_content") - if out: yield out - window = window[-HOLDBACK:] - break - - else: # mode == "content" - think_idx = window.find(THINK_OPEN_TAG) - think_close_idx = window.find(THINK_CLOSE_TAG) - tool_idx = window.find(TOOL_OPEN_TAG) - hits = [(i, t) for i, t in - ((think_idx, "think"), - (think_close_idx, "think_close"), - (tool_idx, "tool")) if i != -1] - if hits: - hits.sort() - idx, which = hits[0] - pre = window[:idx] - out = emit_delta(pre, "content") - if out: yield out - if which == "think": - window = window[idx + len(THINK_OPEN_TAG):] - mode = "reasoning" - elif which == "think_close": - window = window[idx + len(THINK_CLOSE_TAG):] - else: - tool_buffer = window[idx:] - window = "" - mode = "tool_buffer" - continue - if len(window) > HOLDBACK: - safe = window[:-HOLDBACK] - out = emit_delta(safe, "content") - if out: yield out - window = window[-HOLDBACK:] - break - - if stop_hit: - finish_reason = "stop" + for event in _feed_shared_stream_piece( + stream_state, piece, stops=stops, tool_policy=tool_policy): + if event["kind"] == "content": + yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n" + elif event["kind"] == "reasoning": + yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n" + elif event["kind"] == "tool_start": + yield tool_chunk( + event["index"], + call_id=event["id"], + name=event["name"], + ) + elif event["kind"] == "tool_args": + yield tool_chunk( + event["index"], + arguments=event["arguments"], + ) + if stream_state.stop_hit: + break + + if stream_state.stop_hit: + if not timing.get("daemon_done"): + await _adrain_until_sentinel(r_pipe, timing) + stream_events, tool_calls = _flush_shared_stream( + stream_state, tool_policy=tool_policy) + if not tool_calls and stream_state.final_tool_calls: + tool_calls = list(stream_state.final_tool_calls) + for event in stream_events: + if event["kind"] == "content": + yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n" + elif event["kind"] == "reasoning": + yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n" + elif event["kind"] == "tool_start": + yield tool_chunk( + event["index"], + call_id=event["id"], + name=event["name"], + ) + elif event["kind"] == "tool_args": + yield tool_chunk( + event["index"], + arguments=event["arguments"], + ) + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + err = {"error": { + "message": tool_choice_error.message, + "type": tool_choice_error.error_type, + }} + if tool_choice_error.param is not None: + err["error"]["param"] = tool_choice_error.param + yield f"data: {json.dumps(err)}\n\n" + yield "data: [DONE]\n\n" + return + if tool_calls: + _remember_tool_call_text(stream_state.raw_text, tool_calls) + finish_reason = "tool_calls" + else: + finish_reason = "stop" yield f"data: {json.dumps(chunk({}, finish=finish_reason))}\n\n" if include_usage: usage_chunk = {"id": completion_id, "object": "chat.completion.chunk", @@ -1266,41 +2155,44 @@ def emit_delta(text, kind): "total_tokens": prompt_len + completion_tokens}} yield f"data: {json.dumps(usage_chunk)}\n\n" yield "data: [DONE]\n\n" - if timing.get("daemon_done") and full_hit is None: - try: cur_bin.unlink() - except Exception: pass - _park_draft_if_lazy(timing) return - # Flush remaining - if mode == "reasoning" and window: - out = emit_delta(window, "reasoning_content") - if out: yield out - elif mode == "content" and window: - out = emit_delta(window, "content") - if out: yield out - elif mode == "tool_buffer": - tool_buffer += window - window = "" - - finish_reason = "stop" - if mode == "tool_buffer": - cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=req.tools) - if tool_calls: - if cleaned_after: - out = emit_delta(cleaned_after, "content") - if out: yield out - tc_delta_list = [{ - "index": i, "id": tc["id"], "type": "function", - "function": {"name": tc["function"]["name"], - "arguments": tc["function"]["arguments"]}, - } for i, tc in enumerate(tool_calls)] - yield f"data: {json.dumps(chunk({'tool_calls': tc_delta_list}))}\n\n" - finish_reason = "tool_calls" - else: - out = emit_delta(tool_buffer, "content") - if out: yield out + stream_events, tool_calls = _flush_shared_stream( + stream_state, tool_policy=tool_policy) + for event in stream_events: + if event["kind"] == "content": + yield f"data: {json.dumps(chunk({'content': event['text']}))}\n\n" + elif event["kind"] == "reasoning": + yield f"data: {json.dumps(chunk({'reasoning_content': event['text']}))}\n\n" + elif event["kind"] == "tool_start": + yield tool_chunk( + event["index"], + call_id=event["id"], + name=event["name"], + ) + elif event["kind"] == "tool_args": + yield tool_chunk( + event["index"], + arguments=event["arguments"], + ) + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + err = {"error": { + "message": tool_choice_error.message, + "type": tool_choice_error.error_type, + }} + if tool_choice_error.param is not None: + err["error"]["param"] = tool_choice_error.param + yield f"data: {json.dumps(err)}\n\n" + yield "data: [DONE]\n\n" + return + if tool_calls: + _remember_tool_call_text(stream_state.raw_text, tool_calls) + finish_reason = "tool_calls" finally: + _confirm_or_abort_snap( + completion_tokens, full_snap_prep_ref[0], snap_prep, + prompt_ids, cur_bin, cur_ids) if timing.get("daemon_done"): if full_hit is None: try: cur_bin.unlink() @@ -1312,11 +2204,7 @@ def emit_delta(text, kind): log.warning( "stream ended before daemon sentinel; " "retaining prompt .bin for in-flight daemon read") - - _confirm_or_abort_snap( - completion_tokens, full_snap_prep_ref[0], snap_prep, - prompt_ids, cur_bin, cur_ids) - _park_draft_if_lazy(timing) + _park_draft_if_lazy(timing) yield f"data: {json.dumps(chunk({}, finish=finish_reason))}\n\n" if include_usage: @@ -1347,7 +2235,7 @@ def emit_delta(text, kind): full_snap_prep_ref = [None] snap_prep = None - full_hit = prefix_cache.lookup_full(prompt_ids) + full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids) if full_hit is not None: slot, cached_cur_bin, cached_cur_ids_len = full_hit cur_bin = Path(cached_cur_bin) @@ -1365,7 +2253,8 @@ def emit_delta(text, kind): t_compress = time.monotonic() cur_bin, cur_ids = await asyncio.to_thread( _maybe_compress, raw_msgs, prompt_bin, prompt_ids, - req.chat_template_kwargs) + req.chat_template_kwargs, + bypass=tool_policy.bypass_compression) timing["compress"] = time.monotonic() - t_compress prompt_len = len(cur_ids) gen_len = _gen_len_for(prompt_len, req.max_tokens) @@ -1383,6 +2272,7 @@ def emit_delta(text, kind): try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) return JSONResponse({"detail": str(e)}, status_code=503) # FIX 6: use run_in_executor instead of list() blocking event loop @@ -1411,7 +2301,11 @@ def emit_delta(text, kind): thinking_enabled = True if req.chat_template_kwargs: thinking_enabled = req.chat_template_kwargs.get("enable_thinking", True) - cleaned, tool_calls = parse_tool_calls(text, tools=req.tools) + cleaned, tool_calls = _parse_generated_tool_calls(text, tool_policy) + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + raise tool_choice_error + _remember_tool_call_text(text, tool_calls) cleaned, reasoning = parse_reasoning( cleaned, thinking_enabled=thinking_enabled, @@ -1530,6 +2424,7 @@ async def sse() -> AsyncIterator[str]: try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) yield f"event: error\ndata: {json.dumps({'type':'error','error':{'type':'server_error','message':str(e)}})}\n\n" return @@ -1643,6 +2538,7 @@ async def sse() -> AsyncIterator[str]: try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) return JSONResponse({"type": "error", "error": {"type": "server_error", "message": str(e)}}, status_code=503) @@ -1688,12 +2584,34 @@ def _map_responses_input(req: ResponsesCreateRequest ) -> tuple[list[ChatMessage], list[ToolDef] | None]: """Map Responses API input → ChatMessage list + ToolDef list.""" messages: list[ChatMessage] = [] + pending_assistant_text: list[str] = [] + pending_assistant_calls: list[ToolCall] = [] # Collect all system-level content (instructions + developer messages) # and merge into a single system message at position 0, since # Qwen's chat template requires the system message at the beginning. system_parts: list[str] = [] + def _flush_pending_assistant() -> None: + if not pending_assistant_text and not pending_assistant_calls: + return + messages.append(ChatMessage( + role="assistant", + content="".join(pending_assistant_text) or None, + tool_calls=list(pending_assistant_calls) or None, + )) + pending_assistant_text.clear() + pending_assistant_calls.clear() + + def _responses_item_text(content: Any) -> str: + if isinstance(content, list): + text_parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") in ("output_text", "text", "input_text"): + text_parts.append(part.get("text", "")) + return "".join(text_parts) + return content if isinstance(content, str) else "" + if req.instructions: system_parts.append(req.instructions) @@ -1709,18 +2627,14 @@ def _map_responses_input(req: ResponsesCreateRequest if item_type == "message": role = item.get("role", "user") - content = item.get("content", "") - if isinstance(content, list): - # Extract text from content parts - text_parts = [] - for part in content: - if isinstance(part, dict): - if part.get("type") in ("output_text", "text", "input_text"): - text_parts.append(part.get("text", "")) - content = "".join(text_parts) + content = _responses_item_text(item.get("content", "")) if role == "developer" or role == "system": + _flush_pending_assistant() system_parts.append(content) + elif role == "assistant": + pending_assistant_text.append(content) else: + _flush_pending_assistant() messages.append(ChatMessage(role=role, content=content)) elif item_type == "function_call": @@ -1732,10 +2646,10 @@ def _map_responses_input(req: ResponsesCreateRequest arguments=item.get("arguments", "{}"), ), ) - messages.append(ChatMessage( - role="assistant", content=None, tool_calls=[tc])) + pending_assistant_calls.append(tc) elif item_type == "function_call_output": + _flush_pending_assistant() output = item.get("output", "") if not isinstance(output, str): output = json.dumps(output) @@ -1747,6 +2661,8 @@ def _map_responses_input(req: ResponsesCreateRequest # Ignore reasoning, local_shell_call, etc. — we just # need the message/function_call/output items for the model. + _flush_pending_assistant() + # Prepend merged system message if system_parts: messages.insert(0, ChatMessage( @@ -1792,13 +2708,14 @@ async def responses_create(req: ResponsesCreateRequest): tool_choice=req.tool_choice, chat_template_kwargs={"enable_thinking": enable_thinking}, ) + tool_policy = _resolve_tool_policy(chat_req.tools, chat_req.tool_choice) response_id = "resp_" + uuid.uuid4().hex[:24] msg_item_id = "msg_" + uuid.uuid4().hex[:24] created_at = int(time.time()) # Tokenize - prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(chat_req) + prompt_bin, prompt_ids, raw_msgs, started_in_thinking = _tokenize_prompt(chat_req, tool_policy) prompt_len = len(prompt_ids) # Summarise roles for the log line @@ -1817,17 +2734,17 @@ async def responses_create(req: ResponsesCreateRequest): if req.stream: return await _responses_stream( - chat_req, prompt_bin, prompt_ids, raw_msgs, + chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy, started_in_thinking, response_id, msg_item_id, created_at, prompt_len, time.monotonic()) else: return await _responses_non_stream( - chat_req, prompt_bin, prompt_ids, raw_msgs, + chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy, started_in_thinking, response_id, msg_item_id, created_at, prompt_len, time.monotonic()) async def _responses_non_stream( - chat_req, prompt_bin, prompt_ids, raw_msgs, + chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy, started_in_thinking, response_id, msg_item_id, created_at, prompt_len, t0): """Non-streaming Responses API handler.""" @@ -1835,14 +2752,13 @@ async def _responses_non_stream( timing = {} full_snap_prep_ref = [None] snap_prep = None + cur_ids = None - full_hit = prefix_cache.lookup_full(prompt_ids) + full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids) if full_hit is not None: slot, cached_cur_bin, cached_cur_ids_len = full_hit cur_bin = Path(cached_cur_bin) - cur_ids = None prompt_len = cached_cur_ids_len - started_in_thinking = False # cached: no think prefill gen_len = _gen_len_for(prompt_len, chat_req.max_tokens) if gen_len <= 0: log.warning( @@ -1861,7 +2777,8 @@ async def _responses_non_stream( t_compress = time.monotonic() cur_bin, cur_ids = await asyncio.to_thread( _maybe_compress, raw_msgs, prompt_bin, prompt_ids, - chat_req.chat_template_kwargs) + chat_req.chat_template_kwargs, + bypass=tool_policy.bypass_compression) timing["compress"] = time.monotonic() - t_compress prompt_len = len(cur_ids) gen_len = _gen_len_for(prompt_len, chat_req.max_tokens) @@ -1885,6 +2802,7 @@ async def _responses_non_stream( try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) return JSONResponse({ "type": "error", "error": {"type": "server_error", "message": str(e)} @@ -1908,31 +2826,19 @@ async def _responses_non_stream( thinking_enabled = True if chat_req.chat_template_kwargs: thinking_enabled = chat_req.chat_template_kwargs.get("enable_thinking", True) - cleaned, tool_calls = parse_tool_calls(text, tools=chat_req.tools) + cleaned, tool_calls, output_order = _parse_responses_non_stream_output( + text, tool_policy, msg_item_id, + started_in_thinking=started_in_thinking) + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + raise tool_choice_error + _remember_tool_call_text(text, tool_calls) cleaned, reasoning = parse_reasoning( cleaned, thinking_enabled=thinking_enabled, started_in_thinking=started_in_thinking) - # Build output items - output: list[dict] = [] - if tool_calls: - for tc in tool_calls: - output.append({ - "type": "function_call", - "id": tc["id"], - "status": "completed", - "call_id": tc["id"], - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - }) - else: - output.append({ - "type": "message", - "id": msg_item_id, - "status": "completed", - "role": "assistant", - "content": [{"type": "output_text", "text": cleaned, "annotations": []}], - }) + output = _build_responses_output( + cleaned, tool_calls, msg_item_id, output_order=output_order) out_types = [o.get("type") for o in output] elapsed = time.monotonic() - t0 @@ -1959,7 +2865,7 @@ async def _responses_non_stream( }) async def _responses_stream( - chat_req, prompt_bin, prompt_ids, raw_msgs, + chat_req, prompt_bin, prompt_ids, raw_msgs, tool_policy, started_in_thinking, response_id, msg_item_id, created_at, prompt_len, t0): """Streaming Responses API handler — emits Responses SSE events.""" @@ -1971,13 +2877,13 @@ async def sse() -> AsyncIterator[str]: timing = {} full_snap_prep_ref = [None] snap_prep = None + cur_ids = None - full_hit = prefix_cache.lookup_full(prompt_ids) + full_hit = None if tool_policy.bypass_compression else prefix_cache.lookup_full(prompt_ids) if full_hit is not None: slot, cached_cur_bin, cached_cur_ids_len = full_hit cur_bin = Path(cached_cur_bin) prompt_len = cached_cur_ids_len - started_in_thinking = False gen_len = _gen_len_for(prompt_len, chat_req.max_tokens) if gen_len <= 0: log.warning( @@ -1995,7 +2901,8 @@ async def sse() -> AsyncIterator[str]: t_compress = time.monotonic() cur_bin, cur_ids = await asyncio.to_thread( _maybe_compress, raw_msgs, prompt_bin, prompt_ids, - chat_req.chat_template_kwargs) + chat_req.chat_template_kwargs, + bypass=tool_policy.bypass_compression) timing["compress"] = time.monotonic() - t_compress prompt_len = len(cur_ids) gen_len = _gen_len_for(prompt_len, chat_req.max_tokens) @@ -2018,106 +2925,157 @@ async def sse() -> AsyncIterator[str]: try: _write_cmd(cmd_line, timing) except RuntimeError as e: + _abort_reserved_snap(full_snap_prep_ref[0], snap_prep) yield _resp_sse("error", { "error": {"type": "server_error", "message": str(e)}}) return - # Lifecycle: response.created yield _resp_sse("response.created", { "response": _resp_shell(response_id, chat_req.model, created_at, "in_progress")}) - # Announce output item - yield _resp_sse("response.output_item.added", { - "output_index": 0, - "item": {"type": "message", "id": msg_item_id, - "status": "in_progress", "role": "assistant", - "content": []}}) - - # Announce content part - yield _resp_sse("response.content_part.added", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, - "part": {"type": "output_text", "text": "", "annotations": []}}) - - # Stream tokens with state machine - mode = "reasoning" if started_in_thinking else "content" - window = "" - tool_buffer = "" - accumulated_text = "" - tag_holdback = max(len(THINK_OPEN_TAG), len(THINK_CLOSE_TAG), len(TOOL_OPEN_TAG)) - HOLDBACK = tag_holdback + stream_state = _make_shared_stream_state(started_in_thinking, [], tool_policy) completion_tokens = 0 - tool_call_active = False + tool_calls: list[dict] = [] + next_output_index = 0 + message_output_index: int | None = None + message_announced = False + message_done = False + tool_output_indices: dict[str, int] = {} + output_item_order: list[tuple[str, str]] = [] + + def ensure_message_started() -> list[str]: + nonlocal next_output_index, message_output_index, message_announced + if message_announced: + return [] + message_output_index = next_output_index + next_output_index += 1 + message_announced = True + output_item_order.append(("message", msg_item_id)) + return [ + _resp_sse("response.output_item.added", { + "output_index": message_output_index, + "item": { + "type": "message", + "id": msg_item_id, + "status": "in_progress", + "role": "assistant", + "content": [], + }, + }), + _resp_sse("response.content_part.added", { + "item_id": msg_item_id, + "output_index": message_output_index, + "content_index": 0, + "part": {"type": "output_text", "text": "", "annotations": []}, + }), + ] + + def ensure_tool_started(event: dict) -> list[str]: + nonlocal next_output_index + output_index = tool_output_indices.get(event["id"]) + if output_index is not None: + return [] + output_index = next_output_index + next_output_index += 1 + tool_output_indices[event["id"]] = output_index + output_item_order.append(("function_call", event["id"])) + return [_resp_sse("response.output_item.added", { + "output_index": output_index, + "item": { + "type": "function_call", + "id": event["id"], + "status": "in_progress", + "call_id": event["id"], + "name": event["name"], + "arguments": "", + }, + })] + + def finalize_message_item() -> list[str]: + nonlocal message_done + if not message_announced or message_done or message_output_index is None: + return [] + message_done = True + return [ + _resp_sse("response.output_text.done", { + "item_id": msg_item_id, + "output_index": message_output_index, + "content_index": 0, + "text": stream_state.visible_text, + }), + _resp_sse("response.content_part.done", { + "item_id": msg_item_id, + "output_index": message_output_index, + "content_index": 0, + "part": {"type": "output_text", "text": stream_state.visible_text, + "annotations": []}, + }), + _resp_sse("response.output_item.done", { + "output_index": message_output_index, + "item": _responses_message_item(msg_item_id, stream_state.visible_text), + }), + ] + + def finalize_tool_items(final_tool_calls: list[dict]) -> list[str]: + messages: list[str] = [] + tool_map = {tc["id"]: tc for tc in final_tool_calls} + for kind, item_id in output_item_order: + if kind != "function_call": + continue + tc = tool_map.get(item_id) + output_index = tool_output_indices.get(item_id) + if tc is None or output_index is None: + continue + messages.append(_resp_sse("response.function_call_arguments.done", { + "item_id": item_id, + "output_index": output_index, + "arguments": tc["function"]["arguments"], + "name": tc["function"]["name"], + })) + messages.append(_resp_sse("response.output_item.done", { + "output_index": output_index, + "item": _responses_function_call_item(tc), + })) + return messages + + def emit_stream_event(event: dict) -> list[str]: + messages: list[str] = [] + if event["kind"] == "content": + messages.extend(ensure_message_started()) + assert message_output_index is not None + messages.append(_resp_sse("response.output_text.delta", { + "item_id": msg_item_id, + "output_index": message_output_index, + "content_index": 0, + "delta": event["text"], + })) + return messages + if event["kind"] == "tool_start": + messages.extend(ensure_tool_started(event)) + return messages + if event["kind"] == "tool_args": + if event["id"] not in tool_output_indices: + messages.extend(ensure_tool_started(event)) + messages.append(_resp_sse("response.function_call_arguments.delta", { + "item_id": event["id"], + "output_index": tool_output_indices[event["id"]], + "delta": event["arguments"], + })) + return messages + if event["kind"] == "tool_done": + if event["id"] not in tool_output_indices: + messages.extend(ensure_tool_started(event)) + return messages try: async for tok_id in _astream_tokens(r_pipe, gen_len, timing): completion_tokens += 1 piece = tokenizer.decode([tok_id]) - window += piece - - while True: - if mode == "tool_buffer": - tool_buffer += window - window = "" - break - - if mode == "reasoning": - idx = window.find(THINK_CLOSE_TAG) - if idx != -1: - window = window[idx + len(THINK_CLOSE_TAG):] - mode = "content" - continue - if len(window) > HOLDBACK: - window = window[-HOLDBACK:] - break - - else: # content - think_idx = window.find(THINK_OPEN_TAG) - think_close_idx = window.find(THINK_CLOSE_TAG) - tool_idx = window.find(TOOL_OPEN_TAG) - hits = [(i, t) for i, t in - ((think_idx, "think"), - (think_close_idx, "think_close"), - (tool_idx, "tool")) if i != -1] - if hits: - hits.sort() - idx, which = hits[0] - pre = window[:idx] - if pre: - accumulated_text += pre - yield _resp_sse("response.output_text.delta", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "delta": pre}) - if which == "think": - window = window[idx + len(THINK_OPEN_TAG):] - mode = "reasoning" - elif which == "think_close": - window = window[idx + len(THINK_CLOSE_TAG):] - else: - tool_buffer = window[idx:] - window = "" - mode = "tool_buffer" - continue - if len(window) > HOLDBACK: - safe = window[:-HOLDBACK] - accumulated_text += safe - yield _resp_sse("response.output_text.delta", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "delta": safe}) - window = window[-HOLDBACK:] - break - - # Flush remaining window - if mode == "content" and window: - accumulated_text += window - yield _resp_sse("response.output_text.delta", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "delta": window}) - elif mode == "tool_buffer": - tool_buffer += window - window = "" - + for event in _feed_shared_stream_piece( + stream_state, piece, stops=[], tool_policy=tool_policy): + for message in emit_stream_event(event): + yield message finally: if timing.get("daemon_done"): if full_hit is None: @@ -2136,70 +3094,45 @@ async def sse() -> AsyncIterator[str]: prompt_ids, cur_bin, cur_ids) _park_draft_if_lazy(timing) - # Build final output items - final_output: list[dict] = [] - if mode == "tool_buffer" and tool_buffer: - cleaned_after, tool_calls = parse_tool_calls(tool_buffer, tools=chat_req.tools) - if tool_calls: - if cleaned_after: - accumulated_text += cleaned_after - for tc in tool_calls: - tool_call_active = True - tc_item_id = tc["id"] - # Emit function_call_arguments.delta for each tool call - yield _resp_sse("response.function_call_arguments.delta", { - "item_id": tc_item_id, "output_index": 0, - "delta": tc["function"]["arguments"]}) - yield _resp_sse("response.function_call_arguments.done", { - "item_id": tc_item_id, "output_index": 0, - "arguments": tc["function"]["arguments"], - "name": tc["function"]["name"]}) - final_output.append({ - "type": "function_call", - "id": tc_item_id, - "status": "completed", - "call_id": tc_item_id, - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"], - }) - else: - accumulated_text += tool_buffer - yield _resp_sse("response.output_text.delta", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "delta": tool_buffer}) - - # Finalize text output - yield _resp_sse("response.output_text.done", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, "text": accumulated_text}) - yield _resp_sse("response.content_part.done", { - "item_id": msg_item_id, "output_index": 0, - "content_index": 0, - "part": {"type": "output_text", "text": accumulated_text, - "annotations": []}}) - - if not tool_call_active: - final_output.append({ - "type": "message", - "id": msg_item_id, - "status": "completed", - "role": "assistant", - "content": [{"type": "output_text", "text": accumulated_text, - "annotations": []}], - }) + stream_events, tool_calls = _flush_shared_stream( + stream_state, tool_policy=tool_policy) + for event in stream_events: + for message in emit_stream_event(event): + yield message + + if tool_calls: + _remember_tool_call_text(stream_state.raw_text, tool_calls) + + if not message_announced and not tool_calls: + for message in ensure_message_started(): + yield message + for message in finalize_message_item(): + yield message + + tool_choice_error = _tool_choice_violation(tool_policy, tool_calls) + if tool_choice_error is not None: + yield _resp_sse("response.failed", { + "response": _resp_shell(response_id, chat_req.model, created_at, + "failed")}) + err = {"error": { + "type": tool_choice_error.error_type, + "message": tool_choice_error.message, + }} + if tool_choice_error.param is not None: + err["error"]["param"] = tool_choice_error.param + yield _resp_sse("error", err) + return - yield _resp_sse("response.output_item.done", { - "output_index": 0, - "item": final_output[0] if final_output else { - "type": "message", "id": msg_item_id, - "status": "completed", "role": "assistant", - "content": []}}) + for message in finalize_tool_items(tool_calls): + yield message + final_output = _build_responses_output( + stream_state.visible_text, tool_calls, msg_item_id, + output_order=output_item_order) - # response.completed shell = _resp_shell(response_id, chat_req.model, created_at, "completed") shell["output"] = final_output - shell["output_text"] = accumulated_text + shell["output_text"] = stream_state.visible_text shell["usage"] = { "input_tokens": prompt_len, "output_tokens": completion_tokens, @@ -2211,7 +3144,7 @@ async def sse() -> AsyncIterator[str]: log.info( "responses DONE %s in=%d out=%d %.1fs %.1f tok/s output=%s text_len=%d %s", response_id, prompt_len, completion_tokens, - elapsed, tok_s, out_types, len(accumulated_text), + elapsed, tok_s, out_types, len(stream_state.visible_text), _timing_summary(timing, completion_tokens), ) yield _resp_sse("response.completed", {"response": shell}) diff --git a/dflash/scripts/test_server.py b/dflash/scripts/test_server.py index 7bdfc569..bd4a37e2 100644 --- a/dflash/scripts/test_server.py +++ b/dflash/scripts/test_server.py @@ -8,6 +8,7 @@ import pytest from fastapi.testclient import TestClient +from _prefill_hook import PrefillConfig from server import ( build_app, MODEL_NAME, parse_tool_calls, parse_reasoning, @@ -43,11 +44,122 @@ def app(mock_tokenizer): return a +@pytest.fixture +def app_with_prefill(mock_tokenizer): + """Build a FastAPI app with compression enabled.""" + drafter_tokenizer = MagicMock() + drafter_tokenizer.return_value = {"input_ids": [1, 2, 3]} + drafter_tokenizer.decode.return_value = "compressed prompt" + with patch("server.subprocess.Popen") as mock_popen: + mock_popen.return_value.poll.return_value = None # daemon alive + a = build_app( + target=Path("target.gguf"), + draft=Path("draft.safetensors"), + bin_path=Path("test_dflash"), + budget=22, + max_ctx=131072, + tokenizer=mock_tokenizer, + stop_ids={2}, + prefill_cache_slots=0, + prefill_cfg=PrefillConfig( + mode="always", + threshold=1, + keep_ratio=0.5, + drafter_gguf=Path("drafter.gguf"), + drafter_tokenizer_id="mock-tokenizer", + ), + drafter_tokenizer=drafter_tokenizer, + ) + return a + + @pytest.fixture def client(app): return TestClient(app) +def _build_app_with_process(mock_tokenizer, process, *, enable_prefill: bool = False): + kwargs = {} + if enable_prefill: + drafter_tokenizer = MagicMock() + drafter_tokenizer.return_value = {"input_ids": [1, 2, 3]} + drafter_tokenizer.decode.return_value = "compressed prompt" + kwargs.update( + prefill_cache_slots=0, + prefill_cfg=PrefillConfig( + mode="always", + threshold=1, + keep_ratio=0.5, + drafter_gguf=Path("drafter.gguf"), + drafter_tokenizer_id="mock-tokenizer", + ), + drafter_tokenizer=drafter_tokenizer, + ) + with patch("server.subprocess.Popen") as mock_popen: + mock_popen.return_value = process + return build_app( + target=Path("target.gguf"), + draft=Path("draft.safetensors"), + bin_path=Path("test_dflash"), + budget=22, + max_ctx=131072, + tokenizer=mock_tokenizer, + stop_ids={2}, + **kwargs, + ) + + +def _chat_sse_chunks(text: str) -> list[dict]: + return [ + json.loads(line[6:]) + for line in text.strip().split("\n\n") + if line.startswith("data: ") and line != "data: [DONE]" + ] + + +def _responses_sse_events(text: str) -> list[tuple[str, dict]]: + events: list[tuple[str, dict]] = [] + for block in text.strip().split("\n\n"): + if not block.strip(): + continue + event_line = next((line for line in block.splitlines() if line.startswith("event: ")), None) + data_line = next((line for line in block.splitlines() if line.startswith("data: ")), None) + if event_line and data_line: + events.append((event_line[7:], json.loads(data_line[6:]))) + return events + + +def _chat_stream_assistant_message(chunks: list[dict]) -> dict: + content_parts: list[str] = [] + tool_calls: dict[int, dict] = {} + for chunk in chunks: + for choice in chunk.get("choices", []): + delta = choice.get("delta", {}) + text = delta.get("content") + if isinstance(text, str): + content_parts.append(text) + for tc_delta in delta.get("tool_calls", []): + index = tc_delta["index"] + state = tool_calls.setdefault(index, { + "id": tc_delta.get("id"), + "type": tc_delta.get("type", "function"), + "function": {"name": None, "arguments": ""}, + }) + if tc_delta.get("id"): + state["id"] = tc_delta["id"] + fn_delta = tc_delta.get("function", {}) + if fn_delta.get("name"): + state["function"]["name"] = fn_delta["name"] + if fn_delta.get("arguments"): + state["function"]["arguments"] += fn_delta["arguments"] + msg = {"role": "assistant"} + content = "".join(content_parts) + msg["content"] = content or None + if tool_calls: + msg["tool_calls"] = [tool_calls[i] for i in sorted(tool_calls)] + return msg + + # ─── parse_reasoning ─────────────────────────────────────────────── class TestParseReasoning: @@ -385,6 +497,79 @@ def test_chat_completions_non_streaming_with_tool_call(mock_os_read, mock_pipe, assert tc[0]["function"]["name"] == "read_file" +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_completions_replays_raw_tool_call_text(mock_os_read, mock_pipe, + mock_tokenizer, app): + mock_pipe.return_value = (1, 2) + raw_tool_text = ( + "Before\n" + "" + "test.py" + "\n" + "After" + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("", "", "8"] + mock_lookup_full.return_value = (7, "cached_prompt.bin", 1) + mock_tokenizer.apply_chat_template.return_value = "prompt\n" + + def decode_side_effect(ids, *args, **kwargs): + token_id = ids[0] if isinstance(ids, list) else ids + return { + 10: "hidden reasoning", + 11: "", + 12: "visible answer", + }[token_id] + + mock_tokenizer.decode.side_effect = decode_side_effect mock_os_read.side_effect = [ struct.pack("" not in text - assert '"content":"8"' in text or '"content": "8"' in text - + chunks = _chat_sse_chunks(response.text) + reasoning = "".join( + choice.get("delta", {}).get("reasoning_content", "") + for chunk in chunks + for choice in chunk.get("choices", []) + ) + content = "".join( + choice.get("delta", {}).get("content", "") + for chunk in chunks + for choice in chunk.get("choices", []) + ) + assert reasoning == "hidden reasoning" + assert content == "visible answer" + assert "" not in content + assert "hidden reasoning" not in content -# ─── POST /v1/responses ─────────────────────────────────────────── @patch("server.os.pipe") @patch("server.os.read") -def test_responses_non_streaming(mock_os_read, mock_pipe, mock_tokenizer, app): - """POST /v1/responses non-streaming returns ResponsesObject.""" +@pytest.mark.parametrize(("decoded_chunks", "leaked_fragments"), [ + ( + [ + "test", + ".py", + ], + ["", "test", + ".py", + ], + ["{"name":"read_file","arguments":{"path":"test', + '.py"}}', + ], + ["", '{"name":"read_file"'], + ), +]) +def test_chat_completions_streaming_tool_call_deltas(mock_os_read, mock_pipe, + mock_tokenizer, app, + decoded_chunks, leaked_fragments): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack(" 0 - assert data["usage"]["output_tokens"] > 0 - assert data["usage"]["total_tokens"] == data["usage"]["input_tokens"] + data["usage"]["output_tokens"] + chunks = _chat_sse_chunks(response.text) + tool_deltas = [ + chunk["choices"][0]["delta"]["tool_calls"][0] + for chunk in chunks + if chunk.get("choices") + and chunk["choices"][0].get("delta", {}).get("tool_calls") + ] + assert len(tool_deltas) >= 2 + assert tool_deltas[0]["id"].startswith("call_") + assert tool_deltas[0]["function"]["name"] == "read_file" + assert "".join( + delta.get("function", {}).get("arguments", "") + for delta in tool_deltas + ) == '{"path":"test.py"}' + assert not any( + any(fragment in choice.get("delta", {}).get("content", "") for fragment in leaked_fragments) + for chunk in chunks + for choice in chunk.get("choices", []) + ) + finish_chunk = next( + chunk for chunk in reversed(chunks) + if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None + ) + assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls" @patch("server.os.pipe") @patch("server.os.read") -def test_responses_non_streaming_string_input(mock_os_read, mock_pipe, - mock_tokenizer, app): - """Responses API accepts a plain string as input.""" +def test_chat_streaming_stop_sequence_preserves_tool_deltas(mock_os_read, mock_pipe, + mock_tokenizer, app): mock_pipe.return_value = (1, 2) + stop_marker = "" + mock_tokenizer.decode.return_value = ( + "" + "test.py" + "" + f"{stop_marker}" + ) mock_os_read.side_effect = [struct.pack("" in choice.get("delta", {}).get("content", "") + for chunk in chunks + for choice in chunk.get("choices", []) + ) + assistant_msg = _chat_stream_assistant_message(chunks) + assert assistant_msg["content"] is None + assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file" + assert assistant_msg["tool_calls"][0]["function"]["arguments"] == '{"path":"test.py"}' + finish_chunk = next( + chunk for chunk in reversed(chunks) + if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None + ) + assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls" @patch("server.os.pipe") @patch("server.os.read") -def test_responses_non_streaming_started_in_thinking(mock_os_read, mock_pipe, - mock_tokenizer, app): - """When prompt ends with , reasoning without tags is not misclassified as content.""" +def test_chat_streaming_stop_sequence_after_bare_function_close_tag( + mock_os_read, mock_pipe, mock_tokenizer, app): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("\n - mock_tokenizer.apply_chat_template.return_value = "prompt\n" - # Model output has no tags — it's a continuation of the prefilled block - mock_tokenizer.decode.return_value = "internal reasoning\nactual answer" + stop_marker = "" + mock_tokenizer.decode.side_effect = [ + "test", + ".py", + f"{stop_marker}ignored", + ] + mock_os_read.side_effect = [ + struct.pack("', + ], + ['{"name":"read_file"', ""], + ), + ( + [ + '{"name":"read_file","arguments":{"path":"test', + '.py"}}', + ], + ["", '{"name":"read_file"', ""], + ), +]) +def test_chat_streaming_stop_sequence_after_json_tool_call( + mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks, + leaked_fragments): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("", + "stream": True, }) assert response.status_code == 200 - # Verify apply_chat_template was called with system message - calls = mock_tokenizer.apply_chat_template.call_args_list - last_call = calls[-1] - msgs = last_call[0][0] # first positional arg - assert msgs[0]["role"] == "system" - assert "coding assistant" in msgs[0]["content"] + chunks = _chat_sse_chunks(response.text) + assistant_msg = _chat_stream_assistant_message(chunks) + assert assistant_msg["content"] is None + assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file" + assert assistant_msg["tool_calls"][0]["function"]["arguments"] == '{"path":"test.py"}' + assert not any( + any(fragment in choice.get("delta", {}).get("content", "") for fragment in leaked_fragments) + for chunk in chunks + for choice in chunk.get("choices", []) + ) + finish_chunk = next( + chunk for chunk in reversed(chunks) + if chunk.get("choices") and chunk["choices"][0]["finish_reason"] is not None + ) + assert finish_chunk["choices"][0]["finish_reason"] == "tool_calls" @patch("server.os.pipe") @patch("server.os.read") -def test_responses_streaming(mock_os_read, mock_pipe, mock_tokenizer, app): - """POST /v1/responses streaming emits proper SSE lifecycle events.""" +def test_chat_streaming_replays_raw_tool_call_text(mock_os_read, mock_pipe, + mock_tokenizer, app): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("" + "test.py" + "\\n" + "After" + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("", "", "8"] + stop_marker = "" + trimmed_raw_tool_text = ( + "Before\n" + "" + "test.py" + "\n" + "After" + ) + mock_tokenizer.decode.side_effect = [ + trimmed_raw_tool_text + stop_marker + "ignored", + "followup", + ] mock_os_read.side_effect = [ - struct.pack("" not in text - assert '"delta":"8"' in text or '"delta": "8"' in text + second = client.post("/v1/chat/completions", json={ + "model": MODEL_NAME, + "messages": [ + {"role": "user", "content": "read test.py"}, + assistant_msg, + {"role": "tool", "tool_call_id": assistant_msg["tool_calls"][0]["id"], "content": "file body"}, + {"role": "user", "content": "what next?"}, + ], + "stream": False, + }) + assert second.status_code == 200 + + msgs = mock_tokenizer.apply_chat_template.call_args_list[-1][0][0] + assistant = next(m for m in msgs if m["role"] == "assistant") + assert assistant["content"] == trimmed_raw_tool_text + assert stop_marker not in assistant["content"] + assert "ignored" not in assistant["content"] + assert "tool_calls" not in assistant @patch("server.os.pipe") @patch("server.os.read") -def test_responses_with_tools(mock_os_read, mock_pipe, mock_tokenizer, app): - """POST /v1/responses with function tools maps correctly.""" +@pytest.mark.parametrize("decoded_chunks", [ + [ + "test", + ".py", + ], + [ + "test", + ".py", + ], +]) +def test_chat_streaming_stop_hit_drains_daemon_before_next_stream( + mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("", + "stream": True, }) + assert first.status_code == 200 + assistant_msg = _chat_stream_assistant_message(_chat_sse_chunks(first.text)) + assert assistant_msg["tool_calls"][0]["function"]["name"] == "read_file" - assert response.status_code == 200 - data = response.json() - assert data["object"] == "response" - assert data["status"] == "completed" + second = client.post("/v1/chat/completions", json={ + "model": MODEL_NAME, + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }) + assert second.status_code == 200 + second_msg = _chat_stream_assistant_message(_chat_sse_chunks(second.text)) + assert second_msg["content"] == "fresh reply" + assert "tool_calls" not in second_msg +@patch("server.PrefixCache.abort_full_snap") +@patch("server.PrefixCache.confirm_full_snap") +@patch("server.PrefixCache.prepare_full_snap", return_value=(7, 0)) +@patch("server.compress_text_via_daemon", return_value="compressed prompt") @patch("server.os.pipe") @patch("server.os.read") -def test_responses_object_tool_choice(mock_os_read, mock_pipe, - mock_tokenizer, app): - """POST /v1/responses with object-style tool_choice must not 422.""" +def test_chat_streaming_stop_hit_confirms_reserved_full_snapshot( + mock_os_read, mock_pipe, _mock_compress, _mock_prepare_full_snap, + mock_confirm_full_snap, mock_abort_full_snap, mock_tokenizer, + app_with_prefill): mock_pipe.return_value = (1, 2) - mock_os_read.side_effect = [struct.pack("", + 12: "stale", + }[token_id] + + mock_tokenizer.decode.side_effect = decode_side_effect + mock_os_read.side_effect = [ + struct.pack("", + "stream": True, + }) + + assert response.status_code == 200 + assistant_msg = _chat_stream_assistant_message(_chat_sse_chunks(response.text)) + assert assistant_msg["content"] == "hello" + mock_confirm_full_snap.assert_called_once() + slot, prompt_ids, cur_bin, cur_ids_len = mock_confirm_full_snap.call_args.args + assert slot == 7 + assert prompt_ids == [1] + assert isinstance(cur_bin, Path) + assert cur_ids_len == 1 + mock_abort_full_snap.assert_not_called() + + +@patch("server.PrefixCache.abort_full_snap") +@patch("server.PrefixCache.prepare_full_snap", return_value=(7, 0)) +@patch("server.compress_text_via_daemon", return_value="compressed prompt") +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_streaming_write_failure_aborts_reserved_full_snapshot( + mock_os_read, mock_pipe, _mock_compress, _mock_prepare_full_snap, + mock_abort_full_snap, mock_tokenizer): + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [] + dead_proc = MagicMock() + dead_proc.poll.return_value = 1 + local_app = _build_app_with_process( + mock_tokenizer, dead_proc, enable_prefill=True) + + client = TestClient(local_app) + response = client.post("/v1/chat/completions", json={ + "model": MODEL_NAME, + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }) + + assert response.status_code == 200 + assert "dflash daemon has exited unexpectedly" in response.text + mock_abort_full_snap.assert_called_once_with(7) + + +@patch("server.os.pipe") +@patch("server.os.read") +def test_chat_tool_choice_required_rejects_plain_text(mock_os_read, mock_pipe, + mock_tokenizer, app): + mock_pipe.return_value = (1, 2) + mock_tokenizer.decode.return_value = "plain text" + mock_os_read.side_effect = [struct.pack("file.txt' + '' + ) + mock_os_read.side_effect = [struct.pack(" 0 + assert data["usage"]["output_tokens"] > 0 + assert data["usage"]["total_tokens"] == data["usage"]["input_tokens"] + data["usage"]["output_tokens"] + + +@patch("server.os.pipe") +@patch("server.os.read") +def test_responses_non_streaming_string_input(mock_os_read, mock_pipe, + mock_tokenizer, app): + """Responses API accepts a plain string as input.""" + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("" + "file.txt" + "" + "After tool" + ) + mock_os_read.side_effect = [struct.pack(", reasoning without tags is not misclassified as content.""" + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("\n + mock_tokenizer.apply_chat_template.return_value = "prompt\n" + # Model output has no tags — it's a continuation of the prefilled block + mock_tokenizer.decode.return_value = "internal reasoning\nactual answer" + + client = TestClient(app) + response = client.post("/v1/responses", json={ + "model": MODEL_NAME, + "input": [{"type": "message", "role": "user", "content": "hello"}], + }) + + assert response.status_code == 200 + data = response.json() + assert data["object"] == "response" + # The "actual answer" part should be the output text, not the reasoning + assert "actual answer" in data["output_text"] + # The reasoning should NOT leak into the output text + assert "internal reasoning" not in data["output_text"] + + +@patch("server.os.pipe") +@patch("server.os.read") +def test_responses_with_instructions(mock_os_read, mock_pipe, + mock_tokenizer, app): + """Instructions are mapped to system message.""" + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("", + 12: "visible answer", + }[token_id] + + mock_tokenizer.decode.side_effect = decode_side_effect + mock_os_read.side_effect = [ + struct.pack("" in delta or "hidden reasoning" in delta for delta in deltas) + completed = next(data for event, data in events if event == "response.completed") + assert completed["response"]["output_text"] == "visible answer" + + +@patch("server.os.pipe") +@patch("server.os.read") +@pytest.mark.parametrize("decoded_chunks", [ + [ + "test", + ".py", + ], + [ + "test", + ".py", + ], + [ + '{"name":"read_file","arguments":{"path":"test', + '.py"}}', + ], + [ + '{"name":"read_file","arguments":{"path":"test', + '.py"}}', + ], +]) +def test_responses_streaming_function_call_lifecycle(mock_os_read, mock_pipe, + mock_tokenizer, app, + decoded_chunks): + mock_pipe.return_value = (1, 2) + mock_tokenizer.decode.side_effect = decoded_chunks + mock_os_read.side_effect = [ + struct.pack("file.txt', + ], + ["", "{"name":"write_file","arguments":{"path":"file.txt"}}', + ], + ["", '{"name":"write_file"'], + ), +]) +def test_responses_streaming_tool_choice_failure_suppresses_terminal_function_events( + mock_os_read, mock_pipe, mock_tokenizer, app, decoded_chunks, + leaked_fragments): + mock_pipe.return_value = (1, 2) + mock_tokenizer.decode.side_effect = decoded_chunks + mock_os_read.side_effect = [struct.pack("file.txt", + "After tool", + ] + mock_os_read.side_effect = [ + struct.pack("", "", "8"] + mock_os_read.side_effect = [ + struct.pack("" not in text + assert '"delta":"8"' in text or '"delta": "8"' in text + + +@patch("server.os.pipe") +@patch("server.os.read") +def test_responses_with_tools(mock_os_read, mock_pipe, mock_tokenizer, app): + """POST /v1/responses with function tools maps correctly.""" + mock_pipe.return_value = (1, 2) + mock_os_read.side_effect = [struct.pack("' + 'file.txt' + '' + ) + mock_tokenizer.decode.return_value = raw_tool_text + mock_os_read.side_effect = [struct.pack("' + 'file.txt' + '' + ) + mock_os_read.side_effect = [struct.pack("' + 'other.txt' + '' + '' + 'file.txt' + '' + ) + mock_os_read.side_effect = [struct.pack("' + 'file.txt' + '' + ) + mock_os_read.side_effect = [struct.pack("' + 'file.txt' + '' + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("' + 'file.txt' + '' + 'After tool' + ) + mock_tokenizer.decode.side_effect = [raw_tool_text, "followup"] + mock_os_read.side_effect = [ + struct.pack("old") + mem.remember(["call_new"], "new") + + assert mem.lookup_message([{"id": "call_old"}]) is None + assert mem.lookup_message([{"id": "call_new"}]) == "new" + assert len(mem.by_block) == 1 + + +def test_tool_memory_lookup_message_requires_same_raw_text(): + mem = ToolMemory(max_entries=8, max_bytes=4096) + mem.remember(["call_a"], "a") + mem.remember(["call_b"], "b") + + assert mem.lookup_message([{"id": "call_a"}, {"id": "call_b"}]) is None diff --git a/dflash/scripts/tool_memory.py b/dflash/scripts/tool_memory.py new file mode 100644 index 00000000..58bb4a83 --- /dev/null +++ b/dflash/scripts/tool_memory.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Iterable, Sequence + + +@dataclass +class _ToolMemoryBlock: + raw_text: str + size_bytes: int + refs: int = 0 + + +class ToolMemory: + """Exact assistant-text memory for tool-calling turns. + + The server receives assistant tool calls back as structured JSON on later + turns. Re-rendering those objects can change key order, spacing, or wrapper + shape, which changes prompt tokenization and breaks prefix/KV reuse. This + store remembers the exact assistant text that originally produced a set of + tool-call IDs so replay can inject the original text verbatim. + """ + + def __init__(self, *, max_entries: int = 50_000, max_bytes: int = 64 * 1024 * 1024): + self.max_entries = max(0, int(max_entries)) + self.max_bytes = max(0, int(max_bytes)) + self.by_id: dict[str, _ToolMemoryBlock] = {} + self.by_block: dict[str, _ToolMemoryBlock] = {} + self._lru: OrderedDict[str, None] = OrderedDict() + self.total_bytes = 0 + + @property + def disabled(self) -> bool: + return self.max_entries == 0 or self.max_bytes == 0 + + def remember(self, call_ids: Iterable[str], raw_text: str) -> None: + if self.disabled or not raw_text: + return + unique_ids = [] + seen: set[str] = set() + for call_id in call_ids: + if not isinstance(call_id, str) or not call_id or call_id in seen: + continue + seen.add(call_id) + unique_ids.append(call_id) + if not unique_ids: + return + + block = self.by_block.get(raw_text) + if block is None: + block = _ToolMemoryBlock( + raw_text=raw_text, + size_bytes=len(raw_text.encode("utf-8")), + ) + self.by_block[raw_text] = block + self.total_bytes += block.size_bytes + + for call_id in unique_ids: + current = self.by_id.get(call_id) + if current is block: + self._touch(call_id) + continue + if current is not None: + self._drop_entry(call_id, current) + self.by_id[call_id] = block + block.refs += 1 + self._touch(call_id) + + self._prune() + + def lookup_message(self, tool_calls: Sequence[Any]) -> str | None: + raw_text: str | None = None + touched: list[str] = [] + for item in tool_calls: + call_id = self._extract_call_id(item) + if not call_id: + return None + block = self.by_id.get(call_id) + if block is None: + return None + touched.append(call_id) + if raw_text is None: + raw_text = block.raw_text + elif raw_text != block.raw_text: + return None + if raw_text is None: + return None + for call_id in touched: + self._touch(call_id) + return raw_text + + def _touch(self, call_id: str) -> None: + self._lru[call_id] = None + self._lru.move_to_end(call_id) + + def _prune(self) -> None: + while self.by_id and ( + (self.max_entries > 0 and len(self.by_id) > self.max_entries) + or (self.max_bytes > 0 and self.total_bytes > self.max_bytes) + ): + oldest_id, _ = self._lru.popitem(last=False) + block = self.by_id.get(oldest_id) + if block is not None: + self._drop_entry(oldest_id, block) + + def _drop_entry(self, call_id: str, block: _ToolMemoryBlock) -> None: + self.by_id.pop(call_id, None) + self._lru.pop(call_id, None) + if block.refs > 0: + block.refs -= 1 + if block.refs == 0: + self.by_block.pop(block.raw_text, None) + self.total_bytes -= block.size_bytes + if self.total_bytes < 0: + self.total_bytes = 0 + + @staticmethod + def _extract_call_id(item: Any) -> str | None: + if isinstance(item, dict): + call_id = item.get("id") + return call_id if isinstance(call_id, str) and call_id else None + call_id = getattr(item, "id", None) + return call_id if isinstance(call_id, str) and call_id else None diff --git a/pflash/README.md b/pflash/README.md index 8f6a473f..350704bd 100644 --- a/pflash/README.md +++ b/pflash/README.md @@ -95,7 +95,7 @@ python tests/bench_niah_cpp.py \ ## OpenAI server flags -For an OpenAI-compatible server with transparent compression on long prompts, run [`dflash/scripts/server.py`](../dflash/scripts/server.py) (or `server_tools.py` for tool-calling) with these flags: +For an OpenAI-compatible server with transparent compression on long prompts, run [`dflash/scripts/server.py`](../dflash/scripts/server.py) with these flags: | Flag | Choices / type | Default | Effect | |---|---|:---:|---|