@@ -35,15 +35,27 @@ async def test_on_initialize_connects_client():
3535 """Test that on_initialize calls client._connect()."""
3636 mock_client = Mock ()
3737 mock_client ._connect = AsyncMock ()
38+ mock_client .initialize_result = mt .InitializeResult (
39+ protocolVersion = '2024-11-05' ,
40+ capabilities = mt .ServerCapabilities (),
41+ serverInfo = mt .Implementation (name = 'backend-server' , version = '2.0' ),
42+ )
3843
3944 mock_factory = Mock ()
4045 mock_factory .set_init_params = Mock ()
4146 mock_factory .get_client = AsyncMock (return_value = mock_client )
4247
4348 middleware = InitializeMiddleware (mock_factory )
4449
50+ mock_init_options = Mock ()
51+ mock_session = Mock ()
52+ mock_session ._init_options = mock_init_options
53+ mock_fastmcp_ctx = Mock ()
54+ mock_fastmcp_ctx ._session = mock_session
55+
4556 mock_context = Mock ()
4657 mock_context .message = create_initialize_request ('test-client' )
58+ mock_context .fastmcp_context = mock_fastmcp_ctx
4759
4860 mock_call_next = AsyncMock ()
4961
@@ -54,6 +66,9 @@ async def test_on_initialize_connects_client():
5466 mock_client ._connect .assert_called_once ()
5567 mock_call_next .assert_called_once_with (mock_context )
5668
69+ # Verify init_options capabilities were overwritten with backend server info
70+ assert mock_init_options .capabilities == mt .ServerCapabilities ()
71+
5772
5873@pytest .mark .asyncio
5974async def test_on_initialize_fails_if_connect_fails ():
@@ -94,6 +109,7 @@ async def test_on_initialize_skips_connect_for_special_clients(client_name):
94109 """Test that on_initialize skips _connect() for Kiro CLI and Q Dev CLI."""
95110 mock_client = Mock ()
96111 mock_client ._connect = AsyncMock ()
112+ mock_client .initialize_result = None
97113
98114 mock_factory = Mock ()
99115 mock_factory .set_init_params = Mock ()
@@ -110,3 +126,116 @@ async def test_on_initialize_skips_connect_for_special_clients(client_name):
110126
111127 mock_client ._connect .assert_not_called ()
112128 mock_call_next .assert_called_once_with (mock_context )
129+
130+
131+ @pytest .mark .asyncio
132+ async def test_on_initialize_overwrites_init_options_with_backend_info ():
133+ """Test that on_initialize overwrites session init_options with backend server info."""
134+ backend_capabilities = mt .ServerCapabilities (
135+ logging = mt .LoggingCapability (),
136+ )
137+ backend_result = mt .InitializeResult (
138+ protocolVersion = '2024-11-05' ,
139+ capabilities = backend_capabilities ,
140+ serverInfo = mt .Implementation (name = 'backend-mcp' , version = '3.1' ),
141+ )
142+
143+ mock_client = Mock ()
144+ mock_client ._connect = AsyncMock ()
145+ mock_client .initialize_result = backend_result
146+
147+ mock_factory = Mock ()
148+ mock_factory .set_init_params = Mock ()
149+ mock_factory .get_client = AsyncMock (return_value = mock_client )
150+
151+ middleware = InitializeMiddleware (mock_factory )
152+
153+ mock_init_options = Mock ()
154+ mock_init_options .server_name = 'proxy-name'
155+ mock_init_options .server_version = '1.0'
156+ mock_init_options .capabilities = mt .ServerCapabilities ()
157+ mock_session = Mock ()
158+ mock_session ._init_options = mock_init_options
159+ mock_fastmcp_ctx = Mock ()
160+ mock_fastmcp_ctx ._session = mock_session
161+
162+ mock_context = Mock ()
163+ mock_context .message = create_initialize_request ('test-client' )
164+ mock_context .fastmcp_context = mock_fastmcp_ctx
165+
166+ mock_call_next = AsyncMock ()
167+
168+ await middleware .on_initialize (mock_context , mock_call_next )
169+
170+ assert mock_init_options .capabilities == backend_capabilities
171+
172+
173+ @pytest .mark .asyncio
174+ async def test_on_initialize_disables_prompts_and_resources ():
175+ """Test that prompts and resources capabilities are disabled even if backend supports them."""
176+ backend_capabilities = mt .ServerCapabilities (
177+ tools = mt .ToolsCapability (),
178+ prompts = mt .PromptsCapability (),
179+ resources = mt .ResourcesCapability (),
180+ )
181+ backend_result = mt .InitializeResult (
182+ protocolVersion = '2024-11-05' ,
183+ capabilities = backend_capabilities ,
184+ serverInfo = mt .Implementation (name = 'backend' , version = '1.0' ),
185+ )
186+
187+ mock_client = Mock ()
188+ mock_client ._connect = AsyncMock ()
189+ mock_client .initialize_result = backend_result
190+
191+ mock_factory = Mock ()
192+ mock_factory .set_init_params = Mock ()
193+ mock_factory .get_client = AsyncMock (return_value = mock_client )
194+
195+ middleware = InitializeMiddleware (mock_factory )
196+
197+ mock_init_options = Mock ()
198+ mock_session = Mock ()
199+ mock_session ._init_options = mock_init_options
200+ mock_fastmcp_ctx = Mock ()
201+ mock_fastmcp_ctx ._session = mock_session
202+
203+ mock_context = Mock ()
204+ mock_context .message = create_initialize_request ('test-client' )
205+ mock_context .fastmcp_context = mock_fastmcp_ctx
206+
207+ mock_call_next = AsyncMock ()
208+
209+ await middleware .on_initialize (mock_context , mock_call_next )
210+
211+ assert mock_init_options .capabilities .prompts is not None
212+ assert mock_init_options .capabilities .resources is not None
213+ assert mock_init_options .capabilities .tools is not None
214+
215+
216+ @pytest .mark .asyncio
217+ async def test_on_initialize_skips_overwrite_when_no_session ():
218+ """Test that overwrite is skipped when no session is available."""
219+ mock_client = Mock ()
220+ mock_client ._connect = AsyncMock ()
221+ mock_client .initialize_result = mt .InitializeResult (
222+ protocolVersion = '2024-11-05' ,
223+ capabilities = mt .ServerCapabilities (),
224+ serverInfo = mt .Implementation (name = 'backend' , version = '1.0' ),
225+ )
226+
227+ mock_factory = Mock ()
228+ mock_factory .set_init_params = Mock ()
229+ mock_factory .get_client = AsyncMock (return_value = mock_client )
230+
231+ middleware = InitializeMiddleware (mock_factory )
232+
233+ mock_context = Mock ()
234+ mock_context .message = create_initialize_request ('test-client' )
235+ mock_context .fastmcp_context = None
236+
237+ mock_call_next = AsyncMock ()
238+
239+ # Should not raise, just skip overwrite
240+ await middleware .on_initialize (mock_context , mock_call_next )
241+ mock_call_next .assert_called_once_with (mock_context )
0 commit comments