Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import (
GroupChatGetThread,
GroupChatPause,
GroupChatReset,
GroupChatResume,
GroupChatStart,
GroupChatTermination,
GroupChatThread,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent
Expand Down Expand Up @@ -745,6 +747,36 @@ async def resume(self) -> None:
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
)

async def get_thread(self) -> Sequence[BaseAgentEvent | BaseChatMessage]:
"""Get the current message thread for the group chat team.

The returned sequence is a snapshot of the manager's internal thread and
contains the task messages, agent events, and chat messages accumulated
so far. Mutating the returned sequence does not mutate the team state.
"""

if not self._initialized:
await self._init(self._runtime)

started_runtime = False
if self._embedded_runtime and not self._is_running:
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
self._runtime.start()
started_runtime = True

try:
thread = await self._runtime.send_message(
GroupChatGetThread(),
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
)
if not isinstance(thread, GroupChatThread):
raise RuntimeError(f"Expected GroupChatThread response, got {type(thread)}.")
return list(thread.messages)
finally:
if started_runtime:
assert isinstance(self._runtime, SingleThreadedAgentRuntime)
await self._runtime.stop_when_idle()

async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the group chat team.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._events import (
GroupChatAgentResponse,
GroupChatError,
GroupChatGetThread,
GroupChatMessage,
GroupChatPause,
GroupChatRequestPublish,
Expand All @@ -17,6 +18,7 @@
GroupChatStart,
GroupChatTeamResponse,
GroupChatTermination,
GroupChatThread,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent
Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
GroupChatTeamResponse,
GroupChatMessage,
GroupChatReset,
GroupChatGetThread,
],
)
if max_turns is not None and max_turns <= 0:
Expand Down Expand Up @@ -285,6 +288,11 @@ async def handle_resume(self, message: GroupChatResume, ctx: MessageContext) ->
"""Resume the group chat manager. This is a no-op in the base class."""
pass

@rpc
async def handle_get_thread(self, message: GroupChatGetThread, ctx: MessageContext) -> GroupChatThread:
"""Get a snapshot of the current group chat message thread."""
return GroupChatThread(messages=list(self._message_thread))

@abstractmethod
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
"""Validate the state of the group chat given the start messages.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ class GroupChatResume(BaseModel):
...


class GroupChatGetThread(BaseModel):
"""A request to get the current group chat message thread."""

...


class GroupChatThread(BaseModel):
"""A response containing the current group chat message thread."""

messages: List[SerializeAsAny[BaseAgentEvent | BaseChatMessage]]
"""The messages and events in the current group chat thread."""


class GroupChatError(BaseModel):
"""A message indicating that an error occurred in the group chat."""

Expand Down
26 changes: 26 additions & 0 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,32 @@ async def test_round_robin_group_chat_with_resume_and_reset(runtime: AgentRuntim
assert result.stop_reason is not None


@pytest.mark.asyncio
async def test_round_robin_group_chat_get_thread(runtime: AgentRuntime | None) -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
termination = MaxMessageTermination(2)
team = RoundRobinGroupChat(
participants=[agent_1, agent_2], termination_condition=termination, runtime=runtime
)

assert await team.get_thread() == []

result = await team.run(task="Write a program that prints 'Hello, world!'")
thread = await team.get_thread()
assert list(thread) == result.messages

thread_list = list(thread)
thread_list.append(TextMessage(content="local mutation", source="test"))
assert list(await team.get_thread()) == result.messages

result = await team.run()
assert list(await team.get_thread())[-len(result.messages) :] == result.messages

await team.reset()
assert await team.get_thread() == []


@pytest.mark.asyncio
async def test_round_robin_group_chat_with_exception_raised_from_agent(runtime: AgentRuntime | None) -> None:
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
Expand Down