1717import asyncio
1818import httpx
1919import pytest
20+ from fastmcp .exceptions import ToolError
2021from fastmcp .server .middleware import MiddlewareContext
2122from mcp_proxy_for_aws .middleware .profile_switcher import ProfileOverrideMiddleware
2223from unittest .mock import AsyncMock , MagicMock , Mock , patch
@@ -51,8 +52,8 @@ class TestOnListTools:
5152 """Tests for the on_list_tools method."""
5253
5354 @pytest .mark .asyncio
54- async def test_injects_profile_property_into_tool_schemas (self , middleware , mock_context ):
55- """Every proxied tool gets a profile property in its schema."""
55+ async def test_injects_proxy_profile_property_into_tool_schemas (self , middleware , mock_context ):
56+ """Every proxied tool gets a proxy_profile property in its schema."""
5657 tool = Mock ()
5758 tool .name = 'some_tool'
5859 tool .parameters = {'type' : 'object' , 'properties' : {'arg' : {'type' : 'string' }}}
@@ -62,7 +63,7 @@ async def test_injects_profile_property_into_tool_schemas(self, middleware, mock
6263
6364 assert len (result ) == 1
6465 assert result [0 ].name == 'some_tool'
65- profile_schema = result [0 ].parameters ['properties' ]['profile ' ]
66+ profile_schema = result [0 ].parameters ['properties' ]['proxy_profile ' ]
6667 assert profile_schema ['type' ] == 'string'
6768 assert 'AWS CLI profile' in profile_schema ['description' ]
6869 assert profile_schema ['enum' ] == sorted (ALLOWED_PROFILES )
@@ -101,15 +102,15 @@ async def test_adds_properties_key_when_missing(self, middleware, mock_context):
101102 result = await middleware .on_list_tools (mock_context , call_next )
102103
103104 assert 'properties' in result [0 ].parameters
104- assert 'profile ' in result [0 ].parameters ['properties' ]
105+ assert 'proxy_profile ' in result [0 ].parameters ['properties' ]
105106
106107
107108class TestOnCallTool :
108109 """Tests for the on_call_tool method."""
109110
110111 @pytest .mark .asyncio
111- async def test_passes_through_calls_without_profile (self , middleware , mock_context ):
112- """Tool calls without profile are forwarded unchanged."""
112+ async def test_passes_through_calls_without_proxy_profile (self , middleware , mock_context ):
113+ """Tool calls without proxy_profile are forwarded unchanged."""
113114 mock_context .message = Mock ()
114115 mock_context .message .name = 'some_tool'
115116 mock_context .message .arguments = {'arg' : 'value' }
@@ -141,20 +142,20 @@ class TestPerCallProfileOverride:
141142
142143 @pytest .mark .asyncio
143144 async def test_profile_override_disallowed (self , middleware , mock_context ):
144- """Profile with a disallowed profile returns an error ."""
145+ """Disallowed profile raises ToolError ."""
145146 mock_context .message = Mock ()
146147 mock_context .message .name = 'some_tool'
147- mock_context .message .arguments = {'arg' : 'value' , 'profile ' : 'evil-profile' }
148+ mock_context .message .arguments = {'arg' : 'value' , 'proxy_profile ' : 'evil-profile' }
148149 call_next = AsyncMock ()
149150
150- result = await middleware .on_call_tool (mock_context , call_next )
151+ with pytest .raises (ToolError , match = 'not in the allowed list' ):
152+ await middleware .on_call_tool (mock_context , call_next )
151153
152- assert 'not in the allowed list' in result .content [0 ].text
153154 call_next .assert_not_called ()
154155
155156 @pytest .mark .asyncio
156- async def test_profile_override_strips_profile_arg (self , middleware , mock_context ):
157- """Profile is stripped before forwarding to the backend."""
157+ async def test_profile_override_strips_proxy_profile_arg (self , middleware , mock_context ):
158+ """proxy_profile is stripped before forwarding to the backend."""
158159 mock_client = AsyncMock ()
159160 mock_call_result = MagicMock ()
160161 mock_call_result .content = 'result'
@@ -164,7 +165,7 @@ async def test_profile_override_strips_profile_arg(self, middleware, mock_contex
164165
165166 mock_context .message = Mock ()
166167 mock_context .message .name = 'some_tool'
167- mock_context .message .arguments = {'arg' : 'value' , 'profile ' : 'dev-profile' }
168+ mock_context .message .arguments = {'arg' : 'value' , 'proxy_profile ' : 'dev-profile' }
168169 call_next = AsyncMock ()
169170
170171 with patch .object (middleware , '_get_profile_client' , return_value = mock_client ):
@@ -175,36 +176,36 @@ async def test_profile_override_strips_profile_arg(self, middleware, mock_contex
175176
176177 @pytest .mark .asyncio
177178 async def test_profile_override_connection_failure (self , middleware , mock_context ):
178- """Connection failure returns a sanitized error ."""
179+ """Connection failure raises ToolError with sanitized message ."""
179180 mock_context .message = Mock ()
180181 mock_context .message .name = 'some_tool'
181- mock_context .message .arguments = {'arg' : 'value' , 'profile ' : 'dev-profile' }
182+ mock_context .message .arguments = {'arg' : 'value' , 'proxy_profile ' : 'dev-profile' }
182183 call_next = AsyncMock ()
183184
184185 with patch .object (
185186 middleware , '_get_profile_client' , side_effect = Exception ('connection refused' )
186187 ):
187- result = await middleware .on_call_tool (mock_context , call_next )
188+ with pytest .raises (ToolError , match = 'Failed to create connection' ) as exc_info :
189+ await middleware .on_call_tool (mock_context , call_next )
188190
189- assert 'failed to create connection' in result .content [0 ].text
190- assert 'connection refused' not in result .content [0 ].text
191+ assert 'connection refused' not in str (exc_info .value )
191192
192193 @pytest .mark .asyncio
193194 async def test_profile_override_tool_call_failure (self , middleware , mock_context ):
194- """Tool call failure returns a sanitized error ."""
195+ """Tool call failure raises ToolError with sanitized message ."""
195196 mock_client = AsyncMock ()
196197 mock_client .call_tool .side_effect = Exception ('backend error' )
197198
198199 mock_context .message = Mock ()
199200 mock_context .message .name = 'some_tool'
200- mock_context .message .arguments = {'arg' : 'value' , 'profile ' : 'dev-profile' }
201+ mock_context .message .arguments = {'arg' : 'value' , 'proxy_profile ' : 'dev-profile' }
201202 call_next = AsyncMock ()
202203
203204 with patch .object (middleware , '_get_profile_client' , return_value = mock_client ):
204- result = await middleware .on_call_tool (mock_context , call_next )
205+ with pytest .raises (ToolError , match = 'Tool call failed' ) as exc_info :
206+ await middleware .on_call_tool (mock_context , call_next )
205207
206- assert 'tool call failed' in result .content [0 ].text
207- assert 'backend error' not in result .content [0 ].text
208+ assert 'backend error' not in str (exc_info .value )
208209
209210
210211class TestGetProfileClient :
0 commit comments