-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcallbacks.py
More file actions
324 lines (274 loc) · 11.1 KB
/
callbacks.py
File metadata and controls
324 lines (274 loc) · 11.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""Callback utilities for HITL kernel generation agents."""
import json
import logging
import os
from typing import Any, Dict, Optional
from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse
from google.adk.tools import BaseTool, ToolContext
from hitl_agent.config import TPU_VERSION, WORKDIR
from hitl_agent.knowledge_base import pallas_docs, pallas_profiling_docs
def create_path_saver(state_key: str):
"""
Factory function that creates a callback to save file paths to a specific state key.
Args:
state_key: The key in tool_context.state where the file path will be saved.
Returns:
A callback function compatible with after_tool_callback signature.
"""
def save_path(
tool: BaseTool,
args: Dict[str, Any],
tool_context: ToolContext,
tool_response: Optional[Dict],
) -> Optional[Dict]:
# MCP filesystem tools may have different naming patterns
# Check for both snake_case and potential prefixed versions
if "read" in tool.name.lower() or "write" in tool.name.lower():
file_path = args.get("path", None)
if file_path:
tool_context.state[state_key] = file_path
logging.info(
f"Saved file path to {state_key}: {file_path} (from tool: {tool.name})"
)
return None
return save_path
def save_kernel_file_paths(
tool: BaseTool,
args: Dict[str, Any],
tool_context: ToolContext,
tool_response: Optional[Dict],
) -> Optional[Dict]:
"""
Saves kernel file paths with semantic naming based on read order.
First file read = base_kernel_path, Second file read = optimized_kernel_path.
This callback is used by agents that need to compare two kernels.
"""
if tool.name == "read_file":
file_path = args.get("path", None)
# If base_kernel_path not set, this is the first file (base)
if (
"base_kernel_path" not in tool_context.state
or not tool_context.state["base_kernel_path"]
):
tool_context.state["base_kernel_path"] = file_path
logging.info(f"Set base kernel path: {file_path}")
# Otherwise, this is the second file (optimized)
else:
tool_context.state["optimized_kernel_path"] = file_path
logging.info(f"Set optimized kernel path: {file_path}")
return None
def save_kernel_and_plan_paths(
tool: BaseTool,
args: Dict[str, Any],
tool_context: ToolContext,
tool_response: Optional[Dict],
) -> Optional[Dict]:
"""Saves both optimized_kernel_path and kernel_plan_path during implementation."""
if "read" in tool.name.lower() or "write" in tool.name.lower():
file_path = args.get("path", None)
if file_path:
# Check if this is a plan file based on filename or path
if "plan" in file_path.lower() and file_path.endswith(".md"):
tool_context.state["kernel_plan_path"] = file_path
logging.info(
f"Saved plan path to kernel_plan_path: {file_path} (from tool: {tool.name})"
)
# Otherwise assume it's the kernel file
else:
tool_context.state["optimized_kernel_path"] = file_path
logging.info(
f"Saved kernel path to optimized_kernel_path: {file_path} (from tool: {tool.name})"
)
return None
def load_single_kernel_to_state(callback_context: CallbackContext):
"""
Loads a single kernel file content into state.
Uses kernel_file_path to find the file.
Stores content in 'kernel_code' for use by compilation/profiling agents.
"""
file_path = callback_context.state.get("kernel_file_path", None)
if file_path:
try:
with open(file_path, "r") as f:
kernel_code = f.read()
callback_context.state["kernel_code"] = kernel_code
logging.info(f"Loaded kernel code from {file_path}")
except Exception as e:
logging.error(f"Failed to read kernel file: {e}")
callback_context.state["kernel_code"] = None
else:
logging.warning("No kernel file path found in state")
def load_profiling_script_to_state(callback_context: CallbackContext):
"""
Loads profiling script file content into state.
Uses profiling_script_path to find the file.
Stores content in 'profiling_script' for use by profiling execution agent.
"""
file_path = callback_context.state.get("profiling_script_path", None)
if file_path:
try:
with open(file_path, "r") as f:
profiling_code = f.read()
callback_context.state["profiling_script"] = profiling_code
logging.info(f"Loaded profiling script from {file_path}")
except Exception as e:
logging.error(f"Failed to read profiling script file: {e}")
callback_context.state["profiling_script"] = None
else:
logging.warning("No profiling script path found in state")
def load_two_kernels_to_state(callback_context: CallbackContext):
"""
Loads two kernel files (base and optimized) into state for comparison.
Reads from base_kernel_path and optimized_kernel_path.
Stores contents in base_kernel_code and optimized_kernel_code.
"""
base_path = callback_context.state.get("base_kernel_path", None)
optimized_path = callback_context.state.get("optimized_kernel_path", None)
if base_path:
try:
with open(base_path, "r") as f:
base_code = f.read()
callback_context.state["base_kernel_code"] = base_code
logging.info(f"Loaded base kernel code from {base_path}")
except Exception as e:
logging.error(f"Failed to read base kernel file: {e}")
callback_context.state["base_kernel_code"] = None
else:
logging.warning("No base kernel path found in state")
if optimized_path:
try:
with open(optimized_path, "r") as f:
optimized_code = f.read()
callback_context.state["optimized_kernel_code"] = optimized_code
logging.info(f"Loaded optimized kernel code from {optimized_path}")
except Exception as e:
logging.error(f"Failed to read optimized kernel file: {e}")
callback_context.state["optimized_kernel_code"] = None
else:
logging.warning("No optimized kernel path found in state")
def load_kernel_and_plan_to_state(callback_context: CallbackContext):
"""
Loads kernel file and optimization plan into state for compilation fixing.
Uses optimized_kernel_path and kernel_plan_path to find files.
Stores content in 'kernel_code' and 'kernel_plan' for use by fix agent.
Also formats compilation_history for better readability.
"""
# Load kernel code
kernel_path = callback_context.state.get("optimized_kernel_path", None)
if kernel_path and os.path.exists(kernel_path):
try:
with open(kernel_path, "r") as f:
kernel_code = f.read()
callback_context.state["kernel_code"] = kernel_code
logging.info(f"Loaded kernel code from {kernel_path}")
except Exception as e:
logging.error(f"Failed to read kernel file: {e}")
callback_context.state["kernel_code"] = None
else:
logging.warning("No kernel file path found in state or file does not exist")
callback_context.state["kernel_code"] = None
# Load optimization plan if exists
plan_path = callback_context.state.get("kernel_plan_path", None)
if plan_path and os.path.exists(plan_path):
try:
with open(plan_path, "r") as f:
kernel_plan = f.read()
callback_context.state["kernel_plan"] = kernel_plan
logging.info(f"Loaded optimization plan from {plan_path}")
except Exception as e:
logging.error(f"Failed to read plan file: {e}")
callback_context.state["kernel_plan"] = None
else:
logging.info(
"No optimization plan path found (this is okay for some workflows)"
)
callback_context.state["kernel_plan"] = None
# Format compilation history for readability
history = callback_context.state.get("compilation_history", [])
if history:
formatted_history = []
for record in history:
attempt_num = record.get("attempt", "?")
success = record.get("success", False)
result = record.get("result", "Unknown")
fix_summary = record.get("fix_summary", None)
status = "✓ SUCCESS" if success else "✗ FAILED"
formatted_history.append(f"**Attempt {attempt_num}:** {status}")
if not success:
# Include the fix that was attempted (if available)
if fix_summary:
formatted_history.append(f"Fix Applied: {fix_summary}")
formatted_history.append("") # Blank line
# Store formatted version in a separate key for the prompt
callback_context.state["compilation_history_formatted"] = "\n".join(
formatted_history
)
else:
callback_context.state["compilation_history_formatted"] = (
"No previous attempts (this is the first attempt)"
)
def get_tpu_version_callback(callback_context: CallbackContext):
"""Load TPU version and specifications into state."""
tpu_version = TPU_VERSION
callback_context.state["tpu_version"] = tpu_version
logging.info(f"Detected TPU version: {tpu_version}")
try:
with open("hitl_agent/tpu_specs.json", "r") as f:
tpu_specs = json.load(f)
if tpu_version in tpu_specs:
callback_context.state["tpu_specs"] = tpu_specs[tpu_version]
else:
callback_context.state["tpu_specs"] = (
"TPU specs not found for detected version."
)
logging.info(f"Loaded TPU specs for {tpu_version}")
except Exception as e:
logging.error(f"Failed to load TPU specs: {e}")
callback_context.state["tpu_specs"] = None
def add_workdir_callback(callback_context: CallbackContext):
"""Add working directory to state."""
session_id = callback_context.session.id
session_dir = os.path.join(WORKDIR, session_id)
os.makedirs(session_dir, exist_ok=True)
callback_context.state["workdir"] = session_dir
logging.info(f"Set working directory to: {session_dir}")
def extract_fix_summary(
callback_context: CallbackContext, llm_response: LlmResponse
) -> LlmResponse:
"""Extract the agent's response and store it as the fix summary.
This is an after_model_callback that receives the LlmResponse directly.
"""
if llm_response.content is None or not llm_response.content.parts:
logging.warning("No content in LlmResponse to extract fix summary from")
return llm_response
# Collect all text parts from the response
text_parts = []
for part in llm_response.content.parts:
if hasattr(part, "text") and part.text:
text_parts.append(part.text)
if text_parts:
# Join all text parts and store as fix summary
fix_summary = "\n".join(text_parts).strip()
callback_context.state["fix_summary"] = fix_summary
logging.info(f"Captured fix summary ({len(fix_summary)} chars)")
else:
logging.warning("No text parts found in LlmResponse")
return llm_response
def add_pallas_docs(callback_context: CallbackContext):
"""Adds the full Pallas documentation to the callback context."""
callback_context.state["pallas_docs"] = pallas_docs.PROMPT
callback_context.state["pallas_profiling_docs"] = pallas_profiling_docs.PROMPT
__all__ = [
"create_path_saver",
"save_kernel_file_paths",
"save_kernel_and_plan_paths",
"load_single_kernel_to_state",
"load_profiling_script_to_state",
"load_two_kernels_to_state",
"load_kernel_and_plan_to_state",
"get_tpu_version_callback",
"add_workdir_callback",
"extract_fix_summary",
"add_pallas_docs",
]