diff --git a/python/composio/core/models/custom_tools.py b/python/composio/core/models/custom_tools.py index 92d7b2d2d7..1bd709b012 100644 --- a/python/composio/core/models/custom_tools.py +++ b/python/composio/core/models/custom_tools.py @@ -30,9 +30,10 @@ import inspect import typing as t +from composio_client import Omit, omit from pydantic import BaseModel -from composio.client import HttpClient, NotGiven +from composio.client import HttpClient from composio.client.types import ( Tool, tool_list_response, @@ -45,15 +46,22 @@ class ExecuteRequestFn(t.Protocol): + """Proxy callable handed to custom tools. + + ``connected_account_id`` is intentionally absent. The SDK binds the + proxy call to the same connected account that supplied + ``auth_credentials``, so the proxy and credentials always address the + same trusted account; a caller-supplied ``connected_account_id`` would + introduce a credential/identity mismatch and is rejected with + ``TypeError`` at the call site. + """ + def __call__( self, endpoint: str, method: t.Literal["GET", "POST", "PUT", "DELETE", "PATCH"], - body: t.Dict | NotGiven = HttpClient.not_given, - connected_account_id: str | NotGiven = HttpClient.not_given, - parameters: ( - t.Iterable[tool_proxy_params.Parameter] | NotGiven - ) = HttpClient.not_given, + body: t.Dict | Omit = omit, + parameters: t.Iterable[tool_proxy_params.Parameter] | Omit = omit, ) -> tool_proxy_response.ToolProxyResponse: ... @@ -134,11 +142,41 @@ def __parse_info(self) -> Tool: tags=[], ) - def __get_auth_credentials(self, user_id: str) -> dict: - """Get the auth config for the custom tool.""" + def __resolve_connected_account( + self, + user_id: str, + connected_account_id: t.Optional[str] = None, + ) -> t.Tuple[str, dict]: + """Resolve the connected account for this tool call. + + Returns ``(account_id, credentials)``. When ``connected_account_id`` + is provided, the resolver list-filters by the trusted + ``(toolkit, user_id, connected_account_id)`` envelope so the + explicit id is server-side bound to the trust principal — an + upstream caller forwarding another tenant's id cannot grant + cross-tenant credential access (CWE-639). Without an explicit + id, the resolver falls back to the most recently created account + for ``(toolkit, user_id)``. The returned ``account_id`` is bound + onto ``execute_request`` so the proxy call carries the same auth + context as ``auth_credentials``. + """ if self.toolkit is None: raise ValueError("Toolkit is required for custom tools") + if connected_account_id is not None: + response = self.client.connected_accounts.list( + toolkit_slugs=[self.toolkit], + user_ids=[user_id], + connected_account_ids=[connected_account_id], + ) + if len(response.items) == 0: + raise ValueError( + f"Connected account {connected_account_id} not found " + f"for toolkit {self.toolkit} and user {user_id}" + ) + account = response.items[0] + return account.id, account.state.val.model_dump() + connected_accounts = self.client.connected_accounts.list( toolkit_slugs=[self.toolkit], user_ids=[user_id], @@ -154,7 +192,7 @@ def __get_auth_credentials(self, user_id: str) -> dict: key=lambda x: x.created_at, reverse=True, ) - return account.state.val.model_dump() + return account.id, account.state.val.model_dump() def __call__(self, **kwargs: t.Any) -> t.Any: """Call the custom tool with the default ``user_id``. @@ -163,34 +201,73 @@ def __call__(self, **kwargs: t.Any) -> t.Any: silently dropping it — see the module docstring for why. Use ``CustomTools.execute(slug, request, user_id=...)`` (the trusted entry point) when you need a non-default ``user_id``. + + ``connected_account_id`` is popped from ``kwargs`` before request + validation so it never reaches the user's tool function as a + request field. """ if "user_id" in kwargs: raise TypeError( "CustomTool.__call__ does not accept user_id; " "use CustomTools.execute(slug, request, user_id=...) instead." ) - return self.invoke_trusted(user_id="default", request_kwargs=kwargs) + connected_account_id = kwargs.pop("connected_account_id", None) + return self.invoke_trusted( + user_id="default", + request_kwargs=kwargs, + connected_account_id=connected_account_id, + ) def invoke_trusted( self, user_id: str, request_kwargs: t.Dict[str, t.Any], + connected_account_id: t.Optional[str] = None, ) -> t.Any: """Trusted entry point used by ``CustomTools.execute``. - ``user_id`` is taken as a structurally separate parameter, so an - LLM-controlled ``user_id`` key sitting inside ``request_kwargs`` - cannot override it (the request model's default ``extra='ignore'`` - means any such key is also dropped during validation). + ``user_id`` and ``connected_account_id`` are taken as structurally + separate parameters, so an LLM-controlled key sitting inside + ``request_kwargs`` cannot override either (the request model's + default ``extra='ignore'`` means any such key is also dropped + during validation). """ request = self.request_model.model_validate(request_kwargs) if self.toolkit is None: return t.cast(CustomToolProtocol, self.f)(request=request) + account_id, credentials = self.__resolve_connected_account( + user_id=user_id, + connected_account_id=connected_account_id, + ) + + # Bind the proxy to the resolved account via a wrapper rather than + # ``functools.partial``: ``partial`` lets a caller-supplied keyword + # override the pre-bound value, so a tool that does + # ``execute_request(connected_account_id="ca_other")`` could run + # the proxy under one account while ``auth_credentials`` came + # from another (CWE-639). The wrapper omits ``connected_account_id`` + # from its signature so any attempt to override it raises TypeError. + client = self.client + + def execute_request( + endpoint: str, + method: t.Literal["GET", "POST", "PUT", "DELETE", "PATCH"], + body: t.Dict | Omit = omit, + parameters: t.Iterable[tool_proxy_params.Parameter] | Omit = omit, + ) -> tool_proxy_response.ToolProxyResponse: + return client.tools.proxy( + endpoint=endpoint, + method=method, + body=body, + parameters=parameters, + connected_account_id=account_id, + ) + return t.cast(CustomToolWithProxyProtocol, self.f)( request=request, - execute_request=t.cast(ExecuteRequestFn, self.client.tools.proxy), - auth_credentials=self.__get_auth_credentials(user_id), + execute_request=t.cast(ExecuteRequestFn, execute_request), + auth_credentials=credentials, ) @@ -262,6 +339,7 @@ def execute( slug: str, request: t.Dict, user_id: t.Optional[str] = None, + connected_account_id: t.Optional[str] = None, ) -> t.Any: """Execute a custom tool — the trust boundary for LLM-supplied args. @@ -269,9 +347,9 @@ def execute( the tool's Pydantic ``request_model`` (canonical names + aliases). Anything else — including the historical ``user_id`` smuggling vector and any future identity-bearing keys — is dropped before - the call reaches credential lookup. ``user_id`` is forwarded as a - structurally separate parameter; see the module docstring for the - full security model. + the call reaches credential lookup. ``user_id`` and + ``connected_account_id`` are forwarded as structurally separate + parameters; see the module docstring for the full security model. """ custom_tool = self.get(slug) if custom_tool is None: @@ -281,4 +359,5 @@ def execute( return custom_tool.invoke_trusted( user_id=user_id or "default", request_kwargs=sanitized_request, + connected_account_id=connected_account_id, ) diff --git a/python/composio/core/models/tools.py b/python/composio/core/models/tools.py index 22b7a9d289..55d717f72c 100644 --- a/python/composio/core/models/tools.py +++ b/python/composio/core/models/tools.py @@ -576,6 +576,7 @@ def _execute_custom_tool( self, slug: str, arguments: t.Dict, + connected_account_id: t.Optional[str] = None, user_id: t.Optional[str] = None, ) -> ToolExecutionResponse: """Execute a custom tool""" @@ -586,6 +587,7 @@ def _execute_custom_tool( slug=slug, request=arguments, user_id=user_id, + connected_account_id=connected_account_id, ), "error": None, "successful": True, @@ -764,6 +766,7 @@ def execute( self._execute_custom_tool( slug=slug, arguments=arguments, + connected_account_id=connected_account_id, user_id=user_id, ) if self._custom_tools.get(slug) is not None diff --git a/python/tests/test_custom_tools_security.py b/python/tests/test_custom_tools_security.py index 80584206f3..89cda4fdd4 100644 --- a/python/tests/test_custom_tools_security.py +++ b/python/tests/test_custom_tools_security.py @@ -272,3 +272,360 @@ def test_tools_execute_e2e_strips_user_id_through_full_stack( toolkit_slugs=["github"], user_ids=["trusted-user"], ) + + +# ──────────────────────────────────────────────────────────────── +# PLEN-2345: explicit ``connected_account_id`` plumbing +# +# Before PLEN-2345 ``connected_account_id`` was silently dropped on the +# custom-tool path: ``Tools._execute_custom_tool`` did not accept it, and +# ``execute_request`` was a bare reference to ``client.tools.proxy`` with +# no auth context bound — so every proxy call from a custom tool 400'd +# with ``ExternalProxy_MissingAuthContext`` after the backend stopped +# accepting empty-auth proxies. The pins below cover the new contract. +# ──────────────────────────────────────────────────────────────── + + +@pytest.fixture +def mock_http_client_with_explicit_account(mock_http_client: MagicMock) -> MagicMock: + """``mock_http_client`` extended for the explicit-id resolver path. + + The resolver list-filters by ``(toolkit, user, connected_account_ids)`` + when an explicit id is provided — the side_effect routes: + + * ``connected_account_ids=["ca_explicit"]`` → returns ``[ca_explicit]`` + (server-side authorized: id belongs to this toolkit + user). + * ``connected_account_ids=["ca_attacker"]`` → returns ``[]`` + (server-side unauthorized: id outside the trusted envelope). + * No ``connected_account_ids`` filter → falls back to the existing + "most recently created" account (``ca_fallback``). + + Lets tests pin both the auth-context fix and the CWE-639 boundary. + """ + explicit_state = MagicMock() + explicit_state.val.model_dump.return_value = {"access_token": "explicit-token"} + + explicit_account = MagicMock() + explicit_account.id = "ca_explicit" + explicit_account.state = explicit_state + + fallback_response = mock_http_client.connected_accounts.list.return_value + fallback_response.items[0].id = "ca_fallback" + + explicit_response = MagicMock() + explicit_response.items = [explicit_account] + + empty_response = MagicMock() + empty_response.items = [] + + def list_side_effect( + *, + toolkit_slugs=None, + user_ids=None, + connected_account_ids=None, + **_, + ): + if connected_account_ids is None: + return fallback_response + if ( + connected_account_ids == ["ca_explicit"] + and toolkit_slugs == ["github"] + and user_ids == ["trusted-user"] + ): + return explicit_response + return empty_response + + mock_http_client.connected_accounts.list.side_effect = list_side_effect + return mock_http_client + + +def test_execute_with_connected_account_id_filters_list_by_envelope( + mock_http_client_with_explicit_account: MagicMock, + custom_tools: CustomTools, + github_tool: CustomTool, +) -> None: + """An explicit ``connected_account_id`` MUST be list-filtered by ``(toolkit, user, id)``. + + Pins that the explicit id is server-side bound to the trusted envelope + so an upstream caller forwarding another tenant's id can't grant + cross-tenant credential access (CWE-639). Also pins that the explicit + path overrides the "most recently created" fallback. + """ + result = custom_tools.execute( + slug=github_tool.slug, + request={"issue_number": 7}, + user_id="trusted-user", + connected_account_id="ca_explicit", + ) + + mock_http_client_with_explicit_account.connected_accounts.list.assert_called_once_with( + toolkit_slugs=["github"], + user_ids=["trusted-user"], + connected_account_ids=["ca_explicit"], + ) + mock_http_client_with_explicit_account.connected_accounts.retrieve.assert_not_called() + assert result == {"issue_number": 7, "token": "explicit-token"} + + +def test_execute_request_rejects_caller_supplied_connected_account_id( + mock_http_client_with_explicit_account: MagicMock, +) -> None: + """``execute_request`` MUST refuse a caller-supplied ``connected_account_id``. + + ``functools.partial`` would let a caller-supplied keyword override the + SDK-bound id while ``auth_credentials`` still came from the trusted + account — that would re-introduce a credential/identity mismatch + (CWE-639) right at the proxy call site. The wrapper omits + ``connected_account_id`` from its signature so any attempt to set it + raises ``TypeError`` rather than silently swapping the account. + """ + captured: dict = {} + + def proxy_tool(request: _IssueInput, execute_request, auth_credentials): + """Try to override the SDK-bound account id from inside the tool.""" + try: + execute_request( + endpoint="/api/x", + method="GET", + connected_account_id="ca_other", # type: ignore[call-arg] + ) + except TypeError as exc: + captured["error"] = str(exc) + return {"rejected": True} + return {"rejected": False} + + tool = CustomTool( + f=proxy_tool, + client=mock_http_client_with_explicit_account, + toolkit="github", + ) + tools = CustomTools(client=mock_http_client_with_explicit_account) + tools.custom_tools_registry[tool.slug] = tool + + result = tools.execute( + slug=tool.slug, + request={"issue_number": 1}, + user_id="trusted-user", + connected_account_id="ca_explicit", + ) + + assert result == {"rejected": True} + assert "connected_account_id" in captured["error"] + # Proxy was NOT called — wrapper rejected the override before reaching it. + mock_http_client_with_explicit_account.tools.proxy.assert_not_called() + + +def test_explicit_connected_account_id_outside_envelope_raises( + mock_http_client_with_explicit_account: MagicMock, + custom_tools: CustomTools, + github_tool: CustomTool, +) -> None: + """An explicit id outside the trusted ``(toolkit, user)`` envelope MUST raise. + + Pins the CWE-639 boundary: passing an id that does not belong to the + trusted user/toolkit cannot return credentials. The list call returns + no items because the server-side filter rejects the id, and the SDK + raises rather than running with the wrong credentials. + + Without this guard, an upstream caller forwarding an attacker-influenced + id (e.g., from prompt injection tracked at a different layer) could + silently use another tenant's OAuth tokens. + """ + with pytest.raises(ValueError, match="ca_attacker"): + custom_tools.execute( + slug=github_tool.slug, + request={"issue_number": 1}, + user_id="trusted-user", + connected_account_id="ca_attacker", + ) + + # No proxy call attempted — bail out happens before f() is invoked. + mock_http_client_with_explicit_account.tools.proxy.assert_not_called() + + +def test_execute_request_wrapper_binds_resolved_account_id( + mock_http_client_with_explicit_account: MagicMock, +) -> None: + """``execute_request`` MUST pre-bind the resolved account id. + + This is the user-facing fix. Before PLEN-2345 ``execute_request`` + was a bare reference to ``client.tools.proxy``; the proxy endpoint + now requires auth context and 400'd every custom-tool proxy call. + The binding is a wrapper closure (not ``functools.partial``) so a + caller-supplied ``connected_account_id`` cannot override the bound + value — see ``test_execute_request_rejects_caller_supplied_...``. + """ + captured: dict = {} + + def proxy_tool(request: _IssueInput, execute_request, auth_credentials): + """Make a proxy call from inside a custom tool.""" + execute_request(endpoint="/api/x", method="GET") + captured["credentials"] = auth_credentials + return {"ok": True} + + tool = CustomTool( + f=proxy_tool, + client=mock_http_client_with_explicit_account, + toolkit="github", + ) + tools = CustomTools(client=mock_http_client_with_explicit_account) + tools.custom_tools_registry[tool.slug] = tool + + tools.execute( + slug=tool.slug, + request={"issue_number": 1}, + user_id="trusted-user", + connected_account_id="ca_explicit", + ) + + # The wrapper forwards body/parameters as the Stainless ``NOT_GIVEN`` + # sentinel when the user's tool omits them — same semantics as if the + # user passed the bare ``client.tools.proxy``, plus the SDK-bound id. + proxy_call = mock_http_client_with_explicit_account.tools.proxy.call_args + assert proxy_call.kwargs["endpoint"] == "/api/x" + assert proxy_call.kwargs["method"] == "GET" + assert proxy_call.kwargs["connected_account_id"] == "ca_explicit" + assert captured["credentials"] == {"access_token": "explicit-token"} + + +def test_execute_request_wrapper_binds_listed_account_on_fallback( + mock_http_client_with_explicit_account: MagicMock, +) -> None: + """On the list-fallback path, the wrapper MUST still bind an id. + + The auth-context fix has to apply when callers do not pass + ``connected_account_id`` — otherwise legacy code that relied on the + "default account" behaviour stays broken on the proxy path. + """ + + def proxy_tool(request: _IssueInput, execute_request, auth_credentials): + """Proxy from inside a custom tool with no explicit account id.""" + execute_request(endpoint="/api/y", method="GET") + return {"ok": True} + + tool = CustomTool( + f=proxy_tool, + client=mock_http_client_with_explicit_account, + toolkit="github", + ) + tools = CustomTools(client=mock_http_client_with_explicit_account) + tools.custom_tools_registry[tool.slug] = tool + + tools.execute( + slug=tool.slug, + request={"issue_number": 1}, + user_id="trusted-user", + ) + + # No connected_account_ids filter — fallback path + mock_http_client_with_explicit_account.connected_accounts.list.assert_called_once_with( + toolkit_slugs=["github"], + user_ids=["trusted-user"], + ) + proxy_call = mock_http_client_with_explicit_account.tools.proxy.call_args + assert proxy_call.kwargs["endpoint"] == "/api/y" + assert proxy_call.kwargs["method"] == "GET" + assert proxy_call.kwargs["connected_account_id"] == "ca_fallback" + + +def test_explicit_connected_account_id_wins_over_smuggled_one( + mock_http_client_with_explicit_account: MagicMock, + custom_tools: CustomTools, + github_tool: CustomTool, +) -> None: + """The trusted ``connected_account_id`` parameter MUST beat a smuggled one. + + The allowlist drops ``connected_account_id`` from the LLM-supplied + ``request`` (existing SEC-365 contract); the explicit parameter on + ``CustomTools.execute`` is the only way to influence which account + is used. The list filter MUST receive the trusted id, never the + smuggled one. + """ + custom_tools.execute( + slug=github_tool.slug, + request={"issue_number": 1, "connected_account_id": "ca_evil"}, + user_id="trusted-user", + connected_account_id="ca_explicit", + ) + + mock_http_client_with_explicit_account.connected_accounts.list.assert_called_once_with( + toolkit_slugs=["github"], + user_ids=["trusted-user"], + connected_account_ids=["ca_explicit"], + ) + + +def test_tools_execute_e2e_forwards_connected_account_id( + mock_http_client_with_explicit_account: MagicMock, + custom_tools: CustomTools, + github_tool: CustomTool, +) -> None: + """``Tools.execute(connected_account_id=...)`` MUST reach the custom-tool branch. + + End-to-end pin on the public SDK entry point: before PLEN-2345 the + custom-tool branch did not accept ``connected_account_id``, so the + parameter was silently dropped at the routing fork in ``Tools.execute``. + """ + from composio.core.models.tools import Tools + + provider = MagicMock() + provider.name = "test" + + tools = Tools(client=mock_http_client_with_explicit_account, provider=provider) + tools._custom_tools = custom_tools + + response = tools.execute( + slug=github_tool.slug, + arguments={"issue_number": 1}, + user_id="trusted-user", + connected_account_id="ca_explicit", + ) + + assert response["successful"] is True + assert response["data"]["token"] == "explicit-token" + mock_http_client_with_explicit_account.connected_accounts.list.assert_called_once_with( + toolkit_slugs=["github"], + user_ids=["trusted-user"], + connected_account_ids=["ca_explicit"], + ) + + +def test_call_pops_connected_account_id_before_request_validation( + mock_http_client_with_explicit_account: MagicMock, +) -> None: + """``CustomTool.__call__`` MUST extract ``connected_account_id`` from kwargs. + + Mirrors how ``__call__`` already refuses ``user_id``: auth-path keys + are structurally separate from request kwargs and never reach the + user's tool function as a request field. The ``__call__`` shortcut + uses ``user_id="default"``, so the resolver hits the "default" + id + envelope branch in the side_effect — pinned via the empty fallback. + """ + captured: dict = {} + + def echo_tool(request: _IssueInput, execute_request, auth_credentials): + """Echo the validated request.""" + captured["fields"] = request.model_dump() + return {"ok": True} + + tool = CustomTool( + f=echo_tool, + client=mock_http_client_with_explicit_account, + toolkit="github", + ) + + # Default user_id "default" + ca_explicit is outside the trusted envelope + # in the fixture (which only authorizes ("github", "trusted-user")), so + # this raises — but ONLY because connected_account_id reached the resolver + # as a structurally separate parameter rather than landing in request_kwargs. + with pytest.raises(ValueError, match="ca_explicit"): + tool(issue_number=1, connected_account_id="ca_explicit") + + # Resolver was called with the trusted parameter, not via request_kwargs. + mock_http_client_with_explicit_account.connected_accounts.list.assert_called_once_with( + toolkit_slugs=["github"], + user_ids=["default"], + connected_account_ids=["ca_explicit"], + ) + # The user's tool function was never invoked (resolver raised first). + assert captured == {} diff --git a/python/tests/test_provider.py b/python/tests/test_provider.py index 0bbe443b07..1d2589ff50 100644 --- a/python/tests/test_provider.py +++ b/python/tests/test_provider.py @@ -859,7 +859,7 @@ def mock_get(slug): tools._custom_tools.get = Mock(side_effect=mock_get) - def mock_execute(slug, request, user_id): + def mock_execute(slug, request, user_id, connected_account_id=None): return {"custom_result": "success", "slug": slug} tools._custom_tools.execute = Mock(side_effect=mock_execute) diff --git a/python/tests/test_tool_execution.py b/python/tests/test_tool_execution.py index d0e8c6b5d7..4211b892c7 100644 --- a/python/tests/test_tool_execution.py +++ b/python/tests/test_tool_execution.py @@ -745,7 +745,7 @@ def mock_get(slug): tools._custom_tools.get = Mock(side_effect=mock_get) # Mock the execute method of custom tool - def mock_execute(slug, request, user_id): + def mock_execute(slug, request, user_id, connected_account_id=None): return {"custom_result": "success"} tools._custom_tools.execute = Mock(side_effect=mock_execute) @@ -766,6 +766,7 @@ def mock_execute(slug, request, user_id): slug="CUSTOM_TOOL", request={"param": "value"}, user_id="user-123", + connected_account_id=None, ) def test_execute_with_modifiers_before_execute(self):