Skip to content

Commit 42c2b12

Browse files
committed
Add truncation for messages and tools. Version to 3.0.0.
1 parent 016dbc7 commit 42c2b12

4 files changed

Lines changed: 57 additions & 7 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "rigging"
3-
version = "3.0.0-rc.4"
3+
version = "3.0.0"
44
description = "LLM Interaction Framework"
55
authors = ["Nick Landers <monoxgas@gmail.com>"]
66
license = "MIT"

rigging/message.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from rigging.model import Model, ModelT
3131
from rigging.parsing import try_parse_many
3232
from rigging.tool.api import ApiToolCall
33-
from rigging.util import AudioFormat, identify_audio_format, truncate_string
33+
from rigging.util import AudioFormat, identify_audio_format, shorten_string, truncate_string
3434

3535
Role = t.Literal["system", "user", "assistant", "tool"]
3636
"""The role of a message. Can be 'system', 'user', 'assistant', or 'tool'."""
@@ -111,7 +111,7 @@ class ImageUrl(BaseModel):
111111
"""Cache control entry for prompt caching."""
112112

113113
def __str__(self) -> str:
114-
return f"<ContentImageUrl url='{truncate_string(self.image_url.url, 50)}'>"
114+
return f"<ContentImageUrl url='{shorten_string(self.image_url.url, 50)}'>"
115115

116116
@classmethod
117117
def from_file(
@@ -241,7 +241,7 @@ def __str__(self) -> str:
241241
return (
242242
f"<ContentAudioInput format='{self.input_audio.format}' "
243243
f"transcript='{self.input_audio.transcript}' "
244-
f"data='{truncate_string(self.input_audio.data, 50)}'>"
244+
f"data='{shorten_string(self.input_audio.data, 50)}'>"
245245
)
246246

247247
@classmethod
@@ -655,6 +655,8 @@ def clone(self) -> "Message":
655655
self.role,
656656
copy.deepcopy(self.content_parts),
657657
parts=copy.deepcopy(self.parts),
658+
tool_calls=copy.deepcopy(self.tool_calls),
659+
tool_call_id=self.tool_call_id,
658660
)
659661

660662
def cache(self, cache_control: dict[str, str] | bool = True) -> "Message": # noqa: FBT002
@@ -705,6 +707,20 @@ def apply(self, **kwargs: str) -> "Message":
705707
new.content = template.safe_substitute(**kwargs)
706708
return new
707709

710+
def truncate(self, max_length: int, suffix: str = "\n[truncated]") -> "Message":
711+
"""
712+
Truncates the message content to a maximum length.
713+
714+
Args:
715+
max_length: The maximum length of the message content.
716+
717+
Returns:
718+
The truncated message.
719+
"""
720+
new = self.clone()
721+
new.content = truncate_string(new.content, max_length, suf=suffix)
722+
return new
723+
708724
def strip(
709725
self,
710726
model_type: type[Model],

rigging/tool/base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ class Tool(t.Generic[P, R]):
7373
- `True`: Catch all exceptions.
7474
- `list[type[Exception]]`: Catch only the specified exceptions.
7575
"""
76+
truncate: int | None = None
77+
"""If set, the maximum number of characters to truncate any tool output to."""
7678

7779
_signature: inspect.Signature | None = field(default=None, init=False, repr=False)
7880
_type_adapter: TypeAdapter[t.Any] | None = field(
@@ -100,6 +102,7 @@ def from_callable(
100102
name: str | None = None,
101103
description: str | None = None,
102104
catch: bool | t.Iterable[type[Exception]] = False,
105+
truncate: int | None = None,
103106
) -> te.Self:
104107
from rigging.prompt import Prompt
105108

@@ -194,6 +197,7 @@ def empty_func(*args, **kwargs): # type: ignore [no-untyped-def] # noqa: ARG001
194197
parameters_schema=schema,
195198
fn=fn,
196199
catch=catch if isinstance(catch, bool) else set(catch),
200+
truncate=truncate,
197201
)
198202

199203
self._signature = signature
@@ -367,6 +371,9 @@ async def handle_tool_call(
367371
else:
368372
message.content_parts = [ContentText(text=str(result))]
369373

374+
if self.truncate:
375+
message = message.truncate(self.truncate)
376+
370377
# If this is a native tool call, we should wrap up our
371378
# result in a NativeToolResult object to provide clarity to the
372379
# generator. Otherwise we can rely on the `tool` role and associated
@@ -402,6 +409,7 @@ def tool(
402409
name: str | None = None,
403410
description: str | None = None,
404411
catch: bool | t.Iterable[type[Exception]] = False,
412+
truncate: int | None = None,
405413
) -> t.Callable[[t.Callable[P, R]], Tool[P, R]]:
406414
...
407415

@@ -421,6 +429,7 @@ def tool(
421429
name: str | None = None,
422430
description: str | None = None,
423431
catch: bool | t.Iterable[type[Exception]] = False,
432+
truncate: int | None = None,
424433
) -> t.Callable[[t.Callable[P, R]], Tool[P, R]] | Tool[P, R]:
425434
"""
426435
Decorator for creating a Tool, useful for overriding a name or description.
@@ -433,6 +442,7 @@ def tool(
433442
- `False`: Do not catch exceptions.
434443
- `True`: Catch all exceptions.
435444
- `list[type[Exception]]`: Catch only the specified exceptions.
445+
truncate: If set, the maximum number of characters to truncate any tool output to.
436446
437447
Returns:
438448
The decorated Tool object.
@@ -453,7 +463,13 @@ def make_tool(func: t.Callable[..., t.Any]) -> Tool[P, R]:
453463
stacklevel=3,
454464
)
455465

456-
return Tool.from_callable(func, name=name, description=description, catch=catch)
466+
return Tool.from_callable(
467+
func,
468+
name=name,
469+
description=description,
470+
catch=catch,
471+
truncate=truncate,
472+
)
457473

458474
if func is not None:
459475
return make_tool(func)
@@ -496,6 +512,7 @@ def tool_method(
496512
name: str | None = None,
497513
description: str | None = None,
498514
catch: bool | t.Iterable[type[Exception]] = False,
515+
truncate: int | None = None,
499516
) -> t.Callable[[t.Callable[t.Concatenate[t.Any, P], R]], ToolMethod[P, R]]:
500517
...
501518

@@ -515,6 +532,7 @@ def tool_method(
515532
name: str | None = None,
516533
description: str | None = None,
517534
catch: bool | t.Iterable[type[Exception]] = False,
535+
truncate: int | None = None,
518536
) -> t.Callable[[t.Callable[t.Concatenate[t.Any, P], R]], ToolMethod[P, R]] | ToolMethod[P, R]:
519537
"""
520538
Decorator for creating a Tool from a class method.
@@ -570,6 +588,7 @@ def wrapper(self: t.Any, *args: P.args, **kwargs: P.kwargs) -> R:
570588
name=name,
571589
description=description,
572590
catch=catch,
591+
truncate=truncate,
573592
)
574593

575594
if func is not None:

rigging/util.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,31 @@ def get_qualified_name(obj: t.Callable[..., t.Any]) -> str:
150150
# Formatting
151151

152152

153-
def truncate_string(content: str, max_length: int, *, sep: str = "...") -> str:
154-
"""Return a string at most max_length characters long."""
153+
def shorten_string(content: str, max_length: int, *, sep: str = "...") -> str:
154+
"""Return a string at most max_length characters long by removing the middle of the string."""
155155
if len(content) <= max_length:
156156
return content
157157

158158
remaining = max_length - len(sep)
159+
if remaining <= 0:
160+
return sep
161+
159162
middle = remaining // 2
160163
return content[:middle] + sep + content[-middle:]
161164

162165

166+
def truncate_string(content: str, max_length: int, *, suf: str = "...") -> str:
167+
"""Return a string at most max_length characters long by removing the end of the string."""
168+
if len(content) <= max_length:
169+
return content
170+
171+
remaining = max_length - len(suf)
172+
if remaining <= 0:
173+
return suf
174+
175+
return content[:remaining] + suf
176+
177+
163178
# List utilities
164179

165180

0 commit comments

Comments
 (0)