|
6 | 6 | from typing import Any |
7 | 7 | from unittest.mock import AsyncMock, MagicMock |
8 | 8 |
|
| 9 | +import httpx |
9 | 10 | import pytest |
10 | 11 | from fastapi import HTTPException, Request |
| 12 | +from llama_stack_client import APIConnectionError |
11 | 13 | from pytest_mock import MockerFixture |
12 | 14 |
|
13 | 15 | from a2a.types import ( |
@@ -654,6 +656,145 @@ async def test_process_task_streaming_no_input( |
654 | 656 | call_args = task_updater.update_status.call_args |
655 | 657 | assert call_args[0][0] == TaskState.input_required |
656 | 658 |
|
| 659 | + @pytest.mark.asyncio |
| 660 | + async def test_process_task_streaming_handles_api_connection_error_on_models_list( |
| 661 | + self, |
| 662 | + mocker: MockerFixture, |
| 663 | + setup_configuration: AppConfig, # pylint: disable=unused-argument |
| 664 | + ) -> None: |
| 665 | + """Test _process_task_streaming handles APIConnectionError from models.list().""" |
| 666 | + executor = A2AAgentExecutor(auth_token="test-token") |
| 667 | + |
| 668 | + # Mock the context with valid input |
| 669 | + mock_message = MagicMock() |
| 670 | + mock_message.role = "user" |
| 671 | + mock_message.parts = [Part(root=TextPart(text="Hello"))] |
| 672 | + mock_message.metadata = {} |
| 673 | + |
| 674 | + context = MagicMock(spec=RequestContext) |
| 675 | + context.task_id = "task-123" |
| 676 | + context.context_id = "ctx-456" |
| 677 | + context.message = mock_message |
| 678 | + context.get_user_input.return_value = "Hello" |
| 679 | + |
| 680 | + # Mock event queue |
| 681 | + event_queue = AsyncMock(spec=EventQueue) |
| 682 | + |
| 683 | + # Create task updater mock |
| 684 | + task_updater = MagicMock() |
| 685 | + task_updater.update_status = AsyncMock() |
| 686 | + task_updater.event_queue = event_queue |
| 687 | + |
| 688 | + # Mock the context store |
| 689 | + mock_context_store = AsyncMock() |
| 690 | + mock_context_store.get.return_value = None |
| 691 | + mocker.patch( |
| 692 | + "app.endpoints.a2a._get_context_store", return_value=mock_context_store |
| 693 | + ) |
| 694 | + |
| 695 | + # Mock the client to raise APIConnectionError on models.list() |
| 696 | + mock_client = AsyncMock() |
| 697 | + # Create a mock httpx.Request for APIConnectionError |
| 698 | + mock_request = httpx.Request("GET", "http://test-llama-stack/models") |
| 699 | + mock_client.models.list.side_effect = APIConnectionError( |
| 700 | + message="Connection refused: unable to reach Llama Stack", |
| 701 | + request=mock_request, |
| 702 | + ) |
| 703 | + mocker.patch( |
| 704 | + "app.endpoints.a2a.AsyncLlamaStackClientHolder" |
| 705 | + ).return_value.get_client.return_value = mock_client |
| 706 | + |
| 707 | + await executor._process_task_streaming( |
| 708 | + context, task_updater, context.task_id, context.context_id |
| 709 | + ) |
| 710 | + |
| 711 | + # Verify failure status was sent |
| 712 | + task_updater.update_status.assert_called_once() |
| 713 | + call_args = task_updater.update_status.call_args |
| 714 | + assert call_args[0][0] == TaskState.failed |
| 715 | + assert call_args[1]["final"] is True |
| 716 | + # Verify error message contains helpful info |
| 717 | + error_message = call_args[1]["message"] |
| 718 | + assert "Unable to connect to Llama Stack backend service" in str(error_message) |
| 719 | + |
| 720 | + @pytest.mark.asyncio |
| 721 | + async def test_process_task_streaming_handles_api_connection_error_on_retrieve_response( |
| 722 | + self, |
| 723 | + mocker: MockerFixture, |
| 724 | + setup_configuration: AppConfig, # pylint: disable=unused-argument |
| 725 | + ) -> None: |
| 726 | + """Test _process_task_streaming handles APIConnectionError from retrieve_response().""" |
| 727 | + executor = A2AAgentExecutor(auth_token="test-token") |
| 728 | + |
| 729 | + # Mock the context with valid input |
| 730 | + mock_message = MagicMock() |
| 731 | + mock_message.role = "user" |
| 732 | + mock_message.parts = [Part(root=TextPart(text="Hello"))] |
| 733 | + mock_message.metadata = {} |
| 734 | + |
| 735 | + context = MagicMock(spec=RequestContext) |
| 736 | + context.task_id = "task-123" |
| 737 | + context.context_id = "ctx-456" |
| 738 | + context.message = mock_message |
| 739 | + context.get_user_input.return_value = "Hello" |
| 740 | + |
| 741 | + # Mock event queue |
| 742 | + event_queue = AsyncMock(spec=EventQueue) |
| 743 | + |
| 744 | + # Create task updater mock |
| 745 | + task_updater = MagicMock() |
| 746 | + task_updater.update_status = AsyncMock() |
| 747 | + task_updater.event_queue = event_queue |
| 748 | + |
| 749 | + # Mock the context store |
| 750 | + mock_context_store = AsyncMock() |
| 751 | + mock_context_store.get.return_value = None |
| 752 | + mocker.patch( |
| 753 | + "app.endpoints.a2a._get_context_store", return_value=mock_context_store |
| 754 | + ) |
| 755 | + |
| 756 | + # Mock the client to succeed on models.list() |
| 757 | + mock_client = AsyncMock() |
| 758 | + mock_models = MagicMock() |
| 759 | + mock_models.models = [] |
| 760 | + mock_client.models.list.return_value = mock_models |
| 761 | + mocker.patch( |
| 762 | + "app.endpoints.a2a.AsyncLlamaStackClientHolder" |
| 763 | + ).return_value.get_client.return_value = mock_client |
| 764 | + |
| 765 | + # Mock select_model_and_provider_id |
| 766 | + mocker.patch( |
| 767 | + "app.endpoints.a2a.select_model_and_provider_id", |
| 768 | + return_value=("model-id", "model-id", "provider-id"), |
| 769 | + ) |
| 770 | + |
| 771 | + # Mock evaluate_model_hints |
| 772 | + mocker.patch( |
| 773 | + "app.endpoints.a2a.evaluate_model_hints", return_value=(None, None) |
| 774 | + ) |
| 775 | + |
| 776 | + # Mock retrieve_response to raise APIConnectionError |
| 777 | + mock_request = httpx.Request("POST", "http://test-llama-stack/responses") |
| 778 | + mocker.patch( |
| 779 | + "app.endpoints.a2a.retrieve_response", |
| 780 | + side_effect=APIConnectionError( |
| 781 | + message="Connection timeout during streaming", request=mock_request |
| 782 | + ), |
| 783 | + ) |
| 784 | + |
| 785 | + await executor._process_task_streaming( |
| 786 | + context, task_updater, context.task_id, context.context_id |
| 787 | + ) |
| 788 | + |
| 789 | + # Verify failure status was sent |
| 790 | + task_updater.update_status.assert_called_once() |
| 791 | + call_args = task_updater.update_status.call_args |
| 792 | + assert call_args[0][0] == TaskState.failed |
| 793 | + assert call_args[1]["final"] is True |
| 794 | + # Verify error message contains helpful info |
| 795 | + error_message = call_args[1]["message"] |
| 796 | + assert "Unable to connect to Llama Stack backend service" in str(error_message) |
| 797 | + |
657 | 798 | @pytest.mark.asyncio |
658 | 799 | async def test_cancel_raises_not_implemented(self) -> None: |
659 | 800 | """Test that cancel raises NotImplementedError.""" |
|
0 commit comments