Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 98 additions & 19 deletions python/composio/core/models/custom_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
import inspect
import typing as t

from composio_client import Omit, omit
from pydantic import BaseModel

from composio.client import HttpClient, NotGiven
from composio.client import HttpClient
from composio.client.types import (
Tool,
tool_list_response,
Expand All @@ -45,15 +46,22 @@


class ExecuteRequestFn(t.Protocol):
"""Proxy callable handed to custom tools.

``connected_account_id`` is intentionally absent. The SDK binds the
proxy call to the same connected account that supplied
``auth_credentials``, so the proxy and credentials always address the
same trusted account; a caller-supplied ``connected_account_id`` would
introduce a credential/identity mismatch and is rejected with
``TypeError`` at the call site.
"""

def __call__(
self,
endpoint: str,
method: t.Literal["GET", "POST", "PUT", "DELETE", "PATCH"],
body: t.Dict | NotGiven = HttpClient.not_given,
connected_account_id: str | NotGiven = HttpClient.not_given,
parameters: (
t.Iterable[tool_proxy_params.Parameter] | NotGiven
) = HttpClient.not_given,
body: t.Dict | Omit = omit,
parameters: t.Iterable[tool_proxy_params.Parameter] | Omit = omit,
) -> tool_proxy_response.ToolProxyResponse: ...


Expand Down Expand Up @@ -134,11 +142,41 @@ def __parse_info(self) -> Tool:
tags=[],
)

def __get_auth_credentials(self, user_id: str) -> dict:
"""Get the auth config for the custom tool."""
def __resolve_connected_account(
self,
user_id: str,
connected_account_id: t.Optional[str] = None,
) -> t.Tuple[str, dict]:
"""Resolve the connected account for this tool call.

Returns ``(account_id, credentials)``. When ``connected_account_id``
is provided, the resolver list-filters by the trusted
``(toolkit, user_id, connected_account_id)`` envelope so the
explicit id is server-side bound to the trust principal — an
upstream caller forwarding another tenant's id cannot grant
cross-tenant credential access (CWE-639). Without an explicit
id, the resolver falls back to the most recently created account
for ``(toolkit, user_id)``. The returned ``account_id`` is bound
onto ``execute_request`` so the proxy call carries the same auth
context as ``auth_credentials``.
"""
if self.toolkit is None:
raise ValueError("Toolkit is required for custom tools")

if connected_account_id is not None:
response = self.client.connected_accounts.list(
toolkit_slugs=[self.toolkit],
user_ids=[user_id],
connected_account_ids=[connected_account_id],
)
if len(response.items) == 0:
raise ValueError(
f"Connected account {connected_account_id} not found "
f"for toolkit {self.toolkit} and user {user_id}"
)
account = response.items[0]
return account.id, account.state.val.model_dump()

connected_accounts = self.client.connected_accounts.list(
toolkit_slugs=[self.toolkit],
user_ids=[user_id],
Expand All @@ -154,7 +192,7 @@ def __get_auth_credentials(self, user_id: str) -> dict:
key=lambda x: x.created_at,
reverse=True,
)
return account.state.val.model_dump()
return account.id, account.state.val.model_dump()

def __call__(self, **kwargs: t.Any) -> t.Any:
"""Call the custom tool with the default ``user_id``.
Expand All @@ -163,34 +201,73 @@ def __call__(self, **kwargs: t.Any) -> t.Any:
silently dropping it — see the module docstring for why. Use
``CustomTools.execute(slug, request, user_id=...)`` (the trusted
entry point) when you need a non-default ``user_id``.

``connected_account_id`` is popped from ``kwargs`` before request
validation so it never reaches the user's tool function as a
request field.
"""
if "user_id" in kwargs:
raise TypeError(
"CustomTool.__call__ does not accept user_id; "
"use CustomTools.execute(slug, request, user_id=...) instead."
)
return self.invoke_trusted(user_id="default", request_kwargs=kwargs)
connected_account_id = kwargs.pop("connected_account_id", None)
return self.invoke_trusted(
user_id="default",
request_kwargs=kwargs,
connected_account_id=connected_account_id,
)

def invoke_trusted(
self,
user_id: str,
request_kwargs: t.Dict[str, t.Any],
connected_account_id: t.Optional[str] = None,
) -> t.Any:
"""Trusted entry point used by ``CustomTools.execute``.

``user_id`` is taken as a structurally separate parameter, so an
LLM-controlled ``user_id`` key sitting inside ``request_kwargs``
cannot override it (the request model's default ``extra='ignore'``
means any such key is also dropped during validation).
``user_id`` and ``connected_account_id`` are taken as structurally
separate parameters, so an LLM-controlled key sitting inside
``request_kwargs`` cannot override either (the request model's
default ``extra='ignore'`` means any such key is also dropped
during validation).
"""
request = self.request_model.model_validate(request_kwargs)
if self.toolkit is None:
return t.cast(CustomToolProtocol, self.f)(request=request)

account_id, credentials = self.__resolve_connected_account(
user_id=user_id,
connected_account_id=connected_account_id,
)

# Bind the proxy to the resolved account via a wrapper rather than
# ``functools.partial``: ``partial`` lets a caller-supplied keyword
# override the pre-bound value, so a tool that does
# ``execute_request(connected_account_id="ca_other")`` could run
# the proxy under one account while ``auth_credentials`` came
# from another (CWE-639). The wrapper omits ``connected_account_id``
# from its signature so any attempt to override it raises TypeError.
client = self.client

def execute_request(
endpoint: str,
method: t.Literal["GET", "POST", "PUT", "DELETE", "PATCH"],
body: t.Dict | Omit = omit,
parameters: t.Iterable[tool_proxy_params.Parameter] | Omit = omit,
) -> tool_proxy_response.ToolProxyResponse:
return client.tools.proxy(
endpoint=endpoint,
method=method,
body=body,
parameters=parameters,
connected_account_id=account_id,
)

return t.cast(CustomToolWithProxyProtocol, self.f)(
request=request,
execute_request=t.cast(ExecuteRequestFn, self.client.tools.proxy),
auth_credentials=self.__get_auth_credentials(user_id),
execute_request=t.cast(ExecuteRequestFn, execute_request),
auth_credentials=credentials,
)


Expand Down Expand Up @@ -262,16 +339,17 @@ def execute(
slug: str,
request: t.Dict,
user_id: t.Optional[str] = None,
connected_account_id: t.Optional[str] = None,
) -> t.Any:
"""Execute a custom tool — the trust boundary for LLM-supplied args.

``request`` is filtered through an allowlist of fields declared on
the tool's Pydantic ``request_model`` (canonical names + aliases).
Anything else — including the historical ``user_id`` smuggling
vector and any future identity-bearing keys — is dropped before
the call reaches credential lookup. ``user_id`` is forwarded as a
structurally separate parameter; see the module docstring for the
full security model.
the call reaches credential lookup. ``user_id`` and
``connected_account_id`` are forwarded as structurally separate
parameters; see the module docstring for the full security model.
"""
custom_tool = self.get(slug)
if custom_tool is None:
Expand All @@ -281,4 +359,5 @@ def execute(
return custom_tool.invoke_trusted(
user_id=user_id or "default",
request_kwargs=sanitized_request,
connected_account_id=connected_account_id,
)
3 changes: 3 additions & 0 deletions python/composio/core/models/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def _execute_custom_tool(
self,
slug: str,
arguments: t.Dict,
connected_account_id: t.Optional[str] = None,
user_id: t.Optional[str] = None,
) -> ToolExecutionResponse:
"""Execute a custom tool"""
Expand All @@ -559,6 +560,7 @@ def _execute_custom_tool(
slug=slug,
request=arguments,
user_id=user_id,
connected_account_id=connected_account_id,
),
"error": None,
"successful": True,
Expand Down Expand Up @@ -737,6 +739,7 @@ def execute(
self._execute_custom_tool(
slug=slug,
arguments=arguments,
connected_account_id=connected_account_id,
user_id=user_id,
)
if self._custom_tools.get(slug) is not None
Expand Down
Loading
Loading