|
18 | 18 | import json |
19 | 19 | import os |
20 | 20 | import re |
| 21 | +from contextlib import asynccontextmanager |
21 | 22 | from unittest.mock import AsyncMock |
22 | 23 | from unittest.mock import MagicMock |
23 | 24 | from unittest.mock import patch |
|
43 | 44 | from nat.data_models.api_server import ErrorTypes |
44 | 45 | from nat.data_models.api_server import ObservabilityTraceContent |
45 | 46 | from nat.data_models.api_server import ResponseIntermediateStep |
| 47 | +from nat.data_models.api_server import ResponseObservabilityTrace |
46 | 48 | from nat.data_models.api_server import ResponsePayloadOutput |
47 | 49 | from nat.data_models.api_server import SystemIntermediateStepContent |
48 | 50 | from nat.data_models.api_server import SystemResponseContent |
@@ -1053,3 +1055,130 @@ async def test_restore_execution_state_sends_prompt_with_remaining_timeout(): |
1053 | 1055 | call_kwargs = handler.create_websocket_message.call_args[1] |
1054 | 1056 | sent_content = call_kwargs["data_model"] |
1055 | 1057 | assert sent_content.timeout == 7 |
| 1058 | + |
| 1059 | + |
| 1060 | +async def test_process_workflow_request_cancels_in_flight_task(): |
| 1061 | + """A new workflow request cancels any in-flight task before creating a replacement.""" |
| 1062 | + mock_socket = AsyncMock() |
| 1063 | + mock_session_manager = MagicMock() |
| 1064 | + mock_step_adaptor = MagicMock() |
| 1065 | + mock_worker = MagicMock() |
| 1066 | + mock_worker.get_conversation_handler.return_value = None |
| 1067 | + |
| 1068 | + handler = WebSocketMessageHandler( |
| 1069 | + socket=mock_socket, |
| 1070 | + session_manager=mock_session_manager, |
| 1071 | + step_adaptor=mock_step_adaptor, |
| 1072 | + worker=mock_worker, |
| 1073 | + ) |
| 1074 | + |
| 1075 | + existing_task = asyncio.create_task(asyncio.sleep(100)) |
| 1076 | + handler._running_workflow_task = existing_task |
| 1077 | + |
| 1078 | + msg = WebSocketUserMessage.model_validate({**user_message, "type": "user_message"}) |
| 1079 | + |
| 1080 | + async def _noop_workflow(*args, **kwargs): |
| 1081 | + await asyncio.sleep(0) |
| 1082 | + |
| 1083 | + with patch.object(handler, "_run_workflow", _noop_workflow): |
| 1084 | + await handler.process_workflow_request(msg) |
| 1085 | + |
| 1086 | + assert existing_task.cancelled() |
| 1087 | + assert handler._running_workflow_task is not None |
| 1088 | + assert handler._running_workflow_task is not existing_task |
| 1089 | + |
| 1090 | + new_task = handler._running_workflow_task |
| 1091 | + new_task.cancel() |
| 1092 | + try: |
| 1093 | + await new_task |
| 1094 | + except (asyncio.CancelledError, Exception): |
| 1095 | + pass |
| 1096 | + |
| 1097 | + |
| 1098 | +async def test_done_callback_guards_against_stale_task(): |
| 1099 | + """_done_callback does not clear _running_workflow_task when the task has been replaced.""" |
| 1100 | + mock_socket = AsyncMock() |
| 1101 | + mock_session_manager = MagicMock() |
| 1102 | + mock_step_adaptor = MagicMock() |
| 1103 | + mock_worker = MagicMock() |
| 1104 | + mock_worker.get_conversation_handler.return_value = None |
| 1105 | + |
| 1106 | + handler = WebSocketMessageHandler( |
| 1107 | + socket=mock_socket, |
| 1108 | + session_manager=mock_session_manager, |
| 1109 | + step_adaptor=mock_step_adaptor, |
| 1110 | + worker=mock_worker, |
| 1111 | + ) |
| 1112 | + |
| 1113 | + msg = WebSocketUserMessage.model_validate({**user_message, "type": "user_message"}) |
| 1114 | + completed = asyncio.Event() |
| 1115 | + |
| 1116 | + async def _quick_workflow(*args, **kwargs): |
| 1117 | + await asyncio.sleep(0) |
| 1118 | + completed.set() |
| 1119 | + |
| 1120 | + with patch.object(handler, "_run_workflow", _quick_workflow): |
| 1121 | + await handler.process_workflow_request(msg) |
| 1122 | + |
| 1123 | + # Simulate a second request replacing the first task reference |
| 1124 | + second_task = asyncio.create_task(asyncio.sleep(100)) |
| 1125 | + handler._running_workflow_task = second_task |
| 1126 | + |
| 1127 | + # Let the first task complete and fire its done callback |
| 1128 | + await completed.wait() |
| 1129 | + await asyncio.sleep(0) |
| 1130 | + |
| 1131 | + # second_task must remain untouched by the first task's callback |
| 1132 | + assert handler._running_workflow_task is second_task |
| 1133 | + mock_worker.remove_conversation_handler.assert_not_called() |
| 1134 | + |
| 1135 | + second_task.cancel() |
| 1136 | + try: |
| 1137 | + await second_task |
| 1138 | + except asyncio.CancelledError: |
| 1139 | + pass |
| 1140 | + |
| 1141 | + |
| 1142 | +async def test_run_workflow_skips_response_on_cancellation(): |
| 1143 | + """When _run_workflow is cancelled, RESPONSE_MESSAGE is not sent and pending trace is cleared.""" |
| 1144 | + mock_socket = AsyncMock() |
| 1145 | + mock_session_manager = MagicMock() |
| 1146 | + mock_step_adaptor = MagicMock() |
| 1147 | + mock_worker = MagicMock() |
| 1148 | + |
| 1149 | + handler = WebSocketMessageHandler( |
| 1150 | + socket=mock_socket, |
| 1151 | + session_manager=mock_session_manager, |
| 1152 | + step_adaptor=mock_step_adaptor, |
| 1153 | + worker=mock_worker, |
| 1154 | + ) |
| 1155 | + handler.create_websocket_message = AsyncMock() |
| 1156 | + handler._user_message_payload = {} |
| 1157 | + |
| 1158 | + handler._pending_observability_trace = ResponseObservabilityTrace(observability_trace_id="trace-to-clear") |
| 1159 | + |
| 1160 | + blocking_event = asyncio.Event() |
| 1161 | + |
| 1162 | + @asynccontextmanager |
| 1163 | + async def _mock_session(*args, **kwargs): |
| 1164 | + yield MagicMock() |
| 1165 | + |
| 1166 | + mock_session_manager.session = _mock_session |
| 1167 | + |
| 1168 | + async def _blocking_generator(*args, **kwargs): |
| 1169 | + blocking_event.set() |
| 1170 | + await asyncio.sleep(100) |
| 1171 | + yield # pragma: no cover |
| 1172 | + |
| 1173 | + with patch("nat.front_ends.fastapi.message_handler.generate_streaming_response", _blocking_generator): |
| 1174 | + task = asyncio.create_task(handler._run_workflow(payload="test")) |
| 1175 | + await blocking_event.wait() |
| 1176 | + task.cancel() |
| 1177 | + with pytest.raises(asyncio.CancelledError): |
| 1178 | + await task |
| 1179 | + |
| 1180 | + for call in handler.create_websocket_message.call_args_list: |
| 1181 | + msg_type = call.kwargs.get("message_type") or (call.args[1] if len(call.args) > 1 else None) |
| 1182 | + assert msg_type != WebSocketMessageType.RESPONSE_MESSAGE |
| 1183 | + |
| 1184 | + assert handler._pending_observability_trace is None |
0 commit comments