diff --git a/perfkitbenchmarker/data/agents/pirate_agent/adk/pirate_agent_adk.py b/perfkitbenchmarker/data/agents/pirate_agent/adk/pirate_agent_adk.py new file mode 100644 index 0000000000..432ea7c649 --- /dev/null +++ b/perfkitbenchmarker/data/agents/pirate_agent/adk/pirate_agent_adk.py @@ -0,0 +1,59 @@ +"""Pirate Agent using ADK.""" + +import os +from typing import Any, override +import urllib.parse + +from absl import logging +import adk_utils +import common_utils +from google.adk.agents import llm_agent +from google.cloud import storage + + +class AgentHandler(adk_utils.AdkAgentHandler): + """Standard interface for a pirate agent written in ADK.""" + + @override + def get_app_name_prefix(self) -> str: + return "pirate_agent_adk" + + @override + def _create_agent(self, config: common_utils.AgentConfig) -> Any: + return llm_agent.Agent( + name="pirate_supervisor", + description="A pirate agent.", + instruction=( + "You are a pirate. You must reply back to the user's prompt by" + " translating it to pirate speak." + ), + model=config.model, + ) + + @override + def export_results( + self, output_dir: str, response_text: str, generic_metrics: dict[str, Any] + ) -> None: + answer_file = "answer.txt" + with open(answer_file, "w") as f: + f.write(response_text) + + target_answer_path = os.path.join(output_dir, answer_file) + + if target_answer_path.startswith("gs://"): + parsed_url = urllib.parse.urlparse(target_answer_path) + bucket_name = parsed_url.netloc + blob_path = parsed_url.path.strip("/") + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(blob_path) + blob.upload_from_filename(answer_file) + logging.info("Successfully uploaded answer to %s", target_answer_path) + + results = { + "metrics": generic_metrics, + "artifacts": {"answer": target_answer_path}, + "response": response_text, + } + target_path = os.path.join(output_dir, "results.json") + common_utils.upload_dict_to_gcs(results, target_path) diff --git a/perfkitbenchmarker/data/agents/pirate_agent/adk/pyproject.toml b/perfkitbenchmarker/data/agents/pirate_agent/adk/pyproject.toml new file mode 100644 index 0000000000..09b789b7f5 --- /dev/null +++ b/perfkitbenchmarker/data/agents/pirate_agent/adk/pyproject.toml @@ -0,0 +1,21 @@ +[project] +name = "pirate_agent_adk" +version = "0.1.0" +description = "Pirate ADK Agent" +requires-python = ">=3.12" +dependencies = [ + "absl-py", + "cloudpickle", + "google-adk", + "google-cloud-aiplatform", + "google-cloud-storage", + "opentelemetry-sdk", + "pydantic" +] + +[tool.setuptools] +py-modules = [ + "pirate_agent_adk", + "adk_utils", + "common_utils", +] diff --git a/perfkitbenchmarker/linux_benchmarks/pirate_agent_benchmark.py b/perfkitbenchmarker/linux_benchmarks/pirate_agent_benchmark.py new file mode 100644 index 0000000000..4118777f3d --- /dev/null +++ b/perfkitbenchmarker/linux_benchmarks/pirate_agent_benchmark.py @@ -0,0 +1,79 @@ +"""Trivial agent benchmark which translates to pirate-speak..""" + +from collections.abc import Iterable +from typing import Any, override + +from absl import flags +from perfkitbenchmarker import ai_agent_benchmark_helper as agent +from perfkitbenchmarker import benchmark_spec +from perfkitbenchmarker import configs +from perfkitbenchmarker import sample + +BENCHMARK_NAME = 'pirate_agent' +BENCHMARK_CONFIG = """ +pirate_agent: + description: > + Pirate Agent benchmark. + flags: + gcloud_scopes: cloud-platform + vm_groups: + clients: + vm_spec: *default_dual_core + vm_count: 1 + ai_agent_service: + cloud: GCP + deployment_type: agent_engine + agent: pirate_agent + framework: adk + model: gemini-3-flash-preview +""" + + +FLAGS = flags.FLAGS + + +def GetConfig(user_config: dict[str, Any]) -> dict[str, Any]: + """Loads and returns benchmark config.""" + return configs.LoadConfig(BENCHMARK_CONFIG, user_config, BENCHMARK_NAME) + + +class PirateAgent(agent.BaseAgent): + """Agent for translating to pirate.""" + + @override + def GetPrompts(self) -> Iterable[agent.Prompt]: + """Generates prompts for the agent.""" + prompt_text = 'Hello, world! Can you tell me a story about cloud computing?' + return [agent.Prompt(id='default', session_id='default', text=prompt_text)] + + @property + @override + def agent_name(self): + return 'pirate_agent' + + @override + def UploadValidatorScript(self) -> None: + pass + + @override + def RunValidationLogic( + self, + prompt: agent.Prompt, + results: agent.PromptResults, + ) -> tuple[float, list[sample.Sample]]: + return 1.0, [] + + +def Prepare(spec: benchmark_spec.BenchmarkSpec) -> None: + """Configures the VM before running.""" + PirateAgent.GetOrCreateFromSpec(spec).Prepare() + + +def Run(spec: benchmark_spec.BenchmarkSpec) -> list[sample.Sample]: + """Runs the benchmark.""" + return PirateAgent.GetOrCreateFromSpec(spec).Run() + + +def Cleanup(spec: benchmark_spec.BenchmarkSpec) -> None: + """Cleanup the resources.""" + PirateAgent.GetOrCreateFromSpec(spec).Cleanup()