Skip to content

Commit 98c7176

Browse files
dorellangcopybara-github
authored andcommitted
Create dummy agent benchmark (to measure provisioning time only).
PiperOrigin-RevId: 941337523
1 parent c3944d8 commit 98c7176

3 files changed

Lines changed: 159 additions & 0 deletions

File tree

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Pirate Agent using ADK."""
2+
3+
import os
4+
from typing import Any, override
5+
import urllib.parse
6+
7+
from absl import logging
8+
import adk_utils
9+
import common_utils
10+
from google.adk.agents import llm_agent
11+
from google.cloud import storage
12+
13+
14+
class AgentHandler(adk_utils.AdkAgentHandler):
15+
"""Standard interface for a pirate agent written in ADK."""
16+
17+
@override
18+
def get_app_name_prefix(self) -> str:
19+
return "pirate_agent_adk"
20+
21+
@override
22+
def _create_agent(self, config: common_utils.AgentConfig) -> Any:
23+
return llm_agent.Agent(
24+
name="pirate_supervisor",
25+
description="A pirate agent.",
26+
instruction=(
27+
"You are a pirate. You must reply back to the user's prompt by"
28+
" translating it to pirate speak."
29+
),
30+
model=config.model,
31+
)
32+
33+
@override
34+
def export_results(
35+
self, output_dir: str, response_text: str, generic_metrics: dict[str, Any]
36+
) -> None:
37+
answer_file = "answer.txt"
38+
with open(answer_file, "w") as f:
39+
f.write(response_text)
40+
41+
target_answer_path = os.path.join(output_dir, answer_file)
42+
43+
if target_answer_path.startswith("gs://"):
44+
parsed_url = urllib.parse.urlparse(target_answer_path)
45+
bucket_name = parsed_url.netloc
46+
blob_path = parsed_url.path.strip("/")
47+
storage_client = storage.Client()
48+
bucket = storage_client.bucket(bucket_name)
49+
blob = bucket.blob(blob_path)
50+
blob.upload_from_filename(answer_file)
51+
logging.info("Successfully uploaded answer to %s", target_answer_path)
52+
53+
results = {
54+
"metrics": generic_metrics,
55+
"artifacts": {"answer": target_answer_path},
56+
"response": response_text,
57+
}
58+
target_path = os.path.join(output_dir, "results.json")
59+
common_utils.upload_dict_to_gcs(results, target_path)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[project]
2+
name = "pirate_agent_adk"
3+
version = "0.1.0"
4+
description = "Pirate ADK Agent"
5+
requires-python = ">=3.12"
6+
dependencies = [
7+
"absl-py",
8+
"cloudpickle",
9+
"google-adk",
10+
"google-cloud-aiplatform",
11+
"google-cloud-storage",
12+
"opentelemetry-sdk",
13+
"pydantic"
14+
]
15+
16+
[tool.setuptools]
17+
py-modules = [
18+
"pirate_agent_adk",
19+
"adk_utils",
20+
"common_utils",
21+
]
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Trivial agent benchmark which translates to pirate-speak.."""
2+
3+
from collections.abc import Iterable
4+
from typing import Any, override
5+
6+
from absl import flags
7+
from perfkitbenchmarker import ai_agent_benchmark_helper as agent
8+
from perfkitbenchmarker import benchmark_spec
9+
from perfkitbenchmarker import configs
10+
from perfkitbenchmarker import sample
11+
12+
BENCHMARK_NAME = 'pirate_agent'
13+
BENCHMARK_CONFIG = """
14+
pirate_agent:
15+
description: >
16+
Pirate Agent benchmark.
17+
flags:
18+
gcloud_scopes: cloud-platform
19+
vm_groups:
20+
clients:
21+
vm_spec: *default_dual_core
22+
vm_count: 1
23+
ai_agent_service:
24+
cloud: GCP
25+
deployment_type: agent_engine
26+
agent: pirate_agent
27+
framework: adk
28+
model: gemini-3-flash-preview
29+
"""
30+
31+
32+
FLAGS = flags.FLAGS
33+
34+
35+
def GetConfig(user_config: dict[str, Any]) -> dict[str, Any]:
36+
"""Loads and returns benchmark config."""
37+
return configs.LoadConfig(BENCHMARK_CONFIG, user_config, BENCHMARK_NAME)
38+
39+
40+
class PirateAgent(agent.BaseAgent):
41+
"""Agent for translating to pirate."""
42+
43+
@override
44+
def GetPrompts(self) -> Iterable[agent.Prompt]:
45+
"""Generates prompts for the agent."""
46+
prompt_text = 'Hello, world! Can you tell me a story about cloud computing?'
47+
return [agent.Prompt(id='default', session_id='default', text=prompt_text)]
48+
49+
@property
50+
@override
51+
def agent_name(self):
52+
return 'pirate_agent'
53+
54+
@override
55+
def UploadValidatorScript(self) -> None:
56+
pass
57+
58+
@override
59+
def RunValidationLogic(
60+
self,
61+
prompt: agent.Prompt,
62+
results: agent.PromptResults,
63+
) -> tuple[float, list[sample.Sample]]:
64+
return 1.0, []
65+
66+
67+
def Prepare(spec: benchmark_spec.BenchmarkSpec) -> None:
68+
"""Configures the VM before running."""
69+
PirateAgent.GetOrCreateFromSpec(spec).Prepare()
70+
71+
72+
def Run(spec: benchmark_spec.BenchmarkSpec) -> list[sample.Sample]:
73+
"""Runs the benchmark."""
74+
return PirateAgent.GetOrCreateFromSpec(spec).Run()
75+
76+
77+
def Cleanup(spec: benchmark_spec.BenchmarkSpec) -> None:
78+
"""Cleanup the resources."""
79+
PirateAgent.GetOrCreateFromSpec(spec).Cleanup()

0 commit comments

Comments
 (0)