Skip to content

Commit 7faa470

Browse files
committed
Better handling A2A issues when connecting to LlamaStack
1 parent 3421bf3 commit 7faa470

2 files changed

Lines changed: 179 additions & 12 deletions

File tree

src/app/endpoints/a2a.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from llama_stack.apis.agents.openai_responses import (
1212
OpenAIResponseObjectStream,
1313
)
14+
from llama_stack_client import APIConnectionError
1415
from starlette.responses import Response, StreamingResponse
1516

1617
from a2a.types import (
@@ -310,19 +311,44 @@ async def _process_task_streaming( # pylint: disable=too-many-locals
310311

311312
# Get LLM client and select model
312313
client = AsyncLlamaStackClientHolder().get_client()
313-
llama_stack_model_id, _model_id, _provider_id = select_model_and_provider_id(
314-
await client.models.list(),
315-
*evaluate_model_hints(user_conversation=None, query_request=query_request),
316-
)
314+
try:
315+
llama_stack_model_id, _model_id, _provider_id = (
316+
select_model_and_provider_id(
317+
await client.models.list(),
318+
*evaluate_model_hints(
319+
user_conversation=None, query_request=query_request
320+
),
321+
)
322+
)
317323

318-
# Stream response from LLM using the Responses API
319-
stream, conversation_id = await retrieve_response(
320-
client,
321-
llama_stack_model_id,
322-
query_request,
323-
self.auth_token,
324-
mcp_headers=self.mcp_headers,
325-
)
324+
# Stream response from LLM using the Responses API
325+
stream, conversation_id = await retrieve_response(
326+
client,
327+
llama_stack_model_id,
328+
query_request,
329+
self.auth_token,
330+
mcp_headers=self.mcp_headers,
331+
)
332+
except APIConnectionError as e:
333+
error_message = (
334+
f"Unable to connect to Llama Stack backend service: {str(e)}. "
335+
"The service may be temporarily unavailable. Please try again later."
336+
)
337+
logger.error(
338+
"APIConnectionError in A2A request: %s",
339+
str(e),
340+
exc_info=True,
341+
)
342+
await task_updater.update_status(
343+
TaskState.failed,
344+
message=new_agent_text_message(
345+
error_message,
346+
context_id=context_id,
347+
task_id=task_id,
348+
),
349+
final=True,
350+
)
351+
return
326352

327353
# Persist conversation_id for next turn in same A2A context
328354
if conversation_id:

tests/unit/app/endpoints/test_a2a.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from typing import Any
77
from unittest.mock import AsyncMock, MagicMock
88

9+
import httpx
910
import pytest
1011
from fastapi import HTTPException, Request
12+
from llama_stack_client import APIConnectionError
1113
from pytest_mock import MockerFixture
1214

1315
from a2a.types import (
@@ -654,6 +656,145 @@ async def test_process_task_streaming_no_input(
654656
call_args = task_updater.update_status.call_args
655657
assert call_args[0][0] == TaskState.input_required
656658

659+
@pytest.mark.asyncio
660+
async def test_process_task_streaming_handles_api_connection_error_on_models_list(
661+
self,
662+
mocker: MockerFixture,
663+
setup_configuration: AppConfig, # pylint: disable=unused-argument
664+
) -> None:
665+
"""Test _process_task_streaming handles APIConnectionError from models.list()."""
666+
executor = A2AAgentExecutor(auth_token="test-token")
667+
668+
# Mock the context with valid input
669+
mock_message = MagicMock()
670+
mock_message.role = "user"
671+
mock_message.parts = [Part(root=TextPart(text="Hello"))]
672+
mock_message.metadata = {}
673+
674+
context = MagicMock(spec=RequestContext)
675+
context.task_id = "task-123"
676+
context.context_id = "ctx-456"
677+
context.message = mock_message
678+
context.get_user_input.return_value = "Hello"
679+
680+
# Mock event queue
681+
event_queue = AsyncMock(spec=EventQueue)
682+
683+
# Create task updater mock
684+
task_updater = MagicMock()
685+
task_updater.update_status = AsyncMock()
686+
task_updater.event_queue = event_queue
687+
688+
# Mock the context store
689+
mock_context_store = AsyncMock()
690+
mock_context_store.get.return_value = None
691+
mocker.patch(
692+
"app.endpoints.a2a._get_context_store", return_value=mock_context_store
693+
)
694+
695+
# Mock the client to raise APIConnectionError on models.list()
696+
mock_client = AsyncMock()
697+
# Create a mock httpx.Request for APIConnectionError
698+
mock_request = httpx.Request("GET", "http://test-llama-stack/models")
699+
mock_client.models.list.side_effect = APIConnectionError(
700+
message="Connection refused: unable to reach Llama Stack",
701+
request=mock_request,
702+
)
703+
mocker.patch(
704+
"app.endpoints.a2a.AsyncLlamaStackClientHolder"
705+
).return_value.get_client.return_value = mock_client
706+
707+
await executor._process_task_streaming(
708+
context, task_updater, context.task_id, context.context_id
709+
)
710+
711+
# Verify failure status was sent
712+
task_updater.update_status.assert_called_once()
713+
call_args = task_updater.update_status.call_args
714+
assert call_args[0][0] == TaskState.failed
715+
assert call_args[1]["final"] is True
716+
# Verify error message contains helpful info
717+
error_message = call_args[1]["message"]
718+
assert "Unable to connect to Llama Stack backend service" in str(error_message)
719+
720+
@pytest.mark.asyncio
721+
async def test_process_task_streaming_handles_api_connection_error_on_retrieve_response(
722+
self,
723+
mocker: MockerFixture,
724+
setup_configuration: AppConfig, # pylint: disable=unused-argument
725+
) -> None:
726+
"""Test _process_task_streaming handles APIConnectionError from retrieve_response()."""
727+
executor = A2AAgentExecutor(auth_token="test-token")
728+
729+
# Mock the context with valid input
730+
mock_message = MagicMock()
731+
mock_message.role = "user"
732+
mock_message.parts = [Part(root=TextPart(text="Hello"))]
733+
mock_message.metadata = {}
734+
735+
context = MagicMock(spec=RequestContext)
736+
context.task_id = "task-123"
737+
context.context_id = "ctx-456"
738+
context.message = mock_message
739+
context.get_user_input.return_value = "Hello"
740+
741+
# Mock event queue
742+
event_queue = AsyncMock(spec=EventQueue)
743+
744+
# Create task updater mock
745+
task_updater = MagicMock()
746+
task_updater.update_status = AsyncMock()
747+
task_updater.event_queue = event_queue
748+
749+
# Mock the context store
750+
mock_context_store = AsyncMock()
751+
mock_context_store.get.return_value = None
752+
mocker.patch(
753+
"app.endpoints.a2a._get_context_store", return_value=mock_context_store
754+
)
755+
756+
# Mock the client to succeed on models.list()
757+
mock_client = AsyncMock()
758+
mock_models = MagicMock()
759+
mock_models.models = []
760+
mock_client.models.list.return_value = mock_models
761+
mocker.patch(
762+
"app.endpoints.a2a.AsyncLlamaStackClientHolder"
763+
).return_value.get_client.return_value = mock_client
764+
765+
# Mock select_model_and_provider_id
766+
mocker.patch(
767+
"app.endpoints.a2a.select_model_and_provider_id",
768+
return_value=("model-id", "model-id", "provider-id"),
769+
)
770+
771+
# Mock evaluate_model_hints
772+
mocker.patch(
773+
"app.endpoints.a2a.evaluate_model_hints", return_value=(None, None)
774+
)
775+
776+
# Mock retrieve_response to raise APIConnectionError
777+
mock_request = httpx.Request("POST", "http://test-llama-stack/responses")
778+
mocker.patch(
779+
"app.endpoints.a2a.retrieve_response",
780+
side_effect=APIConnectionError(
781+
message="Connection timeout during streaming", request=mock_request
782+
),
783+
)
784+
785+
await executor._process_task_streaming(
786+
context, task_updater, context.task_id, context.context_id
787+
)
788+
789+
# Verify failure status was sent
790+
task_updater.update_status.assert_called_once()
791+
call_args = task_updater.update_status.call_args
792+
assert call_args[0][0] == TaskState.failed
793+
assert call_args[1]["final"] is True
794+
# Verify error message contains helpful info
795+
error_message = call_args[1]["message"]
796+
assert "Unable to connect to Llama Stack backend service" in str(error_message)
797+
657798
@pytest.mark.asyncio
658799
async def test_cancel_raises_not_implemented(self) -> None:
659800
"""Test that cancel raises NotImplementedError."""

0 commit comments

Comments
 (0)