Skip to content

Commit d8ff5d6

Browse files
committed
feat: add agent local examples
1 parent 0777012 commit d8ff5d6

File tree

2 files changed

+137
-158
lines changed

2 files changed

+137
-158
lines changed

veadk/agent.py

Lines changed: 18 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,12 @@
2626

2727
import uuid
2828

29-
from google.adk.agents import LlmAgent, RunConfig
29+
from google.adk.agents import LlmAgent
3030
from google.adk.agents.base_agent import BaseAgent
3131
from google.adk.agents.context_cache_config import ContextCacheConfig
3232
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
33-
from google.adk.agents.run_config import StreamingMode
33+
from google.adk.examples.base_example_provider import BaseExampleProvider
3434
from google.adk.models.lite_llm import LiteLlm
35-
from google.adk.runners import Runner
36-
from google.genai import types
3735
from pydantic import ConfigDict, Field
3836
from typing_extensions import Any
3937

@@ -42,7 +40,6 @@
4240
DEFAULT_AGENT_NAME,
4341
DEFAULT_MODEL_EXTRA_CONFIG,
4442
)
45-
from veadk.evaluation import EvalSetRecorder
4643
from veadk.knowledgebase import KnowledgeBase
4744
from veadk.memory.long_term_memory import LongTermMemory
4845
from veadk.memory.short_term_memory import ShortTermMemory
@@ -87,6 +84,8 @@ class Agent(LlmAgent):
8784
tracers (list[BaseTracer]): List of tracers used for telemetry and monitoring.
8885
enable_authz (bool): Whether to enable agent authorization checks.
8986
auto_save_session (bool): Whether to automatically save sessions to long-term memory.
87+
skills (list[str]): List of skills that equip the agent with specific capabilities.
88+
example_store (Optional[BaseExampleProvider]): Example store for providing example Q/A.
9089
"""
9190

9291
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
@@ -147,6 +146,8 @@ class Agent(LlmAgent):
147146

148147
skills: list[str] = Field(default_factory=list)
149148

149+
example_store: Optional[BaseExampleProvider] = None
150+
150151
def model_post_init(self, __context: Any) -> None:
151152
super().model_post_init(None) # for sub_agents init
152153

@@ -282,6 +283,11 @@ def model_post_init(self, __context: Any) -> None:
282283
if self.skills:
283284
self.load_skills()
284285

286+
if self.example_store:
287+
from google.adk.tools.example_tool import ExampleTool
288+
289+
self.tools.append(ExampleTool(examples=self.example_store))
290+
285291
logger.info(f"VeADK version: {VERSION}")
286292

287293
logger.info(f"{self.__class__.__name__} `{self.name}` init done.")
@@ -297,18 +303,19 @@ def update_model(self, model_name: str):
297303

298304
def load_skills(self):
299305
from pathlib import Path
306+
300307
from veadk.skills.skill import Skill
301308
from veadk.skills.utils import (
302-
load_skills_from_directory,
303309
load_skills_from_cloud,
310+
load_skills_from_directory,
304311
)
305312
from veadk.tools.builtin_tools.playwright import playwright_tools
306313
from veadk.tools.skills_tools import (
307314
SkillsTool,
315+
bash_tool,
316+
edit_file_tool,
308317
read_file_tool,
309318
write_file_tool,
310-
edit_file_tool,
311-
bash_tool,
312319
)
313320

314321
skills: Dict[str, Skill] = {}
@@ -338,61 +345,6 @@ def load_skills(self):
338345
self.tools.append(bash_tool)
339346
self.tools.append(playwright_tools)
340347

341-
async def _run(
342-
self,
343-
runner,
344-
user_id: str,
345-
session_id: str,
346-
message: types.Content,
347-
stream: bool,
348-
run_processor: Optional[BaseRunProcessor] = None,
349-
):
350-
"""Internal run method with run processor support.
351-
352-
Args:
353-
runner: The Runner instance.
354-
user_id: User ID for the session.
355-
session_id: Session ID.
356-
message: The message to send.
357-
stream: Whether to stream the output.
358-
run_processor: Optional run processor to use. If not provided, uses self.run_processor.
359-
360-
Returns:
361-
The final output string.
362-
"""
363-
stream_mode = StreamingMode.SSE if stream else StreamingMode.NONE
364-
365-
# Use provided run_processor or fall back to instance's run_processor
366-
processor = run_processor or self.run_processor
367-
368-
@processor.process_run(runner=runner, message=message)
369-
async def event_generator():
370-
async for event in runner.run_async(
371-
user_id=user_id,
372-
session_id=session_id,
373-
new_message=message,
374-
run_config=RunConfig(streaming_mode=stream_mode),
375-
):
376-
if event.get_function_calls():
377-
for function_call in event.get_function_calls():
378-
logger.debug(f"Function call: {function_call}")
379-
elif (
380-
event.content is not None
381-
and event.content.parts[0].text is not None
382-
and len(event.content.parts[0].text.strip()) > 0
383-
):
384-
yield event.content.parts[0].text
385-
386-
final_output = ""
387-
async for chunk in event_generator():
388-
if stream:
389-
print(chunk, end="", flush=True)
390-
final_output += chunk
391-
if stream:
392-
print() # end with a new line
393-
394-
return final_output
395-
396348
def _prepare_tracers(self):
397349
enable_apmplus_tracer = os.getenv("ENABLE_APMPLUS", "false").lower() == "true"
398350
enable_cozeloop_tracer = os.getenv("ENABLE_COZELOOP", "false").lower() == "true"
@@ -439,99 +391,7 @@ def _prepare_tracers(self):
439391
f"Opentelemetry Tracer init {len(self.tracers[0].exporters)} exporters" # type: ignore
440392
)
441393

442-
async def run(
443-
self,
444-
prompt: str | list[str],
445-
stream: bool = False,
446-
app_name: str = "veadk_app",
447-
user_id: str = "veadk_user",
448-
session_id="veadk_session",
449-
load_history_sessions_from_db: bool = False,
450-
db_url: str = "",
451-
collect_runtime_data: bool = False,
452-
eval_set_id: str = "",
453-
save_session_to_memory: bool = False,
454-
run_processor: Optional[BaseRunProcessor] = None,
455-
):
456-
"""Running the agent. The runner and session service will be created automatically.
457-
458-
For production, consider using Google-ADK runner to run agent, rather than invoking this method.
459-
460-
Args:
461-
prompt (str | list[str]): The prompt to run the agent.
462-
stream (bool, optional): Whether to stream the output. Defaults to False.
463-
app_name (str, optional): The name of the application. Defaults to "veadk_app".
464-
user_id (str, optional): The id of the user. Defaults to "veadk_user".
465-
session_id (str, optional): The id of the session. Defaults to "veadk_session".
466-
load_history_sessions_from_db (bool, optional): Whether to load history sessions from database. Defaults to False.
467-
db_url (str, optional): The url of the database. Defaults to "".
468-
collect_runtime_data (bool, optional): Whether to collect runtime data. Defaults to False.
469-
eval_set_id (str, optional): The id of the eval set. Defaults to "".
470-
save_session_to_memory (bool, optional): Whether to save this turn session to memory. Defaults to False.
471-
run_processor (Optional[BaseRunProcessor], optional): Optional run processor to use for this run.
472-
If not provided, uses the agent's default run_processor. Defaults to None.
473-
"""
474-
475-
logger.warning(
476-
"Running agent in this function is only for development and testing, do not use this function in production. For production, consider using `Google ADK Runner` to run agent, rather than invoking this method."
477-
)
478-
logger.info(
479-
f"Run agent {self.name}: app_name: {app_name}, user_id: {user_id}, session_id: {session_id}."
480-
)
481-
prompt = [prompt] if isinstance(prompt, str) else prompt
482-
483-
# memory service
484-
short_term_memory = ShortTermMemory(
485-
backend="database" if load_history_sessions_from_db else "local",
486-
db_url=db_url,
394+
async def run(self, **kwargs):
395+
raise NotImplementedError(
396+
"Run method in VeADK agent is deprecated since version 0.5.6. Please use runner.run_async instead. Ref: https://agentkit.gitbook.io/docs/runner/overview"
487397
)
488-
session_service = short_term_memory.session_service
489-
await short_term_memory.create_session(
490-
app_name=app_name, user_id=user_id, session_id=session_id
491-
)
492-
493-
# runner
494-
runner = Runner(
495-
agent=self,
496-
app_name=app_name,
497-
session_service=session_service,
498-
memory_service=self.long_term_memory,
499-
)
500-
501-
logger.info(f"Begin to process prompt {prompt}")
502-
# run
503-
final_output = ""
504-
for _prompt in prompt:
505-
message = types.Content(role="user", parts=[types.Part(text=_prompt)])
506-
final_output = await self._run(
507-
runner, user_id, session_id, message, stream, run_processor
508-
)
509-
510-
# VeADK features
511-
if save_session_to_memory:
512-
assert self.long_term_memory is not None, (
513-
"Long-term memory is not initialized in agent"
514-
)
515-
session = await session_service.get_session(
516-
app_name=app_name,
517-
user_id=user_id,
518-
session_id=session_id,
519-
)
520-
if session:
521-
await self.long_term_memory.add_session_to_memory(session)
522-
logger.info(f"Add session `{session.id}` to your long-term memory.")
523-
else:
524-
logger.error(
525-
f"Session {session_id} not found in session service, cannot save to long-term memory."
526-
)
527-
528-
if collect_runtime_data:
529-
eval_set_recorder = EvalSetRecorder(session_service, eval_set_id)
530-
dump_path = await eval_set_recorder.dump(app_name, user_id, session_id)
531-
self._dump_path = dump_path # just for test/debug/instrumentation
532-
533-
if self.tracers:
534-
for tracer in self.tracers:
535-
tracer.dump(user_id=user_id, session_id=session_id)
536-
537-
return final_output
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any
16+
17+
from google.adk.examples.base_example_provider import BaseExampleProvider
18+
from google.adk.examples.example import Example as ADKExample
19+
from google.genai.types import Content, FunctionCall, Part
20+
from pydantic import BaseModel, Field
21+
from typing_extensions import override
22+
23+
24+
class ExampleFunctionCall(BaseModel):
25+
function_name: str
26+
arguments: dict[str, Any] = Field(default_factory=dict)
27+
28+
29+
class Example(BaseModel):
30+
input: str
31+
expected_output: str | None
32+
expected_function_call: ExampleFunctionCall | None
33+
34+
35+
class InMemoryExampleStore(BaseExampleProvider):
36+
def __init__(
37+
self,
38+
name: str = "in_memory_example_store",
39+
examples: list[Example | ADKExample] | None = None,
40+
):
41+
self.name = name
42+
if examples:
43+
self.examples: list[ADKExample] = self.convert_examples_to_adk_examples(
44+
examples
45+
)
46+
else:
47+
self.examples: list[ADKExample] = []
48+
49+
def add_example(self, example: Example | ADKExample):
50+
"""Add an example to the provider.
51+
52+
Args:
53+
example: A VeADK example or ADK example.
54+
"""
55+
self.examples.append(self.convert_examples_to_adk_examples([example])[0])
56+
57+
def convert_examples_to_adk_examples(
58+
self,
59+
examples: list[Example | ADKExample],
60+
) -> list[ADKExample]:
61+
"""Convert VeADK example to ADK example.
62+
63+
Args:
64+
examples: A list of VeADK example or ADK example.
65+
66+
Returns:
67+
A list of ADK example.
68+
"""
69+
adk_examples = []
70+
for example in examples:
71+
if isinstance(example, ADKExample):
72+
adk_examples.append(example)
73+
else:
74+
output_string_content = (
75+
Content(parts=[Part(text=example.expected_output)], role="model")
76+
if example.expected_output
77+
else None
78+
)
79+
output_fc_content = (
80+
Content(
81+
parts=[
82+
Part(
83+
function_call=FunctionCall(
84+
name=example.expected_function_call.function_name,
85+
args=example.expected_function_call.arguments,
86+
)
87+
)
88+
],
89+
role="model",
90+
)
91+
if example.expected_function_call
92+
else None
93+
)
94+
95+
output = []
96+
if output_string_content:
97+
output.append(output_string_content)
98+
if output_fc_content:
99+
output.append(output_fc_content)
100+
101+
adk_examples.append(
102+
ADKExample(
103+
input=Content(parts=[Part(text=example.input)], role="user"),
104+
output=output,
105+
)
106+
)
107+
return adk_examples
108+
109+
@override
110+
def get_examples(self, query: str) -> list[ADKExample]:
111+
"""Simply return all examples.
112+
113+
Args:
114+
query: The query to get examples for.
115+
116+
Returns:
117+
A list of Example objects.
118+
"""
119+
return self.examples

0 commit comments

Comments
 (0)