diff --git a/rigging/__init__.py b/rigging/__init__.py
index 03ff1fa..7324336 100644
--- a/rigging/__init__.py
+++ b/rigging/__init__.py
@@ -14,6 +14,7 @@
MapCompletionCallback,
ThenCompletionCallback,
)
+from rigging.error import Stop
from rigging.generator import (
GeneratedMessage,
GeneratedText,
@@ -64,6 +65,7 @@
"PipelineStepContextManager",
"PipelineStepGenerator",
"Prompt",
+ "Stop",
"ThenChatCallback",
"ThenCompletionCallback",
"Tool",
diff --git a/rigging/chat.py b/rigging/chat.py
index 1ab091b..25008a6 100644
--- a/rigging/chat.py
+++ b/rigging/chat.py
@@ -720,7 +720,7 @@ def __init__(
self.tool_mode: ToolMode = "auto"
self.api_tool_choice: ApiToolChoice | None = None
self.inject_tool_prompt = True
- self.stop_on_tool_calls = True
+ self.add_tool_stop_token = True
self.then_callbacks: list[tuple[ThenChatCallback, int]] = []
self.map_callbacks: list[tuple[MapChatCallback, int]] = []
self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or []
@@ -880,6 +880,7 @@ def clone(
*,
only_messages: bool = False,
chat: Chat | None = None,
+ callbacks: bool | t.Sequence[MapChatCallback | ThenChatCallback] = True,
) -> "ChatPipeline":
"""
Creates a clone of the current `ChatPipeline` instance.
@@ -890,6 +891,8 @@ def clone(
including until callbacks, types, tools, metadata, etc.
chat: An optional chat object clone for use in the new pipeline, otherwise the current
internal chat object will be cloned.
+ callbacks: If True (default), all callbacks will be cloned. If False, no callbacks will be cloned.
+ Otherwise provide a sequence of callbacks which should be maintained in the new pipeline.
Returns:
The cloned ChatPipeline.
@@ -906,16 +909,20 @@ def clone(
new.tools = self.tools.copy()
new.tool_mode = self.tool_mode
new.metadata = deepcopy(self.metadata)
- new.map_callbacks = self.map_callbacks.copy()
new.on_failed = self.on_failed
new.errors_to_catch = self.errors_to_catch.copy()
new.errors_to_exclude = self.errors_to_exclude.copy()
new.caching = self.caching
+ new.watch_callbacks = self.watch_callbacks.copy()
+
# Check if any of our callbacks are bound methods to a ChatPipline.
# If so, we should rebind them to `self` to ensure they work correctly
# and aren't operating with old state.
+ if callbacks is False:
+ return new
+
new.then_callbacks = [
(callback, max_depth)
if not hasattr(callback, "__self__")
@@ -931,6 +938,18 @@ def clone(
for callback, max_depth in self.map_callbacks.copy()
]
+ if not isinstance(callbacks, bool):
+ new.then_callbacks = [
+ (callback, max_depth)
+ for callback, max_depth in self.then_callbacks
+ if callback in callbacks
+ ]
+ new.map_callbacks = [
+ (callback, max_depth)
+ for callback, max_depth in self.map_callbacks
+ if callback in callbacks
+ ]
+
return new
def meta(self, **kwargs: t.Any) -> "ChatPipeline":
@@ -1105,7 +1124,7 @@ def using(
mode: ToolMode | None = None,
choice: ApiToolChoice | None = None,
max_depth: int = DEFAULT_MAX_DEPTH,
- stop_on_tool_calls: bool | None = None,
+ add_stop_token: bool | None = None,
) -> "ChatPipeline":
"""
Adds a tool or a sequence of tools to participate in the generation process.
@@ -1119,7 +1138,8 @@ def using(
mode: The tool calling mode to use (e.g., "xml", "json-in-xml", "api").
choice: The API tool choice to use. This is only relevant when using the "api" tool mode.
max_depth: The maximum depth for recursive tool calls (this is shared between all tools).
- stop_on_tool_calls: When using natively parsed tools, whether to stop generation when a tool call block is observed.
+ add_stop_token: When using natively parsed tools ("xml", "json-in-xml"), use stop tokens to
+ immediately process a tool call when observed.
Returns:
The updated pipeline.
@@ -1172,8 +1192,8 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
if choice is not None:
self.api_tool_choice = choice
- if stop_on_tool_calls is not None:
- self.stop_on_tool_calls = stop_on_tool_calls
+ if add_stop_token is not None:
+ self.add_tool_stop_token = add_stop_token
return self
@@ -1237,7 +1257,7 @@ def until_parsed_as(
async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
if (
- self.stop_on_tool_calls
+ self.add_tool_stop_token
and self.tool_mode in ["xml", "json-in-xml"]
and chat.stop_reason == "stop"
):
@@ -1270,31 +1290,26 @@ async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
if not tool_calls:
return None
- next_pipeline = self.clone(chat=chat)
+ next_pipeline = self.clone(chat=chat, callbacks=[self._then_tools])
- should_continue = True
+ stop = False
for tool_call in tool_calls:
tool = next((t for t in self.tools if t.name == tool_call.name), None)
if tool is None:
raise UnknownToolError(tool_call.name)
- message, _should_continue = await tool.handle_tool_call(tool_call)
+ message, _stop = await tool.handle_tool_call(tool_call)
+ stop = _stop if not _stop else stop
next_pipeline.add(message)
- # If the tool returns none, we should resolve tool calls, but
- # not continue the pipeline.
-
- if not _should_continue:
- should_continue = _should_continue
-
# Need to prevent infinite loops and treat tool_choice like
# an ephemeral setting which resets after the first tool call.
if self.tool_mode == "api" and next_pipeline.params:
next_pipeline.params.tool_choice = None
- if not should_continue:
+ if stop:
# TODO(nick): Type hints here stop us from mixing step generators
# and basic chat returns.
return next_pipeline.chat # type: ignore [return-value]
@@ -1302,7 +1317,7 @@ async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
return next_pipeline.step()
async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
- next_pipeline = self.clone(chat=chat)
+ next_pipeline = self.clone(chat=chat, callbacks=[self._then_parse])
try:
chat.last.parse_many(*self.until_types)
@@ -1310,14 +1325,14 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
next_pipeline.add(
Message.from_model(
ValidationErrorModel(content=str(e)),
- suffix="Rewrite your entire message with all the required xml structure.",
+ suffix="Rewrite your entire message with all of the required xml elements.",
),
)
except Exception as e: # noqa: BLE001
next_pipeline.add(
Message.from_model(
SystemErrorModel(content=str(e)),
- suffix="Rewrite your entire message with all the required xml structure.",
+ suffix="Rewrite your entire message with all of the required xml elements.",
),
)
else: # parsed successfully
@@ -1336,7 +1351,7 @@ async def _pre_run(self) -> None:
self.chat.inject_tool_prompt(self.tools, self.tool_mode)
self.inject_native_tool_prompt = False
- if self.stop_on_tool_calls:
+ if self.add_tool_stop_token:
self.params = self.params = GenerateParams()
self.params.stop = self.params.stop or []
self.params.stop.append(f"{TOOL_CALLS_TAG}>")
diff --git a/rigging/error.py b/rigging/error.py
index cc0f5b5..1dc1acf 100644
--- a/rigging/error.py
+++ b/rigging/error.py
@@ -14,6 +14,38 @@
from rigging.message import Message
+# User Throwable Exceptions
+
+
+class Stop(Exception): # noqa: N818
+ """
+ Raise inside a pipeline to indicate a stopping condition.
+
+ Example:
+ ```
+ import rigging as rg
+
+ async def read_file(path: str) -> str:
+ "Read the contents of a file."
+
+ if no_more_files(path):
+ raise rg.Stop("There are no more files to read.")
+
+ ...
+
+ chat = await pipeline.using(read_file).run()
+ ```
+ """
+
+ def __init__(self, message: str):
+ super().__init__(message)
+ self.message = message
+ """The message associated with the stop."""
+
+
+# System Exceptions
+
+
class UnknownToolError(Exception):
"""
Raised when the an api tool call is made for an unknown tool.
diff --git a/rigging/prompt.py b/rigging/prompt.py
index 4ad7727..3a6431b 100644
--- a/rigging/prompt.py
+++ b/rigging/prompt.py
@@ -566,14 +566,14 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
next_pipeline.add(
Message.from_model(
ValidationErrorModel(content=str(e)),
- suffix="Rewrite your entire message with all the required xml structure.",
+ suffix="Rewrite your entire message with all of the required xml elements.",
),
)
except Exception as e: # noqa: BLE001
next_pipeline.add(
Message.from_model(
SystemErrorModel(content=str(e)),
- suffix="Rewrite your entire message with all the required xml structure.",
+ suffix="Rewrite your entire message with all of the required xml elements.",
),
)
else: # parsed successfully
@@ -1084,8 +1084,7 @@ def prompt(
generator_id: str | None = None,
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
system_prompt: str | None = None,
-) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]:
- ...
+) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]: ...
@t.overload
@@ -1098,8 +1097,7 @@ def prompt(
generator_id: str | None = None,
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
system_prompt: str | None = None,
-) -> Prompt[P, R]:
- ...
+) -> Prompt[P, R]: ...
@t.overload
@@ -1112,8 +1110,7 @@ def prompt(
generator_id: str | None = None,
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
system_prompt: str | None = None,
-) -> Prompt[P, R]:
- ...
+) -> Prompt[P, R]: ...
def prompt(
@@ -1214,8 +1211,12 @@ def make_prompt(
@t.overload
-def make_prompt(content: str, return_type: type[R], *, ctx: Ctx | None = None) -> Prompt[..., R]:
- ...
+def make_prompt(
+ content: str,
+ return_type: type[R],
+ *,
+ ctx: Ctx | None = None,
+) -> Prompt[..., R]: ...
@t.overload
@@ -1224,8 +1225,7 @@ def make_prompt(
return_type: None = None,
*,
ctx: Ctx | None = None,
-) -> Prompt[..., str]:
- ...
+) -> Prompt[..., str]: ...
def make_prompt(
diff --git a/rigging/tool/base.py b/rigging/tool/base.py
index 55e4844..c868d87 100644
--- a/rigging/tool/base.py
+++ b/rigging/tool/base.py
@@ -14,7 +14,7 @@
import typing_extensions as te
from pydantic import TypeAdapter
-from rigging.error import ToolDefinitionError
+from rigging.error import Stop, ToolDefinitionError
from rigging.model import Model, make_from_schema, make_from_signature
from rigging.tool.api import ApiFunctionDefinition, ApiToolCall, ApiToolDefinition
from rigging.tool.native import (
@@ -258,7 +258,7 @@ def json_definition(self) -> JsonInXmlToolDefinition:
parameters=json.dumps(self.parameters_schema),
)
- async def handle_tool_call( # noqa: PLR0912
+ async def handle_tool_call( # noqa: PLR0912, PLR0915
self,
tool_call: ApiToolCall | XmlToolCall | JsonInXmlToolCall,
) -> tuple["Message", bool]:
@@ -269,7 +269,8 @@ async def handle_tool_call( # noqa: PLR0912
tool_call: The tool call to handle.
Returns:
- The message to send back to the generator or `None` if iterative tool calling should not proceed any further.
+ A tuple containing the message to send back to the generator and a
+ boolean indicating whether tool calling should stop.
"""
from rigging.message import ContentText, ContentTypes, Message
@@ -330,10 +331,16 @@ async def handle_tool_call( # noqa: PLR0912
# Call the function
+ stop = False
+
try:
result: t.Any = self.fn(**kwargs) # type: ignore [call-arg]
if inspect.isawaitable(result):
result = await result
+ except Stop as e:
+ result = f"{e.message}"
+ span.set_attribute("stop", True)
+ stop = True
except Exception as e:
if self.catch is True or (
not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch))
@@ -350,11 +357,6 @@ async def handle_tool_call( # noqa: PLR0912
else Message("user")
)
- # If the tool returns nothing back to us, we'll assume that
- # they do not want to proceed with additional tool calling
-
- should_continue = result is not None
-
# If the tool gave us back anything that looks like a message, we'll
# just pass it along. Otherwise we need to box up the result.
@@ -395,7 +397,7 @@ async def handle_tool_call( # noqa: PLR0912
result=message.content_parts[0].text,
).to_pretty_xml()
- return message, should_continue
+ return message, stop
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
return self.fn(*args, **kwargs)
diff --git a/tests/test_tool.py b/tests/test_tool.py
index 7919893..2563016 100644
--- a/tests/test_tool.py
+++ b/tests/test_tool.py
@@ -223,9 +223,9 @@ async def test_handle_api_tool_call(self, sample_tool: Tool[..., t.Any]) -> None
),
)
- message, should_continue = await sample_tool.handle_tool_call(tool_call)
+ message, stop = await sample_tool.handle_tool_call(tool_call)
- assert should_continue is True
+ assert stop is False
assert message is not None
assert message.role == "tool"
assert message.tool_call_id == "call123"
@@ -245,9 +245,9 @@ async def test_handle_xml_tool_call(self, sample_tool: Tool[..., t.Any]) -> None
).strip(),
)
- message, should_continue = await sample_tool.handle_tool_call(tool_call)
+ message, stop = await sample_tool.handle_tool_call(tool_call)
- assert should_continue is True
+ assert stop is False
assert message is not None
assert message.role == "user"
assert message.content == '8'
@@ -260,9 +260,9 @@ async def test_handle_json_xml_tool_call(self, sample_tool: Tool[..., t.Any]) ->
parameters=json.dumps({"a": 4, "b": 4, "operation": "add"}),
)
- message, should_continue = await sample_tool.handle_tool_call(tool_call)
+ message, stop = await sample_tool.handle_tool_call(tool_call)
- assert should_continue is True
+ assert stop is False
assert message is not None
assert message.role == "user"
assert message.content == '8'
@@ -378,7 +378,7 @@ def faulty_function(x: int) -> int:
tool = Tool.from_callable(faulty_function, catch={ValueError})
- message, should_continue = await tool.handle_tool_call(tool_call)
+ message, stop = await tool.handle_tool_call(tool_call)
- assert should_continue is True
+ assert stop is False
assert "This is a test error" in message.content