diff --git a/homeassistant/components/mcp/config_flow.py b/homeassistant/components/mcp/config_flow.py index e3a176dd6d7d6..feae358857404 100644 --- a/homeassistant/components/mcp/config_flow.py +++ b/homeassistant/components/mcp/config_flow.py @@ -151,8 +151,11 @@ async def validate_input( except vol.Invalid as error: raise InvalidUrl from error try: - async with mcp_client(hass, url, token_manager=token_manager) as session: - response = await session.initialize() + async with mcp_client(hass, url, token_manager=token_manager) as ( + _session, + response, + ): + pass except httpx.TimeoutException as error: _LOGGER.info("Timeout connecting to MCP server: %s", error) raise TimeoutConnectError from error diff --git a/homeassistant/components/mcp/coordinator.py b/homeassistant/components/mcp/coordinator.py index 2e299a0660553..79b732da0ab78 100644 --- a/homeassistant/components/mcp/coordinator.py +++ b/homeassistant/components/mcp/coordinator.py @@ -11,6 +11,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client +from mcp.types import InitializeResult import voluptuous as vol from voluptuous_openapi import convert_to_voluptuous @@ -38,7 +39,7 @@ async def mcp_client( hass: HomeAssistant, url: str, token_manager: TokenManager | None = None, -) -> AsyncGenerator[ClientSession]: +) -> AsyncGenerator[tuple[ClientSession, InitializeResult]]: """Create an MCP client. This is an asynccontext manager that exists to wrap other async context managers @@ -57,8 +58,8 @@ async def mcp_client( ) as (read_stream, write_stream, _), ClientSession(read_stream, write_stream) as session, ): - await session.initialize() - yield session + result = await session.initialize() + yield session, result except ExceptionGroup as streamable_err: main_error = streamable_err.exceptions[0] # Method not Allowed likely means this is not a streamable HTTP server, @@ -78,8 +79,8 @@ async def mcp_client( sse_client(url=url, headers=headers) as streams, ClientSession(*streams) as session, ): - await session.initialize() - yield session + result = await session.initialize() + yield session, result except ExceptionGroup as sse_err: _LOGGER.debug("Error creating SSE MCP client: %s", sse_err) raise sse_err.exceptions[0] from sse_err @@ -115,9 +116,10 @@ async def async_call( """Call the tool.""" try: async with asyncio.timeout(TIMEOUT): - async with mcp_client( - hass, self.server_url, self.token_manager - ) as session: + async with mcp_client(hass, self.server_url, self.token_manager) as ( + session, + _, + ): result = await session.call_tool( tool_input.tool_name, tool_input.tool_args ) @@ -161,7 +163,7 @@ async def _async_update_data(self) -> list[llm.Tool]: async with asyncio.timeout(TIMEOUT): async with mcp_client( self.hass, self.config_entry.data[CONF_URL], self.token_manager - ) as session: + ) as (session, _): result = await session.list_tools() except TimeoutError as error: _LOGGER.debug("Timeout when listing tools: %s", error) diff --git a/tests/components/mcp/test_config_flow.py b/tests/components/mcp/test_config_flow.py index 3dc43daaaf70c..08235a2fabc28 100644 --- a/tests/components/mcp/test_config_flow.py +++ b/tests/components/mcp/test_config_flow.py @@ -120,6 +120,32 @@ async def test_form( assert len(mock_setup_entry.mock_calls) == 1 +async def test_initialize_called_once_in_config_flow( + hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_mcp_client: Mock +) -> None: + """Test that initialize is called exactly once during the config flow. + + Regression test for the double-initialize bug: the mcp_client context + manager calls initialize() before yielding the session, and validate_input + previously called it a second time. MCP servers reject the second call with + -32600 "Invalid Request: Server already initialized". + """ + response = Mock() + response.serverInfo.name = TEST_API_NAME + mock_mcp_client.return_value.initialize.return_value = response + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_URL: MCP_SERVER_URL}, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + mock_mcp_client.return_value.initialize.assert_called_once() + + @pytest.mark.parametrize( ("side_effect", "expected_error"), [