Skip to content

Commit 7f9dfc0

Browse files
authored
Merge branch 'develop' into vasiliyr/03152026
2 parents 7c10f0b + adc0373 commit 7f9dfc0

4 files changed

Lines changed: 247 additions & 23 deletions

File tree

packages/nvidia_nat_core/src/nat/front_ends/fastapi/message_handler.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -272,22 +272,33 @@ async def process_workflow_request(self, user_message_as_validated_type: WebSock
272272
self._initialize_workflow_request(user_message_as_validated_type)
273273
message_content: typing.Any = await self._process_websocket_user_message(user_message_as_validated_type)
274274

275-
if (self._running_workflow_task is None):
275+
if self._workflow_schema_type is None:
276+
raise RuntimeError("Workflow schema type is not initialized")
276277

277-
def _done_callback(_task: asyncio.Task):
278+
if self._running_workflow_task is not None:
279+
self._running_workflow_task.cancel()
280+
try:
281+
await self._running_workflow_task
282+
except (asyncio.CancelledError, Exception):
283+
pass
284+
self._running_workflow_task = None
285+
286+
_conversation_id = self._conversation_id
287+
288+
def _done_callback(_task: asyncio.Task):
289+
if self._running_workflow_task is _task:
278290
self._running_workflow_task = None
279-
if self._conversation_id:
280-
self._worker.remove_conversation_handler(self._conversation_id)
281-
282-
if self._workflow_schema_type is None:
283-
raise RuntimeError("Workflow schema type is not initialized")
284-
self._running_workflow_task = asyncio.create_task(
285-
self._run_workflow(payload=message_content,
286-
user_message_id=self._message_parent_id,
287-
conversation_id=self._conversation_id,
288-
result_type=self._schema_output_mapping[self._workflow_schema_type],
289-
output_type=self._schema_output_mapping[self._workflow_schema_type]))
290-
self._running_workflow_task.add_done_callback(_done_callback)
291+
if self._running_workflow_task is None and _conversation_id and \
292+
self._worker.get_conversation_handler(_conversation_id) is self:
293+
self._worker.remove_conversation_handler(_conversation_id)
294+
295+
self._running_workflow_task = asyncio.create_task(
296+
self._run_workflow(payload=message_content,
297+
user_message_id=self._message_parent_id,
298+
conversation_id=self._conversation_id,
299+
result_type=self._schema_output_mapping[self._workflow_schema_type],
300+
output_type=self._schema_output_mapping[self._workflow_schema_type]))
301+
self._running_workflow_task.add_done_callback(_done_callback)
291302

292303
except ValueError as e:
293304
logger.exception("User message content not found: %s", str(e))
@@ -439,6 +450,7 @@ async def _run_workflow(self,
439450
result_type: type | None = None,
440451
output_type: type | None = None) -> None:
441452

453+
_cancelled = False
442454
try:
443455
auth_callback = self._flow_handler.authenticate if self._flow_handler else None
444456
async with self._session_manager.session(user_id=self._user_id,
@@ -466,6 +478,10 @@ async def _run_workflow(self,
466478

467479
await self.create_websocket_message(data_model=value, status=WebSocketMessageStatus.IN_PROGRESS)
468480

481+
except asyncio.CancelledError:
482+
_cancelled = True
483+
raise
484+
469485
except Exception as e:
470486
logger.exception("Unhandled workflow error")
471487
await self.create_websocket_message(data_model=Error(code=ErrorTypes.WORKFLOW_ERROR,
@@ -475,12 +491,16 @@ async def _run_workflow(self,
475491
status=WebSocketMessageStatus.IN_PROGRESS)
476492

477493
finally:
478-
await self.create_websocket_message(data_model=SystemResponseContent(),
479-
message_type=WebSocketMessageType.RESPONSE_MESSAGE,
480-
status=WebSocketMessageStatus.COMPLETE)
481-
482-
# Send observability trace after completion message
483-
if self._pending_observability_trace is not None:
484-
await self.create_websocket_message(data_model=self._pending_observability_trace,
485-
message_type=WebSocketMessageType.OBSERVABILITY_TRACE_MESSAGE)
494+
try:
495+
if not _cancelled:
496+
await self.create_websocket_message(data_model=SystemResponseContent(),
497+
message_type=WebSocketMessageType.RESPONSE_MESSAGE,
498+
status=WebSocketMessageStatus.COMPLETE)
499+
500+
# Send observability trace after completion message
501+
if self._pending_observability_trace is not None:
502+
await self.create_websocket_message(
503+
data_model=self._pending_observability_trace,
504+
message_type=WebSocketMessageType.OBSERVABILITY_TRACE_MESSAGE)
505+
finally:
486506
self._pending_observability_trace = None

packages/nvidia_nat_core/src/nat/runtime/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ async def __aexit__(self, exc_type, exc_value, traceback):
150150

151151
self._context_state.runtime_type.reset(self._runtime_type_token)
152152

153-
if (self._state not in (RunnerState.COMPLETED, RunnerState.FAILED)):
153+
if (self._state not in (RunnerState.COMPLETED, RunnerState.FAILED)) and exc_type is None:
154154
raise ValueError("Cannot exit the context without completing the workflow")
155155

156156
@typing.overload

packages/nvidia_nat_core/tests/nat/runtime/test_runner.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,78 @@ async def test_runner_state_management():
282282
async with runner:
283283
result = await runner.result()
284284
assert result == "test!"
285+
286+
287+
async def test_runner_aexit_raises_on_incomplete_clean_exit():
288+
"""Test that Runner raises ValueError when exited cleanly without completing the workflow."""
289+
290+
async with WorkflowBuilder() as builder:
291+
entry_fn = await builder.add_function(name="test_function", config=SingleOutputConfig())
292+
293+
context_state = ContextState()
294+
exporter_manager = ExporterManager()
295+
296+
with pytest.raises(ValueError, match="Cannot exit the context without completing the workflow"):
297+
async with Runner(input_message="test",
298+
entry_fn=entry_fn,
299+
context_state=context_state,
300+
exporter_manager=exporter_manager):
301+
pass # exit without calling result()
302+
303+
304+
async def test_runner_aexit_allows_cancelled_error_to_propagate():
305+
"""Test that Runner does not mask CancelledError with a ValueError on exit."""
306+
import asyncio
307+
308+
async with WorkflowBuilder() as builder:
309+
entry_fn = await builder.add_function(name="test_function", config=SingleOutputConfig())
310+
311+
context_state = ContextState()
312+
exporter_manager = ExporterManager()
313+
314+
with pytest.raises(asyncio.CancelledError):
315+
async with Runner(input_message="test",
316+
entry_fn=entry_fn,
317+
context_state=context_state,
318+
exporter_manager=exporter_manager):
319+
raise asyncio.CancelledError()
320+
321+
322+
async def test_runner_workflow_replacement_handoff():
323+
"""Test the workflow-replacement handoff path tied to the message_handler regression.
324+
325+
Cancelling the first in-flight Runner via task.cancel() must not mask CancelledError
326+
with a ValueError (runner.py fix), and the immediately-following second Runner must
327+
run to completion on the same context_state/exporter_manager.
328+
"""
329+
import asyncio
330+
331+
async with WorkflowBuilder() as builder:
332+
entry_fn = await builder.add_function(name="test_function", config=SingleOutputConfig())
333+
334+
context_state = ContextState()
335+
exporter_manager = ExporterManager()
336+
337+
# Simulate message_handler cancelling an in-flight task when a new message arrives.
338+
async def _first_workflow():
339+
async with Runner(input_message="first",
340+
entry_fn=entry_fn,
341+
context_state=context_state,
342+
exporter_manager=exporter_manager):
343+
await asyncio.sleep(0) # yield so external cancel can be delivered
344+
345+
first_task = asyncio.create_task(_first_workflow())
346+
await asyncio.sleep(0) # let the task enter the Runner context
347+
first_task.cancel()
348+
349+
# CancelledError must propagate cleanly — not be masked by ValueError.
350+
with pytest.raises(asyncio.CancelledError):
351+
await first_task
352+
353+
# Handoff: second Runner starts immediately on the same context and runs to completion.
354+
async with Runner(input_message="second",
355+
entry_fn=entry_fn,
356+
context_state=context_state,
357+
exporter_manager=exporter_manager) as runner2:
358+
result = await runner2.result()
359+
assert result == "second!"

packages/nvidia_nat_core/tests/nat/server/test_unified_api_server.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import os
2020
import re
21+
from contextlib import asynccontextmanager
2122
from unittest.mock import AsyncMock
2223
from unittest.mock import MagicMock
2324
from unittest.mock import patch
@@ -43,6 +44,7 @@
4344
from nat.data_models.api_server import ErrorTypes
4445
from nat.data_models.api_server import ObservabilityTraceContent
4546
from nat.data_models.api_server import ResponseIntermediateStep
47+
from nat.data_models.api_server import ResponseObservabilityTrace
4648
from nat.data_models.api_server import ResponsePayloadOutput
4749
from nat.data_models.api_server import SystemIntermediateStepContent
4850
from nat.data_models.api_server import SystemResponseContent
@@ -1053,3 +1055,130 @@ async def test_restore_execution_state_sends_prompt_with_remaining_timeout():
10531055
call_kwargs = handler.create_websocket_message.call_args[1]
10541056
sent_content = call_kwargs["data_model"]
10551057
assert sent_content.timeout == 7
1058+
1059+
1060+
async def test_process_workflow_request_cancels_in_flight_task():
1061+
"""A new workflow request cancels any in-flight task before creating a replacement."""
1062+
mock_socket = AsyncMock()
1063+
mock_session_manager = MagicMock()
1064+
mock_step_adaptor = MagicMock()
1065+
mock_worker = MagicMock()
1066+
mock_worker.get_conversation_handler.return_value = None
1067+
1068+
handler = WebSocketMessageHandler(
1069+
socket=mock_socket,
1070+
session_manager=mock_session_manager,
1071+
step_adaptor=mock_step_adaptor,
1072+
worker=mock_worker,
1073+
)
1074+
1075+
existing_task = asyncio.create_task(asyncio.sleep(100))
1076+
handler._running_workflow_task = existing_task
1077+
1078+
msg = WebSocketUserMessage.model_validate({**user_message, "type": "user_message"})
1079+
1080+
async def _noop_workflow(*args, **kwargs):
1081+
await asyncio.sleep(0)
1082+
1083+
with patch.object(handler, "_run_workflow", _noop_workflow):
1084+
await handler.process_workflow_request(msg)
1085+
1086+
assert existing_task.cancelled()
1087+
assert handler._running_workflow_task is not None
1088+
assert handler._running_workflow_task is not existing_task
1089+
1090+
new_task = handler._running_workflow_task
1091+
new_task.cancel()
1092+
try:
1093+
await new_task
1094+
except (asyncio.CancelledError, Exception):
1095+
pass
1096+
1097+
1098+
async def test_done_callback_guards_against_stale_task():
1099+
"""_done_callback does not clear _running_workflow_task when the task has been replaced."""
1100+
mock_socket = AsyncMock()
1101+
mock_session_manager = MagicMock()
1102+
mock_step_adaptor = MagicMock()
1103+
mock_worker = MagicMock()
1104+
mock_worker.get_conversation_handler.return_value = None
1105+
1106+
handler = WebSocketMessageHandler(
1107+
socket=mock_socket,
1108+
session_manager=mock_session_manager,
1109+
step_adaptor=mock_step_adaptor,
1110+
worker=mock_worker,
1111+
)
1112+
1113+
msg = WebSocketUserMessage.model_validate({**user_message, "type": "user_message"})
1114+
completed = asyncio.Event()
1115+
1116+
async def _quick_workflow(*args, **kwargs):
1117+
await asyncio.sleep(0)
1118+
completed.set()
1119+
1120+
with patch.object(handler, "_run_workflow", _quick_workflow):
1121+
await handler.process_workflow_request(msg)
1122+
1123+
# Simulate a second request replacing the first task reference
1124+
second_task = asyncio.create_task(asyncio.sleep(100))
1125+
handler._running_workflow_task = second_task
1126+
1127+
# Let the first task complete and fire its done callback
1128+
await completed.wait()
1129+
await asyncio.sleep(0)
1130+
1131+
# second_task must remain untouched by the first task's callback
1132+
assert handler._running_workflow_task is second_task
1133+
mock_worker.remove_conversation_handler.assert_not_called()
1134+
1135+
second_task.cancel()
1136+
try:
1137+
await second_task
1138+
except asyncio.CancelledError:
1139+
pass
1140+
1141+
1142+
async def test_run_workflow_skips_response_on_cancellation():
1143+
"""When _run_workflow is cancelled, RESPONSE_MESSAGE is not sent and pending trace is cleared."""
1144+
mock_socket = AsyncMock()
1145+
mock_session_manager = MagicMock()
1146+
mock_step_adaptor = MagicMock()
1147+
mock_worker = MagicMock()
1148+
1149+
handler = WebSocketMessageHandler(
1150+
socket=mock_socket,
1151+
session_manager=mock_session_manager,
1152+
step_adaptor=mock_step_adaptor,
1153+
worker=mock_worker,
1154+
)
1155+
handler.create_websocket_message = AsyncMock()
1156+
handler._user_message_payload = {}
1157+
1158+
handler._pending_observability_trace = ResponseObservabilityTrace(observability_trace_id="trace-to-clear")
1159+
1160+
blocking_event = asyncio.Event()
1161+
1162+
@asynccontextmanager
1163+
async def _mock_session(*args, **kwargs):
1164+
yield MagicMock()
1165+
1166+
mock_session_manager.session = _mock_session
1167+
1168+
async def _blocking_generator(*args, **kwargs):
1169+
blocking_event.set()
1170+
await asyncio.sleep(100)
1171+
yield # pragma: no cover
1172+
1173+
with patch("nat.front_ends.fastapi.message_handler.generate_streaming_response", _blocking_generator):
1174+
task = asyncio.create_task(handler._run_workflow(payload="test"))
1175+
await blocking_event.wait()
1176+
task.cancel()
1177+
with pytest.raises(asyncio.CancelledError):
1178+
await task
1179+
1180+
for call in handler.create_websocket_message.call_args_list:
1181+
msg_type = call.kwargs.get("message_type") or (call.args[1] if len(call.args) > 1 else None)
1182+
assert msg_type != WebSocketMessageType.RESPONSE_MESSAGE
1183+
1184+
assert handler._pending_observability_trace is None

0 commit comments

Comments
 (0)