Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions homeassistant/components/mcp/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,11 @@ async def validate_input(
except vol.Invalid as error:
raise InvalidUrl from error
try:
async with mcp_client(hass, url, token_manager=token_manager) as session:
response = await session.initialize()
async with mcp_client(hass, url, token_manager=token_manager) as (
_session,
response,
):
pass
Comment on lines +154 to +158
except httpx.TimeoutException as error:
_LOGGER.info("Timeout connecting to MCP server: %s", error)
raise TimeoutConnectError from error
Expand Down
20 changes: 11 additions & 9 deletions homeassistant/components/mcp/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamable_http_client
from mcp.types import InitializeResult
import voluptuous as vol
from voluptuous_openapi import convert_to_voluptuous

Expand Down Expand Up @@ -38,7 +39,7 @@ async def mcp_client(
hass: HomeAssistant,
url: str,
token_manager: TokenManager | None = None,
) -> AsyncGenerator[ClientSession]:
) -> AsyncGenerator[tuple[ClientSession, InitializeResult]]:
"""Create an MCP client.

This is an asynccontext manager that exists to wrap other async context managers
Comment on lines +42 to 45
Expand All @@ -57,8 +58,8 @@ async def mcp_client(
) as (read_stream, write_stream, _),
ClientSession(read_stream, write_stream) as session,
):
await session.initialize()
yield session
result = await session.initialize()
yield session, result
Comment on lines +61 to +62
except ExceptionGroup as streamable_err:
main_error = streamable_err.exceptions[0]
# Method not Allowed likely means this is not a streamable HTTP server,
Expand All @@ -78,8 +79,8 @@ async def mcp_client(
sse_client(url=url, headers=headers) as streams,
ClientSession(*streams) as session,
):
await session.initialize()
yield session
result = await session.initialize()
yield session, result
except ExceptionGroup as sse_err:
_LOGGER.debug("Error creating SSE MCP client: %s", sse_err)
raise sse_err.exceptions[0] from sse_err
Expand Down Expand Up @@ -115,9 +116,10 @@ async def async_call(
"""Call the tool."""
try:
async with asyncio.timeout(TIMEOUT):
async with mcp_client(
hass, self.server_url, self.token_manager
) as session:
async with mcp_client(hass, self.server_url, self.token_manager) as (
session,
_,
):
result = await session.call_tool(
tool_input.tool_name, tool_input.tool_args
)
Expand Down Expand Up @@ -161,7 +163,7 @@ async def _async_update_data(self) -> list[llm.Tool]:
async with asyncio.timeout(TIMEOUT):
async with mcp_client(
self.hass, self.config_entry.data[CONF_URL], self.token_manager
) as session:
) as (session, _):
result = await session.list_tools()
except TimeoutError as error:
_LOGGER.debug("Timeout when listing tools: %s", error)
Expand Down
26 changes: 26 additions & 0 deletions tests/components/mcp/test_config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,32 @@ async def test_form(
assert len(mock_setup_entry.mock_calls) == 1


async def test_initialize_called_once_in_config_flow(
hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_mcp_client: Mock
) -> None:
"""Test that initialize is called exactly once during the config flow.

Regression test for the double-initialize bug: the mcp_client context
manager calls initialize() before yielding the session, and validate_input
previously called it a second time. MCP servers reject the second call with
-32600 "Invalid Request: Server already initialized".
"""
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
Comment on lines +133 to +135

result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{CONF_URL: MCP_SERVER_URL},
)

assert result["type"] is FlowResultType.CREATE_ENTRY
mock_mcp_client.return_value.initialize.assert_called_once()


@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
Expand Down