Skip to content

Commit 8ff5d8e

Browse files
authored
chore(py): Add Tool type, update generate call signature, remove dynamic-tools (#5063)
1 parent 0aa818e commit 8ff5d8e

12 files changed

Lines changed: 189 additions & 298 deletions

File tree

py/packages/genkit/src/genkit/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
PromptGenerateOptions,
2424
ResumeOptions,
2525
)
26-
from genkit._ai._tools import ToolInterruptError, ToolRunContext, tool_response
26+
from genkit._ai._tools import Tool, ToolInterruptError, ToolRunContext, tool_response
2727
from genkit._core._action import Action, StreamResponse
2828
from genkit._core._error import GenkitError, PublicError
2929
from genkit._core._model import Document
@@ -93,7 +93,9 @@
9393
# Errors
9494
'GenkitError',
9595
'PublicError',
96+
'Tool',
9697
'ToolInterruptError',
98+
'tool_response',
9799
# Content types
98100
'Constrained',
99101
'CustomPart',
@@ -128,7 +130,6 @@
128130
'PromptGenerateOptions',
129131
'ResumeOptions',
130132
'ToolRunContext',
131-
'tool_response',
132133
'ModelRequest',
133134
'ModelResponse',
134135
'ModelResponseChunk',

py/packages/genkit/src/genkit/_ai/_aio.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
import socket
2626
import threading
2727
import uuid
28-
from collections.abc import Awaitable, Callable, Coroutine
28+
from collections.abc import Awaitable, Callable, Coroutine, Sequence
2929
from pathlib import Path
30-
from typing import Any, ParamSpec, TypeVar, cast, overload
30+
from typing import Any, TypeVar, cast, overload
3131

3232
import anyio
3333
import uvicorn
@@ -71,7 +71,7 @@
7171
ResourceOptions,
7272
define_resource,
7373
)
74-
from genkit._ai._tools import define_tool
74+
from genkit._ai._tools import Tool, define_tool
7575
from genkit._core._action import Action, ActionKind, ActionRunContext
7676
from genkit._core._background import (
7777
BackgroundAction,
@@ -118,7 +118,7 @@
118118
InputT = TypeVar('InputT')
119119
OutputT = TypeVar('OutputT')
120120
ChunkT = TypeVar('ChunkT')
121-
P = ParamSpec('P')
121+
122122
R = TypeVar('R')
123123
T = TypeVar('T')
124124

@@ -260,12 +260,10 @@ def define_dynamic_action_provider(
260260
metadata=metadata,
261261
)
262262

263-
def tool(
264-
self, name: str | None = None, description: str | None = None
265-
) -> Callable[[Callable[P, T]], Callable[P, T]]:
263+
def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable[..., Any]], Tool]:
266264
"""Decorator to register a function as a tool."""
267265

268-
def wrapper(func: Callable[P, T]) -> Callable[P, T]:
266+
def wrapper(func: Callable[..., Any]) -> Tool:
269267
return define_tool(self.registry, func, name, description)
270268

271269
return wrapper
@@ -393,7 +391,7 @@ def define_prompt(
393391
max_turns: int | None = None,
394392
return_tool_requests: bool | None = None,
395393
metadata: dict[str, object] | None = None,
396-
tools: list[str] | None = None,
394+
tools: Sequence[str | Tool] | None = None,
397395
tool_choice: ToolChoice | None = None,
398396
use: list[ModelMiddleware] | None = None,
399397
docs: list[Document] | None = None,
@@ -421,7 +419,7 @@ def define_prompt(
421419
max_turns: int | None = None,
422420
return_tool_requests: bool | None = None,
423421
metadata: dict[str, object] | None = None,
424-
tools: list[str] | None = None,
422+
tools: Sequence[str | Tool] | None = None,
425423
tool_choice: ToolChoice | None = None,
426424
use: list[ModelMiddleware] | None = None,
427425
docs: list[Document] | None = None,
@@ -449,7 +447,7 @@ def define_prompt(
449447
max_turns: int | None = None,
450448
return_tool_requests: bool | None = None,
451449
metadata: dict[str, object] | None = None,
452-
tools: list[str] | None = None,
450+
tools: Sequence[str | Tool] | None = None,
453451
tool_choice: ToolChoice | None = None,
454452
use: list[ModelMiddleware] | None = None,
455453
docs: list[Document] | None = None,
@@ -477,7 +475,7 @@ def define_prompt(
477475
max_turns: int | None = None,
478476
return_tool_requests: bool | None = None,
479477
metadata: dict[str, object] | None = None,
480-
tools: list[str] | None = None,
478+
tools: Sequence[str | Tool] | None = None,
481479
tool_choice: ToolChoice | None = None,
482480
use: list[ModelMiddleware] | None = None,
483481
docs: list[Document] | None = None,
@@ -503,7 +501,7 @@ def define_prompt(
503501
max_turns: int | None = None,
504502
return_tool_requests: bool | None = None,
505503
metadata: dict[str, object] | None = None,
506-
tools: list[str] | None = None,
504+
tools: Sequence[str | Tool] | None = None,
507505
tool_choice: ToolChoice | None = None,
508506
use: list[ModelMiddleware] | None = None,
509507
docs: list[Document] | None = None,
@@ -743,7 +741,7 @@ async def generate(
743741
prompt: str | list[Part] | None = None,
744742
system: str | list[Part] | None = None,
745743
messages: list[Message] | None = None,
746-
tools: list[str] | None = None,
744+
tools: Sequence[str | Tool] | None = None,
747745
return_tool_requests: bool | None = None,
748746
tool_choice: ToolChoice | None = None,
749747
tool_responses: list[Part] | None = None,
@@ -768,7 +766,7 @@ async def generate(
768766
prompt: str | list[Part] | None = None,
769767
system: str | list[Part] | None = None,
770768
messages: list[Message] | None = None,
771-
tools: list[str] | None = None,
769+
tools: Sequence[str | Tool] | None = None,
772770
return_tool_requests: bool | None = None,
773771
tool_choice: ToolChoice | None = None,
774772
tool_responses: list[Part] | None = None,
@@ -791,7 +789,7 @@ async def generate(
791789
prompt: str | list[Part] | None = None,
792790
system: str | list[Part] | None = None,
793791
messages: list[Message] | None = None,
794-
tools: list[str] | None = None,
792+
tools: Sequence[str | Tool] | None = None,
795793
return_tool_requests: bool | None = None,
796794
tool_choice: ToolChoice | None = None,
797795
tool_responses: list[Part] | None = None,
@@ -806,7 +804,12 @@ async def generate(
806804
use: list[ModelMiddleware] | None = None,
807805
docs: list[Document] | None = None,
808806
) -> ModelResponse[Any]:
809-
"""Generate text or structured data using a language model."""
807+
"""Generate text or structured data using a language model.
808+
809+
``tools`` is typed as ``Sequence`` rather than ``list`` because ``Sequence``
810+
is covariant: ``list[Tool]`` or ``list[str]`` are both assignable to
811+
``Sequence[str | Tool]``, but not to ``list[str | Tool]``.
812+
"""
810813
return await generate_action(
811814
self.registry,
812815
await to_generate_action_options(
@@ -843,7 +846,7 @@ def generate_stream(
843846
prompt: str | list[Part] | None = None,
844847
system: str | list[Part] | None = None,
845848
messages: list[Message] | None = None,
846-
tools: list[str] | None = None,
849+
tools: Sequence[str | Tool] | None = None,
847850
return_tool_requests: bool | None = None,
848851
tool_choice: ToolChoice | None = None,
849852
config: dict[str, object] | ModelConfig | None = None,
@@ -868,7 +871,7 @@ def generate_stream(
868871
prompt: str | list[Part] | None = None,
869872
system: str | list[Part] | None = None,
870873
messages: list[Message] | None = None,
871-
tools: list[str] | None = None,
874+
tools: Sequence[str | Tool] | None = None,
872875
return_tool_requests: bool | None = None,
873876
tool_choice: ToolChoice | None = None,
874877
config: dict[str, object] | ModelConfig | None = None,
@@ -891,7 +894,7 @@ def generate_stream(
891894
prompt: str | list[Part] | None = None,
892895
system: str | list[Part] | None = None,
893896
messages: list[Message] | None = None,
894-
tools: list[str] | None = None,
897+
tools: Sequence[str | Tool] | None = None,
895898
return_tool_requests: bool | None = None,
896899
tool_choice: ToolChoice | None = None,
897900
config: dict[str, object] | ModelConfig | None = None,
@@ -1055,26 +1058,6 @@ def current_context() -> dict[str, Any] | None:
10551058
"""Get the current execution context, or None if not in an action."""
10561059
return ActionRunContext._current_context() # pyright: ignore[reportPrivateUsage]
10571060

1058-
def dynamic_tool(
1059-
self,
1060-
*,
1061-
name: str,
1062-
fn: Callable[..., object],
1063-
description: str | None = None,
1064-
metadata: dict[str, object] | None = None,
1065-
) -> Action:
1066-
"""Create an unregistered tool action for passing directly to generate()."""
1067-
tool_meta: dict[str, object] = metadata.copy() if metadata else {}
1068-
tool_meta['type'] = 'tool'
1069-
tool_meta['dynamic'] = True
1070-
return Action(
1071-
kind=ActionKind.TOOL,
1072-
name=name,
1073-
fn=fn, # type: ignore[arg-type] # dynamic tools may be sync
1074-
description=description,
1075-
metadata=tool_meta,
1076-
)
1077-
10781061
async def flush_tracing(self) -> None:
10791062
"""Flush all pending trace spans to exporters."""
10801063
provider = trace_api.get_tracer_provider()
@@ -1132,7 +1115,7 @@ async def generate_operation(
11321115
prompt: str | list[Part] | None = None,
11331116
system: str | list[Part] | None = None,
11341117
messages: list[Message] | None = None,
1135-
tools: list[str] | None = None,
1118+
tools: Sequence[str | Tool] | None = None,
11361119
return_tool_requests: bool | None = None,
11371120
tool_choice: ToolChoice | None = None,
11381121
config: dict[str, object] | ModelConfig | None = None,

py/packages/genkit/src/genkit/_ai/_generate.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import copy
2222
import inspect
2323
import re
24-
from collections.abc import Callable
24+
from collections.abc import Callable, Sequence
2525
from typing import Any, cast
2626

2727
from pydantic import BaseModel
@@ -37,7 +37,7 @@
3737
ModelResponseChunk,
3838
)
3939
from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources
40-
from genkit._ai._tools import ToolInterruptError
40+
from genkit._ai._tools import Tool, ToolInterruptError
4141
from genkit._core._action import Action, ActionKind, ActionRunContext
4242
from genkit._core._error import GenkitError
4343
from genkit._core._logger import get_logger
@@ -61,6 +61,21 @@
6161
logger = get_logger(__name__)
6262

6363

64+
def tools_to_action_names(
65+
tools: Sequence[str | Tool] | None,
66+
) -> list[str] | None:
67+
"""Normalize tool arguments to registry tool name strings for GenerateActionOptions."""
68+
if tools is None:
69+
return None
70+
names: list[str] = []
71+
for t in tools:
72+
if isinstance(t, str):
73+
names.append(t)
74+
else:
75+
names.append(t.name)
76+
return names
77+
78+
6479
# Matches data URIs: everything up to the first comma is the media-type +
6580
# parameters (e.g. "data:audio/L16;codec=pcm;rate=24000;base64,").
6681
_DATA_URI_RE = re.compile(r'data:[^,]{0,200},(?=.{100})', re.ASCII)
@@ -744,11 +759,17 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart
744759
raise e
745760

746761

747-
async def resolve_tool(registry: Registry, tool_name: str) -> Action:
748-
"""Resolve a tool by name from the registry."""
749-
tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_name)
762+
async def resolve_tool(registry: Registry, tool_ref: str | Tool) -> Action:
763+
"""Resolve a tool from a registry name or a Tool instance.
764+
765+
Used when building ModelRequest (for example from to_generate_request).
766+
"""
767+
if isinstance(tool_ref, Tool):
768+
return tool_ref.action
769+
770+
tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_ref)
750771
if tool is None:
751-
raise ValueError(f'Unable to resolve tool {tool_name}')
772+
raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_ref}')
752773
return tool
753774

754775

0 commit comments

Comments
 (0)