Skip to content

Commit a36bb43

Browse files
authored
fix: pass-through server capbilities (#264)
1 parent 3c5909c commit a36bb43

2 files changed

Lines changed: 152 additions & 0 deletions

File tree

mcp_proxy_for_aws/middleware/initialize_middleware.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,23 @@ def __init__(self, client_factory: AWSMCPProxyClientFactory) -> None:
3131
super().__init__()
3232
self._client_factory = client_factory
3333

34+
def _overwrite_init_options(
35+
self, context: MiddlewareContext, init_result: mt.InitializeResult
36+
):
37+
"""Overwrite the session's _init_options with the backend server's info.
38+
39+
The MCP SDK builds the InitializeResult from session._init_options
40+
inside call_next. By modifying _init_options before call_next runs,
41+
the response sent to the client will contain the backend server's
42+
info instead of the proxy's defaults.
43+
"""
44+
fastmcp_ctx = context.fastmcp_context
45+
if fastmcp_ctx is None or fastmcp_ctx._session is None:
46+
logger.debug('No session available, skipping init_options overwrite.')
47+
return
48+
49+
fastmcp_ctx._session._init_options.capabilities = init_result.capabilities
50+
3451
@override
3552
async def on_initialize(
3653
self,
@@ -59,6 +76,12 @@ async def on_initialize(
5976
# the list_tool call will require the client to be connected again, so the mcp error
6077
# will be displayed in the q cli logs.
6178
await client._connect()
79+
80+
# Overwrite the proxy's init_options with the backend server's info
81+
# so the InitializeResult sent to the client reflects the backend.
82+
if client.initialize_result is not None:
83+
self._overwrite_init_options(context, client.initialize_result)
84+
6285
return await call_next(context)
6386
except Exception:
6487
logger.exception('Initialize failed in middleware.')

tests/unit/test_initialize_middleware.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,27 @@ async def test_on_initialize_connects_client():
3535
"""Test that on_initialize calls client._connect()."""
3636
mock_client = Mock()
3737
mock_client._connect = AsyncMock()
38+
mock_client.initialize_result = mt.InitializeResult(
39+
protocolVersion='2024-11-05',
40+
capabilities=mt.ServerCapabilities(),
41+
serverInfo=mt.Implementation(name='backend-server', version='2.0'),
42+
)
3843

3944
mock_factory = Mock()
4045
mock_factory.set_init_params = Mock()
4146
mock_factory.get_client = AsyncMock(return_value=mock_client)
4247

4348
middleware = InitializeMiddleware(mock_factory)
4449

50+
mock_init_options = Mock()
51+
mock_session = Mock()
52+
mock_session._init_options = mock_init_options
53+
mock_fastmcp_ctx = Mock()
54+
mock_fastmcp_ctx._session = mock_session
55+
4556
mock_context = Mock()
4657
mock_context.message = create_initialize_request('test-client')
58+
mock_context.fastmcp_context = mock_fastmcp_ctx
4759

4860
mock_call_next = AsyncMock()
4961

@@ -54,6 +66,9 @@ async def test_on_initialize_connects_client():
5466
mock_client._connect.assert_called_once()
5567
mock_call_next.assert_called_once_with(mock_context)
5668

69+
# Verify init_options capabilities were overwritten with backend server info
70+
assert mock_init_options.capabilities == mt.ServerCapabilities()
71+
5772

5873
@pytest.mark.asyncio
5974
async def test_on_initialize_fails_if_connect_fails():
@@ -94,6 +109,7 @@ async def test_on_initialize_skips_connect_for_special_clients(client_name):
94109
"""Test that on_initialize skips _connect() for Kiro CLI and Q Dev CLI."""
95110
mock_client = Mock()
96111
mock_client._connect = AsyncMock()
112+
mock_client.initialize_result = None
97113

98114
mock_factory = Mock()
99115
mock_factory.set_init_params = Mock()
@@ -110,3 +126,116 @@ async def test_on_initialize_skips_connect_for_special_clients(client_name):
110126

111127
mock_client._connect.assert_not_called()
112128
mock_call_next.assert_called_once_with(mock_context)
129+
130+
131+
@pytest.mark.asyncio
132+
async def test_on_initialize_overwrites_init_options_with_backend_info():
133+
"""Test that on_initialize overwrites session init_options with backend server info."""
134+
backend_capabilities = mt.ServerCapabilities(
135+
logging=mt.LoggingCapability(),
136+
)
137+
backend_result = mt.InitializeResult(
138+
protocolVersion='2024-11-05',
139+
capabilities=backend_capabilities,
140+
serverInfo=mt.Implementation(name='backend-mcp', version='3.1'),
141+
)
142+
143+
mock_client = Mock()
144+
mock_client._connect = AsyncMock()
145+
mock_client.initialize_result = backend_result
146+
147+
mock_factory = Mock()
148+
mock_factory.set_init_params = Mock()
149+
mock_factory.get_client = AsyncMock(return_value=mock_client)
150+
151+
middleware = InitializeMiddleware(mock_factory)
152+
153+
mock_init_options = Mock()
154+
mock_init_options.server_name = 'proxy-name'
155+
mock_init_options.server_version = '1.0'
156+
mock_init_options.capabilities = mt.ServerCapabilities()
157+
mock_session = Mock()
158+
mock_session._init_options = mock_init_options
159+
mock_fastmcp_ctx = Mock()
160+
mock_fastmcp_ctx._session = mock_session
161+
162+
mock_context = Mock()
163+
mock_context.message = create_initialize_request('test-client')
164+
mock_context.fastmcp_context = mock_fastmcp_ctx
165+
166+
mock_call_next = AsyncMock()
167+
168+
await middleware.on_initialize(mock_context, mock_call_next)
169+
170+
assert mock_init_options.capabilities == backend_capabilities
171+
172+
173+
@pytest.mark.asyncio
174+
async def test_on_initialize_disables_prompts_and_resources():
175+
"""Test that prompts and resources capabilities are disabled even if backend supports them."""
176+
backend_capabilities = mt.ServerCapabilities(
177+
tools=mt.ToolsCapability(),
178+
prompts=mt.PromptsCapability(),
179+
resources=mt.ResourcesCapability(),
180+
)
181+
backend_result = mt.InitializeResult(
182+
protocolVersion='2024-11-05',
183+
capabilities=backend_capabilities,
184+
serverInfo=mt.Implementation(name='backend', version='1.0'),
185+
)
186+
187+
mock_client = Mock()
188+
mock_client._connect = AsyncMock()
189+
mock_client.initialize_result = backend_result
190+
191+
mock_factory = Mock()
192+
mock_factory.set_init_params = Mock()
193+
mock_factory.get_client = AsyncMock(return_value=mock_client)
194+
195+
middleware = InitializeMiddleware(mock_factory)
196+
197+
mock_init_options = Mock()
198+
mock_session = Mock()
199+
mock_session._init_options = mock_init_options
200+
mock_fastmcp_ctx = Mock()
201+
mock_fastmcp_ctx._session = mock_session
202+
203+
mock_context = Mock()
204+
mock_context.message = create_initialize_request('test-client')
205+
mock_context.fastmcp_context = mock_fastmcp_ctx
206+
207+
mock_call_next = AsyncMock()
208+
209+
await middleware.on_initialize(mock_context, mock_call_next)
210+
211+
assert mock_init_options.capabilities.prompts is not None
212+
assert mock_init_options.capabilities.resources is not None
213+
assert mock_init_options.capabilities.tools is not None
214+
215+
216+
@pytest.mark.asyncio
217+
async def test_on_initialize_skips_overwrite_when_no_session():
218+
"""Test that overwrite is skipped when no session is available."""
219+
mock_client = Mock()
220+
mock_client._connect = AsyncMock()
221+
mock_client.initialize_result = mt.InitializeResult(
222+
protocolVersion='2024-11-05',
223+
capabilities=mt.ServerCapabilities(),
224+
serverInfo=mt.Implementation(name='backend', version='1.0'),
225+
)
226+
227+
mock_factory = Mock()
228+
mock_factory.set_init_params = Mock()
229+
mock_factory.get_client = AsyncMock(return_value=mock_client)
230+
231+
middleware = InitializeMiddleware(mock_factory)
232+
233+
mock_context = Mock()
234+
mock_context.message = create_initialize_request('test-client')
235+
mock_context.fastmcp_context = None
236+
237+
mock_call_next = AsyncMock()
238+
239+
# Should not raise, just skip overwrite
240+
await middleware.on_initialize(mock_context, mock_call_next)
241+
mock_call_next.assert_called_once_with(mock_context)

0 commit comments

Comments
 (0)