Skip to content

Commit 5cbf811

Browse files
authored
Merge pull request #206 from UiPath/fix/pydantic_ai_streaming_message_events
fix: pydantic-ai stream LLM tokens as UiPathConversationMessageEvent
2 parents 6c7914b + 21fddf6 commit 5cbf811

7 files changed

Lines changed: 285 additions & 38 deletions

File tree

packages/uipath-pydantic-ai/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "uipath-pydantic-ai"
3-
version = "0.0.2"
3+
version = "0.0.3"
44
description = "Python SDK that enables developers to build and deploy PydanticAI agents to the UiPath Cloud Platform"
55
readme = "README.md"
66
requires-python = ">=3.11"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
flowchart TB
2+
__start__(__start__)
3+
weather_agent(weather_agent)
4+
weather_agent_tools(tools)
5+
__end__(__end__)
6+
weather_agent --> weather_agent_tools
7+
weather_agent_tools --> weather_agent
8+
__start__ --> |input|weather_agent
9+
weather_agent --> |output|__end__

packages/uipath-pydantic-ai/src/uipath_pydantic_ai/runtime/runtime.py

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
11
"""Runtime class for executing PydanticAI Agents within the UiPath framework."""
22

33
import json
4+
from datetime import datetime, timezone
45
from typing import Any, AsyncGenerator
56
from uuid import uuid4
67

78
from pydantic import BaseModel
89
from pydantic_ai import Agent, FunctionToolset
10+
from pydantic_ai.messages import ToolReturnPart
11+
from uipath.core.chat.content import (
12+
UiPathConversationContentPartChunkEvent,
13+
UiPathConversationContentPartEndEvent,
14+
UiPathConversationContentPartEvent,
15+
UiPathConversationContentPartStartEvent,
16+
)
17+
from uipath.core.chat.message import (
18+
UiPathConversationMessageEndEvent,
19+
UiPathConversationMessageEvent,
20+
UiPathConversationMessageStartEvent,
21+
)
922
from uipath.core.serialization import serialize_json
1023
from uipath.runtime import (
1124
UiPathExecuteOptions,
@@ -88,18 +101,65 @@ async def stream(
88101
)
89102

90103
model_node = node
91-
node = await agent_run.next(node)
92-
93-
yield UiPathRuntimeMessageEvent(
94-
payload=json.loads(serialize_json(model_node.request)),
95-
metadata={"event_name": "model_request"},
96-
)
104+
message_id = str(uuid4())
105+
content_part_id = f"chunk-{message_id}-0"
106+
has_text = False
107+
108+
async with model_node.stream(agent_run.ctx) as stream:
109+
async for text_chunk in stream.stream_text(
110+
delta=True, debounce_by=None
111+
):
112+
if not has_text:
113+
has_text = True
114+
yield UiPathRuntimeMessageEvent(
115+
payload=UiPathConversationMessageEvent(
116+
message_id=message_id,
117+
start=UiPathConversationMessageStartEvent(
118+
role="assistant",
119+
timestamp=self._get_timestamp(),
120+
),
121+
content_part=UiPathConversationContentPartEvent(
122+
content_part_id=content_part_id,
123+
start=UiPathConversationContentPartStartEvent(
124+
mime_type="text/plain",
125+
),
126+
),
127+
),
128+
)
129+
130+
yield UiPathRuntimeMessageEvent(
131+
payload=UiPathConversationMessageEvent(
132+
message_id=message_id,
133+
content_part=UiPathConversationContentPartEvent(
134+
content_part_id=content_part_id,
135+
chunk=UiPathConversationContentPartChunkEvent(
136+
data=text_chunk,
137+
),
138+
),
139+
),
140+
)
141+
142+
next_node = await agent_run.next(model_node)
143+
144+
if has_text:
145+
yield UiPathRuntimeMessageEvent(
146+
payload=UiPathConversationMessageEvent(
147+
message_id=message_id,
148+
end=UiPathConversationMessageEndEvent(),
149+
content_part=UiPathConversationContentPartEvent(
150+
content_part_id=content_part_id,
151+
end=UiPathConversationContentPartEndEvent(),
152+
),
153+
),
154+
)
97155

98-
yield UiPathRuntimeStateEvent(
99-
payload=self._model_response_payload(node),
100-
node_name=agent_name,
101-
phase=UiPathRuntimeStatePhase.COMPLETED,
102-
)
156+
if Agent.is_call_tools_node(next_node):
157+
yield UiPathRuntimeStateEvent(
158+
payload=self._model_response_payload(next_node),
159+
node_name=agent_name,
160+
phase=UiPathRuntimeStatePhase.COMPLETED,
161+
)
162+
node = next_node
103163

104164
elif Agent.is_call_tools_node(node):
105165
tool_calls = node.model_response.tool_calls if has_tools else []
@@ -115,14 +175,15 @@ async def stream(
115175
phase=UiPathRuntimeStatePhase.STARTED,
116176
)
117177

118-
node = await agent_run.next(node)
178+
next_node = await agent_run.next(node)
119179

120-
if tool_calls:
180+
if tool_calls and Agent.is_model_request_node(next_node):
121181
yield UiPathRuntimeStateEvent(
122-
payload=self._tool_results_payload(node),
182+
payload=self._tool_results_payload(next_node),
123183
node_name=tools_node_name,
124184
phase=UiPathRuntimeStatePhase.COMPLETED,
125185
)
186+
node = next_node
126187

127188
else:
128189
node = await agent_run.next(node)
@@ -135,6 +196,12 @@ async def stream(
135196
except Exception as e:
136197
raise self._create_runtime_error(e) from e
137198

199+
@staticmethod
200+
def _get_timestamp() -> str:
201+
"""Get current UTC timestamp in ISO 8601 format."""
202+
now = datetime.now(timezone.utc)
203+
return now.strftime("%Y-%m-%dT%H:%M:%S.") + f"{now.microsecond // 1000:03d}Z"
204+
138205
@staticmethod
139206
def _model_request_payload(node: Any) -> dict[str, Any]:
140207
"""Build payload for a ModelRequestNode STARTED event."""
@@ -181,8 +248,6 @@ def _tool_results_payload(next_node: Any) -> dict[str, Any]:
181248
After agent_run.next() the returned node is a ModelRequestNode
182249
whose request.parts contain ToolReturnPart objects with results.
183250
"""
184-
from pydantic_ai.messages import ToolReturnPart
185-
186251
payload: dict[str, Any] = {}
187252
try:
188253
parts = next_node.request.parts if next_node.request else []

packages/uipath-pydantic-ai/src/uipath_pydantic_ai/runtime/schema.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,10 @@ def _conversation_message_item_schema() -> dict[str, Any]:
287287
},
288288
"required": ["inline"],
289289
},
290-
"citations": {
291-
"type": "array",
292-
"items": {"type": "object"},
293-
},
294290
},
295291
"required": ["data"],
296292
},
297293
},
298-
"toolCalls": {"type": "array", "items": {"type": "object"}},
299-
"interrupts": {"type": "array", "items": {"type": "object"}},
300294
},
301295
"required": ["role", "contentParts"],
302296
}

packages/uipath-pydantic-ai/tests/test_runtime.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pytest
66
from pydantic import BaseModel
77
from pydantic_ai import Agent
8+
from uipath.core.chat.message import UiPathConversationMessageEvent
9+
from uipath.runtime.events import UiPathRuntimeMessageEvent
810

911
from uipath_pydantic_ai.runtime.errors import (
1012
UiPathPydanticAIErrorCode,
@@ -582,3 +584,188 @@ def my_tool(ctx, query: str) -> str:
582584
result = event.payload["tool_results"][0]
583585
assert "tool_name" in result
584586
assert "content" in result
587+
588+
589+
# ============= TOKEN STREAMING TESTS =============
590+
591+
592+
@pytest.mark.asyncio
593+
async def test_stream_emits_message_events_with_message_id():
594+
"""Streaming must emit UiPathConversationMessageEvent payloads with a message_id."""
595+
from pydantic_ai.models.test import TestModel
596+
597+
agent = Agent(TestModel(custom_output_text="Hi there"), name="msg_agent")
598+
runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test")
599+
600+
msg_events: list[UiPathConversationMessageEvent] = []
601+
async for event in runtime.stream(input=_uipath_input("Hello")):
602+
if isinstance(event, UiPathRuntimeMessageEvent):
603+
payload = event.payload
604+
assert isinstance(payload, UiPathConversationMessageEvent)
605+
msg_events.append(payload)
606+
607+
assert len(msg_events) >= 3 # START + at least one CHUNK + END
608+
# All events share the same message_id
609+
ids = {e.message_id for e in msg_events}
610+
assert len(ids) == 1
611+
612+
613+
@pytest.mark.asyncio
614+
async def test_stream_message_lifecycle_start_chunks_end():
615+
"""Streaming follows START -> CHUNK(s) -> END lifecycle."""
616+
from pydantic_ai.models.test import TestModel
617+
618+
agent = Agent(TestModel(custom_output_text="Hello world"), name="lc_agent")
619+
runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test")
620+
621+
msg_events: list[UiPathConversationMessageEvent] = []
622+
async for event in runtime.stream(input=_uipath_input("Say hello")):
623+
if isinstance(event, UiPathRuntimeMessageEvent):
624+
msg_events.append(event.payload)
625+
626+
# First event: START (has start + content_part.start)
627+
first = msg_events[0]
628+
assert first.start is not None
629+
assert first.start.role == "assistant"
630+
assert first.start.timestamp is not None
631+
assert first.content_part is not None
632+
assert first.content_part.start is not None
633+
assert first.content_part.start.mime_type == "text/plain"
634+
635+
# Middle events: CHUNK (has content_part.chunk)
636+
chunks = msg_events[1:-1]
637+
assert len(chunks) >= 1
638+
for chunk_event in chunks:
639+
assert chunk_event.content_part is not None
640+
assert chunk_event.content_part.chunk is not None
641+
assert isinstance(chunk_event.content_part.chunk.data, str)
642+
assert len(chunk_event.content_part.chunk.data) > 0
643+
644+
# Last event: END (has end + content_part.end)
645+
last = msg_events[-1]
646+
assert last.end is not None
647+
assert last.content_part is not None
648+
assert last.content_part.end is not None
649+
650+
651+
@pytest.mark.asyncio
652+
async def test_stream_token_chunks_reassemble_to_full_text():
653+
"""Concatenating all chunk data must produce the full response text."""
654+
from pydantic_ai.models.test import TestModel
655+
656+
expected_text = "The quick brown fox jumps over the lazy dog"
657+
agent = Agent(TestModel(custom_output_text=expected_text), name="concat_agent")
658+
runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test")
659+
660+
chunk_texts: list[str] = []
661+
async for event in runtime.stream(input=_uipath_input("Tell me something")):
662+
if isinstance(event, UiPathRuntimeMessageEvent):
663+
payload = event.payload
664+
if payload.content_part and payload.content_part.chunk:
665+
chunk_texts.append(payload.content_part.chunk.data)
666+
667+
reassembled = "".join(chunk_texts)
668+
assert reassembled == expected_text
669+
670+
671+
@pytest.mark.asyncio
672+
async def test_stream_content_part_id_consistent():
673+
"""All content_part events in a message must share the same content_part_id."""
674+
from pydantic_ai.models.test import TestModel
675+
676+
agent = Agent(TestModel(custom_output_text="Consistent IDs"), name="cpid_agent")
677+
runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test")
678+
679+
content_part_ids: set[str] = set()
680+
async for event in runtime.stream(input=_uipath_input("Check IDs")):
681+
if isinstance(event, UiPathRuntimeMessageEvent):
682+
payload = event.payload
683+
if payload.content_part:
684+
content_part_ids.add(payload.content_part.content_part_id)
685+
686+
assert len(content_part_ids) == 1
687+
688+
689+
@pytest.mark.asyncio
690+
async def test_stream_with_tools_emits_message_events():
691+
"""Streaming an agent with tools must emit message events for the final text response."""
692+
from pydantic_ai.models.test import TestModel
693+
694+
def my_tool(ctx, query: str) -> str:
695+
"""Search tool.
696+
697+
Args:
698+
ctx: The agent context.
699+
query: The search query.
700+
701+
Returns:
702+
Search results.
703+
"""
704+
return f"Result for {query}"
705+
706+
agent = Agent(TestModel(), name="tool_msg_agent", tools=[my_tool])
707+
runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test")
708+
709+
msg_events: list[UiPathConversationMessageEvent] = []
710+
async for event in runtime.stream(input=_uipath_input("Search for cats")):
711+
if isinstance(event, UiPathRuntimeMessageEvent):
712+
msg_events.append(event.payload)
713+
714+
# Should have at least one message lifecycle (final response after tool call)
715+
assert len(msg_events) >= 3
716+
717+
# Verify START/END presence
718+
starts = [e for e in msg_events if e.start is not None]
719+
ends = [e for e in msg_events if e.end is not None]
720+
assert len(starts) >= 1
721+
assert len(ends) >= 1
722+
723+
# Text chunks should exist
724+
chunks = [e for e in msg_events if e.content_part and e.content_part.chunk]
725+
assert len(chunks) >= 1
726+
727+
728+
@pytest.mark.asyncio
729+
async def test_stream_tool_only_turn_skips_message_events():
730+
"""Model turns that produce only tool calls (no text) should not emit message events."""
731+
from pydantic_ai.models.test import TestModel
732+
from uipath.runtime.events import (
733+
UiPathRuntimeStateEvent,
734+
UiPathRuntimeStatePhase,
735+
)
736+
737+
def my_tool(ctx, query: str) -> str:
738+
"""A tool.
739+
740+
Args:
741+
ctx: The agent context.
742+
query: The query.
743+
744+
Returns:
745+
Results.
746+
"""
747+
return "result"
748+
749+
# TestModel with tools: first turn calls tool (no text), second turn returns text
750+
agent = Agent(TestModel(), name="skip_agent", tools=[my_tool])
751+
runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test")
752+
753+
msg_events: list[UiPathConversationMessageEvent] = []
754+
state_events: list[UiPathRuntimeStateEvent] = []
755+
async for event in runtime.stream(input=_uipath_input("Do something")):
756+
if isinstance(event, UiPathRuntimeMessageEvent):
757+
msg_events.append(event.payload)
758+
elif isinstance(event, UiPathRuntimeStateEvent):
759+
state_events.append(event)
760+
761+
# Should have multiple model turns via state events (tool turn + final turn)
762+
agent_started = [
763+
e
764+
for e in state_events
765+
if e.node_name == "skip_agent" and e.phase == UiPathRuntimeStatePhase.STARTED
766+
]
767+
assert len(agent_started) >= 2 # at least 2 model request turns
768+
769+
# Message events only come from the text-producing turn(s)
770+
message_ids = {e.message_id for e in msg_events}
771+
assert len(message_ids) == 1 # only the final text response

0 commit comments

Comments
 (0)