Skip to content

Commit ed64fc3

Browse files
authored
chain_limit/before_call/after_call for conversations
* chain_limit/before_call/after_call for conversations, closes #1088 * Docs for before_call/after_call including for model.conversation
1 parent b5d1c5e commit ed64fc3

4 files changed

Lines changed: 155 additions & 51 deletions

File tree

docs/python-api.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,46 @@ for response in chain.responses():
148148
print(chunk, end="", flush=True)
149149
```
150150

151+
(python-api-tools-debug-hooks)=
152+
153+
#### Tool debugging hooks
154+
155+
Pass a function to the `before_call=` parameter of `model.chain()` to have that called before every tool call is executed. You can raise `llm.CancelToolCall()` to cancel that tool call.
156+
157+
The method signature is `def before_call(tool: llm.Tool, tool_call: llm.ToolCall)`. Here's an example:
158+
```python
159+
import llm
160+
161+
def upper(text: str) -> str:
162+
"Convert text to uppercase."
163+
return text.upper()
164+
165+
def before_call(tool: llm.Tool, tool_call: llm.ToolCall):
166+
print(f"About to call tool {tool.name} with arguments {tool_call.arguments}")
167+
if tool.name == "upper" and "bad" in repr(tool_call.arguments):
168+
raise llm.CancelToolCall("Not allowed to call upper on text containing 'bad'")
169+
170+
model = llm.get_model("gpt-4.1-mini")
171+
response = model.chain(
172+
"Convert panda to upper and badger to upper",
173+
tools=[upper],
174+
before_call=before_call,
175+
)
176+
print(response.text())
177+
```
178+
The `after_call=` parameter can be used to run a logging function after each tool call has been executed. The method signature is `def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult)`. This continues the previous example:
179+
```python
180+
def after_call(tool: llm.Tool, tool_call: llm.ToolCall, tool_result: llm.ToolResult):
181+
print(f"Tool {tool.name} called with arguments {tool_call.arguments} returned {tool_result.output}")
182+
183+
response = model.chain(
184+
"Convert panda to upper and badger to upper",
185+
tools=[upper],
186+
after_call=after_call,
187+
)
188+
print(response.text())
189+
```
190+
151191
(python-api-tools-attachments)=
152192

153193
#### Tools can return attachments
@@ -575,6 +615,8 @@ print(conversation.chain(
575615
"Same with pangolin"
576616
).text())
577617
```
618+
The `before_call=` and `after_call=` parameters {ref}`described above <python-api-tools-debug-hooks>` can be passed directly to the `model.conversation()` method to set those options for all chained prompts in that conversation.
619+
578620

579621
(python-api-listing-models)=
580622

llm/cli.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,11 @@ def chat(
10801080
# Ensure it can see the API key
10811081
conversation.model = model
10821082

1083+
if tools_debug:
1084+
conversation.after_call = _debug_tool_call
1085+
if tools_approve:
1086+
conversation.before_call = _approve_tool_call
1087+
10831088
# Validate options
10841089
validated_options = get_model_options(model.model_id)
10851090
if options:
@@ -1100,10 +1105,6 @@ def chat(
11001105

11011106
if tool_functions:
11021107
kwargs["chain_limit"] = chain_limit
1103-
if tools_debug:
1104-
kwargs["after_call"] = _debug_tool_call
1105-
if tools_approve:
1106-
kwargs["before_call"] = _approve_tool_call
11071108
kwargs["tools"] = tool_functions
11081109

11091110
should_stream = model.can_stream and not no_stream

llm/models.py

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,6 @@ def introspect_methods(cls):
259259
return methods
260260

261261

262-
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
263-
264-
265262
@dataclass
266263
class ToolCall:
267264
name: str
@@ -286,6 +283,13 @@ class ToolOutput:
286283
attachments: List[Attachment] = field(default_factory=list)
287284

288285

286+
ToolDef = Union[Tool, Toolbox, Callable[..., Any]]
287+
BeforeCallSync = Callable[[Tool, ToolCall], None]
288+
AfterCallSync = Callable[[Tool, ToolCall, ToolResult], None]
289+
BeforeCallAsync = Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
290+
AfterCallAsync = Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
291+
292+
289293
class CancelToolCall(Exception):
290294
pass
291295

@@ -368,6 +372,7 @@ class _BaseConversation:
368372
name: Optional[str] = None
369373
responses: List["_BaseResponse"] = field(default_factory=list)
370374
tools: Optional[List[Tool]] = None
375+
chain_limit: Optional[int] = None
371376

372377
@classmethod
373378
@abstractmethod
@@ -377,6 +382,9 @@ def from_row(cls, row: Any) -> "_BaseConversation":
377382

378383
@dataclass
379384
class Conversation(_BaseConversation):
385+
before_call: Optional[BeforeCallSync] = None
386+
after_call: Optional[AfterCallSync] = None
387+
380388
def prompt(
381389
self,
382390
prompt: Optional[str] = None,
@@ -424,8 +432,8 @@ def chain(
424432
tools: Optional[List[Tool]] = None,
425433
tool_results: Optional[List[ToolResult]] = None,
426434
chain_limit: Optional[int] = None,
427-
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
428-
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
435+
before_call: Optional[BeforeCallSync] = None,
436+
after_call: Optional[AfterCallSync] = None,
429437
key: Optional[str] = None,
430438
options: Optional[dict] = None,
431439
) -> "ChainResponse":
@@ -447,9 +455,9 @@ def chain(
447455
stream=stream,
448456
conversation=self,
449457
key=key,
450-
before_call=before_call,
451-
after_call=after_call,
452-
chain_limit=chain_limit,
458+
before_call=before_call or self.before_call,
459+
after_call=after_call or self.after_call,
460+
chain_limit=chain_limit if chain_limit is not None else self.chain_limit,
453461
)
454462

455463
@classmethod
@@ -470,6 +478,9 @@ def __repr__(self):
470478

471479
@dataclass
472480
class AsyncConversation(_BaseConversation):
481+
before_call: Optional[BeforeCallAsync] = None
482+
after_call: Optional[AfterCallAsync] = None
483+
473484
def chain(
474485
self,
475486
prompt: Optional[str] = None,
@@ -483,12 +494,8 @@ def chain(
483494
tools: Optional[List[Tool]] = None,
484495
tool_results: Optional[List[ToolResult]] = None,
485496
chain_limit: Optional[int] = None,
486-
before_call: Optional[
487-
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
488-
] = None,
489-
after_call: Optional[
490-
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
491-
] = None,
497+
before_call: Optional[BeforeCallAsync] = None,
498+
after_call: Optional[AfterCallAsync] = None,
492499
key: Optional[str] = None,
493500
options: Optional[dict] = None,
494501
) -> "AsyncChainResponse":
@@ -510,9 +517,9 @@ def chain(
510517
stream=stream,
511518
conversation=self,
512519
key=key,
513-
before_call=before_call,
514-
after_call=after_call,
515-
chain_limit=chain_limit,
520+
before_call=before_call or self.before_call,
521+
after_call=after_call or self.after_call,
522+
chain_limit=chain_limit if chain_limit is not None else self.chain_limit,
516523
)
517524

518525
def prompt(
@@ -975,12 +982,8 @@ def text_or_raise(self) -> str:
975982
def execute_tool_calls(
976983
self,
977984
*,
978-
before_call: Optional[
979-
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
980-
] = None,
981-
after_call: Optional[
982-
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
983-
] = None,
985+
before_call: Optional[BeforeCallSync] = None,
986+
after_call: Optional[AfterCallSync] = None,
984987
) -> List[ToolResult]:
985988
tool_results = []
986989
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
@@ -1147,12 +1150,8 @@ async def _on_done(self):
11471150
async def execute_tool_calls(
11481151
self,
11491152
*,
1150-
before_call: Optional[
1151-
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
1152-
] = None,
1153-
after_call: Optional[
1154-
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
1155-
] = None,
1153+
before_call: Optional[BeforeCallAsync] = None,
1154+
after_call: Optional[AfterCallAsync] = None,
11561155
) -> List[ToolResult]:
11571156
tool_calls_list = await self.tool_calls()
11581157
tools_by_name = {tool.name: tool for tool in self.prompt.tools}
@@ -1437,12 +1436,8 @@ def __init__(
14371436
conversation: _BaseConversation,
14381437
key: Optional[str] = None,
14391438
chain_limit: Optional[int] = 10,
1440-
before_call: Optional[
1441-
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
1442-
] = None,
1443-
after_call: Optional[
1444-
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
1445-
] = None,
1439+
before_call: Optional[Union[BeforeCallSync, BeforeCallAsync]] = None,
1440+
after_call: Optional[Union[AfterCallSync, AfterCallAsync]] = None,
14461441
):
14471442
self.prompt = prompt
14481443
self.model = model
@@ -1467,6 +1462,8 @@ def log_to_db(self, db):
14671462

14681463
class ChainResponse(_BaseChainResponse):
14691464
_responses: List["Response"]
1465+
before_call: Optional[BeforeCallSync] = None
1466+
after_call: Optional[AfterCallSync] = None
14701467

14711468
def responses(self) -> Iterator[Response]:
14721469
prompt = self.prompt
@@ -1521,6 +1518,8 @@ def text(self) -> str:
15211518

15221519
class AsyncChainResponse(_BaseChainResponse):
15231520
_responses: List["AsyncResponse"]
1521+
before_call: Optional[BeforeCallAsync] = None
1522+
after_call: Optional[AfterCallAsync] = None
15241523

15251524
async def responses(self) -> AsyncIterator[AsyncResponse]:
15261525
prompt = self.prompt
@@ -1656,8 +1655,20 @@ def __repr__(self) -> str:
16561655

16571656

16581657
class _Model(_BaseModel):
1659-
def conversation(self, tools: Optional[List[Tool]] = None) -> Conversation:
1660-
return Conversation(model=self, tools=tools)
1658+
def conversation(
1659+
self,
1660+
tools: Optional[List[Tool]] = None,
1661+
before_call: Optional[BeforeCallSync] = None,
1662+
after_call: Optional[AfterCallSync] = None,
1663+
chain_limit: Optional[int] = None,
1664+
) -> Conversation:
1665+
return Conversation(
1666+
model=self,
1667+
tools=tools,
1668+
before_call=before_call,
1669+
after_call=after_call,
1670+
chain_limit=chain_limit,
1671+
)
16611672

16621673
def prompt(
16631674
self,
@@ -1705,8 +1716,8 @@ def chain(
17051716
schema: Optional[Union[dict, type[BaseModel]]] = None,
17061717
tools: Optional[List[Tool]] = None,
17071718
tool_results: Optional[List[ToolResult]] = None,
1708-
before_call: Optional[Callable[[Tool, ToolCall], None]] = None,
1709-
after_call: Optional[Callable[[Tool, ToolCall, ToolResult], None]] = None,
1719+
before_call: Optional[BeforeCallSync] = None,
1720+
after_call: Optional[AfterCallSync] = None,
17101721
key: Optional[str] = None,
17111722
options: Optional[dict] = None,
17121723
) -> ChainResponse:
@@ -1753,8 +1764,20 @@ def execute(
17531764

17541765

17551766
class _AsyncModel(_BaseModel):
1756-
def conversation(self, tools: Optional[List[Tool]] = None) -> AsyncConversation:
1757-
return AsyncConversation(model=self, tools=tools)
1767+
def conversation(
1768+
self,
1769+
tools: Optional[List[Tool]] = None,
1770+
before_call: Optional[BeforeCallAsync] = None,
1771+
after_call: Optional[AfterCallAsync] = None,
1772+
chain_limit: Optional[int] = None,
1773+
) -> AsyncConversation:
1774+
return AsyncConversation(
1775+
model=self,
1776+
tools=tools,
1777+
before_call=before_call,
1778+
after_call=after_call,
1779+
chain_limit=chain_limit,
1780+
)
17581781

17591782
def prompt(
17601783
self,
@@ -1802,12 +1825,8 @@ def chain(
18021825
schema: Optional[Union[dict, type[BaseModel]]] = None,
18031826
tools: Optional[List[Tool]] = None,
18041827
tool_results: Optional[List[ToolResult]] = None,
1805-
before_call: Optional[
1806-
Callable[[Tool, ToolCall], Union[None, Awaitable[None]]]
1807-
] = None,
1808-
after_call: Optional[
1809-
Callable[[Tool, ToolCall, ToolResult], Union[None, Awaitable[None]]]
1810-
] = None,
1828+
before_call: Optional[BeforeCallAsync] = None,
1829+
after_call: Optional[AfterCallAsync] = None,
18111830
key: Optional[str] = None,
18121831
options: Optional[dict] = None,
18131832
) -> AsyncChainResponse:

tests/test_tools.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,45 @@ async def return_attachment() -> llm.Attachment:
343343
output = await chain_response.text()
344344
assert '"type": "image/png"' in output
345345
assert '"output": "Output"' in output
346+
347+
348+
def test_tool_conversation_settings():
349+
model = llm.get_model("echo")
350+
before_collected = []
351+
after_collected = []
352+
353+
def before(*args):
354+
before_collected.append(args)
355+
356+
def after(*args):
357+
after_collected.append(args)
358+
359+
conversation = model.conversation(
360+
tools=[llm_time], before_call=before, after_call=after
361+
)
362+
# Run two things
363+
conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
364+
conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
365+
assert len(before_collected) == 2
366+
assert len(after_collected) == 2
367+
368+
369+
@pytest.mark.asyncio
370+
async def test_tool_conversation_settings_async():
371+
model = llm.get_async_model("echo")
372+
before_collected = []
373+
after_collected = []
374+
375+
async def before(*args):
376+
before_collected.append(args)
377+
378+
async def after(*args):
379+
after_collected.append(args)
380+
381+
conversation = model.conversation(
382+
tools=[llm_time], before_call=before, after_call=after
383+
)
384+
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
385+
await conversation.chain(json.dumps({"tool_calls": [{"name": "llm_time"}]})).text()
386+
assert len(before_collected) == 2
387+
assert len(after_collected) == 2

0 commit comments

Comments
 (0)