Skip to content

Commit b23a615

Browse files
committed
refactor(middleware): rename profile to proxy_profile and raise ToolError on failures
Avoids collisions with backend tool arguments by using a namespaced proxy_profile parameter. Errors now raise ToolError instead of returning ToolResult for proper MCP error propagation. Deep-copies tool parameters to prevent mutating shared upstream dicts. Extracts profile override middleware setup into a dedicated helper.
1 parent b4051ef commit b23a615

3 files changed

Lines changed: 88 additions & 54 deletions

File tree

mcp_proxy_for_aws/middleware/profile_switcher.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Middleware that enables per-call AWS profile overrides via a ``profile`` argument.
15+
"""Middleware that enables per-call AWS profile overrides via a ``proxy_profile`` argument.
1616
17-
Pass ``profile`` as an extra argument on any tool call to route that single request
17+
Pass ``proxy_profile`` as an extra argument on any tool call to route that single request
1818
through a dedicated transport signed with the specified profile's credentials. The
1919
argument is stripped before forwarding to the backend.
2020
@@ -23,12 +23,14 @@
2323
"""
2424

2525
import asyncio
26+
import copy
2627
import httpx
2728
import logging
2829
import mcp.types as mt
2930
from collections.abc import Sequence
3031
from fastmcp import Client
3132
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
33+
from fastmcp.exceptions import ToolError
3234
from fastmcp.tools.tool import Tool, ToolResult
3335
from mcp_proxy_for_aws.utils import create_transport_with_sigv4
3436
from typing import Any, cast
@@ -39,12 +41,12 @@
3941

4042

4143
class ProfileOverrideMiddleware(Middleware):
42-
"""Middleware that intercepts ``profile`` on any tool call for per-request AWS identity switching.
44+
"""Middleware that intercepts ``proxy_profile`` on any tool call for per-request AWS identity switching.
4345
44-
When a tool call includes a ``profile`` argument, the middleware:
46+
When a tool call includes a ``proxy_profile`` argument, the middleware:
4547
4648
1. Validates the profile against the allowed list
47-
2. Strips ``profile`` from the arguments
49+
2. Strips ``proxy_profile`` from the arguments
4850
3. Forwards the call through a dedicated per-profile MCP client
4951
5052
Each profile gets its own transport and session to the backend so that
@@ -79,22 +81,24 @@ async def on_list_tools(
7981
context: MiddlewareContext[mt.ListToolsRequest],
8082
call_next: CallNext[mt.ListToolsRequest, Sequence[Tool]],
8183
) -> Sequence[Tool]:
82-
"""Inject ``profile`` into every tool's schema."""
84+
"""Inject ``proxy_profile`` into every tool's schema."""
8385
tools = await call_next(context)
8486

8587
for tool in tools:
86-
params = tool.parameters
87-
if not isinstance(params, dict):
88+
if not isinstance(tool.parameters, dict):
8889
continue
90+
# Deep-copy to avoid mutating upstream cached/shared dicts
91+
params = copy.deepcopy(tool.parameters)
8992
if 'properties' not in params:
9093
params['properties'] = {}
91-
params['properties']['profile'] = {
94+
params['properties']['proxy_profile'] = {
9295
'type': 'string',
9396
'description': (
9497
'AWS CLI profile to sign this request with. Omit to use the default profile.'
9598
),
9699
'enum': sorted(self._allowed_profiles),
97100
}
101+
tool.parameters = params
98102

99103
return list(tools)
100104

@@ -106,10 +110,10 @@ async def on_call_tool(
106110
context: MiddlewareContext[mt.CallToolRequestParams],
107111
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
108112
) -> ToolResult:
109-
"""Intercept ``profile`` and route through a dedicated per-profile client."""
113+
"""Intercept ``proxy_profile`` and route through a dedicated per-profile client."""
110114
arguments = context.message.arguments
111-
if isinstance(arguments, dict) and 'profile' in arguments:
112-
profile = arguments['profile']
115+
if isinstance(arguments, dict) and 'proxy_profile' in arguments:
116+
profile = arguments['proxy_profile']
113117
return await self._call_with_profile(profile, context, call_next)
114118

115119
return await call_next(context)
@@ -157,14 +161,14 @@ async def _call_with_profile(
157161
"""Forward a tool call through a dedicated per-profile connection."""
158162
if profile not in self._allowed_profiles:
159163
allowed = ', '.join(sorted(self._allowed_profiles))
160-
return ToolResult(
161-
content=f'Error: profile {profile!r} is not in the allowed list. '
164+
raise ToolError(
165+
f'Profile {profile!r} is not in the allowed list. '
162166
f'Allowed profiles: {allowed}'
163167
)
164168

165-
# Strip profile before forwarding to the backend
169+
# Strip proxy_profile before forwarding to the backend
166170
arguments: dict[str, Any] = dict(cast(dict[str, Any], context.message.arguments))
167-
arguments.pop('profile', None)
171+
arguments.pop('proxy_profile', None)
168172

169173
logger.info(
170174
'Per-call profile override: routing through dedicated connection for %s', profile
@@ -174,8 +178,8 @@ async def _call_with_profile(
174178
client = await self._get_profile_client(profile)
175179
except Exception:
176180
logger.exception('Failed to create connection for profile %s', profile)
177-
return ToolResult(
178-
content=f'Error: failed to create connection for profile {profile!r}. '
181+
raise ToolError(
182+
f'Failed to create connection for profile {profile!r}. '
179183
'Check that the profile is configured and credentials are valid.'
180184
)
181185

@@ -188,7 +192,7 @@ async def _call_with_profile(
188192
)
189193
except Exception:
190194
logger.exception('Error calling tool via profile %s', profile)
191-
return ToolResult(
192-
content=f'Error: tool call failed using profile {profile!r}. '
195+
raise ToolError(
196+
f'Tool call failed using profile {profile!r}. '
193197
'The request could not be completed with the specified profile.'
194198
)

mcp_proxy_for_aws/server.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,9 @@ async def run_proxy(args) -> None:
105105
add_logging_middleware(proxy, args.log_level)
106106
add_tool_filtering_middleware(proxy, args.read_only)
107107

108-
allowed_profiles = getattr(args, 'allow_switch_profile', None)
109-
if isinstance(allowed_profiles, list) and allowed_profiles:
110-
profile_middleware = ProfileOverrideMiddleware(
111-
allowed_profiles=allowed_profiles,
112-
service=service,
113-
region=region,
114-
metadata=metadata,
115-
timeout=timeout,
116-
endpoint=args.endpoint,
117-
)
118-
proxy.add_middleware(profile_middleware)
108+
profile_middleware = add_profile_override_middleware(
109+
proxy, args, service, region, metadata, timeout
110+
)
119111

120112
if args.retries:
121113
add_retry_middleware(proxy, args.retries)
@@ -140,6 +132,43 @@ def add_tool_error_middleware(mcp: FastMCP, tool_timeout: float) -> None:
140132
mcp.add_middleware(ToolErrorMiddleware(tool_call_timeout=tool_timeout))
141133

142134

135+
def add_profile_override_middleware(
136+
mcp: FastMCP,
137+
args,
138+
service: str,
139+
region: str,
140+
metadata: dict,
141+
timeout: httpx.Timeout,
142+
) -> ProfileOverrideMiddleware | None:
143+
"""Add profile override middleware to target MCP server.
144+
145+
Args:
146+
mcp: The FastMCP instance to add profile override to
147+
args: The parsed CLI arguments
148+
service: The AWS service name
149+
region: The AWS region
150+
metadata: The metadata dictionary
151+
timeout: The httpx timeout configuration
152+
153+
Returns:
154+
The ProfileOverrideMiddleware instance if added, None otherwise
155+
"""
156+
allowed_profiles = getattr(args, 'allow_switch_profile', None)
157+
if not isinstance(allowed_profiles, list) or not allowed_profiles:
158+
return None
159+
logger.info('Adding profile override middleware')
160+
middleware = ProfileOverrideMiddleware(
161+
allowed_profiles=allowed_profiles,
162+
service=service,
163+
region=region,
164+
metadata=metadata,
165+
timeout=timeout,
166+
endpoint=args.endpoint,
167+
)
168+
mcp.add_middleware(middleware)
169+
return middleware
170+
171+
143172
def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
144173
"""Add tool filtering middleware to target MCP server.
145174

tests/unit/test_profile_switcher.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import asyncio
1818
import httpx
1919
import pytest
20+
from fastmcp.exceptions import ToolError
2021
from fastmcp.server.middleware import MiddlewareContext
2122
from mcp_proxy_for_aws.middleware.profile_switcher import ProfileOverrideMiddleware
2223
from unittest.mock import AsyncMock, MagicMock, Mock, patch
@@ -51,8 +52,8 @@ class TestOnListTools:
5152
"""Tests for the on_list_tools method."""
5253

5354
@pytest.mark.asyncio
54-
async def test_injects_profile_property_into_tool_schemas(self, middleware, mock_context):
55-
"""Every proxied tool gets a profile property in its schema."""
55+
async def test_injects_proxy_profile_property_into_tool_schemas(self, middleware, mock_context):
56+
"""Every proxied tool gets a proxy_profile property in its schema."""
5657
tool = Mock()
5758
tool.name = 'some_tool'
5859
tool.parameters = {'type': 'object', 'properties': {'arg': {'type': 'string'}}}
@@ -62,7 +63,7 @@ async def test_injects_profile_property_into_tool_schemas(self, middleware, mock
6263

6364
assert len(result) == 1
6465
assert result[0].name == 'some_tool'
65-
profile_schema = result[0].parameters['properties']['profile']
66+
profile_schema = result[0].parameters['properties']['proxy_profile']
6667
assert profile_schema['type'] == 'string'
6768
assert 'AWS CLI profile' in profile_schema['description']
6869
assert profile_schema['enum'] == sorted(ALLOWED_PROFILES)
@@ -101,15 +102,15 @@ async def test_adds_properties_key_when_missing(self, middleware, mock_context):
101102
result = await middleware.on_list_tools(mock_context, call_next)
102103

103104
assert 'properties' in result[0].parameters
104-
assert 'profile' in result[0].parameters['properties']
105+
assert 'proxy_profile' in result[0].parameters['properties']
105106

106107

107108
class TestOnCallTool:
108109
"""Tests for the on_call_tool method."""
109110

110111
@pytest.mark.asyncio
111-
async def test_passes_through_calls_without_profile(self, middleware, mock_context):
112-
"""Tool calls without profile are forwarded unchanged."""
112+
async def test_passes_through_calls_without_proxy_profile(self, middleware, mock_context):
113+
"""Tool calls without proxy_profile are forwarded unchanged."""
113114
mock_context.message = Mock()
114115
mock_context.message.name = 'some_tool'
115116
mock_context.message.arguments = {'arg': 'value'}
@@ -141,20 +142,20 @@ class TestPerCallProfileOverride:
141142

142143
@pytest.mark.asyncio
143144
async def test_profile_override_disallowed(self, middleware, mock_context):
144-
"""Profile with a disallowed profile returns an error."""
145+
"""Disallowed profile raises ToolError."""
145146
mock_context.message = Mock()
146147
mock_context.message.name = 'some_tool'
147-
mock_context.message.arguments = {'arg': 'value', 'profile': 'evil-profile'}
148+
mock_context.message.arguments = {'arg': 'value', 'proxy_profile': 'evil-profile'}
148149
call_next = AsyncMock()
149150

150-
result = await middleware.on_call_tool(mock_context, call_next)
151+
with pytest.raises(ToolError, match='not in the allowed list'):
152+
await middleware.on_call_tool(mock_context, call_next)
151153

152-
assert 'not in the allowed list' in result.content[0].text
153154
call_next.assert_not_called()
154155

155156
@pytest.mark.asyncio
156-
async def test_profile_override_strips_profile_arg(self, middleware, mock_context):
157-
"""Profile is stripped before forwarding to the backend."""
157+
async def test_profile_override_strips_proxy_profile_arg(self, middleware, mock_context):
158+
"""proxy_profile is stripped before forwarding to the backend."""
158159
mock_client = AsyncMock()
159160
mock_call_result = MagicMock()
160161
mock_call_result.content = 'result'
@@ -164,7 +165,7 @@ async def test_profile_override_strips_profile_arg(self, middleware, mock_contex
164165

165166
mock_context.message = Mock()
166167
mock_context.message.name = 'some_tool'
167-
mock_context.message.arguments = {'arg': 'value', 'profile': 'dev-profile'}
168+
mock_context.message.arguments = {'arg': 'value', 'proxy_profile': 'dev-profile'}
168169
call_next = AsyncMock()
169170

170171
with patch.object(middleware, '_get_profile_client', return_value=mock_client):
@@ -175,36 +176,36 @@ async def test_profile_override_strips_profile_arg(self, middleware, mock_contex
175176

176177
@pytest.mark.asyncio
177178
async def test_profile_override_connection_failure(self, middleware, mock_context):
178-
"""Connection failure returns a sanitized error."""
179+
"""Connection failure raises ToolError with sanitized message."""
179180
mock_context.message = Mock()
180181
mock_context.message.name = 'some_tool'
181-
mock_context.message.arguments = {'arg': 'value', 'profile': 'dev-profile'}
182+
mock_context.message.arguments = {'arg': 'value', 'proxy_profile': 'dev-profile'}
182183
call_next = AsyncMock()
183184

184185
with patch.object(
185186
middleware, '_get_profile_client', side_effect=Exception('connection refused')
186187
):
187-
result = await middleware.on_call_tool(mock_context, call_next)
188+
with pytest.raises(ToolError, match='Failed to create connection') as exc_info:
189+
await middleware.on_call_tool(mock_context, call_next)
188190

189-
assert 'failed to create connection' in result.content[0].text
190-
assert 'connection refused' not in result.content[0].text
191+
assert 'connection refused' not in str(exc_info.value)
191192

192193
@pytest.mark.asyncio
193194
async def test_profile_override_tool_call_failure(self, middleware, mock_context):
194-
"""Tool call failure returns a sanitized error."""
195+
"""Tool call failure raises ToolError with sanitized message."""
195196
mock_client = AsyncMock()
196197
mock_client.call_tool.side_effect = Exception('backend error')
197198

198199
mock_context.message = Mock()
199200
mock_context.message.name = 'some_tool'
200-
mock_context.message.arguments = {'arg': 'value', 'profile': 'dev-profile'}
201+
mock_context.message.arguments = {'arg': 'value', 'proxy_profile': 'dev-profile'}
201202
call_next = AsyncMock()
202203

203204
with patch.object(middleware, '_get_profile_client', return_value=mock_client):
204-
result = await middleware.on_call_tool(mock_context, call_next)
205+
with pytest.raises(ToolError, match='Tool call failed') as exc_info:
206+
await middleware.on_call_tool(mock_context, call_next)
205207

206-
assert 'tool call failed' in result.content[0].text
207-
assert 'backend error' not in result.content[0].text
208+
assert 'backend error' not in str(exc_info.value)
208209

209210

210211
class TestGetProfileClient:

0 commit comments

Comments
 (0)