Skip to content

Commit 24a5355

Browse files
authored
Persist User Message ID For HTTP Connections (#696)
Fixes a bug where conversation_id was not being properly set for WebSocket connections and adds support for persisting user_message_id from HTTP connections. Closes: [Issue 658](#658) ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/resources/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Eric Evans II (https://github.com/ericevans-nv) Approvers: - Will Killian (https://github.com/willkill07) URL: #696
1 parent ad75058 commit 24a5355

5 files changed

Lines changed: 83 additions & 58 deletions

File tree

external/nat-ui

src/nat/builder/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class ContextState(metaclass=Singleton):
6565

6666
def __init__(self):
6767
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
68+
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
6869
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
6970
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
7071
self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
@@ -165,6 +166,13 @@ def conversation_id(self) -> str | None:
165166
"""
166167
return self._context_state.conversation_id.get()
167168

169+
@property
170+
def user_message_id(self) -> str | None:
171+
"""
172+
This property retrieves the user message ID which is the unique identifier for the current user message.
173+
"""
174+
return self._context_state.user_message_id.get()
175+
168176
@contextmanager
169177
def push_active_function(self, function_name: str, input_data: typing.Any | None):
170178
"""

src/nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ async def run_evaluation(job_id: str, config_file: str, reps: int, session_manag
307307
async def start_evaluation(request: EvaluateRequest, background_tasks: BackgroundTasks, http_request: Request):
308308
"""Handle evaluation requests."""
309309

310-
async with session_manager.session(request=http_request):
310+
async with session_manager.session(http_connection=http_request):
311311

312312
# if job_id is present and already exists return the job info
313313
if request.job_id:
@@ -336,7 +336,7 @@ async def get_job_status(job_id: str, http_request: Request) -> EvaluateStatusRe
336336
"""Get the status of an evaluation job."""
337337
logger.info("Getting status for job %s", job_id)
338338

339-
async with session_manager.session(request=http_request):
339+
async with session_manager.session(http_connection=http_request):
340340

341341
job = job_store.get_job(job_id)
342342
if not job:
@@ -349,7 +349,7 @@ async def get_last_job_status(http_request: Request) -> EvaluateStatusResponse:
349349
"""Get the status of the last created evaluation job."""
350350
logger.info("Getting last job status")
351351

352-
async with session_manager.session(request=http_request):
352+
async with session_manager.session(http_connection=http_request):
353353

354354
job = job_store.get_last_job()
355355
if not job:
@@ -361,7 +361,7 @@ async def get_last_job_status(http_request: Request) -> EvaluateStatusResponse:
361361
async def get_jobs(http_request: Request, status: str | None = None) -> list[EvaluateStatusResponse]:
362362
"""Get all jobs, optionally filtered by status."""
363363

364-
async with session_manager.session(request=http_request):
364+
async with session_manager.session(http_connection=http_request):
365365

366366
if status is None:
367367
logger.info("Getting all jobs")
@@ -572,7 +572,7 @@ async def get_single(response: Response, request: Request):
572572

573573
response.headers["Content-Type"] = "application/json"
574574

575-
async with session_manager.session(request=request,
575+
async with session_manager.session(http_connection=request,
576576
user_authentication_callback=self._http_flow_handler.authenticate):
577577

578578
return await generate_single_response(None, session_manager, result_type=result_type)
@@ -583,7 +583,7 @@ def get_streaming_endpoint(streaming: bool, result_type: type | None, output_typ
583583

584584
async def get_stream(request: Request):
585585

586-
async with session_manager.session(request=request,
586+
async with session_manager.session(http_connection=request,
587587
user_authentication_callback=self._http_flow_handler.authenticate):
588588

589589
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -618,7 +618,7 @@ async def post_single(response: Response, request: Request, payload: request_typ
618618

619619
response.headers["Content-Type"] = "application/json"
620620

621-
async with session_manager.session(request=request,
621+
async with session_manager.session(http_connection=request,
622622
user_authentication_callback=self._http_flow_handler.authenticate):
623623

624624
return await generate_single_response(payload, session_manager, result_type=result_type)
@@ -632,7 +632,7 @@ def post_streaming_endpoint(request_type: type,
632632

633633
async def post_stream(request: Request, payload: request_type):
634634

635-
async with session_manager.session(request=request,
635+
async with session_manager.session(http_connection=request,
636636
user_authentication_callback=self._http_flow_handler.authenticate):
637637

638638
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -677,7 +677,7 @@ async def post_openai_api_compatible(response: Response, request: Request, paylo
677677
# Check if streaming is requested
678678
stream_requested = getattr(payload, 'stream', False)
679679

680-
async with session_manager.session(request=request):
680+
async with session_manager.session(http_connection=request):
681681
if stream_requested:
682682
# Return streaming response
683683
return StreamingResponse(headers={"Content-Type": "text/event-stream; charset=utf-8"},
@@ -688,42 +688,41 @@ async def post_openai_api_compatible(response: Response, request: Request, paylo
688688
step_adaptor=self.get_step_adaptor(),
689689
result_type=ChatResponseChunk,
690690
output_type=ChatResponseChunk))
691-
else:
692-
# Return single response - check if workflow supports non-streaming
693-
try:
691+
692+
# Return single response - check if workflow supports non-streaming
693+
try:
694+
response.headers["Content-Type"] = "application/json"
695+
return await generate_single_response(payload, session_manager, result_type=ChatResponse)
696+
except ValueError as e:
697+
if "Cannot get a single output value for streaming workflows" in str(e):
698+
# Workflow only supports streaming, but client requested non-streaming
699+
# Fall back to streaming and collect the result
700+
chunks = []
701+
async for chunk_str in generate_streaming_response_as_str(
702+
payload,
703+
session_manager=session_manager,
704+
streaming=True,
705+
step_adaptor=self.get_step_adaptor(),
706+
result_type=ChatResponseChunk,
707+
output_type=ChatResponseChunk):
708+
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
709+
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
710+
if chunk_data:
711+
try:
712+
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
713+
if (chunk_json.choices and len(chunk_json.choices) > 0
714+
and chunk_json.choices[0].delta
715+
and chunk_json.choices[0].delta.content is not None):
716+
chunks.append(chunk_json.choices[0].delta.content)
717+
except Exception:
718+
continue
719+
720+
# Create a single response from collected chunks
721+
content = "".join(chunks)
722+
single_response = ChatResponse.from_string(content)
694723
response.headers["Content-Type"] = "application/json"
695-
return await generate_single_response(payload, session_manager, result_type=ChatResponse)
696-
except ValueError as e:
697-
if "Cannot get a single output value for streaming workflows" in str(e):
698-
# Workflow only supports streaming, but client requested non-streaming
699-
# Fall back to streaming and collect the result
700-
chunks = []
701-
async for chunk_str in generate_streaming_response_as_str(
702-
payload,
703-
session_manager=session_manager,
704-
streaming=True,
705-
step_adaptor=self.get_step_adaptor(),
706-
result_type=ChatResponseChunk,
707-
output_type=ChatResponseChunk):
708-
if chunk_str.startswith("data: ") and not chunk_str.startswith("data: [DONE]"):
709-
chunk_data = chunk_str[6:].strip() # Remove "data: " prefix
710-
if chunk_data:
711-
try:
712-
chunk_json = ChatResponseChunk.model_validate_json(chunk_data)
713-
if (chunk_json.choices and len(chunk_json.choices) > 0
714-
and chunk_json.choices[0].delta
715-
and chunk_json.choices[0].delta.content is not None):
716-
chunks.append(chunk_json.choices[0].delta.content)
717-
except Exception:
718-
continue
719-
720-
# Create a single response from collected chunks
721-
content = "".join(chunks)
722-
single_response = ChatResponse.from_string(content)
723-
response.headers["Content-Type"] = "application/json"
724-
return single_response
725-
else:
726-
raise
724+
return single_response
725+
raise
727726

728727
return post_openai_api_compatible
729728

@@ -758,7 +757,7 @@ async def start_async_generation(
758757
http_request: Request) -> AsyncGenerateResponse | AsyncGenerationStatusResponse:
759758
"""Handle async generation requests."""
760759

761-
async with session_manager.session(request=http_request):
760+
async with session_manager.session(http_connection=http_request):
762761

763762
# if job_id is present and already exists return the job info
764763
if request.job_id:
@@ -804,7 +803,7 @@ async def get_async_job_status(job_id: str, http_request: Request) -> AsyncGener
804803
"""Get the status of an async job."""
805804
logger.info("Getting status for job %s", job_id)
806805

807-
async with session_manager.session(request=http_request):
806+
async with session_manager.session(http_connection=http_request):
808807

809808
job = job_store.get_job(job_id)
810809
if not job:

src/nat/front_ends/fastapi/message_handler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,9 @@ def _done_callback(task: asyncio.Task): # pylint: disable=unused-argument
166166
self._running_workflow_task = None
167167

168168
self._running_workflow_task = asyncio.create_task(
169-
self._run_workflow(content.text,
170-
self._conversation_id,
169+
self._run_workflow(payload=content.text,
170+
user_message_id=self._message_parent_id,
171+
conversation_id=self._conversation_id,
171172
result_type=self._schema_output_mapping[self._workflow_schema_type],
172173
output_type=self._schema_output_mapping[
173174
self._workflow_schema_type])).add_done_callback(_done_callback)
@@ -290,14 +291,16 @@ async def human_interaction_callback(self, prompt: InteractionPrompt) -> HumanRe
290291

291292
async def _run_workflow(self,
292293
payload: typing.Any,
294+
user_message_id: str | None = None,
293295
conversation_id: str | None = None,
294296
result_type: type | None = None,
295297
output_type: type | None = None) -> None:
296298

297299
try:
298300
async with self._session_manager.session(
301+
user_message_id=user_message_id,
299302
conversation_id=conversation_id,
300-
request=self._socket,
303+
http_connection=self._socket,
301304
user_input_callback=self.human_interaction_callback,
302305
user_authentication_callback=(self._flow_handler.authenticate
303306
if self._flow_handler else None)) as session:

src/nat/runtime/session.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from contextlib import asynccontextmanager
2222
from contextlib import nullcontext
2323

24+
from fastapi import WebSocket
2425
from starlette.requests import HTTPConnection
26+
from starlette.requests import Request
2527

2628
from nat.builder.context import Context
2729
from nat.builder.context import ContextState
@@ -89,7 +91,8 @@ def context(self) -> Context:
8991
@asynccontextmanager
9092
async def session(self,
9193
user_manager=None,
92-
request: HTTPConnection | None = None,
94+
http_connection: HTTPConnection | None = None,
95+
user_message_id: str | None = None,
9396
conversation_id: str | None = None,
9497
user_input_callback: Callable[[InteractionPrompt], Awaitable[HumanResponse]] = None,
9598
user_authentication_callback: Callable[[AuthProviderBaseConfig, AuthFlowType],
@@ -107,10 +110,11 @@ async def session(self,
107110
if user_authentication_callback is not None:
108111
token_user_authentication = self._context_state.user_auth_callback.set(user_authentication_callback)
109112

110-
if conversation_id is not None and request is None:
111-
self._context_state.conversation_id.set(conversation_id)
113+
if isinstance(http_connection, WebSocket):
114+
self.set_metadata_from_websocket(user_message_id, conversation_id)
112115

113-
self.set_metadata_from_http_request(request)
116+
if isinstance(http_connection, Request):
117+
self.set_metadata_from_http_request(http_connection)
114118

115119
try:
116120
yield self
@@ -135,14 +139,11 @@ async def run(self, message):
135139
async with self._workflow.run(message) as runner:
136140
yield runner
137141

138-
def set_metadata_from_http_request(self, request: HTTPConnection | None) -> None:
142+
def set_metadata_from_http_request(self, request: Request) -> None:
139143
"""
140144
Extracts and sets user metadata request attributes from a HTTP request.
141145
If request is None, no attributes are set.
142146
"""
143-
if request is None:
144-
return
145-
146147
self._context.metadata._request.method = getattr(request, "method", None)
147148
self._context.metadata._request.url_path = request.url.path
148149
self._context.metadata._request.url_port = request.url.port
@@ -157,6 +158,20 @@ def set_metadata_from_http_request(self, request: HTTPConnection | None) -> None
157158
if request.headers.get("conversation-id"):
158159
self._context_state.conversation_id.set(request.headers["conversation-id"])
159160

161+
if request.headers.get("user-message-id"):
162+
self._context_state.user_message_id.set(request.headers["user-message-id"])
163+
164+
def set_metadata_from_websocket(self, user_message_id: str | None, conversation_id: str | None) -> None:
165+
"""
166+
Extracts and sets user metadata for Websocket connections.
167+
"""
168+
169+
if conversation_id is not None:
170+
self._context_state.conversation_id.set(conversation_id)
171+
172+
if user_message_id is not None:
173+
self._context_state.user_message_id.set(user_message_id)
174+
160175

161176
# Compatibility aliases with previous releases
162177
AIQSessionManager = SessionManager

0 commit comments

Comments
 (0)