diff --git a/python/packages/kagent-adk/src/kagent/adk/_session_service.py b/python/packages/kagent-adk/src/kagent/adk/_session_service.py index da08895a5..a47d60f9b 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_session_service.py +++ b/python/packages/kagent-adk/src/kagent/adk/_session_service.py @@ -157,6 +157,97 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) ) response.raise_for_status() + async def _recreate_session(self, session: Session) -> None: + """Recreate a session that was not found (404). + + This handles the case where a session expires or is cleaned up + during a long-running operation. + + Session State Preservation: + - session_id: Preserved (same ID used for recreation) + - user_id: Preserved + - agent_ref: Preserved + - session_name: Preserved (from session.state["session_name"]) + - source: Preserved (from session.state["source"]) + - Other session.state fields: NOT preserved (lost on recreation) + + If additional state fields are added in the future, they must be + explicitly preserved here and added to _PRESERVED_STATE_FIELDS. + + Args: + session: The session object to recreate + + Raises: + httpx.HTTPStatusError: If recreation fails + """ + _PRESERVED_STATE_FIELDS = {"session_name", "source"} + + request_data = { + "id": session.id, + "user_id": session.user_id, + "agent_ref": session.app_name, + } + if session.state and session.state.get("session_name"): + request_data["name"] = session.state["session_name"] + if session.state and session.state.get("source"): + request_data["source"] = session.state["source"] + + # Warn if session has unknown state fields that won't be preserved + if session.state: + extra_fields = [k for k in session.state.keys() if k not in _PRESERVED_STATE_FIELDS] + if extra_fields: + logger.warning( + "Session %s has additional state fields that will not be preserved during recreation: %s. " + "Update _recreate_session() if these fields are critical.", + session.id, + extra_fields, + ) + + response = await self.client.post( + "/api/sessions", + json=request_data, + headers={"X-User-ID": session.user_id}, + ) + if response.status_code == 409: + # Session was already recreated by a concurrent call — treat as success. + logger.info( + "Session %s already exists (409 Conflict) during recreation, " + "likely recreated by a concurrent request. Proceeding with retry.", + session.id, + ) + else: + response.raise_for_status() + logger.info("Successfully recreated session %s", session.id) + + # Fetch existing tasks for this session to log in-flight work for observability + tasks_response = await self.client.get( + f"/api/sessions/{session.id}/tasks?user_id={session.user_id}", + headers={"X-User-ID": session.user_id}, + ) + if tasks_response.status_code == 200: + tasks_data = tasks_response.json() + if tasks_data.get("data"): + logger.debug( + "Session %s has %d existing task(s) after recreation", + session.id, + len(tasks_data["data"]), + ) + for task in tasks_data["data"]: + task_state = task.get("status", {}).get("state", "unknown") + if task_state in ("working", "submitted"): + logger.debug( + "Session %s has in-flight task %s in state '%s'", + session.id, + task.get("id"), + task_state, + ) + else: + logger.warning( + "Failed to fetch tasks for recreated session %s (HTTP %d)", + session.id, + tasks_response.status_code, + ) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: @@ -174,6 +265,29 @@ async def append_event(self, session: Session, event: Event) -> Event: json=event_data, headers={"X-User-ID": session.user_id}, ) + + # Handle 404 by recreating session and retrying once + if response.status_code == 404: + logger.warning( + "Session %s not found (404), attempting to recreate before retry", + session.id, + ) + try: + await self._recreate_session(session) + except Exception as e: + raise RuntimeError( + f"Session {session.id} not found (404) and recreation failed" + ) from e + + # Retry the append ONCE. If this retry also fails (including another 404), + # raise_for_status() below will propagate the error without further attempts. + # This prevents infinite recursion while allowing recovery from transient deletion. + response = await self.client.post( + f"/api/sessions/{session.id}/events?user_id={session.user_id}", + json=event_data, + headers={"X-User-ID": session.user_id}, + ) + response.raise_for_status() # TODO: potentially pull and update the session from the server diff --git a/python/packages/kagent-adk/tests/unittests/test_session_service.py b/python/packages/kagent-adk/tests/unittests/test_session_service.py index 293e88d76..3a4dd48e7 100644 --- a/python/packages/kagent-adk/tests/unittests/test_session_service.py +++ b/python/packages/kagent-adk/tests/unittests/test_session_service.py @@ -1,33 +1,61 @@ -"""Tests for KAgentSessionService.""" - +from unittest import mock from unittest.mock import AsyncMock, MagicMock import httpx import pytest from google.adk.events.event import Event, EventActions +from google.adk.sessions import Session from kagent.adk._session_service import KAgentSessionService +# --------------------------------------------------------------------------- +# Shared fixtures for append_event / _recreate_session tests +# --------------------------------------------------------------------------- + @pytest.fixture -def make_event(): - """Factory fixture: make_event(author, state_delta) -> Event.""" +def mock_client(): + """Simple AsyncMock client for sequential multi-call tests.""" + return mock.AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def session_service(mock_client): + return KAgentSessionService(client=mock_client) + +@pytest.fixture +def sample_session(): + return Session( + id="test-session-123", + user_id="test-user", + app_name="test-app", + state={"session_name": "Test Session"}, + ) + + +@pytest.fixture +def sample_event(): + return Event(invocation_id="test-invocation", author="user") + + +# --------------------------------------------------------------------------- +# Fixtures for get_session tests (factory-style, mirrors original test file) +# --------------------------------------------------------------------------- + +@pytest.fixture +def make_event(): + """Factory: make_event(author, state_delta) -> Event.""" def _factory(author: str = "user", state_delta: dict | None = None) -> Event: if state_delta: return Event(author=author, invocation_id="inv1", actions=EventActions(state_delta=state_delta)) return Event(author=author, invocation_id="inv1") - return _factory @pytest.fixture def session_response(): - """Factory fixture: session_response(events, session_id, user_id) -> dict. - - Builds the JSON envelope that the KAgent API returns for GET /api/sessions/{id}. - """ - + """Factory: build the JSON envelope returned by GET /api/sessions/{id}.""" def _factory(events: list[Event], session_id: str = "s1", user_id: str = "u1") -> dict: return { "data": { @@ -35,14 +63,12 @@ def _factory(events: list[Event], session_id: str = "s1", user_id: str = "u1") - "events": [{"id": e.id, "data": e.model_dump_json()} for e in events], } } - return _factory @pytest.fixture -def mock_client(): - """Factory fixture: mock_client(response_json, status_code) -> MagicMock httpx.AsyncClient.""" - +def get_client(): + """Factory: get_client(response_json, status_code) -> MagicMock AsyncClient (GET only).""" def _factory(response_json: dict | None, status_code: int = 200) -> MagicMock: mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = status_code @@ -52,115 +78,419 @@ def _factory(response_json: dict | None, status_code: int = 200) -> MagicMock: client = MagicMock(spec=httpx.AsyncClient) client.get = AsyncMock(return_value=mock_response) return client - return _factory @pytest.fixture -def service(mock_client): - """Factory fixture: service(response_json, status_code) -> KAgentSessionService.""" - +def get_session_svc(get_client): + """Factory: get_session_svc(response_json, status_code) -> KAgentSessionService.""" def _factory(response_json: dict | None, status_code: int = 200) -> KAgentSessionService: - return KAgentSessionService(mock_client(response_json, status_code)) - + return KAgentSessionService(get_client(response_json, status_code)) return _factory -@pytest.mark.asyncio -async def test_get_session_returns_none_on_404(mock_client): - """A 404 response returns None without raising.""" - svc = KAgentSessionService(mock_client(response_json=None, status_code=404)) - session = await svc.get_session(app_name="app", user_id="u1", session_id="missing") - - assert session is None - - -@pytest.mark.asyncio -async def test_get_session_returns_none_when_no_data(service): - """An empty data envelope returns None.""" - session = await service({"data": None}).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is None - - -@pytest.mark.asyncio -async def test_get_session_event_ids_preserved(make_event, session_response, service): - """Event identity (id) is preserved after loading from the API.""" - events = [make_event("user"), make_event("assistant")] - original_ids = [e.id for e in events] - - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is not None - assert [e.id for e in session.events] == original_ids - - -@pytest.mark.asyncio -async def test_get_session_events_not_duplicated(make_event, session_response, service): - """Each event from the API must appear exactly once in session.events. - - Regression test for the bug where Session(events=events) pre-populated - session.events and super().append_event() then appended each event again. - """ - events = [make_event("user"), make_event("assistant"), make_event("tool")] - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is not None - assert len(session.events) == len(events), ( - f"Expected {len(events)} events but got {len(session.events)} — possible event duplication in get_session" - ) - - -@pytest.mark.asyncio -async def test_get_session_single_event_not_duplicated(make_event, session_response, service): - """Single-event case: still only one event in session.events.""" - events = [make_event("user")] - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is not None - assert len(session.events) == 1 - - -@pytest.mark.asyncio -async def test_get_session_empty_events(session_response, service): - """Zero events from the API yields an empty session.events list.""" - session = await service(session_response([])).get_session(app_name="app", user_id="u1", session_id="s1") - - assert session is not None - assert len(session.events) == 0 - +# --------------------------------------------------------------------------- +# get_session tests (restored from original — including regression guards) +# --------------------------------------------------------------------------- + +class TestGetSession: + + @pytest.mark.asyncio + async def test_returns_none_on_404(self, get_client): + """A 404 response returns None without raising.""" + svc = KAgentSessionService(get_client(response_json=None, status_code=404)) + session = await svc.get_session(app_name="app", user_id="u1", session_id="missing") + assert session is None + + @pytest.mark.asyncio + async def test_returns_none_when_no_data(self, get_session_svc): + """An empty data envelope returns None.""" + session = await get_session_svc({"data": None}).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is None + + @pytest.mark.asyncio + async def test_event_ids_preserved(self, make_event, session_response, get_session_svc): + """Event identity (id) is preserved after loading from the API.""" + events = [make_event("user"), make_event("assistant")] + original_ids = [e.id for e in events] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert [e.id for e in session.events] == original_ids + + @pytest.mark.asyncio + async def test_events_not_duplicated(self, make_event, session_response, get_session_svc): + """Each event from the API must appear exactly once in session.events. + + Regression guard: Session(events=events) pre-populates session.events, + and super().append_event() then appends again — causing duplication. + """ + events = [make_event("user"), make_event("assistant"), make_event("tool")] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert len(session.events) == len(events), ( + f"Expected {len(events)} events but got {len(session.events)} — possible event duplication in get_session" + ) + + @pytest.mark.asyncio + async def test_single_event_not_duplicated(self, make_event, session_response, get_session_svc): + """Single-event case: still only one event in session.events.""" + events = [make_event("user")] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert len(session.events) == 1 + + @pytest.mark.asyncio + async def test_empty_events(self, session_response, get_session_svc): + """Zero events from the API yields an empty session.events list.""" + session = await get_session_svc(session_response([])).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert len(session.events) == 0 + + @pytest.mark.asyncio + async def test_state_delta_applied_once(self, make_event, session_response, get_session_svc): + """State deltas from events must be applied exactly once to session.state. + + Regression guard: double-appending events caused _update_session_state() + to be called twice per event. + """ + events = [make_event("assistant", state_delta={"counter": 7})] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert session.state.get("counter") == 7, ( + f"Expected state['counter'] == 7, got {session.state.get('counter')} — " + "state_delta may have been applied more than once" + ) + + @pytest.mark.asyncio + async def test_multiple_state_deltas_applied_once(self, make_event, session_response, get_session_svc): + """Multiple events each contributing a state key are each applied once.""" + events = [ + make_event("assistant", state_delta={"key_a": "value_a"}), + make_event("tool", state_delta={"key_b": "value_b"}), + ] + session = await get_session_svc(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert session is not None + assert session.state.get("key_a") == "value_a" + assert session.state.get("key_b") == "value_b" + + +# --------------------------------------------------------------------------- +# append_event tests +# --------------------------------------------------------------------------- + +class TestAppendEvent: + """Tests for append_event method.""" + + @pytest.mark.asyncio + async def test_append_event_success(self, session_service, mock_client, sample_session, sample_event): + """Test successful event append.""" + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + result = await session_service.append_event(sample_session, sample_event) + + assert result == sample_event + mock_client.post.assert_called_once() + assert f"/api/sessions/{sample_session.id}/events" in mock_client.post.call_args[0][0] + + @pytest.mark.asyncio + async def test_append_event_404_recovery(self, session_service, mock_client, sample_session, sample_event): + """Test 404 triggers session recreation and retry.""" + mock_response_404 = mock.MagicMock() + mock_response_404.status_code = 404 + + mock_response_create = mock.MagicMock() + mock_response_create.status_code = 201 + mock_response_create.raise_for_status = mock.MagicMock() + + mock_response_success = mock.MagicMock() + mock_response_success.status_code = 201 + mock_response_success.raise_for_status = mock.MagicMock() + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + + mock_client.post.side_effect = [mock_response_404, mock_response_create, mock_response_success] + mock_client.get.return_value = mock_tasks_response + + result = await session_service.append_event(sample_session, sample_event) + + assert result == sample_event + assert mock_client.post.call_count == 3 + assert mock_client.get.call_count == 1 + calls = mock_client.post.call_args_list + assert f"/api/sessions/{sample_session.id}/events" in calls[0][0][0] + assert "/api/sessions" == calls[1][0][0] + assert f"/api/sessions/{sample_session.id}/events" in calls[2][0][0] + + @pytest.mark.asyncio + async def test_append_event_404_retry_also_404(self, session_service, mock_client, sample_session, sample_event): + """Test that a second 404 on retry raises without infinite recursion.""" + mock_response_404 = mock.MagicMock() + mock_response_404.status_code = 404 + + mock_response_create = mock.MagicMock() + mock_response_create.status_code = 201 + mock_response_create.raise_for_status = mock.MagicMock() + + mock_response_retry_404 = mock.MagicMock() + mock_response_retry_404.status_code = 404 + mock_response_retry_404.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not found", request=mock.MagicMock(), response=mock_response_retry_404 + ) + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + + mock_client.post.side_effect = [mock_response_404, mock_response_create, mock_response_retry_404] + mock_client.get.return_value = mock_tasks_response + + with pytest.raises(httpx.HTTPStatusError): + await session_service.append_event(sample_session, sample_event) + + # Exactly 3 POST calls: initial, recreation, retry — no infinite loop + assert mock_client.post.call_count == 3 + + @pytest.mark.asyncio + async def test_append_event_404_recreation_fails_raises_runtime_error( + self, session_service, mock_client, sample_session, sample_event + ): + """Test 404 + recreation failure raises RuntimeError with 404 context.""" + mock_response_404 = mock.MagicMock() + mock_response_404.status_code = 404 + + mock_response_create_500 = mock.MagicMock() + mock_response_create_500.status_code = 500 + mock_response_create_500.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=mock.MagicMock(), response=mock_response_create_500 + ) + + mock_client.post.side_effect = [mock_response_404, mock_response_create_500] + + with pytest.raises(RuntimeError, match="not found.*recreation failed"): + await session_service.append_event(sample_session, sample_event) + + @pytest.mark.asyncio + async def test_append_event_404_recovery_failure(self, session_service, mock_client, sample_session, sample_event): + """Test 404 recovery where retry fails with a 500.""" + mock_response_404 = mock.MagicMock() + mock_response_404.status_code = 404 + + mock_response_create = mock.MagicMock() + mock_response_create.status_code = 201 + mock_response_create.raise_for_status = mock.MagicMock() + + mock_response_retry_fail = mock.MagicMock() + mock_response_retry_fail.status_code = 500 + mock_response_retry_fail.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=mock.MagicMock(), response=mock_response_retry_fail + ) + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + + mock_client.post.side_effect = [mock_response_404, mock_response_create, mock_response_retry_fail] + mock_client.get.return_value = mock_tasks_response + + with pytest.raises(httpx.HTTPStatusError): + await session_service.append_event(sample_session, sample_event) + + assert mock_client.post.call_count == 3 + + @pytest.mark.asyncio + async def test_append_event_non_404_error(self, session_service, mock_client, sample_session, sample_event): + """Test non-404 errors are raised immediately without recovery.""" + mock_response = mock.MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=mock.MagicMock(), response=mock_response + ) + mock_client.post.return_value = mock_response + + with pytest.raises(httpx.HTTPStatusError): + await session_service.append_event(sample_session, sample_event) + + mock_client.post.assert_called_once() + + +# --------------------------------------------------------------------------- +# _recreate_session tests +# --------------------------------------------------------------------------- + +class TestRecreateSession: + """Tests for _recreate_session method.""" + + @pytest.mark.asyncio + async def test_recreate_session_success(self, session_service, mock_client, sample_session): + """Test successful session recreation.""" + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + await session_service._recreate_session(sample_session) + + mock_client.post.assert_called_once() + request_data = mock_client.post.call_args[1]["json"] + assert request_data["id"] == sample_session.id + assert request_data["user_id"] == sample_session.user_id + assert request_data["agent_ref"] == sample_session.app_name + mock_client.get.assert_called_once() + + @pytest.mark.asyncio + async def test_recreate_session_with_session_name(self, session_service, mock_client, sample_session): + """Test session recreation includes session_name from state.""" + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + await session_service._recreate_session(sample_session) + + request_data = mock_client.post.call_args[1]["json"] + assert request_data["name"] == "Test Session" + + @pytest.mark.asyncio + async def test_recreate_session_preserves_source(self, session_service, mock_client): + """Test session recreation preserves the source field from state.""" + session_with_source = Session( + id="sess-src", + user_id="u1", + app_name="app", + state={"session_name": "My Session", "source": "web"}, + ) + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + await session_service._recreate_session(session_with_source) + + request_data = mock_client.post.call_args[1]["json"] + assert request_data["source"] == "web" + assert request_data["name"] == "My Session" + + @pytest.mark.asyncio + async def test_recreate_session_state_loss_warning_for_unknown_fields( + self, session_service, mock_client, caplog + ): + """Warning fires for state fields beyond session_name and source.""" + session_with_extra = Session( + id="sess-extra", + user_id="u1", + app_name="app", + state={"session_name": "S", "source": "web", "custom_key": "custom_value"}, + ) + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + import logging + with caplog.at_level(logging.WARNING): + await session_service._recreate_session(session_with_extra) + + assert any("custom_key" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_recreate_session_no_warning_for_known_fields_only( + self, session_service, mock_client, caplog + ): + """No warning fires when only known fields (session_name, source) are in state.""" + session_known_only = Session( + id="sess-known", + user_id="u1", + app_name="app", + state={"session_name": "S", "source": "web"}, + ) + mock_response = mock.MagicMock() + mock_response.status_code = 201 + mock_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response + + import logging + with caplog.at_level(logging.WARNING): + await session_service._recreate_session(session_known_only) + + assert not any("additional state fields" in record.message for record in caplog.records) + + @pytest.mark.asyncio + async def test_recreate_session_failure(self, session_service, mock_client, sample_session): + """Test session recreation failure raises error.""" + mock_response = mock.MagicMock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=mock.MagicMock(), response=mock_response + ) + mock_client.post.return_value = mock_response + + with pytest.raises(httpx.HTTPStatusError): + await session_service._recreate_session(sample_session) + + @pytest.mark.asyncio + async def test_recreate_session_with_inflight_task(self, session_service, mock_client, sample_session): + """Test session recreation detects in-flight tasks.""" + mock_create_response = mock.MagicMock() + mock_create_response.status_code = 201 + mock_create_response.raise_for_status = mock.MagicMock() + mock_client.post.return_value = mock_create_response + + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = { + "data": [{"id": "task-123", "status": {"state": "working", "message": "Processing..."}}] + } + mock_client.get.return_value = mock_tasks_response -@pytest.mark.asyncio -async def test_get_session_state_delta_applied_once(make_event, session_response, service): - """State deltas from events must be applied exactly once to session.state. + await session_service._recreate_session(sample_session) - Regression test: when events were double-appended, _update_session_state() - was called twice per event, so numeric or overwrite-based state deltas - would be applied twice. - """ - events = [make_event("assistant", state_delta={"counter": 7})] - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + assert mock_client.post.call_count == 1 + assert mock_client.get.call_count == 1 + assert f"/api/sessions/{sample_session.id}/tasks" in mock_client.get.call_args[0][0] - assert session is not None - # State must reflect exactly one application of the delta. - # (BaseSessionService._update_session_state does session.state.update({key: value}), - # so for an idempotent string the bug was silent; here we use a distinct value - # and just verify the key is present with the correct value.) - assert session.state.get("counter") == 7, ( - f"Expected state['counter'] == 7, got {session.state.get('counter')} — " - "state_delta may have been applied more than once" - ) + @pytest.mark.asyncio + async def test_recreate_session_409_treated_as_success(self, session_service, mock_client, sample_session): + """409 Conflict during recreation is treated as success (concurrent recreation).""" + mock_response_409 = mock.MagicMock() + mock_response_409.status_code = 409 + mock_client.post.return_value = mock_response_409 + mock_tasks_response = mock.MagicMock() + mock_tasks_response.status_code = 200 + mock_tasks_response.json.return_value = {"data": []} + mock_client.get.return_value = mock_tasks_response -@pytest.mark.asyncio -async def test_get_session_multiple_state_deltas_applied_once(make_event, session_response, service): - """Multiple events each contributing a state key are each applied once.""" - events = [ - make_event("assistant", state_delta={"key_a": "value_a"}), - make_event("tool", state_delta={"key_b": "value_b"}), - ] - session = await service(session_response(events)).get_session(app_name="app", user_id="u1", session_id="s1") + await session_service._recreate_session(sample_session) - assert session is not None - assert session.state.get("key_a") == "value_a" - assert session.state.get("key_b") == "value_b" + mock_client.post.assert_called_once() + mock_client.get.assert_called_once() diff --git a/ui/src/components/chat/AgentCallDisplay.tsx b/ui/src/components/chat/AgentCallDisplay.tsx index 0c8dcd1be..f9d852a9c 100644 --- a/ui/src/components/chat/AgentCallDisplay.tsx +++ b/ui/src/components/chat/AgentCallDisplay.tsx @@ -57,7 +57,7 @@ function SubagentActivityPanel({ sessionId, isComplete }: SubagentActivityPanelP } } else { const tasks: Task[] = resp.data.tasks; - const extracted = extractMessagesFromTasks(tasks); + const { messages: extracted } = extractMessagesFromTasks(tasks); setMessages(extracted); setWaiting(extracted.length === 0 && !isComplete); setError(null); diff --git a/ui/src/components/chat/ChatInterface.tsx b/ui/src/components/chat/ChatInterface.tsx index 9d5bf5cd4..ac717e1cc 100644 --- a/ui/src/components/chat/ChatInterface.tsx +++ b/ui/src/components/chat/ChatInterface.tsx @@ -26,7 +26,7 @@ import { createMessageHandlers, extractMessagesFromTasks, extractApprovalMessage import { kagentA2AClient } from "@/lib/a2aClient"; import { useChatAgentType, useChatRunInSandbox } from "@/components/chat/ChatAgentContext"; import { v4 as uuidv4 } from "uuid"; -import { getStatusPlaceholder } from "@/lib/statusUtils"; +import { getStatusPlaceholder, mapA2AStateToStatus } from "@/lib/statusUtils"; import { Message, DataPart } from "@a2a-js/sdk"; interface ChatInterfaceProps { @@ -141,7 +141,7 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se setSessionStats({ total: 0, prompt: 0, completion: 0 }); } else { - const extractedMessages = extractMessagesFromTasks(messagesResponse.data); + const { messages: extractedMessages, pendingTask } = extractMessagesFromTasks(messagesResponse.data); setSessionStats(extractTokenStatsFromTasks(messagesResponse.data)); // Resolved approvals are already inline in extractedMessages (with @@ -157,6 +157,62 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se if (hasPendingApproval) { setChatStatus("input_required"); } + + // If there's a pending task (and no pending approval taking priority), reconnect to its stream + if (pendingTask && !hasPendingApproval) { + setIsLoading(false); + setChatStatus(mapA2AStateToStatus(pendingTask.state)); + + try { + abortControllerRef.current = new AbortController(); + + const stream = await kagentA2AClient.resubscribeTask( + selectedNamespace, + selectedAgentName, + pendingTask.taskId, + abortControllerRef.current.signal + ); + + let timeoutTimer: NodeJS.Timeout | null = null; + const streamTimeout = 600000; // 10 minutes + const startTimeout = () => { + if (timeoutTimer) clearTimeout(timeoutTimer); + timeoutTimer = setTimeout(() => { + toast.error("Reconnection timed out - no events received for 10 minutes"); + abortControllerRef.current?.abort(); + }, streamTimeout); + }; + startTimeout(); + + try { + for await (const event of stream) { + startTimeout(); // reset on each event + handleMessageEvent(event); + if (abortControllerRef.current?.signal.aborted) break; + } + } finally { + if (timeoutTimer) clearTimeout(timeoutTimer); + } + } catch (error: unknown) { + if (error instanceof Error && error.name !== 'AbortError') { + toast.error(`Reconnection failed: ${error.message}`); + try { + const refreshedTasks = await getSessionTasks(sessionId); + if (refreshedTasks.data) { + const { messages } = extractMessagesFromTasks(refreshedTasks.data); + setStoredMessages(messages); + } + } catch (refreshError) { + console.error('Failed to refresh tasks after reconnection failure:', refreshError); + } + } + } finally { + // Only reset to ready from transient states; preserve input_required / error set by stream events + setChatStatus(prev => (prev === 'working' || prev === 'submitted' || prev === 'thinking') ? 'ready' : prev); + abortControllerRef.current = null; + } + return; + } } } catch (error) { console.error("Error loading messages:", error); @@ -167,7 +223,11 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se } initializeChat(); - }, [sessionId, selectedAgentName, selectedNamespace, isFirstMessage]); + + return () => { + abortControllerRef.current?.abort(); + }; + }, [sessionId, selectedAgentName, selectedNamespace, isFirstMessage, handleMessageEvent]); useEffect(() => { if (containerRef.current) { diff --git a/ui/src/lib/__tests__/a2aClient.test.ts b/ui/src/lib/__tests__/a2aClient.test.ts new file mode 100644 index 000000000..10c74c4b6 --- /dev/null +++ b/ui/src/lib/__tests__/a2aClient.test.ts @@ -0,0 +1,203 @@ +import { describe, expect, it, jest, beforeEach } from '@jest/globals'; +import { KagentA2AClient } from '../a2aClient'; + +// Mock fetch globally +const mockFetch = jest.fn() as jest.MockedFunction; +global.fetch = mockFetch; + +// Mock utils +jest.mock('../utils', () => ({ + getBackendUrl: () => 'http://localhost:8083/api', +})); + +describe('KagentA2AClient', () => { + let client: KagentA2AClient; + + beforeEach(() => { + client = new KagentA2AClient(); + mockFetch.mockClear(); + }); + + describe('getAgentUrl', () => { + it('should construct correct agent URL', () => { + const url = client.getAgentUrl('test-namespace', 'test-agent'); + expect(url).toBe('http://localhost:8083/api/a2a/test-namespace/test-agent'); + }); + }); + + describe('createStreamingRequest', () => { + it('should create valid JSON-RPC request structure', () => { + const params = { + message: { + kind: 'message' as const, + messageId: 'msg-1', + role: 'user' as const, + parts: [{ kind: 'text' as const, text: 'Hello' }], + }, + }; + + const request = client.createStreamingRequest(params); + + expect(request.jsonrpc).toBe('2.0'); + expect(request.method).toBe('message/stream'); + expect(request.params).toEqual(params); + expect(typeof request.id).toBe('string'); + expect(request.id).toBeTruthy(); + }); + }); + + describe('resubscribeTask', () => { + it('should throw error on non-ok response', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 404, + statusText: 'Not Found', + text: () => Promise.resolve('Task not found'), + } as Response); + + await expect( + client.resubscribeTask('test-ns', 'test-agent', 'task-123') + ).rejects.toThrow('A2A resubscribe failed: 404 Not Found - Task not found'); + + expect(mockFetch).toHaveBeenCalledWith( + '/a2a/test-ns/test-agent', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'text/event-stream', + }, + }) + ); + + // Verify the request body contains resubscribe params + const callArgs = mockFetch.mock.calls[0]; + const requestBody = JSON.parse(callArgs[1]?.body as string); + expect(requestBody.jsonrpc).toBe('2.0'); + expect(requestBody.method).toBe('tasks/resubscribe'); + expect(requestBody.params).toEqual({ id: 'task-123' }); + }); + + it('should throw error when response body is null', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + body: null, + } as Response); + + await expect( + client.resubscribeTask('test-ns', 'test-agent', 'task-123') + ).rejects.toThrow('Response body is null'); + }); + + it('should pass abort signal to fetch', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + body: null, + } as Response); + + const abortController = new AbortController(); + + await expect( + client.resubscribeTask('test-ns', 'test-agent', 'task-123', abortController.signal) + ).rejects.toThrow('Response body is null'); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + signal: abortController.signal, + }) + ); + }); + }); + + describe('sendMessageStream', () => { + it('should throw error on non-ok response', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + text: () => Promise.resolve('Server error'), + } as Response); + + const params = { + message: { + kind: 'message' as const, + messageId: 'msg-1', + role: 'user' as const, + parts: [{ kind: 'text' as const, text: 'Hello' }], + }, + }; + + await expect( + client.sendMessageStream('test-ns', 'test-agent', params) + ).rejects.toThrow('A2A proxy request failed: 500 Internal Server Error - Server error'); + + expect(mockFetch).toHaveBeenCalledWith( + '/a2a/test-ns/test-agent', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'text/event-stream', + }, + }) + ); + + // Verify the request body contains message/stream method + const callArgs = mockFetch.mock.calls[0]; + const requestBody = JSON.parse(callArgs[1]?.body as string); + expect(requestBody.jsonrpc).toBe('2.0'); + expect(requestBody.method).toBe('message/stream'); + expect(requestBody.params).toEqual(params); + }); + + it('should throw error when response body is null', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + body: null, + } as Response); + + const params = { + message: { + kind: 'message' as const, + messageId: 'msg-1', + role: 'user' as const, + parts: [{ kind: 'text' as const, text: 'Hello' }], + }, + }; + + await expect( + client.sendMessageStream('test-ns', 'test-agent', params) + ).rejects.toThrow('Response body is null'); + }); + + it('should pass abort signal to fetch', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + body: null, + } as Response); + + const params = { + message: { + kind: 'message' as const, + messageId: 'msg-1', + role: 'user' as const, + parts: [{ kind: 'text' as const, text: 'Hello' }], + }, + }; + + const abortController = new AbortController(); + + await expect( + client.sendMessageStream('test-ns', 'test-agent', params, abortController.signal) + ).rejects.toThrow('Response body is null'); + + expect(mockFetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + signal: abortController.signal, + }) + ); + }); + }); +}); diff --git a/ui/src/lib/__tests__/messageHandlers.test.ts b/ui/src/lib/__tests__/messageHandlers.test.ts index de95321ab..ce9e24d7f 100644 --- a/ui/src/lib/__tests__/messageHandlers.test.ts +++ b/ui/src/lib/__tests__/messageHandlers.test.ts @@ -43,11 +43,39 @@ describe('messageHandlers helpers', () => { const tasks: any = [ { history: [{ kind: 'message', messageId: mId }, { kind: 'message', messageId: mId }] }, ]; - const out = extractMessagesFromTasks(tasks); + const { messages: out } = extractMessagesFromTasks(tasks); expect(out.length).toBe(1); expect(out[0].messageId).toBe(mId); }); + test('extractMessagesFromTasks detects pending task in working state', () => { + const tasks: any = [ + { id: 'task-1', status: { state: 'working' }, history: [] }, + ]; + const { pendingTask } = extractMessagesFromTasks(tasks); + expect(pendingTask).toBeDefined(); + expect(pendingTask?.taskId).toBe('task-1'); + expect(pendingTask?.state).toBe('working'); + }); + + test('extractMessagesFromTasks detects pending task in submitted state', () => { + const tasks: any = [ + { id: 'task-1', status: { state: 'submitted' }, history: [] }, + ]; + const { pendingTask } = extractMessagesFromTasks(tasks); + expect(pendingTask).toBeDefined(); + expect(pendingTask?.taskId).toBe('task-1'); + expect(pendingTask?.state).toBe('submitted'); + }); + + test('extractMessagesFromTasks returns no pending task for completed tasks', () => { + const tasks: any = [ + { id: 'task-1', status: { state: 'completed' }, history: [] }, + ]; + const { pendingTask } = extractMessagesFromTasks(tasks); + expect(pendingTask).toBeUndefined(); + }); + test('extractMessagesFromTasks injects tokenStats into non-user agent messages only', () => { const tasks = [ { @@ -60,7 +88,7 @@ describe('messageHandlers helpers', () => { ], }, ] as unknown as Task[]; - const messages = extractMessagesFromTasks(tasks); + const { messages } = extractMessagesFromTasks(tasks); // Agent message with usage metadata gets tokenStats injected expect((messages[0].metadata as ADKMetadata & { tokenStats?: TokenStats })?.tokenStats) .toEqual({ total: 10, prompt: 3, completion: 7 }); @@ -445,7 +473,7 @@ describe('subagent_session_id propagation', () => { }], }] as unknown as Task[]; - const messages = extractMessagesFromTasks(tasks); + const { messages } = extractMessagesFromTasks(tasks); expect(messages).toHaveLength(1); const meta = messages[0].metadata as ADKMetadata; expect(meta.originalType).toBe('ToolCallRequestEvent'); @@ -474,7 +502,7 @@ describe('subagent_session_id propagation', () => { }], }] as unknown as Task[]; - const messages = extractMessagesFromTasks(tasks); + const { messages } = extractMessagesFromTasks(tasks); expect(messages).toHaveLength(1); const meta = messages[0].metadata as ADKMetadata; expect(meta.originalType).toBe('ToolCallExecutionEvent'); diff --git a/ui/src/lib/a2aClient.ts b/ui/src/lib/a2aClient.ts index d62f21a1d..323d1f2a5 100644 --- a/ui/src/lib/a2aClient.ts +++ b/ui/src/lib/a2aClient.ts @@ -6,7 +6,7 @@ import { MessageSendParams } from '@a2a-js/sdk'; export interface A2AJsonRpcRequest { jsonrpc: "2.0"; method: string; - params: MessageSendParams; + params: MessageSendParams | { id: string }; id: string | number; } @@ -77,6 +77,47 @@ export class KagentA2AClient { return this.processSSEStream(response.body); } + /** + * Resubscribe to an in-flight task to resume receiving streaming events + * Uses the A2A tasks/resubscribe JSON-RPC method + */ + async resubscribeTask( + namespace: string, + agentName: string, + taskId: string, + signal?: AbortSignal + ): Promise> { + const request: A2AJsonRpcRequest = { + jsonrpc: "2.0", + method: "tasks/resubscribe", + params: { id: taskId }, + id: uuidv4(), + }; + + const proxyUrl = `/a2a/${namespace}/${agentName}`; + + const response = await fetch(proxyUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'text/event-stream', + }, + body: JSON.stringify(request), + signal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`A2A resubscribe failed: ${response.status} ${response.statusText} - ${errorText}`); + } + + if (!response.body) { + throw new Error('Response body is null'); + } + + return this.processSSEStream(response.body); + } + /** * Process Server-Sent Events stream with proper event boundary detection */ diff --git a/ui/src/lib/messageHandlers.ts b/ui/src/lib/messageHandlers.ts index 638337ae9..9329cd4a8 100644 --- a/ui/src/lib/messageHandlers.ts +++ b/ui/src/lib/messageHandlers.ts @@ -4,12 +4,27 @@ import { convertToUserFriendlyName, isAgentToolName, messageUtils } from "@/lib/ import { ApprovalDecision, AdkRequestConfirmationData, HitlPartInfo, ToolDecision, TokenStats, ChatStatus } from "@/types"; import { mapA2AStateToStatus } from "@/lib/statusUtils"; +// Result type for extractMessagesFromTasks +export interface TaskExtractionResult { + messages: Message[]; + pendingTask?: { taskId: string; state: 'working' | 'submitted' }; +} + // Helper functions for extracting data from stored tasks -export function extractMessagesFromTasks(tasks: Task[]): Message[] { +export function extractMessagesFromTasks(tasks: Task[]): TaskExtractionResult { const messages: Message[] = []; const seenMessageIds = new Set(); + let pendingTask: { taskId: string; state: 'working' | 'submitted' } | undefined; for (const task of tasks) { + // Detect in-flight tasks for stream reconnection. + // If multiple tasks are in-flight (unusual), the last one wins — sessions + // are expected to have at most one concurrent active task. + const taskState = task.status?.state; + if (taskState === 'working' || taskState === 'submitted') { + pendingTask = { taskId: task.id, state: taskState }; + } + if (!task.history) continue; // Track the most recent LLM usage seen so far within this task so we can @@ -159,7 +174,7 @@ export function extractMessagesFromTasks(tasks: Task[]): Message[] { } } - return messages; + return { messages, pendingTask }; } /** Returns true if the message is a user HITL decision (approve/reject) or ask-user answer. */