From fb3d20728116c866fd6286ec6b6827ca0d228a3e Mon Sep 17 00:00:00 2001 From: Josh Bell Date: Wed, 1 Apr 2026 17:41:01 -0600 Subject: [PATCH 1/3] improve session resilience Signed-off-by: Josh Bell Signed-off-by: jobell --- .../src/kagent/adk/_session_service.py | 108 ++++- .../tests/unittests/test_session_service.py | 388 +++++++++++------- ui/src/components/chat/AgentCallDisplay.tsx | 2 +- ui/src/components/chat/ChatInterface.tsx | 51 ++- ui/src/lib/__tests__/a2aClient.test.ts | 203 +++++++++ ui/src/lib/__tests__/messageHandlers.test.ts | 36 +- ui/src/lib/a2aClient.ts | 43 +- ui/src/lib/messageHandlers.ts | 17 +- 8 files changed, 696 insertions(+), 152 deletions(-) create mode 100644 ui/src/lib/__tests__/a2aClient.test.ts diff --git a/python/packages/kagent-adk/src/kagent/adk/_session_service.py b/python/packages/kagent-adk/src/kagent/adk/_session_service.py index da08895a5..631b8e278 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_session_service.py +++ b/python/packages/kagent-adk/src/kagent/adk/_session_service.py @@ -114,7 +114,7 @@ async def get_session( session = Session( id=session_data["id"], user_id=session_data["user_id"], - events=[], + events=events, app_name=app_name, state={}, ) @@ -157,6 +157,94 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) ) response.raise_for_status() + async def _recreate_session(self, session: Session) -> None: + """Recreate a session that was not found (404). + + This handles the case where a session expires or is cleaned up + during a long-running operation. + + Session State Preservation: + - session_id: Preserved (same ID used for recreation) + - user_id: Preserved + - agent_ref: Preserved + - session_name: Preserved (from session.state["session_name"]) + - Other session.state fields: NOT preserved (lost on recreation) + + Note: Only session_name is currently used by the application. + If additional state fields are added in the future, they must be + explicitly preserved here. + + Args: + session: The session object to recreate + + Raises: + httpx.HTTPStatusError: If recreation fails + """ + request_data = { + "id": session.id, + "user_id": session.user_id, + "agent_ref": session.app_name, + } + if session.state and session.state.get("session_name"): + request_data["name"] = session.state["session_name"] + + # Warn if session has additional state fields that won't be preserved + if session.state and len(session.state) > 1: + extra_fields = [k for k in session.state.keys() if k != "session_name"] + logger.warning( + "Session %s has additional state fields that will not be preserved during recreation: %s. " + "Update _recreate_session() if these fields are critical.", + session.id, + extra_fields, + ) + + response = await self.client.post( + "/api/sessions", + json=request_data, + headers={"X-User-ID": session.user_id}, + ) + if response.status_code == 409: + # Session was already recreated by a concurrent call — treat as success. + logger.info( + "Session %s already exists (409 Conflict) during recreation, " + "likely recreated by a concurrent request. Proceeding with retry.", + session.id, + ) + else: + response.raise_for_status() + logger.info("Successfully recreated session %s", session.id) + + # Fetch existing tasks for this session to check for in-flight work + tasks_response = await self.client.get( + f"/api/sessions/{session.id}/tasks?user_id={session.user_id}", + headers={"X-User-ID": session.user_id}, + ) + if tasks_response.status_code == 200: + tasks_data = tasks_response.json() + if tasks_data.get("data"): + logger.info( + "Session %s has %d existing task(s) after recreation", + session.id, + len(tasks_data["data"]), + ) + # Log info about in-flight tasks + for task in tasks_data["data"]: + task_status = task.get("status", {}) + task_state = task_status.get("state", "unknown") + if task_state in ("working", "submitted"): + logger.info( + "Found in-flight task %s in state '%s' - UI should resubscribe to continue receiving updates", + task.get("id"), + task_state, + ) + else: + logger.warning( + "Failed to fetch tasks for recreated session %s (HTTP %d). " + "In-flight task detection unavailable - UI may not auto-reconnect to active tasks.", + session.id, + tasks_response.status_code, + ) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: @@ -174,6 +262,24 @@ async def append_event(self, session: Session, event: Event) -> Event: json=event_data, headers={"X-User-ID": session.user_id}, ) + + # Handle 404 by recreating session and retrying once + if response.status_code == 404: + logger.warning( + "Session %s not found (404), attempting to recreate before retry", + session.id, + ) + await self._recreate_session(session) + + # Retry the append ONCE. If this retry also fails (including another 404), + # raise_for_status() below will propagate the error without further attempts. + # This prevents infinite recursion while allowing recovery from transient deletion. + response = await self.client.post( + f"/api/sessions/{session.id}/events?user_id={session.user_id}", + json=event_data, + headers={"X-User-ID": session.user_id}, + ) + response.raise_for_status() # TODO: potentially pull and update the session from the server diff --git a/python/packages/kagent-adk/tests/unittests/test_session_service.py b/python/packages/kagent-adk/tests/unittests/test_session_service.py index 293e88d76..e77411cf3 100644 --- a/python/packages/kagent-adk/tests/unittests/test_session_service.py +++ b/python/packages/kagent-adk/tests/unittests/test_session_service.py @@ -1,166 +1,276 @@ -"""Tests for KAgentSessionService.""" - -from unittest.mock import AsyncMock, MagicMock +from unittest import mock import httpx import pytest -from google.adk.events.event import Event, EventActions +from google.adk.events.event import Event +from google.adk.sessions import Session from kagent.adk._session_service import KAgentSessionService @pytest.fixture -def make_event(): - """Factory fixture: make_event(author, state_delta) -> Event.""" - - def _factory(author: str = "user", state_delta: dict | None = None) -> Event: - if state_delta: - return Event(author=author, invocation_id="inv1", actions=EventActions(state_delta=state_delta)) - return Event(author=author, invocation_id="inv1") - - return _factory +def mock_client(): + """Create a mock httpx.AsyncClient.""" + return mock.AsyncMock(spec=httpx.AsyncClient) @pytest.fixture -def session_response(): - """Factory fixture: session_response(events, session_id, user_id) -> dict. - - Builds the JSON envelope that the KAgent API returns for GET /api/sessions/{id}. - """ - - def _factory(events: list[Event], session_id: str = "s1", user_id: str = "u1") -> dict: - return { - "data": { - "session": {"id": session_id, "user_id": user_id}, - "events": [{"id": e.id, "data": e.model_dump_json()} for e in events], - } - } - - return _factory +def session_service(mock_client): + """Create a KAgentSessionService with mocked client.""" + return KAgentSessionService(client=mock_client) @pytest.fixture -def mock_client(): - """Factory fixture: mock_client(response_json, status_code) -> MagicMock httpx.AsyncClient.""" - - def _factory(response_json: dict | None, status_code: int = 200) -> MagicMock: - mock_response = MagicMock(spec=httpx.Response) - mock_response.status_code = status_code - mock_response.json.return_value = response_json - mock_response.raise_for_status = MagicMock() - - client = MagicMock(spec=httpx.AsyncClient) - client.get = AsyncMock(return_value=mock_response) - return client - - return _factory +def sample_session(): + """Create a sample session for testing.""" + return Session( + id="test-session-123", + user_id="test-user", + app_name="test-app", + state={"session_name": "Test Session"}, + ) @pytest.fixture -def service(mock_client): - """Factory fixture: service(response_json, status_code) -> KAgentSessionService.""" - - def _factory(response_json: dict | None, status_code: int = 200) -> KAgentSessionService: - return KAgentSessionService(mock_client(response_json, status_code)) - - return _factory - - -@pytest.mark.asyncio -async def test_get_session_returns_none_on_404(mock_client): - """A 404 response returns None without raising.""" - svc = KAgentSessionService(mock_client(response_json=None, status_code=404)) - session = await svc.get_session(app_name="app", user_id="u1", session_id="missing") - - assert session is None - - -@pytest.mark.asyncio -async def test_get_session_returns_none_when_no_data(service): - """An empty data envelope returns None.""" - session = await service({"data": None}).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is None - - -@pytest.mark.asyncio -async def test_get_session_event_ids_preserved(make_event, session_response, service): - """Event identity (id) is preserved after loading from the API.""" - events = [make_event("user"), make_event("assistant")] - original_ids = [e.id for e in events] - - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is not None - assert [e.id for e in session.events] == original_ids - - -@pytest.mark.asyncio -async def test_get_session_events_not_duplicated(make_event, session_response, service): - """Each event from the API must appear exactly once in session.events. - - Regression test for the bug where Session(events=events) pre-populated - session.events and super().append_event() then appended each event again. - """ - events = [make_event("user"), make_event("assistant"), make_event("tool")] - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is not None - assert len(session.events) == len(events), ( - f"Expected {len(events)} events but got {len(session.events)} — possible event duplication in get_session" +def sample_event(): + """Create a sample event for testing.""" + return Event( + invocation_id="test-invocation", + author="user", ) -@pytest.mark.asyncio -async def test_get_session_single_event_not_duplicated(make_event, session_response, service): - """Single-event case: still only one event in session.events.""" - events = [make_event("user")] - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is not None - assert len(session.events) == 1 - - -@pytest.mark.asyncio -async def test_get_session_empty_events(session_response, service): - """Zero events from the API yields an empty session.events list.""" - session = await service(session_response([])).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is not None - assert len(session.events) == 0 +class TestAppendEvent: + """Tests for append_event method.""" + + @pytest.mark.asyncio + async def test_append_event_success(self, session_service, mock_client, sample_session, sample_event): + """Test successful event append.""" + # Mock successful response + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + result = await session_service.append_event(sample_session, sample_event) + + assert result == sample_event + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert f"/api/sessions/{sample_session.id}/events" in call_args[0][0] + + @pytest.mark.asyncio + async def test_append_event_404_recovery(self, session_service, mock_client, sample_session, sample_event): + """Test 404 triggers session recreation and retry.""" + # First call returns 404, second call (after recreation) succeeds + mock_response_404 = mock.MagicMock() + mock_response_404.status_code = 404 + + mock_response_success = mock.MagicMock() + mock_response_success.status_code = 201 + mock_response_success.raise_for_status = mock.MagicMock() + + mock_response_create = mock.MagicMock() + mock_response_create.status_code = 201 + mock_response_create.raise_for_status = mock.MagicMock() + + # Mock tasks fetch response (no tasks found) + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + + # Configure mock to return different responses + mock_client.post.side_effect = [ + mock_response_404, # First append attempt (404) + mock_response_create, # Session recreation + mock_response_success, # Retry append (success) + ] + mock_client.get.return_value = mock_tasks_response # Tasks fetch + + result = await session_service.append_event(sample_session, sample_event) + + assert result == sample_event + # Should be called 3 times: initial append, session create, retry append + assert mock_client.post.call_count == 3 + # Should fetch tasks once after recreation + assert mock_client.get.call_count == 1 + + # Verify the calls + calls = mock_client.post.call_args_list + assert f"/api/sessions/{sample_session.id}/events" in calls[0][0][0] # First append + assert "/api/sessions" == calls[1][0][0] # Session recreation + assert f"/api/sessions/{sample_session.id}/events" in calls[2][0][0] # Retry append + + # Verify tasks fetch + get_call = mock_client.get.call_args + assert f"/api/sessions/{sample_session.id}/tasks" in get_call[0][0] + + @pytest.mark.asyncio + async def test_append_event_404_recovery_failure(self, session_service, mock_client, sample_session, sample_event): + """Test 404 recovery fails if retry also fails.""" + # First call returns 404, recreation succeeds, but retry fails + mock_response_404 = mock.MagicMock() + mock_response_404.status_code = 404 + + mock_response_create = mock.MagicMock() + mock_response_create.status_code = 201 + mock_response_create.raise_for_status = mock.MagicMock() + + mock_response_retry_fail = mock.MagicMock() + mock_response_retry_fail.status_code = 500 + mock_response_retry_fail.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", + request=mock.MagicMock(), + response=mock_response_retry_fail, + ) + + # Mock tasks fetch + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + + mock_client.post.side_effect = [ + mock_response_404, # First append (404) + mock_response_create, # Session recreation + mock_response_retry_fail, # Retry append (500) + ] + mock_client.get.return_value = mock_tasks_response + + with pytest.raises(httpx.HTTPStatusError): + await session_service.append_event(sample_session, sample_event) + + assert mock_client.post.call_count == 3 + + @pytest.mark.asyncio + async def test_append_event_non_404_error(self, session_service, mock_client, sample_session, sample_event): + """Test non-404 errors are raised immediately without recovery.""" + mock_response = mock.MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", + request=mock.MagicMock(), + response=mock_response, + ) + mock_client.post.return_value = mock_response + + with pytest.raises(httpx.HTTPStatusError): + await session_service.append_event(sample_session, sample_event) + + # Should only be called once (no retry for non-404 errors) + mock_client.post.assert_called_once() + + +class TestRecreateSession: + """Tests for _recreate_session method.""" + + @pytest.mark.asyncio + async def test_recreate_session_success(self, session_service, mock_client, sample_session): + """Test successful session recreation.""" + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + # Mock tasks fetch + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + # Should not raise + await session_service._recreate_session(sample_session) + + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[0][0] == "/api/sessions" + + # Verify request data + request_data = call_args[1]["json"] + assert request_data["id"] == sample_session.id + assert request_data["user_id"] == sample_session.user_id + assert request_data["agent_ref"] == sample_session.app_name + + # Verify tasks were fetched + mock_client.get.assert_called_once() + + @pytest.mark.asyncio + async def test_recreate_session_with_session_name(self, session_service, mock_client, sample_session): + """Test session recreation includes session name from state.""" + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + # Mock tasks fetch + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + await session_service._recreate_session(sample_session) + + call_args = mock_client.post.call_args + request_data = call_args[1]["json"] + assert request_data["name"] == "Test Session" + + @pytest.mark.asyncio + async def test_recreate_session_failure(self, session_service, mock_client, sample_session): + """Test session recreation failure raises error.""" + mock_response = mock.MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", + request=mock.MagicMock(), + response=mock_response, + ) + mock_client.post.return_value = mock_response + + with pytest.raises(httpx.HTTPStatusError): + await session_service._recreate_session(sample_session) + + @pytest.mark.asyncio + async def test_recreate_session_with_inflight_task(self, session_service, mock_client, sample_session): + """Test session recreation detects in-flight tasks.""" + # Mock successful session creation + mock_create_response = mock.MagicMock() + mock_create_response.status_code = 201 + mock_create_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_create_response + + # Mock tasks response with an in-flight task + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = { + "data": [{"id": "task-123", "status": {"state": "working", "message": "Processing..."}}] + } + mock_client.get.return_value = mock_tasks_response + await session_service._recreate_session(sample_session) -@pytest.mark.asyncio -async def test_get_session_state_delta_applied_once(make_event, session_response, service): - """State deltas from events must be applied exactly once to session.state. + # Verify session creation call + assert mock_client.post.call_count == 1 + # Verify tasks fetch call + assert mock_client.get.call_count == 1 + get_call_url = mock_client.get.call_args[0][0] + assert f"/api/sessions/{sample_session.id}/tasks" in get_call_url - Regression test: when events were double-appended, _update_session_state() - was called twice per event, so numeric or overwrite-based state deltas - would be applied twice. - """ - events = [make_event("assistant", state_delta={"counter": 7})] - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + @pytest.mark.asyncio + async def test_recreate_session_409_treated_as_success(self, session_service, mock_client, sample_session): + """Test that 409 Conflict during recreation is treated as success (concurrent recreation).""" + mock_response_409 = mock.MagicMock() + mock_response_409.status_code = 409 - assert session is not None - # State must reflect exactly one application of the delta. - # (BaseSessionService._update_session_state does session.state.update({key: value}), - # so for an idempotent string the bug was silent; here we use a distinct value - # and just verify the key is present with the correct value.) - assert session.state.get("counter") == 7, ( - f"Expected state['counter'] == 7, got {session.state.get('counter')} — " - "state_delta may have been applied more than once" - ) + mock_client.post.return_value = mock_response_409 + # Mock tasks fetch (empty) + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response -@pytest.mark.asyncio -async def test_get_session_multiple_state_deltas_applied_once(make_event, session_response, service): - """Multiple events each contributing a state key are each applied once.""" - events = [ - make_event("assistant", state_delta={"key_a": "value_a"}), - make_event("tool", state_delta={"key_b": "value_b"}), - ] - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + # Should not raise even though POST returned 409 + await session_service._recreate_session(sample_session) - assert session is not None - assert session.state.get("key_a") == "value_a" - assert session.state.get("key_b") == "value_b" + mock_client.post.assert_called_once() + # Tasks fetch should still proceed after 409 + mock_client.get.assert_called_once() diff --git a/ui/src/components/chat/AgentCallDisplay.tsx b/ui/src/components/chat/AgentCallDisplay.tsx index 0c8dcd1be..f9d852a9c 100644 --- a/ui/src/components/chat/AgentCallDisplay.tsx +++ b/ui/src/components/chat/AgentCallDisplay.tsx @@ -57,7 +57,7 @@ function SubagentActivityPanel({ sessionId, isComplete }: SubagentActivityPanelP } } else { const tasks: Task[] = resp.data.tasks; - const extracted = extractMessagesFromTasks(tasks); + const { messages: extracted } = extractMessagesFromTasks(tasks); setMessages(extracted); setWaiting(extracted.length === 0 && !isComplete); setError(null); diff --git a/ui/src/components/chat/ChatInterface.tsx b/ui/src/components/chat/ChatInterface.tsx index c1be21249..21948f647 100644 --- a/ui/src/components/chat/ChatInterface.tsx +++ b/ui/src/components/chat/ChatInterface.tsx @@ -25,8 +25,8 @@ import { useRouter } from "next/navigation"; import { createMessageHandlers, extractMessagesFromTasks, extractApprovalMessagesFromTasks, extractTokenStatsFromTasks, createMessage, ADKMetadata, ProcessedToolCallData } from "@/lib/messageHandlers"; import { kagentA2AClient } from "@/lib/a2aClient"; import { v4 as uuidv4 } from "uuid"; -import { getStatusPlaceholder } from "@/lib/statusUtils"; -import { Message, DataPart } from "@a2a-js/sdk"; +import { getStatusPlaceholder, mapA2AStateToStatus } from "@/lib/statusUtils"; +import { Message, DataPart, TaskState } from "@a2a-js/sdk"; interface ChatInterfaceProps { selectedAgentName: string; @@ -139,7 +139,7 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se setSessionStats({ total: 0, prompt: 0, completion: 0 }); } else { - const extractedMessages = extractMessagesFromTasks(messagesResponse.data); + const { messages: extractedMessages, pendingTask } = extractMessagesFromTasks(messagesResponse.data); setSessionStats(extractTokenStatsFromTasks(messagesResponse.data)); // Resolved approvals are already inline in extractedMessages (with @@ -155,6 +155,45 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se if (hasPendingApproval) { setChatStatus("input_required"); } + + // If there's a pending task (and no pending approval taking priority), reconnect to its stream + if (pendingTask && !hasPendingApproval) { + setIsLoading(false); + setChatStatus(mapA2AStateToStatus(pendingTask.state as TaskState)); + + try { + abortControllerRef.current = new AbortController(); + + const stream = await kagentA2AClient.resubscribeTask( + selectedNamespace, + selectedAgentName, + pendingTask.taskId, + abortControllerRef.current.signal + ); + + for await (const event of stream) { + handleMessageEvent(event); + if (abortControllerRef.current?.signal.aborted) break; + } + } catch (error: unknown) { + if (error instanceof Error && error.name !== 'AbortError') { + toast.error(`Reconnection failed: ${error.message}`); + try { + const refreshedTasks = await getSessionTasks(sessionId); + if (refreshedTasks.data) { + const { messages } = extractMessagesFromTasks(refreshedTasks.data); + setStoredMessages(messages); + } + } catch (refreshError) { + console.error('Failed to refresh tasks after reconnection failure:', refreshError); + } + } + } finally { + setChatStatus('ready'); + abortControllerRef.current = null; + } + return; + } } } catch (error) { console.error("Error loading messages:", error); @@ -165,7 +204,11 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se } initializeChat(); - }, [sessionId, selectedAgentName, selectedNamespace, isFirstMessage]); + + return () => { + abortControllerRef.current?.abort(); + }; + }, [sessionId, selectedAgentName, selectedNamespace, isFirstMessage, handleMessageEvent]); useEffect(() => { if (containerRef.current) { diff --git a/ui/src/lib/__tests__/a2aClient.test.ts b/ui/src/lib/__tests__/a2aClient.test.ts new file mode 100644 index 000000000..10c74c4b6 --- /dev/null +++ b/ui/src/lib/__tests__/a2aClient.test.ts @@ -0,0 +1,203 @@ +import { describe, expect, it, jest, beforeEach } from '@jest/globals'; +import { KagentA2AClient } from '../a2aClient'; + +// Mock fetch globally +const mockFetch = jest.fn() as jest.MockedFunction; +global.fetch = mockFetch; + +// Mock utils +jest.mock('../utils', () => ({ + getBackendUrl: () => 'http://localhost:8083/api', +})); + +describe('KagentA2AClient', () => { + let client: KagentA2AClient; + + beforeEach(() => { + client = new KagentA2AClient(); + mockFetch.mockClear(); + }); + + describe('getAgentUrl', () => { + it('should construct correct agent URL', () => { + const url = client.getAgentUrl('test-namespace', 'test-agent'); + expect(url).toBe('http://localhost:8083/api/a2a/test-namespace/test-agent'); + }); + }); + + describe('createStreamingRequest', () => { + it('should create valid JSON-RPC request structure', () => { + const params = { + message: { + kind: 'message' as const, + messageId: 'msg-1', + role: 'user' as const, + parts: [{ kind: 'text' as const, text: 'Hello' }], + }, + }; + + const request = client.createStreamingRequest(params); + + expect(request.jsonrpc).toBe('2.0'); + expect(request.method).toBe('message/stream'); + expect(request.params).toEqual(params); + expect(typeof request.id).toBe('string'); + expect(request.id).toBeTruthy(); + }); + }); + + describe('resubscribeTask', () => { + it('should throw error on non-ok response', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + statusText: 'Not Found', + text: () => Promise.resolve('Task not found'), + } as Response); + + await expect( + client.resubscribeTask('test-ns', 'test-agent', 'task-123') + ).rejects.toThrow('A2A resubscribe failed: 404 Not Found - Task not found'); + + expect(mockFetch).toHaveBeenCalledWith( + '/a2a/test-ns/test-agent', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'text/event-stream', + }, + }) + ); + + // Verify the request body contains resubscribe params + const callArgs = mockFetch.mock.calls[0]; + const requestBody = JSON.parse(callArgs[1]?.body as string); + expect(requestBody.jsonrpc).toBe('2.0'); + expect(requestBody.method).toBe('tasks/resubscribe'); + expect(requestBody.params).toEqual({ id: 'task-123' }); + }); + + it('should throw error when response body is null', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + body: null, + } as Response); + + await expect( + client.resubscribeTask('test-ns', 'test-agent', 'task-123') + ).rejects.toThrow('Response body is null'); + }); + + it('should pass abort signal to fetch', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + body: null, + } as Response); + + const abortController = new AbortController(); + + await expect( + client.resubscribeTask('test-ns', 'test-agent', 'task-123', abortController.signal) + ).rejects.toThrow('Response body is null'); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + signal: abortController.signal, + }) + ); + }); + }); + + describe('sendMessageStream', () => { + it('should throw error on non-ok response', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + text: () => Promise.resolve('Server error'), + } as Response); + + const params = { + message: { + kind: 'message' as const, + messageId: 'msg-1', + role: 'user' as const, + parts: [{ kind: 'text' as const, text: 'Hello' }], + }, + }; + + await expect( + client.sendMessageStream('test-ns', 'test-agent', params) + ).rejects.toThrow('A2A proxy request failed: 500 Internal Server Error - Server error'); + + expect(mockFetch).toHaveBeenCalledWith( + '/a2a/test-ns/test-agent', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'text/event-stream', + }, + }) + ); + + // Verify the request body contains message/stream method + const callArgs = mockFetch.mock.calls[0]; + const requestBody = JSON.parse(callArgs[1]?.body as string); + expect(requestBody.jsonrpc).toBe('2.0'); + expect(requestBody.method).toBe('message/stream'); + expect(requestBody.params).toEqual(params); + }); + + it('should throw error when response body is null', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + body: null, + } as Response); + + const params = { + message: { + kind: 'message' as const, + messageId: 'msg-1', + role: 'user' as const, + parts: [{ kind: 'text' as const, text: 'Hello' }], + }, + }; + + await expect( + client.sendMessageStream('test-ns', 'test-agent', params) + ).rejects.toThrow('Response body is null'); + }); + + it('should pass abort signal to fetch', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + body: null, + } as Response); + + const params = { + message: { + kind: 'message' as const, + messageId: 'msg-1', + role: 'user' as const, + parts: [{ kind: 'text' as const, text: 'Hello' }], + }, + }; + + const abortController = new AbortController(); + + await expect( + client.sendMessageStream('test-ns', 'test-agent', params, abortController.signal) + ).rejects.toThrow('Response body is null'); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + signal: abortController.signal, + }) + ); + }); + }); +}); diff --git a/ui/src/lib/__tests__/messageHandlers.test.ts b/ui/src/lib/__tests__/messageHandlers.test.ts index de95321ab..ce9e24d7f 100644 --- a/ui/src/lib/__tests__/messageHandlers.test.ts +++ b/ui/src/lib/__tests__/messageHandlers.test.ts @@ -43,11 +43,39 @@ describe('messageHandlers helpers', () => { const tasks: any = [ { history: [{ kind: 'message', messageId: mId }, { kind: 'message', messageId: mId }] }, ]; - const out = extractMessagesFromTasks(tasks); + const { messages: out } = extractMessagesFromTasks(tasks); expect(out.length).toBe(1); expect(out[0].messageId).toBe(mId); }); + test('extractMessagesFromTasks detects pending task in working state', () => { + const tasks: any = [ + { id: 'task-1', status: { state: 'working' }, history: [] }, + ]; + const { pendingTask } = extractMessagesFromTasks(tasks); + expect(pendingTask).toBeDefined(); + expect(pendingTask?.taskId).toBe('task-1'); + expect(pendingTask?.state).toBe('working'); + }); + + test('extractMessagesFromTasks detects pending task in submitted state', () => { + const tasks: any = [ + { id: 'task-1', status: { state: 'submitted' }, history: [] }, + ]; + const { pendingTask } = extractMessagesFromTasks(tasks); + expect(pendingTask).toBeDefined(); + expect(pendingTask?.taskId).toBe('task-1'); + expect(pendingTask?.state).toBe('submitted'); + }); + + test('extractMessagesFromTasks returns no pending task for completed tasks', () => { + const tasks: any = [ + { id: 'task-1', status: { state: 'completed' }, history: [] }, + ]; + const { pendingTask } = extractMessagesFromTasks(tasks); + expect(pendingTask).toBeUndefined(); + }); + test('extractMessagesFromTasks injects tokenStats into non-user agent messages only', () => { const tasks = [ { @@ -60,7 +88,7 @@ describe('messageHandlers helpers', () => { ], }, ] as unknown as Task[]; - const messages = extractMessagesFromTasks(tasks); + const { messages } = extractMessagesFromTasks(tasks); // Agent message with usage metadata gets tokenStats injected expect((messages[0].metadata as ADKMetadata & { tokenStats?: TokenStats })?.tokenStats) .toEqual({ total: 10, prompt: 3, completion: 7 }); @@ -445,7 +473,7 @@ describe('subagent_session_id propagation', () => { }], }] as unknown as Task[]; - const messages = extractMessagesFromTasks(tasks); + const { messages } = extractMessagesFromTasks(tasks); expect(messages).toHaveLength(1); const meta = messages[0].metadata as ADKMetadata; expect(meta.originalType).toBe('ToolCallRequestEvent'); @@ -474,7 +502,7 @@ describe('subagent_session_id propagation', () => { }], }] as unknown as Task[]; - const messages = extractMessagesFromTasks(tasks); + const { messages } = extractMessagesFromTasks(tasks); expect(messages).toHaveLength(1); const meta = messages[0].metadata as ADKMetadata; expect(meta.originalType).toBe('ToolCallExecutionEvent'); diff --git a/ui/src/lib/a2aClient.ts b/ui/src/lib/a2aClient.ts index 1942f4840..4ec97f4ce 100644 --- a/ui/src/lib/a2aClient.ts +++ b/ui/src/lib/a2aClient.ts @@ -6,7 +6,7 @@ import { MessageSendParams } from '@a2a-js/sdk'; export interface A2AJsonRpcRequest { jsonrpc: "2.0"; method: string; - params: MessageSendParams; + params: MessageSendParams | { id: string }; id: string | number; } @@ -76,6 +76,47 @@ export class KagentA2AClient { return this.processSSEStream(response.body); } + /** + * Resubscribe to an in-flight task to resume receiving streaming events + * Uses the A2A tasks/resubscribe JSON-RPC method + */ + async resubscribeTask( + namespace: string, + agentName: string, + taskId: string, + signal?: AbortSignal + ): Promise> { + const request: A2AJsonRpcRequest = { + jsonrpc: "2.0", + method: "tasks/resubscribe", + params: { id: taskId }, + id: uuidv4(), + }; + + const proxyUrl = `/a2a/${namespace}/${agentName}`; + + const response = await fetch(proxyUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'text/event-stream', + }, + body: JSON.stringify(request), + signal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`A2A resubscribe failed: ${response.status} ${response.statusText} - ${errorText}`); + } + + if (!response.body) { + throw new Error('Response body is null'); + } + + return this.processSSEStream(response.body); + } + /** * Process Server-Sent Events stream with proper event boundary detection */ diff --git a/ui/src/lib/messageHandlers.ts b/ui/src/lib/messageHandlers.ts index 638337ae9..51b62438d 100644 --- a/ui/src/lib/messageHandlers.ts +++ b/ui/src/lib/messageHandlers.ts @@ -4,12 +4,25 @@ import { convertToUserFriendlyName, isAgentToolName, messageUtils } from "@/lib/ import { ApprovalDecision, AdkRequestConfirmationData, HitlPartInfo, ToolDecision, TokenStats, ChatStatus } from "@/types"; import { mapA2AStateToStatus } from "@/lib/statusUtils"; +// Result type for extractMessagesFromTasks +export interface TaskExtractionResult { + messages: Message[]; + pendingTask?: { taskId: string; state: string }; +} + // Helper functions for extracting data from stored tasks -export function extractMessagesFromTasks(tasks: Task[]): Message[] { +export function extractMessagesFromTasks(tasks: Task[]): TaskExtractionResult { const messages: Message[] = []; const seenMessageIds = new Set(); + let pendingTask: { taskId: string; state: string } | undefined; for (const task of tasks) { + // Detect in-flight tasks for stream reconnection + const taskState = task.status?.state; + if (taskState === 'working' || taskState === 'submitted') { + pendingTask = { taskId: task.id, state: taskState }; + } + if (!task.history) continue; // Track the most recent LLM usage seen so far within this task so we can @@ -159,7 +172,7 @@ export function extractMessagesFromTasks(tasks: Task[]): Message[] { } } - return messages; + return { messages, pendingTask }; } /** Returns true if the message is a user HITL decision (approve/reject) or ask-user answer. */ From 8ae1de72b20aa50e83a42bfb4a06bcdbecbcb466 Mon Sep 17 00:00:00 2001 From: jobell Date: Mon, 13 Apr 2026 13:04:24 -0600 Subject: [PATCH 2/3] fix: address review feedback on session resilience PR - Revert events=events to events=[] to fix event duplication regression (Session constructor pre-populates; append_event loop would double-append) - Restore all 8 deleted get_session tests including duplication regression guards - Preserve source field in _recreate_session (alongside session_name) - Update state-loss warning to only fire for truly unknown fields - Wrap _recreate_session failure with RuntimeError preserving 404 context - Add stream timeout (10min) to resubscribe loop matching sendA2AMessage behavior - Fix setChatStatus in finally to be conditional (preserve input_required/error) - Change resubscribeTask return type from any to unknown - Narrow pendingTask.state type to 'working' | 'submitted' union; removes as TaskState cast - Add tests: retry-also-404 (no infinite recursion), recreation-fails RuntimeError, source field preservation, state-loss warning, known-fields-only no-warning Signed-off-by: jobell --- .../src/kagent/adk/_session_service.py | 38 +- .../tests/unittests/test_session_service.py | 372 ++++++++++++++---- ui/src/components/chat/ChatInterface.tsx | 29 +- ui/src/lib/a2aClient.ts | 2 +- ui/src/lib/messageHandlers.ts | 4 +- 5 files changed, 346 insertions(+), 99 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_session_service.py b/python/packages/kagent-adk/src/kagent/adk/_session_service.py index 631b8e278..3195d2c55 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_session_service.py +++ b/python/packages/kagent-adk/src/kagent/adk/_session_service.py @@ -114,7 +114,7 @@ async def get_session( session = Session( id=session_data["id"], user_id=session_data["user_id"], - events=events, + events=[], app_name=app_name, state={}, ) @@ -168,11 +168,11 @@ async def _recreate_session(self, session: Session) -> None: - user_id: Preserved - agent_ref: Preserved - session_name: Preserved (from session.state["session_name"]) + - source: Preserved (from session.state["source"]) - Other session.state fields: NOT preserved (lost on recreation) - Note: Only session_name is currently used by the application. If additional state fields are added in the future, they must be - explicitly preserved here. + explicitly preserved here and added to _PRESERVED_STATE_FIELDS. Args: session: The session object to recreate @@ -180,6 +180,8 @@ async def _recreate_session(self, session: Session) -> None: Raises: httpx.HTTPStatusError: If recreation fails """ + _PRESERVED_STATE_FIELDS = {"session_name", "source"} + request_data = { "id": session.id, "user_id": session.user_id, @@ -187,16 +189,19 @@ async def _recreate_session(self, session: Session) -> None: } if session.state and session.state.get("session_name"): request_data["name"] = session.state["session_name"] - - # Warn if session has additional state fields that won't be preserved - if session.state and len(session.state) > 1: - extra_fields = [k for k in session.state.keys() if k != "session_name"] - logger.warning( - "Session %s has additional state fields that will not be preserved during recreation: %s. " - "Update _recreate_session() if these fields are critical.", - session.id, - extra_fields, - ) + if session.state and session.state.get("source"): + request_data["source"] = session.state["source"] + + # Warn if session has unknown state fields that won't be preserved + if session.state: + extra_fields = [k for k in session.state.keys() if k not in _PRESERVED_STATE_FIELDS] + if extra_fields: + logger.warning( + "Session %s has additional state fields that will not be preserved during recreation: %s. " + "Update _recreate_session() if these fields are critical.", + session.id, + extra_fields, + ) response = await self.client.post( "/api/sessions", @@ -269,7 +274,12 @@ async def append_event(self, session: Session, event: Event) -> Event: "Session %s not found (404), attempting to recreate before retry", session.id, ) - await self._recreate_session(session) + try: + await self._recreate_session(session) + except Exception as e: + raise RuntimeError( + f"Session {session.id} not found (404) and recreation failed" + ) from e # Retry the append ONCE. If this retry also fails (including another 404), # raise_for_status() below will propagate the error without further attempts. diff --git a/python/packages/kagent-adk/tests/unittests/test_session_service.py b/python/packages/kagent-adk/tests/unittests/test_session_service.py index e77411cf3..3a4dd48e7 100644 --- a/python/packages/kagent-adk/tests/unittests/test_session_service.py +++ b/python/packages/kagent-adk/tests/unittests/test_session_service.py @@ -1,28 +1,31 @@ from unittest import mock +from unittest.mock import AsyncMock, MagicMock import httpx import pytest -from google.adk.events.event import Event +from google.adk.events.event import Event, EventActions from google.adk.sessions import Session from kagent.adk._session_service import KAgentSessionService +# --------------------------------------------------------------------------- +# Shared fixtures for append_event / _recreate_session tests +# --------------------------------------------------------------------------- + @pytest.fixture def mock_client(): - """Create a mock httpx.AsyncClient.""" + """Simple AsyncMock client for sequential multi-call tests.""" return mock.AsyncMock(spec=httpx.AsyncClient) @pytest.fixture def session_service(mock_client): - """Create a KAgentSessionService with mocked client.""" return KAgentSessionService(client=mock_client) @pytest.fixture def sample_session(): - """Create a sample session for testing.""" return Session( id="test-session-123", user_id="test-user", @@ -33,12 +36,147 @@ def sample_session(): @pytest.fixture def sample_event(): - """Create a sample event for testing.""" - return Event( - invocation_id="test-invocation", - author="user", - ) + return Event(invocation_id="test-invocation", author="user") + + +# --------------------------------------------------------------------------- +# Fixtures for get_session tests (factory-style, mirrors original test file) +# --------------------------------------------------------------------------- + +@pytest.fixture +def make_event(): + """Factory: make_event(author, state_delta) -> Event.""" + def _factory(author: str = "user", state_delta: dict | None = None) -> Event: + if state_delta: + return Event(author=author, invocation_id="inv1", actions=EventActions(state_delta=state_delta)) + return Event(author=author, invocation_id="inv1") + return _factory + + +@pytest.fixture +def session_response(): + """Factory: build the JSON envelope returned by GET /api/sessions/{id}.""" + def _factory(events: list[Event], session_id: str = "s1", user_id: str = "u1") -> dict: + return { + "data": { + "session": {"id": session_id, "user_id": user_id}, + "events": [{"id": e.id, "data": e.model_dump_json()} for e in events], + } + } + return _factory + + +@pytest.fixture +def get_client(): + """Factory: get_client(response_json, status_code) -> MagicMock AsyncClient (GET only).""" + def _factory(response_json: dict | None, status_code: int = 200) -> MagicMock: + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = status_code + mock_response.json.return_value = response_json + mock_response.raise_for_status = MagicMock() + + client = MagicMock(spec=httpx.AsyncClient) + client.get = AsyncMock(return_value=mock_response) + return client + return _factory + + +@pytest.fixture +def get_session_svc(get_client): + """Factory: get_session_svc(response_json, status_code) -> KAgentSessionService.""" + def _factory(response_json: dict | None, status_code: int = 200) -> KAgentSessionService: + return KAgentSessionService(get_client(response_json, status_code)) + return _factory + + +# --------------------------------------------------------------------------- +# get_session tests (restored from original — including regression guards) +# --------------------------------------------------------------------------- + +class TestGetSession: + + @pytest.mark.asyncio + async def test_returns_none_on_404(self, get_client): + """A 404 response returns None without raising.""" + svc = KAgentSessionService(get_client(response_json=None, status_code=404)) + session = await svc.get_session(app_name="app", user_id="u1", session_id="missing") + assert session is None + + @pytest.mark.asyncio + async def test_returns_none_when_no_data(self, get_session_svc): + """An empty data envelope returns None.""" + session = await get_session_svc({"data": None}).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is None + @pytest.mark.asyncio + async def test_event_ids_preserved(self, make_event, session_response, get_session_svc): + """Event identity (id) is preserved after loading from the API.""" + events = [make_event("user"), make_event("assistant")] + original_ids = [e.id for e in events] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert [e.id for e in session.events] == original_ids + + @pytest.mark.asyncio + async def test_events_not_duplicated(self, make_event, session_response, get_session_svc): + """Each event from the API must appear exactly once in session.events. + + Regression guard: Session(events=events) pre-populates session.events, + and super().append_event() then appends again — causing duplication. + """ + events = [make_event("user"), make_event("assistant"), make_event("tool")] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert len(session.events) == len(events), ( + f"Expected {len(events)} events but got {len(session.events)} — possible event duplication in get_session" + ) + + @pytest.mark.asyncio + async def test_single_event_not_duplicated(self, make_event, session_response, get_session_svc): + """Single-event case: still only one event in session.events.""" + events = [make_event("user")] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert len(session.events) == 1 + + @pytest.mark.asyncio + async def test_empty_events(self, session_response, get_session_svc): + """Zero events from the API yields an empty session.events list.""" + session = await get_session_svc(session_response([])).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert len(session.events) == 0 + + @pytest.mark.asyncio + async def test_state_delta_applied_once(self, make_event, session_response, get_session_svc): + """State deltas from events must be applied exactly once to session.state. + + Regression guard: double-appending events caused _update_session_state() + to be called twice per event. + """ + events = [make_event("assistant", state_delta={"counter": 7})] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert session.state.get("counter") == 7, ( + f"Expected state['counter'] == 7, got {session.state.get('counter')} — " + "state_delta may have been applied more than once" + ) + + @pytest.mark.asyncio + async def test_multiple_state_deltas_applied_once(self, make_event, session_response, get_session_svc): + """Multiple events each contributing a state key are each applied once.""" + events = [ + make_event("assistant", state_delta={"key_a": "value_a"}), + make_event("tool", state_delta={"key_b": "value_b"}), + ] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert session.state.get("key_a") == "value_a" + assert session.state.get("key_b") == "value_b" + + +# --------------------------------------------------------------------------- +# append_event tests +# --------------------------------------------------------------------------- class TestAppendEvent: """Tests for append_event method.""" @@ -46,7 +184,6 @@ class TestAppendEvent: @pytest.mark.asyncio async def test_append_event_success(self, session_service, mock_client, sample_session, sample_event): """Test successful event append.""" - # Mock successful response mock_response = mock.MagicMock() mock_response.status_code = 201 mock_response.raise_for_status = mock.MagicMock() @@ -56,59 +193,90 @@ async def test_append_event_success(self, session_service, mock_client, sample_s assert result == sample_event mock_client.post.assert_called_once() - call_args = mock_client.post.call_args - assert f"/api/sessions/{sample_session.id}/events" in call_args[0][0] + assert f"/api/sessions/{sample_session.id}/events" in mock_client.post.call_args[0][0] @pytest.mark.asyncio async def test_append_event_404_recovery(self, session_service, mock_client, sample_session, sample_event): """Test 404 triggers session recreation and retry.""" - # First call returns 404, second call (after recreation) succeeds mock_response_404 = mock.MagicMock() mock_response_404.status_code = 404 + mock_response_create = mock.MagicMock() + mock_response_create.status_code = 201 + mock_response_create.raise_for_status = mock.MagicMock() + mock_response_success = mock.MagicMock() mock_response_success.status_code = 201 mock_response_success.raise_for_status = mock.MagicMock() + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + + mock_client.post.side_effect = [mock_response_404, mock_response_create, mock_response_success] + mock_client.get.return_value = mock_tasks_response + + result = await session_service.append_event(sample_session, sample_event) + + assert result == sample_event + assert mock_client.post.call_count == 3 + assert mock_client.get.call_count == 1 + calls = mock_client.post.call_args_list + assert f"/api/sessions/{sample_session.id}/events" in calls[0][0][0] + assert "/api/sessions" == calls[1][0][0] + assert f"/api/sessions/{sample_session.id}/events" in calls[2][0][0] + + @pytest.mark.asyncio + async def test_append_event_404_retry_also_404(self, session_service, mock_client, sample_session, sample_event): + """Test that a second 404 on retry raises without infinite recursion.""" + mock_response_404 = mock.MagicMock() + mock_response_404.status_code = 404 + mock_response_create = mock.MagicMock() mock_response_create.status_code = 201 mock_response_create.raise_for_status = mock.MagicMock() - # Mock tasks fetch response (no tasks found) + mock_response_retry_404 = mock.MagicMock() + mock_response_retry_404.status_code = 404 + mock_response_retry_404.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not found", request=mock.MagicMock(), response=mock_response_retry_404 + ) + mock_tasks_response = mock.MagicMock() mock_tasks_response.status_code = 200 mock_tasks_response.json.return_value = {"data": []} - # Configure mock to return different responses - mock_client.post.side_effect = [ - mock_response_404, # First append attempt (404) - mock_response_create, # Session recreation - mock_response_success, # Retry append (success) - ] - mock_client.get.return_value = mock_tasks_response # Tasks fetch + mock_client.post.side_effect = [mock_response_404, mock_response_create, mock_response_retry_404] + mock_client.get.return_value = mock_tasks_response - result = await session_service.append_event(sample_session, sample_event) + with pytest.raises(httpx.HTTPStatusError): + await session_service.append_event(sample_session, sample_event) - assert result == sample_event - # Should be called 3 times: initial append, session create, retry append + # Exactly 3 POST calls: initial, recreation, retry — no infinite loop assert mock_client.post.call_count == 3 - # Should fetch tasks once after recreation - assert mock_client.get.call_count == 1 - # Verify the calls - calls = mock_client.post.call_args_list - assert f"/api/sessions/{sample_session.id}/events" in calls[0][0][0] # First append - assert "/api/sessions" == calls[1][0][0] # Session recreation - assert f"/api/sessions/{sample_session.id}/events" in calls[2][0][0] # Retry append + @pytest.mark.asyncio + async def test_append_event_404_recreation_fails_raises_runtime_error( + self, session_service, mock_client, sample_session, sample_event + ): + """Test 404 + recreation failure raises RuntimeError with 404 context.""" + mock_response_404 = mock.MagicMock() + mock_response_404.status_code = 404 + + mock_response_create_500 = mock.MagicMock() + mock_response_create_500.status_code = 500 + mock_response_create_500.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=mock.MagicMock(), response=mock_response_create_500 + ) + + mock_client.post.side_effect = [mock_response_404, mock_response_create_500] - # Verify tasks fetch - get_call = mock_client.get.call_args - assert f"/api/sessions/{sample_session.id}/tasks" in get_call[0][0] + with pytest.raises(RuntimeError, match="not found.*recreation failed"): + await session_service.append_event(sample_session, sample_event) @pytest.mark.asyncio async def test_append_event_404_recovery_failure(self, session_service, mock_client, sample_session, sample_event): - """Test 404 recovery fails if retry also fails.""" - # First call returns 404, recreation succeeds, but retry fails + """Test 404 recovery where retry fails with a 500.""" mock_response_404 = mock.MagicMock() mock_response_404.status_code = 404 @@ -119,21 +287,14 @@ async def test_append_event_404_recovery_failure(self, session_service, mock_cli mock_response_retry_fail = mock.MagicMock() mock_response_retry_fail.status_code = 500 mock_response_retry_fail.raise_for_status.side_effect = httpx.HTTPStatusError( - "Server error", - request=mock.MagicMock(), - response=mock_response_retry_fail, + "Server error", request=mock.MagicMock(), response=mock_response_retry_fail ) - # Mock tasks fetch mock_tasks_response = mock.MagicMock() mock_tasks_response.status_code = 200 mock_tasks_response.json.return_value = {"data": []} - mock_client.post.side_effect = [ - mock_response_404, # First append (404) - mock_response_create, # Session recreation - mock_response_retry_fail, # Retry append (500) - ] + mock_client.post.side_effect = [mock_response_404, mock_response_create, mock_response_retry_fail] mock_client.get.return_value = mock_tasks_response with pytest.raises(httpx.HTTPStatusError): @@ -147,19 +308,20 @@ async def test_append_event_non_404_error(self, session_service, mock_client, sa mock_response = mock.MagicMock() mock_response.status_code = 500 mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "Server error", - request=mock.MagicMock(), - response=mock_response, + "Server error", request=mock.MagicMock(), response=mock_response ) mock_client.post.return_value = mock_response with pytest.raises(httpx.HTTPStatusError): await session_service.append_event(sample_session, sample_event) - # Should only be called once (no retry for non-404 errors) mock_client.post.assert_called_once() +# --------------------------------------------------------------------------- +# _recreate_session tests +# --------------------------------------------------------------------------- + class TestRecreateSession: """Tests for _recreate_session method.""" @@ -171,37 +333,28 @@ async def test_recreate_session_success(self, session_service, mock_client, samp mock_response.raise_for_status = mock.MagicMock() mock_client.post.return_value = mock_response - # Mock tasks fetch mock_tasks_response = mock.MagicMock() mock_tasks_response.status_code = 200 mock_tasks_response.json.return_value = {"data": []} mock_client.get.return_value = mock_tasks_response - # Should not raise await session_service._recreate_session(sample_session) mock_client.post.assert_called_once() - call_args = mock_client.post.call_args - assert call_args[0][0] == "/api/sessions" - - # Verify request data - request_data = call_args[1]["json"] + request_data = mock_client.post.call_args[1]["json"] assert request_data["id"] == sample_session.id assert request_data["user_id"] == sample_session.user_id assert request_data["agent_ref"] == sample_session.app_name - - # Verify tasks were fetched mock_client.get.assert_called_once() @pytest.mark.asyncio async def test_recreate_session_with_session_name(self, session_service, mock_client, sample_session): - """Test session recreation includes session name from state.""" + """Test session recreation includes session_name from state.""" mock_response = mock.MagicMock() mock_response.status_code = 201 mock_response.raise_for_status = mock.MagicMock() mock_client.post.return_value = mock_response - # Mock tasks fetch mock_tasks_response = mock.MagicMock() mock_tasks_response.status_code = 200 mock_tasks_response.json.return_value = {"data": []} @@ -209,19 +362,95 @@ async def test_recreate_session_with_session_name(self, session_service, mock_cl await session_service._recreate_session(sample_session) - call_args = mock_client.post.call_args - request_data = call_args[1]["json"] + request_data = mock_client.post.call_args[1]["json"] assert request_data["name"] == "Test Session" + @pytest.mark.asyncio + async def test_recreate_session_preserves_source(self, session_service, mock_client): + """Test session recreation preserves the source field from state.""" + session_with_source = Session( + id="sess-src", + user_id="u1", + app_name="app", + state={"session_name": "My Session", "source": "web"}, + ) + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + await session_service._recreate_session(session_with_source) + + request_data = mock_client.post.call_args[1]["json"] + assert request_data["source"] == "web" + assert request_data["name"] == "My Session" + + @pytest.mark.asyncio + async def test_recreate_session_state_loss_warning_for_unknown_fields( + self, session_service, mock_client, caplog + ): + """Warning fires for state fields beyond session_name and source.""" + session_with_extra = Session( + id="sess-extra", + user_id="u1", + app_name="app", + state={"session_name": "S", "source": "web", "custom_key": "custom_value"}, + ) + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + import logging + with caplog.at_level(logging.WARNING): + await session_service._recreate_session(session_with_extra) + + assert any("custom_key" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_recreate_session_no_warning_for_known_fields_only( + self, session_service, mock_client, caplog + ): + """No warning fires when only known fields (session_name, source) are in state.""" + session_known_only = Session( + id="sess-known", + user_id="u1", + app_name="app", + state={"session_name": "S", "source": "web"}, + ) + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + import logging + with caplog.at_level(logging.WARNING): + await session_service._recreate_session(session_known_only) + + assert not any("additional state fields" in record.message for record in caplog.records) + @pytest.mark.asyncio async def test_recreate_session_failure(self, session_service, mock_client, sample_session): """Test session recreation failure raises error.""" mock_response = mock.MagicMock() mock_response.status_code = 500 mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( - "Server error", - request=mock.MagicMock(), - response=mock_response, + "Server error", request=mock.MagicMock(), response=mock_response ) mock_client.post.return_value = mock_response @@ -231,13 +460,11 @@ async def test_recreate_session_failure(self, session_service, mock_client, samp @pytest.mark.asyncio async def test_recreate_session_with_inflight_task(self, session_service, mock_client, sample_session): """Test session recreation detects in-flight tasks.""" - # Mock successful session creation mock_create_response = mock.MagicMock() mock_create_response.status_code = 201 mock_create_response.raise_for_status = mock.MagicMock() mock_client.post.return_value = mock_create_response - # Mock tasks response with an in-flight task mock_tasks_response = mock.MagicMock() mock_tasks_response.status_code = 200 mock_tasks_response.json.return_value = { @@ -247,30 +474,23 @@ async def test_recreate_session_with_inflight_task(self, session_service, mock_c await session_service._recreate_session(sample_session) - # Verify session creation call assert mock_client.post.call_count == 1 - # Verify tasks fetch call assert mock_client.get.call_count == 1 - get_call_url = mock_client.get.call_args[0][0] - assert f"/api/sessions/{sample_session.id}/tasks" in get_call_url + assert f"/api/sessions/{sample_session.id}/tasks" in mock_client.get.call_args[0][0] @pytest.mark.asyncio async def test_recreate_session_409_treated_as_success(self, session_service, mock_client, sample_session): - """Test that 409 Conflict during recreation is treated as success (concurrent recreation).""" + """409 Conflict during recreation is treated as success (concurrent recreation).""" mock_response_409 = mock.MagicMock() mock_response_409.status_code = 409 - mock_client.post.return_value = mock_response_409 - # Mock tasks fetch (empty) mock_tasks_response = mock.MagicMock() mock_tasks_response.status_code = 200 mock_tasks_response.json.return_value = {"data": []} mock_client.get.return_value = mock_tasks_response - # Should not raise even though POST returned 409 await session_service._recreate_session(sample_session) mock_client.post.assert_called_once() - # Tasks fetch should still proceed after 409 mock_client.get.assert_called_once() diff --git a/ui/src/components/chat/ChatInterface.tsx b/ui/src/components/chat/ChatInterface.tsx index 21948f647..9a7e1f45d 100644 --- a/ui/src/components/chat/ChatInterface.tsx +++ b/ui/src/components/chat/ChatInterface.tsx @@ -26,7 +26,7 @@ import { createMessageHandlers, extractMessagesFromTasks, extractApprovalMessage import { kagentA2AClient } from "@/lib/a2aClient"; import { v4 as uuidv4 } from "uuid"; import { getStatusPlaceholder, mapA2AStateToStatus } from "@/lib/statusUtils"; -import { Message, DataPart, TaskState } from "@a2a-js/sdk"; +import { Message, DataPart } from "@a2a-js/sdk"; interface ChatInterfaceProps { selectedAgentName: string; @@ -159,7 +159,7 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se // If there's a pending task (and no pending approval taking priority), reconnect to its stream if (pendingTask && !hasPendingApproval) { setIsLoading(false); - setChatStatus(mapA2AStateToStatus(pendingTask.state as TaskState)); + setChatStatus(mapA2AStateToStatus(pendingTask.state)); try { abortControllerRef.current = new AbortController(); @@ -171,9 +171,25 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se abortControllerRef.current.signal ); - for await (const event of stream) { - handleMessageEvent(event); - if (abortControllerRef.current?.signal.aborted) break; + let timeoutTimer: NodeJS.Timeout | null = null; + const streamTimeout = 600000; // 10 minutes + const startTimeout = () => { + if (timeoutTimer) clearTimeout(timeoutTimer); + timeoutTimer = setTimeout(() => { + toast.error("Reconnection timed out - no events received for 10 minutes"); + abortControllerRef.current?.abort(); + }, streamTimeout); + }; + startTimeout(); + + try { + for await (const event of stream) { + startTimeout(); // reset on each event + handleMessageEvent(event); + if (abortControllerRef.current?.signal.aborted) break; + } + } finally { + if (timeoutTimer) clearTimeout(timeoutTimer); } } catch (error: unknown) { if (error instanceof Error && error.name !== 'AbortError') { @@ -189,7 +205,8 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se } } } finally { - setChatStatus('ready'); + // Only reset to ready from transient states; preserve input_required / error set by stream events + setChatStatus(prev => (prev === 'working' || prev === 'submitted' || prev === 'thinking') ? 'ready' : prev); abortControllerRef.current = null; } return; diff --git a/ui/src/lib/a2aClient.ts b/ui/src/lib/a2aClient.ts index 4ec97f4ce..d340ade54 100644 --- a/ui/src/lib/a2aClient.ts +++ b/ui/src/lib/a2aClient.ts @@ -85,7 +85,7 @@ export class KagentA2AClient { agentName: string, taskId: string, signal?: AbortSignal - ): Promise> { + ): Promise> { const request: A2AJsonRpcRequest = { jsonrpc: "2.0", method: "tasks/resubscribe", diff --git a/ui/src/lib/messageHandlers.ts b/ui/src/lib/messageHandlers.ts index 51b62438d..a0355f2d2 100644 --- a/ui/src/lib/messageHandlers.ts +++ b/ui/src/lib/messageHandlers.ts @@ -7,14 +7,14 @@ import { mapA2AStateToStatus } from "@/lib/statusUtils"; // Result type for extractMessagesFromTasks export interface TaskExtractionResult { messages: Message[]; - pendingTask?: { taskId: string; state: string }; + pendingTask?: { taskId: string; state: 'working' | 'submitted' }; } // Helper functions for extracting data from stored tasks export function extractMessagesFromTasks(tasks: Task[]): TaskExtractionResult { const messages: Message[] = []; const seenMessageIds = new Set(); - let pendingTask: { taskId: string; state: string } | undefined; + let pendingTask: { taskId: string; state: 'working' | 'submitted' } | undefined; for (const task of tasks) { // Detect in-flight tasks for stream reconnection From d62800f4cd432b5974ec096622c3613ae62a3a06 Mon Sep 17 00:00:00 2001 From: jobell Date: Mon, 13 Apr 2026 13:28:01 -0600 Subject: [PATCH 3/3] fix: address remaining review feedback (items 11 and 13) - Document last-wins behavior for multiple pending tasks in extractMessagesFromTasks - Downgrade post-recreation task fetch logging to debug level - Remove UI-prescriptive language ("UI should resubscribe") from backend session service Signed-off-by: jobell --- .../src/kagent/adk/_session_service.py | 16 +++++++--------- ui/src/lib/messageHandlers.ts | 4 +++- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/_session_service.py b/python/packages/kagent-adk/src/kagent/adk/_session_service.py index 3195d2c55..a47d60f9b 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_session_service.py +++ b/python/packages/kagent-adk/src/kagent/adk/_session_service.py @@ -219,7 +219,7 @@ async def _recreate_session(self, session: Session) -> None: response.raise_for_status() logger.info("Successfully recreated session %s", session.id) - # Fetch existing tasks for this session to check for in-flight work + # Fetch existing tasks for this session to log in-flight work for observability tasks_response = await self.client.get( f"/api/sessions/{session.id}/tasks?user_id={session.user_id}", headers={"X-User-ID": session.user_id}, @@ -227,25 +227,23 @@ async def _recreate_session(self, session: Session) -> None: if tasks_response.status_code == 200: tasks_data = tasks_response.json() if tasks_data.get("data"): - logger.info( + logger.debug( "Session %s has %d existing task(s) after recreation", session.id, len(tasks_data["data"]), ) - # Log info about in-flight tasks for task in tasks_data["data"]: - task_status = task.get("status", {}) - task_state = task_status.get("state", "unknown") + task_state = task.get("status", {}).get("state", "unknown") if task_state in ("working", "submitted"): - logger.info( - "Found in-flight task %s in state '%s' - UI should resubscribe to continue receiving updates", + logger.debug( + "Session %s has in-flight task %s in state '%s'", + session.id, task.get("id"), task_state, ) else: logger.warning( - "Failed to fetch tasks for recreated session %s (HTTP %d). " - "In-flight task detection unavailable - UI may not auto-reconnect to active tasks.", + "Failed to fetch tasks for recreated session %s (HTTP %d)", session.id, tasks_response.status_code, ) diff --git a/ui/src/lib/messageHandlers.ts b/ui/src/lib/messageHandlers.ts index a0355f2d2..9329cd4a8 100644 --- a/ui/src/lib/messageHandlers.ts +++ b/ui/src/lib/messageHandlers.ts @@ -17,7 +17,9 @@ export function extractMessagesFromTasks(tasks: Task[]): TaskExtractionResult { let pendingTask: { taskId: string; state: 'working' | 'submitted' } | undefined; for (const task of tasks) { - // Detect in-flight tasks for stream reconnection + // Detect in-flight tasks for stream reconnection. + // If multiple tasks are in-flight (unusual), the last one wins — sessions + // are expected to have at most one concurrent active task. const taskState = task.status?.state; if (taskState === 'working' || taskState === 'submitted') { pendingTask = { taskId: task.id, state: taskState };