From db463488ece36fc1232787748fa5c1fb817e975c Mon Sep 17 00:00:00 2001 From: monoxgas Date: Tue, 20 May 2025 18:24:46 -0600 Subject: [PATCH 1/2] Added new Stop exception for breaking from recursive tool calls. --- rigging/__init__.py | 2 ++ rigging/chat.py | 57 ++++++++++++++++++++++++++++---------------- rigging/error.py | 32 +++++++++++++++++++++++++ rigging/prompt.py | 21 +++++++--------- rigging/tool/base.py | 20 +++++++++------- tests/test_tool.py | 16 ++++++------- 6 files changed, 98 insertions(+), 50 deletions(-) diff --git a/rigging/__init__.py b/rigging/__init__.py index 03ff1fa..7324336 100644 --- a/rigging/__init__.py +++ b/rigging/__init__.py @@ -14,6 +14,7 @@ MapCompletionCallback, ThenCompletionCallback, ) +from rigging.error import Stop from rigging.generator import ( GeneratedMessage, GeneratedText, @@ -64,6 +65,7 @@ "PipelineStepContextManager", "PipelineStepGenerator", "Prompt", + "Stop", "ThenChatCallback", "ThenCompletionCallback", "Tool", diff --git a/rigging/chat.py b/rigging/chat.py index 1ab091b..25008a6 100644 --- a/rigging/chat.py +++ b/rigging/chat.py @@ -720,7 +720,7 @@ def __init__( self.tool_mode: ToolMode = "auto" self.api_tool_choice: ApiToolChoice | None = None self.inject_tool_prompt = True - self.stop_on_tool_calls = True + self.add_tool_stop_token = True self.then_callbacks: list[tuple[ThenChatCallback, int]] = [] self.map_callbacks: list[tuple[MapChatCallback, int]] = [] self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or [] @@ -880,6 +880,7 @@ def clone( *, only_messages: bool = False, chat: Chat | None = None, + callbacks: bool | t.Sequence[MapChatCallback | ThenChatCallback] = True, ) -> "ChatPipeline": """ Creates a clone of the current `ChatPipeline` instance. @@ -890,6 +891,8 @@ def clone( including until callbacks, types, tools, metadata, etc. chat: An optional chat object clone for use in the new pipeline, otherwise the current internal chat object will be cloned. + callbacks: If True (default), all callbacks will be cloned. If False, no callbacks will be cloned. + Otherwise provide a sequence of callbacks which should be maintained in the new pipeline. Returns: The cloned ChatPipeline. @@ -906,16 +909,20 @@ def clone( new.tools = self.tools.copy() new.tool_mode = self.tool_mode new.metadata = deepcopy(self.metadata) - new.map_callbacks = self.map_callbacks.copy() new.on_failed = self.on_failed new.errors_to_catch = self.errors_to_catch.copy() new.errors_to_exclude = self.errors_to_exclude.copy() new.caching = self.caching + new.watch_callbacks = self.watch_callbacks.copy() + # Check if any of our callbacks are bound methods to a ChatPipline. # If so, we should rebind them to `self` to ensure they work correctly # and aren't operating with old state. + if callbacks is False: + return new + new.then_callbacks = [ (callback, max_depth) if not hasattr(callback, "__self__") @@ -931,6 +938,18 @@ def clone( for callback, max_depth in self.map_callbacks.copy() ] + if not isinstance(callbacks, bool): + new.then_callbacks = [ + (callback, max_depth) + for callback, max_depth in self.then_callbacks + if callback in callbacks + ] + new.map_callbacks = [ + (callback, max_depth) + for callback, max_depth in self.map_callbacks + if callback in callbacks + ] + return new def meta(self, **kwargs: t.Any) -> "ChatPipeline": @@ -1105,7 +1124,7 @@ def using( mode: ToolMode | None = None, choice: ApiToolChoice | None = None, max_depth: int = DEFAULT_MAX_DEPTH, - stop_on_tool_calls: bool | None = None, + add_stop_token: bool | None = None, ) -> "ChatPipeline": """ Adds a tool or a sequence of tools to participate in the generation process. @@ -1119,7 +1138,8 @@ def using( mode: The tool calling mode to use (e.g., "xml", "json-in-xml", "api"). choice: The API tool choice to use. This is only relevant when using the "api" tool mode. max_depth: The maximum depth for recursive tool calls (this is shared between all tools). - stop_on_tool_calls: When using natively parsed tools, whether to stop generation when a tool call block is observed. + add_stop_token: When using natively parsed tools ("xml", "json-in-xml"), use stop tokens to + immediately process a tool call when observed. Returns: The updated pipeline. @@ -1172,8 +1192,8 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"]) if choice is not None: self.api_tool_choice = choice - if stop_on_tool_calls is not None: - self.stop_on_tool_calls = stop_on_tool_calls + if add_stop_token is not None: + self.add_tool_stop_token = add_stop_token return self @@ -1237,7 +1257,7 @@ def until_parsed_as( async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None: if ( - self.stop_on_tool_calls + self.add_tool_stop_token and self.tool_mode in ["xml", "json-in-xml"] and chat.stop_reason == "stop" ): @@ -1270,31 +1290,26 @@ async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None: if not tool_calls: return None - next_pipeline = self.clone(chat=chat) + next_pipeline = self.clone(chat=chat, callbacks=[self._then_tools]) - should_continue = True + stop = False for tool_call in tool_calls: tool = next((t for t in self.tools if t.name == tool_call.name), None) if tool is None: raise UnknownToolError(tool_call.name) - message, _should_continue = await tool.handle_tool_call(tool_call) + message, _stop = await tool.handle_tool_call(tool_call) + stop = _stop if not _stop else stop next_pipeline.add(message) - # If the tool returns none, we should resolve tool calls, but - # not continue the pipeline. - - if not _should_continue: - should_continue = _should_continue - # Need to prevent infinite loops and treat tool_choice like # an ephemeral setting which resets after the first tool call. if self.tool_mode == "api" and next_pipeline.params: next_pipeline.params.tool_choice = None - if not should_continue: + if stop: # TODO(nick): Type hints here stop us from mixing step generators # and basic chat returns. return next_pipeline.chat # type: ignore [return-value] @@ -1302,7 +1317,7 @@ async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None: return next_pipeline.step() async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None: - next_pipeline = self.clone(chat=chat) + next_pipeline = self.clone(chat=chat, callbacks=[self._then_parse]) try: chat.last.parse_many(*self.until_types) @@ -1310,14 +1325,14 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None: next_pipeline.add( Message.from_model( ValidationErrorModel(content=str(e)), - suffix="Rewrite your entire message with all the required xml structure.", + suffix="Rewrite your entire message with all of the required xml elements.", ), ) except Exception as e: # noqa: BLE001 next_pipeline.add( Message.from_model( SystemErrorModel(content=str(e)), - suffix="Rewrite your entire message with all the required xml structure.", + suffix="Rewrite your entire message with all of the required xml elements.", ), ) else: # parsed successfully @@ -1336,7 +1351,7 @@ async def _pre_run(self) -> None: self.chat.inject_tool_prompt(self.tools, self.tool_mode) self.inject_native_tool_prompt = False - if self.stop_on_tool_calls: + if self.add_tool_stop_token: self.params = self.params = GenerateParams() self.params.stop = self.params.stop or [] self.params.stop.append(f"") diff --git a/rigging/error.py b/rigging/error.py index cc0f5b5..1dc1acf 100644 --- a/rigging/error.py +++ b/rigging/error.py @@ -14,6 +14,38 @@ from rigging.message import Message +# User Throwable Exceptions + + +class Stop(Exception): # noqa: N818 + """ + Raise inside a pipeline to indicate a stopping condition. + + Example: + ``` + import rigging as rg + + async def read_file(path: str) -> str: + "Read the contents of a file." + + if no_more_files(path): + raise rg.Stop("There are no more files to read.") + + ... + + chat = await pipeline.using(read_file).run() + ``` + """ + + def __init__(self, message: str): + super().__init__(message) + self.message = message + """The message associated with the stop.""" + + +# System Exceptions + + class UnknownToolError(Exception): """ Raised when the an api tool call is made for an unknown tool. diff --git a/rigging/prompt.py b/rigging/prompt.py index 4ad7727..9add2d6 100644 --- a/rigging/prompt.py +++ b/rigging/prompt.py @@ -566,14 +566,14 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None: next_pipeline.add( Message.from_model( ValidationErrorModel(content=str(e)), - suffix="Rewrite your entire message with all the required xml structure.", + suffix="Rewrite your entire message with all of the required xml elements.", ), ) except Exception as e: # noqa: BLE001 next_pipeline.add( Message.from_model( SystemErrorModel(content=str(e)), - suffix="Rewrite your entire message with all the required xml structure.", + suffix="Rewrite your entire message with all of the required xml elements.", ), ) else: # parsed successfully @@ -1084,8 +1084,7 @@ def prompt( generator_id: str | None = None, tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None, system_prompt: str | None = None, -) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]: - ... +) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]: ... @t.overload @@ -1098,8 +1097,7 @@ def prompt( generator_id: str | None = None, tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None, system_prompt: str | None = None, -) -> Prompt[P, R]: - ... +) -> Prompt[P, R]: ... @t.overload @@ -1112,8 +1110,7 @@ def prompt( generator_id: str | None = None, tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None, system_prompt: str | None = None, -) -> Prompt[P, R]: - ... +) -> Prompt[P, R]: ... def prompt( @@ -1214,8 +1211,9 @@ def make_prompt( @t.overload -def make_prompt(content: str, return_type: type[R], *, ctx: Ctx | None = None) -> Prompt[..., R]: - ... +def make_prompt( + content: str, return_type: type[R], *, ctx: Ctx | None = None +) -> Prompt[..., R]: ... @t.overload @@ -1224,8 +1222,7 @@ def make_prompt( return_type: None = None, *, ctx: Ctx | None = None, -) -> Prompt[..., str]: - ... +) -> Prompt[..., str]: ... def make_prompt( diff --git a/rigging/tool/base.py b/rigging/tool/base.py index 55e4844..c868d87 100644 --- a/rigging/tool/base.py +++ b/rigging/tool/base.py @@ -14,7 +14,7 @@ import typing_extensions as te from pydantic import TypeAdapter -from rigging.error import ToolDefinitionError +from rigging.error import Stop, ToolDefinitionError from rigging.model import Model, make_from_schema, make_from_signature from rigging.tool.api import ApiFunctionDefinition, ApiToolCall, ApiToolDefinition from rigging.tool.native import ( @@ -258,7 +258,7 @@ def json_definition(self) -> JsonInXmlToolDefinition: parameters=json.dumps(self.parameters_schema), ) - async def handle_tool_call( # noqa: PLR0912 + async def handle_tool_call( # noqa: PLR0912, PLR0915 self, tool_call: ApiToolCall | XmlToolCall | JsonInXmlToolCall, ) -> tuple["Message", bool]: @@ -269,7 +269,8 @@ async def handle_tool_call( # noqa: PLR0912 tool_call: The tool call to handle. Returns: - The message to send back to the generator or `None` if iterative tool calling should not proceed any further. + A tuple containing the message to send back to the generator and a + boolean indicating whether tool calling should stop. """ from rigging.message import ContentText, ContentTypes, Message @@ -330,10 +331,16 @@ async def handle_tool_call( # noqa: PLR0912 # Call the function + stop = False + try: result: t.Any = self.fn(**kwargs) # type: ignore [call-arg] if inspect.isawaitable(result): result = await result + except Stop as e: + result = f"{e.message}" + span.set_attribute("stop", True) + stop = True except Exception as e: if self.catch is True or ( not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch)) @@ -350,11 +357,6 @@ async def handle_tool_call( # noqa: PLR0912 else Message("user") ) - # If the tool returns nothing back to us, we'll assume that - # they do not want to proceed with additional tool calling - - should_continue = result is not None - # If the tool gave us back anything that looks like a message, we'll # just pass it along. Otherwise we need to box up the result. @@ -395,7 +397,7 @@ async def handle_tool_call( # noqa: PLR0912 result=message.content_parts[0].text, ).to_pretty_xml() - return message, should_continue + return message, stop def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: return self.fn(*args, **kwargs) diff --git a/tests/test_tool.py b/tests/test_tool.py index 7919893..2563016 100644 --- a/tests/test_tool.py +++ b/tests/test_tool.py @@ -223,9 +223,9 @@ async def test_handle_api_tool_call(self, sample_tool: Tool[..., t.Any]) -> None ), ) - message, should_continue = await sample_tool.handle_tool_call(tool_call) + message, stop = await sample_tool.handle_tool_call(tool_call) - assert should_continue is True + assert stop is False assert message is not None assert message.role == "tool" assert message.tool_call_id == "call123" @@ -245,9 +245,9 @@ async def test_handle_xml_tool_call(self, sample_tool: Tool[..., t.Any]) -> None ).strip(), ) - message, should_continue = await sample_tool.handle_tool_call(tool_call) + message, stop = await sample_tool.handle_tool_call(tool_call) - assert should_continue is True + assert stop is False assert message is not None assert message.role == "user" assert message.content == '8' @@ -260,9 +260,9 @@ async def test_handle_json_xml_tool_call(self, sample_tool: Tool[..., t.Any]) -> parameters=json.dumps({"a": 4, "b": 4, "operation": "add"}), ) - message, should_continue = await sample_tool.handle_tool_call(tool_call) + message, stop = await sample_tool.handle_tool_call(tool_call) - assert should_continue is True + assert stop is False assert message is not None assert message.role == "user" assert message.content == '8' @@ -378,7 +378,7 @@ def faulty_function(x: int) -> int: tool = Tool.from_callable(faulty_function, catch={ValueError}) - message, should_continue = await tool.handle_tool_call(tool_call) + message, stop = await tool.handle_tool_call(tool_call) - assert should_continue is True + assert stop is False assert "This is a test error" in message.content From 8a0f4aa467a2d418c52778612ee8226dcc26c9af Mon Sep 17 00:00:00 2001 From: monoxgas Date: Tue, 20 May 2025 22:02:04 -0600 Subject: [PATCH 2/2] fix formatting --- rigging/prompt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/rigging/prompt.py b/rigging/prompt.py index 9add2d6..3a6431b 100644 --- a/rigging/prompt.py +++ b/rigging/prompt.py @@ -1212,7 +1212,10 @@ def make_prompt( @t.overload def make_prompt( - content: str, return_type: type[R], *, ctx: Ctx | None = None + content: str, + return_type: type[R], + *, + ctx: Ctx | None = None, ) -> Prompt[..., R]: ...