Skip to content

Commit d1d43af

Browse files
committed
feat: add supervisor flow
1 parent 384588f commit d1d43af

File tree

4 files changed

+152
-3
lines changed

4 files changed

+152
-3
lines changed

veadk/agent.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import os
18-
from typing import Optional, Union, AsyncGenerator
18+
from typing import AsyncGenerator, Optional, Union
1919

2020
# If user didn't set LITELLM_LOCAL_MODEL_COST_MAP, set it to True
2121
# to enable local model cost map.
@@ -24,12 +24,15 @@
2424
if not os.getenv("LITELLM_LOCAL_MODEL_COST_MAP"):
2525
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
2626

27-
from google.adk.agents import LlmAgent, RunConfig, InvocationContext
27+
from google.adk.agents import InvocationContext, LlmAgent, RunConfig
2828
from google.adk.agents.base_agent import BaseAgent
2929
from google.adk.agents.context_cache_config import ContextCacheConfig
3030
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
3131
from google.adk.agents.run_config import StreamingMode
3232
from google.adk.events import Event, EventActions
33+
from google.adk.flows.llm_flows.auto_flow import AutoFlow
34+
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
35+
from google.adk.flows.llm_flows.single_flow import SingleFlow
3336
from google.adk.models.lite_llm import LiteLlm
3437
from google.adk.runners import Runner
3538
from google.genai import types
@@ -53,8 +56,8 @@
5356
from veadk.prompts.prompt_manager import BasePromptManager
5457
from veadk.tracing.base_tracer import BaseTracer
5558
from veadk.utils.logger import get_logger
56-
from veadk.utils.patches import patch_asyncio, patch_tracer
5759
from veadk.utils.misc import check_litellm_version
60+
from veadk.utils.patches import patch_asyncio, patch_tracer
5861
from veadk.version import VERSION
5962

6063
patch_tracer()
@@ -118,6 +121,8 @@ class Agent(LlmAgent):
118121

119122
enable_responses: bool = False
120123

124+
enable_shadow_agent: bool = False
125+
121126
context_cache_config: Optional[ContextCacheConfig] = None
122127

123128
run_processor: Optional[BaseRunProcessor] = Field(default=None, exclude=True)
@@ -292,6 +297,28 @@ def model_post_init(self, __context: Any) -> None:
292297
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}"
293298
)
294299

300+
@property
301+
def _llm_flow(self) -> BaseLlmFlow:
302+
if (
303+
self.disallow_transfer_to_parent
304+
and self.disallow_transfer_to_peers
305+
and not self.sub_agents
306+
):
307+
from veadk.flows.supervisor_single_flow import SupervisorSingleFlow
308+
309+
if self.enable_shadow_agent:
310+
logger.debug(f"Enable supervisor flow for agent: {self.name}")
311+
return SupervisorSingleFlow(supervised_agent=self)
312+
else:
313+
return SingleFlow()
314+
else:
315+
from veadk.flows.supervisor_auto_flow import SupervisorAutoFlow
316+
317+
if self.enable_shadow_agent:
318+
logger.debug(f"Enable supervisor flow for agent: {self.name}")
319+
return SupervisorAutoFlow(supervised_agent=self)
320+
return AutoFlow()
321+
295322
async def _run_async_impl(
296323
self, ctx: InvocationContext
297324
) -> AsyncGenerator[Event, None]:

veadk/agents/supervise_agent.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from google.adk.models.llm_request import LlmRequest
2+
from jinja2 import Template
3+
from pydantic import BaseModel
4+
5+
from veadk import Agent, Runner
6+
7+
8+
class SupervisorAgentOutput(BaseModel):
9+
advice: str = ""
10+
"""
11+
Advices for the worker agent.
12+
For example, suggested function call / actions / responses.
13+
"""
14+
15+
16+
instruction = Template("""You are a supervisor of an agent system. The system prompt of worker agent is:
17+
18+
```system prompt
19+
{{ system_prompt }}
20+
```
21+
22+
```worker agent tools
23+
{{ agent_tools }}
24+
```
25+
26+
You should guide the agent to finish task. If you think the history execution is not correct, you should give your advice to the worker agent. If you think the history execution is correct, you should output an empty string.
27+
28+
Your final response should be in `json` format.
29+
""")
30+
31+
32+
def build_supervisor(supervised_agent: Agent) -> Agent:
33+
custom_instruction = instruction.render(system_prompt=supervised_agent.instruction)
34+
agent = Agent(
35+
name="supervisor",
36+
description="",
37+
instruction=custom_instruction,
38+
output_schema=SupervisorAgentOutput,
39+
)
40+
41+
return agent
42+
43+
44+
async def generate_advice(agent: Agent, llm_request: LlmRequest) -> str:
45+
runner = Runner(agent=agent)
46+
47+
messages = ""
48+
for content in llm_request.contents:
49+
if content and content.parts:
50+
for part in content.parts:
51+
if part.text:
52+
messages += f"{content.role}: {part.text}"
53+
if part.function_call:
54+
messages += f"{content.role}: {part.function_call}"
55+
if part.function_response:
56+
messages += f"{content.role}: {part.function_response}"
57+
58+
return await runner.run(messages="History trajectory is: " + messages)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import AsyncGenerator
2+
3+
from google.adk.agents.invocation_context import InvocationContext
4+
from google.adk.events import Event
5+
from google.adk.models.llm_request import LlmRequest
6+
from google.adk.models.llm_response import LlmResponse
7+
from google.genai.types import Content, Part
8+
from typing_extensions import override
9+
10+
from veadk import Agent
11+
from veadk.agents.supervise_agent import generate_advice
12+
from veadk.flows.supervisor_single_flow import SupervisorSingleFlow
13+
14+
15+
class SupervisorAutoFlow(SupervisorSingleFlow):
16+
def __init__(self, supervised_agent: Agent):
17+
super().__init__(supervised_agent)
18+
19+
@override
20+
async def _call_llm_async(
21+
self,
22+
invocation_context: InvocationContext,
23+
llm_request: LlmRequest,
24+
model_response_event: Event,
25+
) -> AsyncGenerator[LlmResponse, None]:
26+
advice = await generate_advice(self._supervisor, llm_request)
27+
print(f"Advice: {advice}")
28+
29+
llm_request.contents.append(Content(parts=[Part(text=advice)], role="model"))
30+
31+
async for llm_response in super()._call_llm_async(
32+
invocation_context, llm_request, model_response_event
33+
):
34+
yield llm_response
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import AsyncGenerator
2+
3+
from google.adk.agents.invocation_context import InvocationContext
4+
from google.adk.events import Event
5+
from google.adk.flows.llm_flows.single_flow import SingleFlow
6+
from google.adk.models.llm_request import LlmRequest
7+
from google.adk.models.llm_response import LlmResponse
8+
from typing_extensions import override
9+
10+
from veadk import Agent
11+
from veadk.agents.supervise_agent import build_supervisor
12+
13+
14+
class SupervisorSingleFlow(SingleFlow):
15+
def __init__(self, supervised_agent: Agent):
16+
self._supervisor = build_supervisor(supervised_agent)
17+
18+
super().__init__()
19+
20+
@override
21+
async def _call_llm_async(
22+
self,
23+
invocation_context: InvocationContext,
24+
llm_request: LlmRequest,
25+
model_response_event: Event,
26+
) -> AsyncGenerator[LlmResponse, None]:
27+
async for llm_response in super()._call_llm_async(
28+
invocation_context, llm_request, model_response_event
29+
):
30+
yield llm_response

0 commit comments

Comments
 (0)