|
5 | 5 | import pytest |
6 | 6 | from pydantic import BaseModel |
7 | 7 | from pydantic_ai import Agent |
| 8 | +from uipath.core.chat.message import UiPathConversationMessageEvent |
| 9 | +from uipath.runtime.events import UiPathRuntimeMessageEvent |
8 | 10 |
|
9 | 11 | from uipath_pydantic_ai.runtime.errors import ( |
10 | 12 | UiPathPydanticAIErrorCode, |
@@ -582,3 +584,188 @@ def my_tool(ctx, query: str) -> str: |
582 | 584 | result = event.payload["tool_results"][0] |
583 | 585 | assert "tool_name" in result |
584 | 586 | assert "content" in result |
| 587 | + |
| 588 | + |
| 589 | +# ============= TOKEN STREAMING TESTS ============= |
| 590 | + |
| 591 | + |
| 592 | +@pytest.mark.asyncio |
| 593 | +async def test_stream_emits_message_events_with_message_id(): |
| 594 | + """Streaming must emit UiPathConversationMessageEvent payloads with a message_id.""" |
| 595 | + from pydantic_ai.models.test import TestModel |
| 596 | + |
| 597 | + agent = Agent(TestModel(custom_output_text="Hi there"), name="msg_agent") |
| 598 | + runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test") |
| 599 | + |
| 600 | + msg_events: list[UiPathConversationMessageEvent] = [] |
| 601 | + async for event in runtime.stream(input=_uipath_input("Hello")): |
| 602 | + if isinstance(event, UiPathRuntimeMessageEvent): |
| 603 | + payload = event.payload |
| 604 | + assert isinstance(payload, UiPathConversationMessageEvent) |
| 605 | + msg_events.append(payload) |
| 606 | + |
| 607 | + assert len(msg_events) >= 3 # START + at least one CHUNK + END |
| 608 | + # All events share the same message_id |
| 609 | + ids = {e.message_id for e in msg_events} |
| 610 | + assert len(ids) == 1 |
| 611 | + |
| 612 | + |
| 613 | +@pytest.mark.asyncio |
| 614 | +async def test_stream_message_lifecycle_start_chunks_end(): |
| 615 | + """Streaming follows START -> CHUNK(s) -> END lifecycle.""" |
| 616 | + from pydantic_ai.models.test import TestModel |
| 617 | + |
| 618 | + agent = Agent(TestModel(custom_output_text="Hello world"), name="lc_agent") |
| 619 | + runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test") |
| 620 | + |
| 621 | + msg_events: list[UiPathConversationMessageEvent] = [] |
| 622 | + async for event in runtime.stream(input=_uipath_input("Say hello")): |
| 623 | + if isinstance(event, UiPathRuntimeMessageEvent): |
| 624 | + msg_events.append(event.payload) |
| 625 | + |
| 626 | + # First event: START (has start + content_part.start) |
| 627 | + first = msg_events[0] |
| 628 | + assert first.start is not None |
| 629 | + assert first.start.role == "assistant" |
| 630 | + assert first.start.timestamp is not None |
| 631 | + assert first.content_part is not None |
| 632 | + assert first.content_part.start is not None |
| 633 | + assert first.content_part.start.mime_type == "text/plain" |
| 634 | + |
| 635 | + # Middle events: CHUNK (has content_part.chunk) |
| 636 | + chunks = msg_events[1:-1] |
| 637 | + assert len(chunks) >= 1 |
| 638 | + for chunk_event in chunks: |
| 639 | + assert chunk_event.content_part is not None |
| 640 | + assert chunk_event.content_part.chunk is not None |
| 641 | + assert isinstance(chunk_event.content_part.chunk.data, str) |
| 642 | + assert len(chunk_event.content_part.chunk.data) > 0 |
| 643 | + |
| 644 | + # Last event: END (has end + content_part.end) |
| 645 | + last = msg_events[-1] |
| 646 | + assert last.end is not None |
| 647 | + assert last.content_part is not None |
| 648 | + assert last.content_part.end is not None |
| 649 | + |
| 650 | + |
| 651 | +@pytest.mark.asyncio |
| 652 | +async def test_stream_token_chunks_reassemble_to_full_text(): |
| 653 | + """Concatenating all chunk data must produce the full response text.""" |
| 654 | + from pydantic_ai.models.test import TestModel |
| 655 | + |
| 656 | + expected_text = "The quick brown fox jumps over the lazy dog" |
| 657 | + agent = Agent(TestModel(custom_output_text=expected_text), name="concat_agent") |
| 658 | + runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test") |
| 659 | + |
| 660 | + chunk_texts: list[str] = [] |
| 661 | + async for event in runtime.stream(input=_uipath_input("Tell me something")): |
| 662 | + if isinstance(event, UiPathRuntimeMessageEvent): |
| 663 | + payload = event.payload |
| 664 | + if payload.content_part and payload.content_part.chunk: |
| 665 | + chunk_texts.append(payload.content_part.chunk.data) |
| 666 | + |
| 667 | + reassembled = "".join(chunk_texts) |
| 668 | + assert reassembled == expected_text |
| 669 | + |
| 670 | + |
| 671 | +@pytest.mark.asyncio |
| 672 | +async def test_stream_content_part_id_consistent(): |
| 673 | + """All content_part events in a message must share the same content_part_id.""" |
| 674 | + from pydantic_ai.models.test import TestModel |
| 675 | + |
| 676 | + agent = Agent(TestModel(custom_output_text="Consistent IDs"), name="cpid_agent") |
| 677 | + runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test") |
| 678 | + |
| 679 | + content_part_ids: set[str] = set() |
| 680 | + async for event in runtime.stream(input=_uipath_input("Check IDs")): |
| 681 | + if isinstance(event, UiPathRuntimeMessageEvent): |
| 682 | + payload = event.payload |
| 683 | + if payload.content_part: |
| 684 | + content_part_ids.add(payload.content_part.content_part_id) |
| 685 | + |
| 686 | + assert len(content_part_ids) == 1 |
| 687 | + |
| 688 | + |
| 689 | +@pytest.mark.asyncio |
| 690 | +async def test_stream_with_tools_emits_message_events(): |
| 691 | + """Streaming an agent with tools must emit message events for the final text response.""" |
| 692 | + from pydantic_ai.models.test import TestModel |
| 693 | + |
| 694 | + def my_tool(ctx, query: str) -> str: |
| 695 | + """Search tool. |
| 696 | +
|
| 697 | + Args: |
| 698 | + ctx: The agent context. |
| 699 | + query: The search query. |
| 700 | +
|
| 701 | + Returns: |
| 702 | + Search results. |
| 703 | + """ |
| 704 | + return f"Result for {query}" |
| 705 | + |
| 706 | + agent = Agent(TestModel(), name="tool_msg_agent", tools=[my_tool]) |
| 707 | + runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test") |
| 708 | + |
| 709 | + msg_events: list[UiPathConversationMessageEvent] = [] |
| 710 | + async for event in runtime.stream(input=_uipath_input("Search for cats")): |
| 711 | + if isinstance(event, UiPathRuntimeMessageEvent): |
| 712 | + msg_events.append(event.payload) |
| 713 | + |
| 714 | + # Should have at least one message lifecycle (final response after tool call) |
| 715 | + assert len(msg_events) >= 3 |
| 716 | + |
| 717 | + # Verify START/END presence |
| 718 | + starts = [e for e in msg_events if e.start is not None] |
| 719 | + ends = [e for e in msg_events if e.end is not None] |
| 720 | + assert len(starts) >= 1 |
| 721 | + assert len(ends) >= 1 |
| 722 | + |
| 723 | + # Text chunks should exist |
| 724 | + chunks = [e for e in msg_events if e.content_part and e.content_part.chunk] |
| 725 | + assert len(chunks) >= 1 |
| 726 | + |
| 727 | + |
| 728 | +@pytest.mark.asyncio |
| 729 | +async def test_stream_tool_only_turn_skips_message_events(): |
| 730 | + """Model turns that produce only tool calls (no text) should not emit message events.""" |
| 731 | + from pydantic_ai.models.test import TestModel |
| 732 | + from uipath.runtime.events import ( |
| 733 | + UiPathRuntimeStateEvent, |
| 734 | + UiPathRuntimeStatePhase, |
| 735 | + ) |
| 736 | + |
| 737 | + def my_tool(ctx, query: str) -> str: |
| 738 | + """A tool. |
| 739 | +
|
| 740 | + Args: |
| 741 | + ctx: The agent context. |
| 742 | + query: The query. |
| 743 | +
|
| 744 | + Returns: |
| 745 | + Results. |
| 746 | + """ |
| 747 | + return "result" |
| 748 | + |
| 749 | + # TestModel with tools: first turn calls tool (no text), second turn returns text |
| 750 | + agent = Agent(TestModel(), name="skip_agent", tools=[my_tool]) |
| 751 | + runtime = UiPathPydanticAIRuntime(agent=agent, runtime_id="test", entrypoint="test") |
| 752 | + |
| 753 | + msg_events: list[UiPathConversationMessageEvent] = [] |
| 754 | + state_events: list[UiPathRuntimeStateEvent] = [] |
| 755 | + async for event in runtime.stream(input=_uipath_input("Do something")): |
| 756 | + if isinstance(event, UiPathRuntimeMessageEvent): |
| 757 | + msg_events.append(event.payload) |
| 758 | + elif isinstance(event, UiPathRuntimeStateEvent): |
| 759 | + state_events.append(event) |
| 760 | + |
| 761 | + # Should have multiple model turns via state events (tool turn + final turn) |
| 762 | + agent_started = [ |
| 763 | + e |
| 764 | + for e in state_events |
| 765 | + if e.node_name == "skip_agent" and e.phase == UiPathRuntimeStatePhase.STARTED |
| 766 | + ] |
| 767 | + assert len(agent_started) >= 2 # at least 2 model request turns |
| 768 | + |
| 769 | + # Message events only come from the text-producing turn(s) |
| 770 | + message_ids = {e.message_id for e in msg_events} |
| 771 | + assert len(message_ids) == 1 # only the final text response |
0 commit comments