Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rigging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MapCompletionCallback,
ThenCompletionCallback,
)
from rigging.error import Stop
from rigging.generator import (
GeneratedMessage,
GeneratedText,
Expand Down Expand Up @@ -64,6 +65,7 @@
"PipelineStepContextManager",
"PipelineStepGenerator",
"Prompt",
"Stop",
"ThenChatCallback",
"ThenCompletionCallback",
"Tool",
Expand Down
57 changes: 36 additions & 21 deletions rigging/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def __init__(
self.tool_mode: ToolMode = "auto"
self.api_tool_choice: ApiToolChoice | None = None
self.inject_tool_prompt = True
self.stop_on_tool_calls = True
self.add_tool_stop_token = True
self.then_callbacks: list[tuple[ThenChatCallback, int]] = []
self.map_callbacks: list[tuple[MapChatCallback, int]] = []
self.watch_callbacks: list[WatchChatCallback] = watch_callbacks or []
Expand Down Expand Up @@ -880,6 +880,7 @@ def clone(
*,
only_messages: bool = False,
chat: Chat | None = None,
callbacks: bool | t.Sequence[MapChatCallback | ThenChatCallback] = True,
) -> "ChatPipeline":
"""
Creates a clone of the current `ChatPipeline` instance.
Expand All @@ -890,6 +891,8 @@ def clone(
including until callbacks, types, tools, metadata, etc.
chat: An optional chat object clone for use in the new pipeline, otherwise the current
internal chat object will be cloned.
callbacks: If True (default), all callbacks will be cloned. If False, no callbacks will be cloned.
Otherwise provide a sequence of callbacks which should be maintained in the new pipeline.

Returns:
The cloned ChatPipeline.
Expand All @@ -906,16 +909,20 @@ def clone(
new.tools = self.tools.copy()
new.tool_mode = self.tool_mode
new.metadata = deepcopy(self.metadata)
new.map_callbacks = self.map_callbacks.copy()
new.on_failed = self.on_failed
new.errors_to_catch = self.errors_to_catch.copy()
new.errors_to_exclude = self.errors_to_exclude.copy()
new.caching = self.caching

new.watch_callbacks = self.watch_callbacks.copy()

# Check if any of our callbacks are bound methods to a ChatPipline.
# If so, we should rebind them to `self` to ensure they work correctly
# and aren't operating with old state.

if callbacks is False:
return new

new.then_callbacks = [
(callback, max_depth)
if not hasattr(callback, "__self__")
Expand All @@ -931,6 +938,18 @@ def clone(
for callback, max_depth in self.map_callbacks.copy()
]

if not isinstance(callbacks, bool):
new.then_callbacks = [
(callback, max_depth)
for callback, max_depth in self.then_callbacks
if callback in callbacks
]
new.map_callbacks = [
(callback, max_depth)
for callback, max_depth in self.map_callbacks
if callback in callbacks
]

return new

def meta(self, **kwargs: t.Any) -> "ChatPipeline":
Expand Down Expand Up @@ -1105,7 +1124,7 @@ def using(
mode: ToolMode | None = None,
choice: ApiToolChoice | None = None,
max_depth: int = DEFAULT_MAX_DEPTH,
stop_on_tool_calls: bool | None = None,
add_stop_token: bool | None = None,
) -> "ChatPipeline":
"""
Adds a tool or a sequence of tools to participate in the generation process.
Expand All @@ -1119,7 +1138,8 @@ def using(
mode: The tool calling mode to use (e.g., "xml", "json-in-xml", "api").
choice: The API tool choice to use. This is only relevant when using the "api" tool mode.
max_depth: The maximum depth for recursive tool calls (this is shared between all tools).
stop_on_tool_calls: When using natively parsed tools, whether to stop generation when a tool call block is observed.
add_stop_token: When using natively parsed tools ("xml", "json-in-xml"), use stop tokens to
immediately process a tool call when observed.

Returns:
The updated pipeline.
Expand Down Expand Up @@ -1172,8 +1192,8 @@ async def get_weather(city: Annotated[str, "The city name to get weather for"])
if choice is not None:
self.api_tool_choice = choice

if stop_on_tool_calls is not None:
self.stop_on_tool_calls = stop_on_tool_calls
if add_stop_token is not None:
self.add_tool_stop_token = add_stop_token

return self

Expand Down Expand Up @@ -1237,7 +1257,7 @@ def until_parsed_as(

async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
if (
self.stop_on_tool_calls
self.add_tool_stop_token
and self.tool_mode in ["xml", "json-in-xml"]
and chat.stop_reason == "stop"
):
Expand Down Expand Up @@ -1270,54 +1290,49 @@ async def _then_tools(self, chat: Chat) -> PipelineStepContextManager | None:
if not tool_calls:
return None

next_pipeline = self.clone(chat=chat)
next_pipeline = self.clone(chat=chat, callbacks=[self._then_tools])

should_continue = True
stop = False

for tool_call in tool_calls:
tool = next((t for t in self.tools if t.name == tool_call.name), None)
if tool is None:
raise UnknownToolError(tool_call.name)

message, _should_continue = await tool.handle_tool_call(tool_call)
message, _stop = await tool.handle_tool_call(tool_call)
stop = _stop if not _stop else stop
next_pipeline.add(message)

# If the tool returns none, we should resolve tool calls, but
# not continue the pipeline.

if not _should_continue:
should_continue = _should_continue

# Need to prevent infinite loops and treat tool_choice like
# an ephemeral setting which resets after the first tool call.

if self.tool_mode == "api" and next_pipeline.params:
next_pipeline.params.tool_choice = None

if not should_continue:
if stop:
# TODO(nick): Type hints here stop us from mixing step generators
# and basic chat returns.
return next_pipeline.chat # type: ignore [return-value]

return next_pipeline.step()

async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
next_pipeline = self.clone(chat=chat)
next_pipeline = self.clone(chat=chat, callbacks=[self._then_parse])

try:
chat.last.parse_many(*self.until_types)
except ValidationError as e:
next_pipeline.add(
Message.from_model(
ValidationErrorModel(content=str(e)),
suffix="Rewrite your entire message with all the required xml structure.",
suffix="Rewrite your entire message with all of the required xml elements.",
),
)
except Exception as e: # noqa: BLE001
next_pipeline.add(
Message.from_model(
SystemErrorModel(content=str(e)),
suffix="Rewrite your entire message with all the required xml structure.",
suffix="Rewrite your entire message with all of the required xml elements.",
),
)
else: # parsed successfully
Expand All @@ -1336,7 +1351,7 @@ async def _pre_run(self) -> None:
self.chat.inject_tool_prompt(self.tools, self.tool_mode)
self.inject_native_tool_prompt = False

if self.stop_on_tool_calls:
if self.add_tool_stop_token:
self.params = self.params = GenerateParams()
self.params.stop = self.params.stop or []
self.params.stop.append(f"</{TOOL_CALLS_TAG}>")
Expand Down
32 changes: 32 additions & 0 deletions rigging/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,38 @@
from rigging.message import Message


# User Throwable Exceptions


class Stop(Exception): # noqa: N818
"""
Raise inside a pipeline to indicate a stopping condition.

Example:
```
import rigging as rg

async def read_file(path: str) -> str:
"Read the contents of a file."

if no_more_files(path):
raise rg.Stop("There are no more files to read.")

...

chat = await pipeline.using(read_file).run()
```
"""

def __init__(self, message: str):
super().__init__(message)
self.message = message
"""The message associated with the stop."""


# System Exceptions


class UnknownToolError(Exception):
"""
Raised when the an api tool call is made for an unknown tool.
Expand Down
24 changes: 12 additions & 12 deletions rigging/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,14 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
next_pipeline.add(
Message.from_model(
ValidationErrorModel(content=str(e)),
suffix="Rewrite your entire message with all the required xml structure.",
suffix="Rewrite your entire message with all of the required xml elements.",
),
)
except Exception as e: # noqa: BLE001
next_pipeline.add(
Message.from_model(
SystemErrorModel(content=str(e)),
suffix="Rewrite your entire message with all the required xml structure.",
suffix="Rewrite your entire message with all of the required xml elements.",
),
)
else: # parsed successfully
Expand Down Expand Up @@ -1084,8 +1084,7 @@ def prompt(
generator_id: str | None = None,
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
system_prompt: str | None = None,
) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]:
...
) -> t.Callable[[t.Callable[P, t.Coroutine[t.Any, t.Any, R]] | t.Callable[P, R]], Prompt[P, R]]: ...


@t.overload
Expand All @@ -1098,8 +1097,7 @@ def prompt(
generator_id: str | None = None,
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
system_prompt: str | None = None,
) -> Prompt[P, R]:
...
) -> Prompt[P, R]: ...


@t.overload
Expand All @@ -1112,8 +1110,7 @@ def prompt(
generator_id: str | None = None,
tools: list[Tool[..., t.Any] | t.Callable[..., t.Any]] | None = None,
system_prompt: str | None = None,
) -> Prompt[P, R]:
...
) -> Prompt[P, R]: ...


def prompt(
Expand Down Expand Up @@ -1214,8 +1211,12 @@ def make_prompt(


@t.overload
def make_prompt(content: str, return_type: type[R], *, ctx: Ctx | None = None) -> Prompt[..., R]:
...
def make_prompt(
content: str,
return_type: type[R],
*,
ctx: Ctx | None = None,
) -> Prompt[..., R]: ...


@t.overload
Expand All @@ -1224,8 +1225,7 @@ def make_prompt(
return_type: None = None,
*,
ctx: Ctx | None = None,
) -> Prompt[..., str]:
...
) -> Prompt[..., str]: ...


def make_prompt(
Expand Down
20 changes: 11 additions & 9 deletions rigging/tool/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import typing_extensions as te
from pydantic import TypeAdapter

from rigging.error import ToolDefinitionError
from rigging.error import Stop, ToolDefinitionError
from rigging.model import Model, make_from_schema, make_from_signature
from rigging.tool.api import ApiFunctionDefinition, ApiToolCall, ApiToolDefinition
from rigging.tool.native import (
Expand Down Expand Up @@ -258,7 +258,7 @@ def json_definition(self) -> JsonInXmlToolDefinition:
parameters=json.dumps(self.parameters_schema),
)

async def handle_tool_call( # noqa: PLR0912
async def handle_tool_call( # noqa: PLR0912, PLR0915
self,
tool_call: ApiToolCall | XmlToolCall | JsonInXmlToolCall,
) -> tuple["Message", bool]:
Expand All @@ -269,7 +269,8 @@ async def handle_tool_call( # noqa: PLR0912
tool_call: The tool call to handle.

Returns:
The message to send back to the generator or `None` if iterative tool calling should not proceed any further.
A tuple containing the message to send back to the generator and a
boolean indicating whether tool calling should stop.
"""

from rigging.message import ContentText, ContentTypes, Message
Expand Down Expand Up @@ -330,10 +331,16 @@ async def handle_tool_call( # noqa: PLR0912

# Call the function

stop = False

try:
result: t.Any = self.fn(**kwargs) # type: ignore [call-arg]
if inspect.isawaitable(result):
result = await result
except Stop as e:
result = f"<rg:stop>{e.message}</rg:stop>"
span.set_attribute("stop", True)
stop = True
except Exception as e:
if self.catch is True or (
not isinstance(self.catch, bool) and isinstance(e, tuple(self.catch))
Expand All @@ -350,11 +357,6 @@ async def handle_tool_call( # noqa: PLR0912
else Message("user")
)

# If the tool returns nothing back to us, we'll assume that
# they do not want to proceed with additional tool calling

should_continue = result is not None

# If the tool gave us back anything that looks like a message, we'll
# just pass it along. Otherwise we need to box up the result.

Expand Down Expand Up @@ -395,7 +397,7 @@ async def handle_tool_call( # noqa: PLR0912
result=message.content_parts[0].text,
).to_pretty_xml()

return message, should_continue
return message, stop

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
return self.fn(*args, **kwargs)
Expand Down
Loading
Loading