-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_agent.py
More file actions
365 lines (310 loc) · 13.2 KB
/
base_agent.py
File metadata and controls
365 lines (310 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Union
from google.adk.agents import LlmAgent
from google.adk.tools import FunctionTool
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
import structlog
from config.settings import settings
logger = structlog.get_logger(__name__)
class BaseUseCaseAgent(ABC):
"""
Abstract base class for all use case agents.
Why this design:
1. Enforces consistent structure across all use cases
2. Provides common functionality (context injection, logging)
3. Enables polymorphic usage in experiments
4. Follows ADK patterns established in your codebase
"""
def __init__(self, model: str, use_case: str, role: str):
self.model = model
self.use_case = use_case # e.g., "system_design"
self.role = role # e.g., "producer", "critic"
self.name = f"{use_case}_{role}"
# Initialize components
self.tools = self._initialize_tools()
self.agent = self._create_adk_agent()
logger.info(
"Agent initialized",
agent_name=self.name,
model=self.model,
use_case=self.use_case,
role=self.role,
tool_count=len(self.tools)
)
@abstractmethod
def _initialize_tools(self) -> List[BaseTool]:
"""
Initialize tools specific to this agent's role and use case.
Why abstract: Each agent needs different tools based on its purpose.
Examples:
- System design producer: pricing tools, architecture tools
- System design critic: security analysis, best practices checker
- Content producer: research tools, style guides
- Content critic: grammar checker, fact checker
"""
pass
@abstractmethod
def _get_instructions(self) -> str:
"""
Get agent-specific instructions.
Why abstract: Instructions must be tailored to:
1. Use case domain knowledge
2. Agent role (producer vs critic)
3. Expected output format
4. Quality standards
"""
pass
def _create_adk_agent(self) -> LlmAgent:
"""
Create the ADK LlmAgent instance following established patterns.
Why this pattern:
1. Consistent with your existing codebase
2. Proper ADK integration
3. Common callback handling
4. Error handling and logging
"""
try:
# Special handling for researcher role that needs Google Search
tools_to_use = self.tools
if self.role == "researcher":
from google.adk.tools import google_search
tools_to_use = [google_search] # Use Google Search for researcher
logger.info("Using Google Search tool for researcher agent")
agent = LlmAgent(
model=self.model,
name=self.name,
description=f"{self.use_case} {self.role} agent",
instruction=self._get_instructions(),
tools=tools_to_use,
before_tool_callback=self._before_tool_callback,
after_tool_callback=self._after_tool_callback
)
logger.info(
"ADK agent created successfully",
agent_name=self.name,
model=self.model
)
return agent
except Exception as e:
logger.error(
"Failed to create ADK agent",
agent_name=self.name,
model=self.model,
error=str(e)
)
raise
def _before_tool_callback(self, tool: BaseTool, args: Dict[str, Any],
tool_context: ToolContext) -> None:
"""
Inject common context into all tool calls.
Why this pattern (from your codebase):
1. Consistent context propagation
2. Research experiment tracking
3. Cost and usage monitoring
4. Authentication and session management
"""
# Research experiment tracking
experiment_id = tool_context.state.get("experiment:id")
if experiment_id:
args["experiment_id"] = experiment_id
# Iteration context for reflection research
iteration_count = tool_context.state.get("reflection:iteration", 0)
args["iteration_context"] = iteration_count
# Cost tracking for research
if settings.cost_tracking_enabled:
args["track_usage"] = True
args["use_case"] = self.use_case
args["agent_role"] = self.role
args["model"] = self.model
# Session and auth context (similar to your codebase)
session_id = tool_context.state.get("session:session_id")
if session_id:
args["session_id"] = session_id
# User context for personalization
user_context = tool_context.state.get("user:context", {})
if user_context:
args["user_context"] = user_context
logger.debug(
"Tool context injected",
tool_name=tool.name,
agent_name=self.name,
experiment_id=experiment_id,
iteration=iteration_count
)
def _after_tool_callback(self, *args, **kwargs) -> None:
"""
Process tool execution results and handle any post-execution logic.
This callback allows us to:
1. Log tool execution results
2. Transform or validate tool outputs
3. Track tool performance and costs
4. Handle tool execution errors
"""
# Extract parameters flexibly
tool = args[0] if len(args) > 0 else kwargs.get('tool')
result = args[1] if len(args) > 1 else kwargs.get('result')
tool_context = args[2] if len(args) > 2 else kwargs.get('tool_context')
logger.info(
"Tool execution completed",
tool_name=getattr(tool, 'name', 'unknown') if tool else 'unknown',
agent_name=self.name,
result_type=type(result).__name__ if result is not None else 'None',
success=result is not None,
args_received=len(args),
kwargs_received=list(kwargs.keys())
)
# Log the actual result for debugging (truncate if too long)
if result is not None:
result_str = str(result)
if len(result_str) > 200:
result_str = result_str[:200] + "..."
logger.debug(
"Tool execution result",
tool_name=getattr(tool, 'name', 'unknown') if tool else 'unknown',
result=result_str
)
# Track tool performance for research
if tool_context and hasattr(tool_context, 'state'):
tool_executions = tool_context.state.get('tool_executions', [])
tool_executions.append({
'tool_name': getattr(tool, 'name', 'unknown') if tool else 'unknown',
'agent_name': self.name,
'success': result is not None,
'result_type': type(result).__name__ if result is not None else 'None'
})
tool_context.state['tool_executions'] = tool_executions
async def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the agent with input data using proper ADK patterns.
Why this wrapper:
1. Consistent interface across all agents
2. Error handling and logging
3. Performance monitoring
4. Result standardization
"""
try:
logger.info(
"Agent execution started",
agent_name=self.name,
input_keys=list(input_data.keys())
)
# Import proper ADK components
from google.adk.runners import Runner
from google.adk.sessions import DatabaseSessionService
from google.genai import types
from config.settings import settings
# Create session service (using DatabaseSessionService as recommended)
session_service = DatabaseSessionService(db_url=settings.database_url or "sqlite:///research.db")
# Create session
user_id = "research_user"
session = await session_service.create_session(
app_name=self.name,
user_id=user_id
)
# Create runner with proper app_name matching
runner = Runner(
app_name=self.name,
agent=self.agent,
session_service=session_service
)
# Create proper ADK Content object
message_text = input_data.get("input", str(input_data))
content = types.Content(role="user", parts=[types.Part(text=message_text)])
# Execute agent
invocation_events = []
async for evt in runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=content
):
invocation_events.append(evt)
# Extract final response using proper event handling
final_response = self._extract_text_from_events(invocation_events)
result = {
"response": final_response,
"session_id": session.id,
"events_count": len(invocation_events),
"agent_name": self.name
}
logger.info(
"Agent execution completed",
agent_name=self.name,
response_length=len(final_response) if final_response else 0,
events_count=len(invocation_events)
)
return result
except Exception as e:
logger.error(
"Agent execution failed",
agent_name=self.name,
error=str(e),
input_data=input_data
)
raise
def _extract_text_from_events(self, events: List[Any]) -> Optional[str]:
"""
Extract text content from ADK events, handling function calls properly.
Improved to handle function call responses and combine multiple text parts.
"""
def extract_text_from_event(evt) -> Optional[str]:
content = getattr(evt, "content", None)
if content is None:
return None
parts = getattr(content, "parts", None)
if not parts:
return None
# Extract text from all parts, not just the first
text_parts = []
for part in parts:
# Skip function calls but include text parts
if getattr(part, "function_call", None) is None:
text = getattr(part, "text", None)
if text:
text_parts.append(text)
return " ".join(text_parts) if text_parts else None
# Extract and combine all text responses from events
text_responses = []
for evt in events:
text = extract_text_from_event(evt)
if text:
text_responses.append(text)
# If we have no text responses but we have function calls,
# create a summary response
if not text_responses:
function_calls = []
for evt in events:
content = getattr(evt, "content", None)
if content:
parts = getattr(content, "parts", None)
if parts:
for part in parts:
func_call = getattr(part, "function_call", None)
if func_call:
func_name = getattr(func_call, "name", "unknown_function")
function_calls.append(func_name)
if function_calls:
return f"Successfully executed tools: {', '.join(function_calls)}. The function calls completed successfully."
# Return combined text response or a default message
if text_responses:
return " ".join(text_responses)
else:
return "No text response generated from the agent"
def update_model(self, new_model: str) -> None:
"""
Update the model for this agent.
Why needed: Research experiments test different models
on the same agent configuration.
"""
if new_model not in settings.available_models:
raise ValueError(f"Model {new_model} not in available models: {settings.available_models}")
old_model = self.model
self.model = new_model
# Recreate ADK agent with new model
self.agent = self._create_adk_agent()
logger.info(
"Agent model updated",
agent_name=self.name,
old_model=old_model,
new_model=new_model
)