Skip to content

Commit 772fb57

Browse files
committed
Enable async tool cancellation feature.
1 parent 7660194 commit 772fb57

3 files changed

Lines changed: 191 additions & 9 deletions

File tree

src/pipecat/adapters/base_llm_adapter.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from loguru import logger
1717

18+
from pipecat.adapters.schemas.function_schema import FunctionSchema
1819
from pipecat.adapters.schemas.tools_schema import ToolsSchema
1920
from pipecat.processors.aggregators.llm_context import (
2021
LLMContext,
@@ -48,6 +49,20 @@ class BaseLLMAdapter(ABC, Generic[TLLMInvocationParams]):
4849
def __init__(self):
4950
"""Initialize the adapter."""
5051
self._warned_system_instruction = False
52+
self._builtin_tools: List[FunctionSchema] = []
53+
54+
@property
55+
def builtin_tools(self) -> List[FunctionSchema]:
56+
"""Built-in tools automatically merged into every inference request.
57+
58+
Mixins (e.g. ``AsyncToolCancellationLLMServiceMixin``) append their
59+
tool schemas here so that the tools are injected transparently without
60+
the user having to add them to their ``ToolsSchema``.
61+
62+
Returns:
63+
Mutable list of ``FunctionSchema`` instances.
64+
"""
65+
return self._builtin_tools
5166

5267
@property
5368
@abstractmethod
@@ -122,15 +137,36 @@ def get_messages(self, context: LLMContext) -> List[LLMContextMessage]:
122137
def from_standard_tools(self, tools: Any) -> List[Any] | NotGiven:
123138
"""Convert tools from standard format to provider format.
124139
140+
Built-in tools are automatically merged into the schema before conversion so that every
141+
inference request receives them without the user having to declare them explicitly.
142+
125143
Args:
126144
tools: Tools in standard format or provider-specific format.
127145
128146
Returns:
129147
List of tools converted to provider format, or original tools
130148
if not in standard format.
131149
"""
150+
if self._builtin_tools:
151+
if isinstance(tools, ToolsSchema):
152+
tools = ToolsSchema(
153+
standard_tools=tools.standard_tools + self._builtin_tools,
154+
custom_tools=tools.custom_tools,
155+
)
156+
else:
157+
# User supplied tools in a legacy/provider-specific format;
158+
# we cannot safely merge — build a schema from builtins only.
159+
if tools is not None:
160+
logger.warning(
161+
"Built-in tools could not be merged because the supplied tools are not"
162+
" a ToolsSchema instance. Only built-in tools will be sent."
163+
)
164+
tools = ToolsSchema(standard_tools=self._builtin_tools)
165+
132166
if isinstance(tools, ToolsSchema):
133167
logger.debug(f"Retrieving the tools using the adapter: {type(self)}")
168+
tool_names = [tool.name for tool in tools.standard_tools]
169+
logger.debug(f"Tool names: {tool_names}")
134170
return self.to_provider_tools_format(tools)
135171
# Fallback to return the same tools in case they are not in a standard format
136172
return tools

src/pipecat/services/llm_service.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Awaitable,
1717
Callable,
1818
Dict,
19+
List,
1920
Mapping,
2021
Optional,
2122
Protocol,
@@ -60,6 +61,11 @@
6061
from pipecat.services.settings import LLMSettings
6162
from pipecat.services.websocket_service import WebsocketService
6263
from pipecat.turns.user_turn_completion_mixin import UserTurnCompletionLLMServiceMixin
64+
from pipecat.utils.async_tool_cancellation import (
65+
ASYNC_TOOL_CANCELLATION_INSTRUCTIONS,
66+
CANCEL_ASYNC_TOOL_NAME,
67+
CANCEL_ASYNC_TOOL_SCHEMA,
68+
)
6369
from pipecat.utils.context.llm_context_summarization import (
6470
DEFAULT_SUMMARIZATION_TIMEOUT,
6571
LLMContextSummarizationUtil,
@@ -230,6 +236,7 @@ def __init__(
230236
self._group_parallel_tools = group_parallel_tools
231237
self._function_call_timeout_secs = function_call_timeout_secs
232238
self._filter_incomplete_user_turns: bool = False
239+
self._async_cancellation_enabled: bool = False
233240
self._base_system_instruction: Optional[str] = None
234241
self._adapter = self.adapter_class()
235242
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
@@ -291,6 +298,8 @@ async def start(self, frame: StartFrame):
291298
await super().start(frame)
292299
if not self._run_in_parallel:
293300
await self._create_sequential_runner_task()
301+
if self._has_async_functions():
302+
self._setup_async_tool_cancellation()
294303

295304
async def stop(self, frame: EndFrame):
296305
"""Stop the LLM service.
@@ -315,17 +324,20 @@ async def cancel(self, frame: CancelFrame):
315324
await self._cancel_summary_task()
316325

317326
def _compose_system_instruction(self):
318-
"""Compose system_instruction by appending turn completion instructions.
327+
"""Compose system_instruction from the base and all active addon instructions.
319328
320329
Combines the base system instruction with turn completion instructions
321-
and writes the result to ``self._settings.system_instruction``.
330+
(when enabled) and async tool cancellation instructions (when enabled),
331+
writing the result to ``self._settings.system_instruction``.
322332
"""
323333
base = self._base_system_instruction
324-
completion_instructions = self._user_turn_completion_config.completion_instructions
325-
if base:
326-
self._settings.system_instruction = f"{base}\n\n{completion_instructions}"
327-
else:
328-
self._settings.system_instruction = completion_instructions
334+
parts = [base] if base else []
335+
if self._filter_incomplete_user_turns:
336+
parts.append(self._user_turn_completion_config.completion_instructions)
337+
if self._async_cancellation_enabled:
338+
parts.append(ASYNC_TOOL_CANCELLATION_INSTRUCTIONS)
339+
composed = "\n\n".join(p for p in parts if p)
340+
self._settings.system_instruction = composed or None
329341

330342
async def _update_settings(self, delta: LLMSettings) -> dict[str, Any]:
331343
"""Apply a settings delta, handling turn-completion fields.
@@ -361,10 +373,10 @@ async def _update_settings(self, delta: LLMSettings) -> dict[str, Any]:
361373

362374
if (
363375
"system_instruction" in changed
364-
and self._filter_incomplete_user_turns
376+
and (self._filter_incomplete_user_turns or self._async_cancellation_enabled)
365377
and "filter_incomplete_user_turns" not in changed
366378
):
367-
# system_instruction changed while turn completion is active.
379+
# system_instruction changed while composition is active.
368380
# Treat the new value as the new base and recompose.
369381
self._base_system_instruction = self._settings.system_instruction
370382
self._compose_system_instruction()
@@ -849,6 +861,91 @@ async def timeout_handler():
849861
if timeout_task and not timeout_task.done():
850862
await self.cancel_task(timeout_task)
851863

864+
def _has_async_functions(self) -> bool:
865+
"""Return True if at least one non-builtin async function is registered."""
866+
return any(
867+
not item.cancel_on_interruption
868+
for name, item in self._functions.items()
869+
if name != CANCEL_ASYNC_TOOL_NAME
870+
)
871+
872+
def _setup_async_tool_cancellation(self):
873+
"""Enable async tool cancellation.
874+
875+
Saves the base system instruction, recomposes to include cancellation
876+
instructions, registers the built-in ``cancel_async_tool_call`` handler,
877+
and injects its schema into the adapter's built-in tool list.
878+
"""
879+
logger.debug(f"{self}: Enabling async tool cancellation")
880+
881+
self._async_cancellation_enabled = True
882+
883+
if self._base_system_instruction is None:
884+
self._base_system_instruction = self._settings.system_instruction
885+
886+
self._compose_system_instruction()
887+
888+
if not any(t.name == CANCEL_ASYNC_TOOL_NAME for t in self._adapter.builtin_tools):
889+
self._adapter.builtin_tools.append(CANCEL_ASYNC_TOOL_SCHEMA)
890+
891+
if CANCEL_ASYNC_TOOL_NAME not in self._functions:
892+
self._functions[CANCEL_ASYNC_TOOL_NAME] = FunctionCallRegistryItem(
893+
function_name=CANCEL_ASYNC_TOOL_NAME,
894+
handler=self._cancel_async_tool_call_handler,
895+
cancel_on_interruption=True,
896+
)
897+
898+
async def _cancel_async_tool_call_handler(self, params: "FunctionCallParams"):
899+
"""Handle a ``cancel_async_tool_call`` invocation from the LLM.
900+
901+
Args:
902+
params: Function call parameters containing ``tool_call_id`` to cancel.
903+
"""
904+
logger.info("_cancel_async_tool_call_handler invoked!")
905+
906+
tool_call_id: Optional[str] = params.arguments.get("tool_call_id")
907+
if not tool_call_id:
908+
logger.warning(f"{self} cancel_async_tool_call called with no tool_call_id")
909+
await params.result_callback({"cancelled": None})
910+
return
911+
912+
await self._cancel_function_calls_by_tool_call_id(tool_call_id)
913+
await params.result_callback(
914+
{"cancelled": tool_call_id},
915+
properties=FunctionCallResultProperties(run_llm=True),
916+
)
917+
918+
async def _cancel_function_calls_by_tool_call_id(self, tool_call_id: str):
919+
"""Cancel in-progress function call tasks by their tool_call_id.
920+
921+
Args:
922+
tool_call_id: tool_call_id to cancel.
923+
"""
924+
cancelled_tasks = set()
925+
for task, runner_item in self._function_call_tasks.items():
926+
if runner_item.tool_call_id == tool_call_id:
927+
name = runner_item.function_name
928+
tool_call_id = runner_item.tool_call_id
929+
930+
logger.debug(
931+
f"{self} Cancelling async function call [{name}:{tool_call_id}] "
932+
"by LLM request..."
933+
)
934+
935+
if task:
936+
task.remove_done_callback(self._function_call_task_finished)
937+
await self.cancel_task(task)
938+
cancelled_tasks.add(task)
939+
940+
await self.broadcast_frame(
941+
FunctionCallCancelFrame, function_name=name, tool_call_id=tool_call_id
942+
)
943+
944+
logger.debug(f"{self} Async function call [{name}:{tool_call_id}] cancelled")
945+
946+
for task in cancelled_tasks:
947+
self._function_call_task_finished(task)
948+
852949
async def _cancel_function_call(self, function_name: Optional[str]):
853950
cancelled_tasks = set()
854951
for task, runner_item in self._function_call_tasks.items():
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#
2+
# Copyright (c) 2024-2026, Daily
3+
#
4+
# SPDX-License-Identifier: BSD 2-Clause License
5+
#
6+
7+
"""Constants for the built-in async tool cancellation feature.
8+
9+
When an ``LLMService`` has functions registered with
10+
``cancel_on_interruption=False`` (async tools), it automatically injects the
11+
``cancel_async_tool_call`` tool and the instructions below into every inference
12+
request so the LLM can cancel stale in-progress calls.
13+
"""
14+
15+
from pipecat.adapters.schemas.function_schema import FunctionSchema
16+
17+
CANCEL_ASYNC_TOOL_NAME = "cancel_async_tool_call"
18+
19+
ASYNC_TOOL_CANCELLATION_INSTRUCTIONS = """ASYNC TOOL CANCELLATION:
20+
Some tool calls run asynchronously in the background. When one starts, a tool response \
21+
is added to the conversation whose content is a JSON object with \
22+
"type": "tool", "status": "started", and a "tool_call_id" field containing the \
23+
exact ID of that call (e.g. {"type": "tool", "status": "started", "tool_call_id": "..."}).
24+
25+
If the user changes topic, explicitly says they no longer need the result, or the pending \
26+
result would clearly be stale, call cancel_async_tool_call. \
27+
To find the correct tool_call_id: locate the most recent tool response in the conversation \
28+
whose content has "status": "started" and whose call has NOT already been cancelled, \
29+
then copy the "tool_call_id" value from that content exactly as-is. \
30+
Never invent or guess a tool_call_id."""
31+
32+
CANCEL_ASYNC_TOOL_SCHEMA = FunctionSchema(
33+
name=CANCEL_ASYNC_TOOL_NAME,
34+
description=(
35+
"Cancel a single async tool call that is no longer needed. "
36+
"Use this when the user changes topic, indicates a pending result is "
37+
"no longer relevant, or when processing the result would produce a "
38+
"stale or confusing response. "
39+
"The tool_call_id must be the exact 'id' value from the assistant's "
40+
"tool call which we wish to cancel, visible in the conversation history."
41+
),
42+
properties={
43+
"tool_call_id": {
44+
"type": "string",
45+
"description": ("The exact id of the async call to cancel."),
46+
}
47+
},
48+
required=["tool_call_id"],
49+
)

0 commit comments

Comments
 (0)