Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions py/packages/genkit/src/genkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
PromptGenerateOptions,
ResumeOptions,
)
from genkit._ai._tools import ToolInterruptError, ToolRunContext, tool_response
from genkit._ai._tools import Tool, ToolInterruptError, ToolRunContext, tool_response
from genkit._core._action import Action, StreamResponse
from genkit._core._error import GenkitError, PublicError
from genkit._core._model import Document
Expand Down Expand Up @@ -93,7 +93,9 @@
# Errors
'GenkitError',
'PublicError',
'Tool',
'ToolInterruptError',
'tool_response',
# Content types
'Constrained',
'CustomPart',
Expand Down Expand Up @@ -128,7 +130,6 @@
'PromptGenerateOptions',
'ResumeOptions',
'ToolRunContext',
'tool_response',
'ModelRequest',
'ModelResponse',
'ModelResponseChunk',
Expand Down
65 changes: 24 additions & 41 deletions py/packages/genkit/src/genkit/_ai/_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import socket
import threading
import uuid
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import Awaitable, Callable, Coroutine, Sequence
from pathlib import Path
from typing import Any, ParamSpec, TypeVar, cast, overload
from typing import Any, TypeVar, cast, overload

import anyio
import uvicorn
Expand Down Expand Up @@ -71,7 +71,7 @@
ResourceOptions,
define_resource,
)
from genkit._ai._tools import define_tool
from genkit._ai._tools import Tool, define_tool
from genkit._core._action import Action, ActionKind, ActionRunContext
from genkit._core._background import (
BackgroundAction,
Expand Down Expand Up @@ -118,7 +118,7 @@
InputT = TypeVar('InputT')
OutputT = TypeVar('OutputT')
ChunkT = TypeVar('ChunkT')
P = ParamSpec('P')

R = TypeVar('R')
T = TypeVar('T')

Expand Down Expand Up @@ -260,12 +260,10 @@ def define_dynamic_action_provider(
metadata=metadata,
)

def tool(
self, name: str | None = None, description: str | None = None
) -> Callable[[Callable[P, T]], Callable[P, T]]:
def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable[..., Any]], Tool]:
"""Decorator to register a function as a tool."""

def wrapper(func: Callable[P, T]) -> Callable[P, T]:
def wrapper(func: Callable[..., Any]) -> Tool:
return define_tool(self.registry, func, name, description)

return wrapper
Expand Down Expand Up @@ -393,7 +391,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand Down Expand Up @@ -421,7 +419,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand Down Expand Up @@ -449,7 +447,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand Down Expand Up @@ -477,7 +475,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand All @@ -503,7 +501,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand Down Expand Up @@ -743,7 +741,7 @@ async def generate(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
tool_responses: list[Part] | None = None,
Expand All @@ -768,7 +766,7 @@ async def generate(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
tool_responses: list[Part] | None = None,
Expand All @@ -791,7 +789,7 @@ async def generate(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
tool_responses: list[Part] | None = None,
Expand All @@ -806,7 +804,12 @@ async def generate(
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
) -> ModelResponse[Any]:
"""Generate text or structured data using a language model."""
"""Generate text or structured data using a language model.

``tools`` is typed as ``Sequence`` rather than ``list`` because ``Sequence``
is covariant: ``list[Tool]`` or ``list[str]`` are both assignable to
``Sequence[str | Tool]``, but not to ``list[str | Tool]``.
"""
return await generate_action(
self.registry,
await to_generate_action_options(
Expand Down Expand Up @@ -843,7 +846,7 @@ def generate_stream(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
config: dict[str, object] | ModelConfig | None = None,
Expand All @@ -868,7 +871,7 @@ def generate_stream(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
config: dict[str, object] | ModelConfig | None = None,
Expand All @@ -891,7 +894,7 @@ def generate_stream(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
config: dict[str, object] | ModelConfig | None = None,
Expand Down Expand Up @@ -1055,26 +1058,6 @@ def current_context() -> dict[str, Any] | None:
"""Get the current execution context, or None if not in an action."""
return ActionRunContext._current_context() # pyright: ignore[reportPrivateUsage]

def dynamic_tool(
self,
*,
name: str,
fn: Callable[..., object],
description: str | None = None,
metadata: dict[str, object] | None = None,
) -> Action:
"""Create an unregistered tool action for passing directly to generate()."""
tool_meta: dict[str, object] = metadata.copy() if metadata else {}
tool_meta['type'] = 'tool'
tool_meta['dynamic'] = True
return Action(
kind=ActionKind.TOOL,
name=name,
fn=fn, # type: ignore[arg-type] # dynamic tools may be sync
description=description,
metadata=tool_meta,
)

async def flush_tracing(self) -> None:
"""Flush all pending trace spans to exporters."""
provider = trace_api.get_tracer_provider()
Expand Down Expand Up @@ -1132,7 +1115,7 @@ async def generate_operation(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
config: dict[str, object] | ModelConfig | None = None,
Expand Down
33 changes: 27 additions & 6 deletions py/packages/genkit/src/genkit/_ai/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import copy
import inspect
import re
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import Any, cast

from pydantic import BaseModel
Expand All @@ -37,7 +37,7 @@
ModelResponseChunk,
)
from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources
from genkit._ai._tools import ToolInterruptError
from genkit._ai._tools import Tool, ToolInterruptError
from genkit._core._action import Action, ActionKind, ActionRunContext
from genkit._core._error import GenkitError
from genkit._core._logger import get_logger
Expand All @@ -61,6 +61,21 @@
logger = get_logger(__name__)


def tools_to_action_names(
tools: Sequence[str | Tool] | None,
) -> list[str] | None:
"""Normalize tool arguments to registry tool name strings for GenerateActionOptions."""
if tools is None:
return None
names: list[str] = []
for t in tools:
if isinstance(t, str):
names.append(t)
else:
names.append(t.name)
return names


# Matches data URIs: everything up to the first comma is the media-type +
# parameters (e.g. "data:audio/L16;codec=pcm;rate=24000;base64,").
_DATA_URI_RE = re.compile(r'data:[^,]{0,200},(?=.{100})', re.ASCII)
Expand Down Expand Up @@ -744,11 +759,17 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart
raise e


async def resolve_tool(registry: Registry, tool_name: str) -> Action:
"""Resolve a tool by name from the registry."""
tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_name)
async def resolve_tool(registry: Registry, tool_ref: str | Tool) -> Action:
"""Resolve a tool from a registry name or a Tool instance.

Used when building ModelRequest (for example from to_generate_request).
"""
if isinstance(tool_ref, Tool):
return tool_ref.action

tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_ref)
if tool is None:
raise ValueError(f'Unable to resolve tool {tool_name}')
raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_ref}')
return tool


Expand Down
Loading
Loading