Skip to content

Commit 9f8e892

Browse files
abhinav-aegisekzhulspinheiro
authored
Added Graph Based Execution functionality to Autogen (#6333)
Closes #4623 ### Add Directed Graph-based Group Chat Execution Engine (`DiGraphGroupChat`) This PR introduces a new graph-based execution framework for Autogen agent teams, located under `autogen_agentchat/teams/_group_chat/_graph`. **Key Features:** - **`DiGraphGroupChat`**: A new group chat implementation that executes agents based on a user-defined directed graph (DAG or cyclic with exit conditions). - **`AGGraphBuilder`**: A fluent builder API to programmatically construct graphs. - **`MessageFilterAgent`**: A wrapper to restrict what messages an agent sees before invocation, supporting per-source and per-position filtering. **Capabilities:** - Supports sequential, parallel, conditional, and cyclic workflows. - Enables fine-grained control over both execution order and message context. - Compatible with existing Autogen agents and runtime interfaces. **Tests:** - Located in `autogen_agentchat/tests/test_group_chat_graph.py` - Includes unit and integration tests covering: - Graph validation - Execution paths - Conditional routing - Loops with exit conditions - Message filtering Let me know if anything needs refactoring or if you'd like the components split further. --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Leonardo Pinheiro <leosantospinheiro@gmail.com>
1 parent fcbac2d commit 9f8e892

7 files changed

Lines changed: 2480 additions & 0 deletions

File tree

python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ._assistant_agent import AssistantAgent
77
from ._base_chat_agent import BaseChatAgent
88
from ._code_executor_agent import CodeExecutorAgent
9+
from ._message_filter_agent import MessageFilterAgent, MessageFilterConfig, PerSourceFilter
910
from ._society_of_mind_agent import SocietyOfMindAgent
1011
from ._user_proxy_agent import UserProxyAgent
1112

@@ -15,4 +16,7 @@
1516
"CodeExecutorAgent",
1617
"SocietyOfMindAgent",
1718
"UserProxyAgent",
19+
"MessageFilterAgent",
20+
"MessageFilterConfig",
21+
"PerSourceFilter",
1822
]
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from typing import AsyncGenerator, List, Literal, Optional, Sequence, Union
2+
3+
from autogen_core import CancellationToken, Component, ComponentModel
4+
from pydantic import BaseModel
5+
6+
from autogen_agentchat.agents import BaseChatAgent
7+
from autogen_agentchat.base import Response
8+
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
9+
10+
# ------------------------------
11+
# Message Filter Config
12+
# ------------------------------
13+
14+
15+
class PerSourceFilter(BaseModel):
16+
source: str
17+
position: Optional[Literal["first", "last"]] = None
18+
count: Optional[int] = None
19+
20+
21+
class MessageFilterConfig(BaseModel):
22+
per_source: List[PerSourceFilter]
23+
24+
25+
# ------------------------------
26+
# Component Config
27+
# ------------------------------
28+
29+
30+
class MessageFilterAgentConfig(BaseModel):
31+
name: str
32+
wrapped_agent: ComponentModel
33+
filter: MessageFilterConfig
34+
35+
36+
# ------------------------------
37+
# Message Filter Agent
38+
# ------------------------------
39+
40+
41+
class MessageFilterAgent(BaseChatAgent, Component[MessageFilterAgentConfig]):
42+
"""
43+
A wrapper agent that filters incoming messages before passing them to the inner agent.
44+
45+
.. warning::
46+
47+
This is an experimental feature, and the API will change in the future releases.
48+
49+
This is useful in scenarios like multi-agent workflows where an agent should only
50+
process a subset of the full message history—for example, only the last message
51+
from each upstream agent, or only the first message from a specific source.
52+
53+
Filtering is configured using :class:`MessageFilterConfig`, which supports:
54+
- Filtering by message source (e.g., only messages from "user" or another agent)
55+
- Selecting the first N or last N messages from each source
56+
- If position is `None`, all messages from that source are included
57+
58+
This agent is compatible with both direct message passing and team-based execution
59+
such as :class:`~autogen_agentchat.teams.GraphFlow`.
60+
61+
Example:
62+
>>> agent_a = MessageFilterAgent(
63+
... name="A",
64+
... wrapped_agent=some_other_agent,
65+
... filter=MessageFilterConfig(
66+
... per_source=[
67+
... PerSourceFilter(source="user", position="first", count=1),
68+
... PerSourceFilter(source="B", position="last", count=2),
69+
... ]
70+
... ),
71+
... )
72+
73+
Example use case with Graph:
74+
Suppose you have a looping multi-agent graph: A → B → A → B → C.
75+
76+
You want:
77+
- A to only see the user message and the last message from B
78+
- B to see the user message, last message from A, and its own prior responses (for reflection)
79+
- C to see the user message and the last message from B
80+
81+
Wrap the agents like so:
82+
83+
>>> agent_a = MessageFilterAgent(
84+
... name="A",
85+
... wrapped_agent=agent_a_inner,
86+
... filter=MessageFilterConfig(
87+
... per_source=[
88+
... PerSourceFilter(source="user", position="first", count=1),
89+
... PerSourceFilter(source="B", position="last", count=1),
90+
... ]
91+
... ),
92+
... )
93+
94+
>>> agent_b = MessageFilterAgent(
95+
... name="B",
96+
... wrapped_agent=agent_b_inner,
97+
... filter=MessageFilterConfig(
98+
... per_source=[
99+
... PerSourceFilter(source="user", position="first", count=1),
100+
... PerSourceFilter(source="A", position="last", count=1),
101+
... PerSourceFilter(source="B", position="last", count=10),
102+
... ]
103+
... ),
104+
... )
105+
106+
>>> agent_c = MessageFilterAgent(
107+
... name="C",
108+
... wrapped_agent=agent_c_inner,
109+
... filter=MessageFilterConfig(
110+
... per_source=[
111+
... PerSourceFilter(source="user", position="first", count=1),
112+
... PerSourceFilter(source="B", position="last", count=1),
113+
... ]
114+
... ),
115+
... )
116+
117+
Then define the graph:
118+
119+
>>> graph = DiGraph(
120+
... nodes={
121+
... "A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
122+
... "B": DiGraphNode(
123+
... name="B",
124+
... edges=[
125+
... DiGraphEdge(target="C", condition="exit"),
126+
... DiGraphEdge(target="A", condition="loop"),
127+
... ],
128+
... ),
129+
... "C": DiGraphNode(name="C", edges=[]),
130+
... },
131+
... default_start_node="A",
132+
... )
133+
134+
This will ensure each agent sees only what is needed for its decision or action logic.
135+
"""
136+
137+
component_config_schema = MessageFilterAgentConfig
138+
component_provider_override = "autogen_agentchat.agents.MessageFilterAgent"
139+
140+
def __init__(
141+
self,
142+
name: str,
143+
wrapped_agent: BaseChatAgent,
144+
filter: MessageFilterConfig,
145+
):
146+
super().__init__(name=name, description=f"{wrapped_agent.description} (with message filtering)")
147+
self._wrapped_agent = wrapped_agent
148+
self._filter = filter
149+
150+
@property
151+
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
152+
return self._wrapped_agent.produced_message_types
153+
154+
def _apply_filter(self, messages: Sequence[BaseChatMessage]) -> Sequence[BaseChatMessage]:
155+
result: List[BaseChatMessage] = []
156+
157+
for source_filter in self._filter.per_source:
158+
msgs = [m for m in messages if m.source == source_filter.source]
159+
160+
if source_filter.position == "first" and source_filter.count:
161+
msgs = msgs[: source_filter.count]
162+
elif source_filter.position == "last" and source_filter.count:
163+
msgs = msgs[-source_filter.count :]
164+
165+
result.extend(msgs)
166+
167+
return result
168+
169+
async def on_messages(
170+
self,
171+
messages: Sequence[BaseChatMessage],
172+
cancellation_token: CancellationToken,
173+
) -> Response:
174+
filtered = self._apply_filter(messages)
175+
return await self._wrapped_agent.on_messages(filtered, cancellation_token)
176+
177+
async def on_messages_stream(
178+
self,
179+
messages: Sequence[BaseChatMessage],
180+
cancellation_token: CancellationToken,
181+
) -> AsyncGenerator[Union[BaseAgentEvent, BaseChatMessage, Response], None]:
182+
filtered = self._apply_filter(messages)
183+
async for item in self._wrapped_agent.on_messages_stream(filtered, cancellation_token):
184+
yield item
185+
186+
async def on_reset(self, cancellation_token: CancellationToken) -> None:
187+
await self._wrapped_agent.on_reset(cancellation_token)
188+
189+
def _to_config(self) -> MessageFilterAgentConfig:
190+
return MessageFilterAgentConfig(
191+
name=self.name,
192+
wrapped_agent=self._wrapped_agent.dump_component(),
193+
filter=self._filter,
194+
)
195+
196+
@classmethod
197+
def _from_config(cls, config: MessageFilterAgentConfig) -> "MessageFilterAgent":
198+
wrapped = BaseChatAgent.load_component(config.wrapped_agent)
199+
return cls(
200+
name=config.name,
201+
wrapped_agent=wrapped,
202+
filter=config.filter,
203+
)

python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
"""
55

66
from ._group_chat._base_group_chat import BaseGroupChat
7+
from ._group_chat._graph import (
8+
DiGraph,
9+
DiGraphBuilder,
10+
DiGraphEdge,
11+
DiGraphNode,
12+
GraphFlow,
13+
)
714
from ._group_chat._magentic_one import MagenticOneGroupChat
815
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
916
from ._group_chat._selector_group_chat import SelectorGroupChat
@@ -15,4 +22,9 @@
1522
"SelectorGroupChat",
1623
"Swarm",
1724
"MagenticOneGroupChat",
25+
"DiGraphBuilder",
26+
"DiGraph",
27+
"DiGraphNode",
28+
"DiGraphEdge",
29+
"GraphFlow",
1830
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from ._digraph_group_chat import (
2+
DiGraph,
3+
DiGraphEdge,
4+
DiGraphNode,
5+
GraphFlow,
6+
GraphFlowManager,
7+
)
8+
from ._graph_builder import DiGraphBuilder
9+
10+
__all__ = [
11+
"GraphFlow",
12+
"DiGraph",
13+
"GraphFlowManager",
14+
"DiGraphNode",
15+
"DiGraphEdge",
16+
"DiGraphBuilder",
17+
]

0 commit comments

Comments
 (0)