diff --git a/MaxKernel/hitl_agent/agent.py b/MaxKernel/hitl_agent/agent.py index b34fe68..b079de1 100644 --- a/MaxKernel/hitl_agent/agent.py +++ b/MaxKernel/hitl_agent/agent.py @@ -25,6 +25,7 @@ plan_kernel_agent, validate_kernel_compilation_agent, ) +from hitl_agent.subagents.autotuning.agent import autotune_agent from hitl_agent.subagents.profiling import profile_agent from hitl_agent.subagents.testing import ( unified_test_agent, @@ -52,6 +53,7 @@ unified_test_agent, # Step 5: Run tests and provide summary profile_agent, # Step 6: Profile for bottlenecks gpu_to_jax_agent, # GPU-to-JAX conversion + autotune_agent, # Step 7: Auto-tune kernel ], tools=[ filesystem_tool_r diff --git a/MaxKernel/hitl_agent/dependency/agent_requirements.txt b/MaxKernel/hitl_agent/dependency/agent_requirements.txt index dab5d1b..1e80d1d 100644 --- a/MaxKernel/hitl_agent/dependency/agent_requirements.txt +++ b/MaxKernel/hitl_agent/dependency/agent_requirements.txt @@ -5,4 +5,5 @@ google-cloud-aiplatform[adk,agent_engines] cloudpickle xprof tpu-info -pytest-asyncio \ No newline at end of file +pytest-asyncio +requests \ No newline at end of file diff --git a/MaxKernel/hitl_agent/prompts/interactive_prompt.py b/MaxKernel/hitl_agent/prompts/interactive_prompt.py index 2f8ca6e..e790952 100644 --- a/MaxKernel/hitl_agent/prompts/interactive_prompt.py +++ b/MaxKernel/hitl_agent/prompts/interactive_prompt.py @@ -18,6 +18,7 @@ * **GenerateTestFileAgent**: Generates a comprehensive pytest test file with compilation, correctness, and performance tests for kernel files. * **UnifiedTestAgent**: Executes the generated pytest test file on TPU and provides comprehensive results including full tracebacks. Automatically manages server lifecycle (starts/stops TPU and eval servers as needed). * **ProfileAgentOrchestrator**: Profiles a kernel to identify performance bottlenecks (DMAs, memory transfers, compute ratios). + * **AutotuneAgent**: Auto-tunes Pallas kernels by searching over parameter spaces (like block sizes) to minimize execution time. * **GpuToJaxAgent**: Converts/writes GPU code (CUDA/Triton/PyTorch) to JAX/Pallas. ### Your Reasoning Process @@ -77,6 +78,9 @@ * **Action**: Delegate to `ProfileAgentOrchestrator` to generate and run profiling scripts. * **Note**: This identifies performance bottlenecks like memory transfers vs compute ratios. + * **If the request is to AUTO-TUNE a kernel** (like "Autotune kernel.py", "Search for best parameters", "Optimize block sizes"): + * **Action**: Delegate to `AutotuneAgent` to perform grid search. + * **If the request is GPU-to-JAX conversion**: * **Action**: Delegate to `GpuToJaxAgent` (it handles its own plan-approve-implement workflow). @@ -90,6 +94,7 @@ 4. **Validation Phase (optional)**: User requests validation → You delegate to `ValidateKernelCompilationAgent` → Compilation validated/fixed → **Return control to user** 5. **Test Generation Phase (optional)**: User requests tests → You delegate to `GenerateTestFileAgent` → Test file generated → **Return control to user** 6. **Test Execution Phase (optional)**: User requests test execution → You delegate to `UnifiedTestAgent` → Tests run → **Return control to user** +7. **Autotune Phase (optional)**: User requests auto-tuning → You delegate to `AutotuneAgent` → Parameters optimized → **Return control to user** **Remember**: After ANY agent completes (planning, implementation, testing, profiling, etc.), immediately return control. The user decides the next step, not you. @@ -152,6 +157,12 @@ You: Delegate to ProfileAgentOrchestrator → Profiling analysis complete → [END TURN, wait for user] ``` +**Example 7: Autotuning** +``` +User: "Autotune optimized_kernel.py" +You: Delegate to AutotuneAgent → Grid search complete with best config → [END TURN, wait for user] +``` + **ANTI-PATTERN - NEVER DO THIS:** ``` ❌ WRONG: diff --git a/MaxKernel/hitl_agent/server_utils/cpu_server.py b/MaxKernel/hitl_agent/server_utils/cpu_server.py index ec44f3c..8705a07 100644 --- a/MaxKernel/hitl_agent/server_utils/cpu_server.py +++ b/MaxKernel/hitl_agent/server_utils/cpu_server.py @@ -4,6 +4,8 @@ import os import sys import tempfile +import itertools +import re from typing import Optional from fastapi import FastAPI, HTTPException @@ -38,6 +40,12 @@ class CodeResponse(BaseModel): exit_code: int +class AutotuneRequest(BaseModel): + code_template: str + search_space: dict[str, list] + timeout: Optional[int] = 30 + + class GetBackendVersionResponse(BaseModel): backend_version: str @@ -306,6 +314,99 @@ async def performance_test(request: CodeRequest): logging.info("Performance test finished") +@app.post("/autotune", response_model=CodeResponse) +async def autotune(request: AutotuneRequest): + logging.info("Starting autotune on CPU backend") + async with performance_semaphore: + try: + # Generate all combinations + keys = list(request.search_space.keys()) + values = list(request.search_space.values()) + combinations = list(itertools.product(*values)) + + best_time = float("inf") + best_cfg = None + best_output = "" + + for combo in combinations: + cfg = dict(zip(keys, combo)) + try: + code_content = request.code_template.format(**cfg) + except KeyError as e: + logging.error(f"KeyError during template formatting: {e}. Config: {cfg}") + continue + + # Execute the code + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False + ) as temp_file: + temp_file.write(code_content) + temp_file_path = temp_file.name + + try: + process = await asyncio.create_subprocess_exec( + sys.executable, + temp_file_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=tempfile.gettempdir(), + env=get_cpu_env(), # Force CPU backend + ) + + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=request.timeout + ) + + output = stdout.decode("utf-8") if stdout else "" + error = stderr.decode("utf-8") if stderr else "" + exit_code = process.returncode + + if exit_code == 0: + # Parse RESULT_TIME + match = re.search(r"RESULT_TIME:\s*([0-9.]+)", output) + if match: + time_taken = float(match.group(1)) + if time_taken < best_time: + best_time = time_taken + best_cfg = cfg + best_output = output + else: + logging.warning(f"No RESULT_TIME found in output for config {cfg}") + else: + logging.warning(f"Config {cfg} failed with exit code {exit_code}. Stderr: {error}") + + except asyncio.TimeoutError: + logging.warning(f"Config {cfg} timed out") + process.kill() + await process.wait() + except Exception as e: + logging.error(f"Error running config {cfg}: {e}") + finally: + try: + os.unlink(temp_file_path) + except OSError: + pass + + if best_cfg is None: + return CodeResponse( + output="", + error="No successful configuration found during autotune.", + exit_code=-1, + ) + + output_data = { + "best_cfg": best_cfg, + "best_time": best_time, + "best_output": best_output, + } + logging.info("Autotune finished on CPU backend") + return CodeResponse(output=json.dumps(output_data), error=None, exit_code=0) + + except Exception as e: + logging.error(f"Autotune failed with error: {str(e)}") + raise HTTPException(status_code=500, detail=f"Autotune error: {str(e)}") + + @app.post("/profile", response_model=CodeResponse) async def profile(request: CodeRequest): logging.info("Starting profile on CPU backend") diff --git a/MaxKernel/hitl_agent/server_utils/tpu_server.py b/MaxKernel/hitl_agent/server_utils/tpu_server.py index b73d2b5..f553411 100644 --- a/MaxKernel/hitl_agent/server_utils/tpu_server.py +++ b/MaxKernel/hitl_agent/server_utils/tpu_server.py @@ -6,6 +6,7 @@ import subprocess import sys import tempfile +import itertools from typing import Optional from fastapi import FastAPI, HTTPException @@ -40,6 +41,12 @@ class CodeResponse(BaseModel): exit_code: int +class AutotuneRequest(BaseModel): + code_template: str + search_space: dict[str, list] + timeout: Optional[int] = 30 + + class GetTpuVersionResponse(BaseModel): tpu_version: str @@ -293,6 +300,98 @@ async def performance_test(request: CodeRequest): logging.info("Performance test finished") +@app.post("/autotune", response_model=CodeResponse) +async def autotune(request: AutotuneRequest): + logging.info("Starting autotune") + async with performance_semaphore: + try: + # Generate all combinations + keys = list(request.search_space.keys()) + values = list(request.search_space.values()) + combinations = list(itertools.product(*values)) + + best_time = float("inf") + best_cfg = None + best_output = "" + + for combo in combinations: + cfg = dict(zip(keys, combo)) + try: + code_content = request.code_template.format(**cfg) + except KeyError as e: + logging.error(f"KeyError during template formatting: {e}. Config: {cfg}") + continue + + # Execute the code + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False + ) as temp_file: + temp_file.write(code_content) + temp_file_path = temp_file.name + + try: + process = await asyncio.create_subprocess_exec( + sys.executable, + temp_file_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=tempfile.gettempdir(), + ) + + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=request.timeout + ) + + output = stdout.decode("utf-8") if stdout else "" + error = stderr.decode("utf-8") if stderr else "" + exit_code = process.returncode + + if exit_code == 0: + # Parse RESULT_TIME + match = re.search(r"RESULT_TIME:\s*([0-9.]+)", output) + if match: + time_taken = float(match.group(1)) + if time_taken < best_time: + best_time = time_taken + best_cfg = cfg + best_output = output + else: + logging.warning(f"No RESULT_TIME found in output for config {cfg}") + else: + logging.warning(f"Config {cfg} failed with exit code {exit_code}. Stderr: {error}") + + except asyncio.TimeoutError: + logging.warning(f"Config {cfg} timed out") + process.kill() + await process.wait() + except Exception as e: + logging.error(f"Error running config {cfg}: {e}") + finally: + try: + os.unlink(temp_file_path) + except OSError: + pass + + if best_cfg is None: + return CodeResponse( + output="", + error="No successful configuration found during autotune.", + exit_code=-1, + ) + + output_data = { + "best_cfg": best_cfg, + "best_time": best_time, + "best_output": best_output, + } + logging.info("Autotune finished") + return CodeResponse(output=json.dumps(output_data), error=None, exit_code=0) + + except Exception as e: + logging.error(f"Autotune failed with error: {str(e)}") + raise HTTPException(status_code=500, detail=f"Autotune error: {str(e)}") + + @app.post("/profile", response_model=CodeResponse) async def profile(request: CodeRequest): logging.info("Starting profile") diff --git a/MaxKernel/hitl_agent/subagents/autotuning/agent.py b/MaxKernel/hitl_agent/subagents/autotuning/agent.py new file mode 100644 index 0000000..2e18929 --- /dev/null +++ b/MaxKernel/hitl_agent/subagents/autotuning/agent.py @@ -0,0 +1,19 @@ +"""Autotuning agent.""" + +from hitl_agent.config import model_config, thinking_planner +from hitl_agent.constants import MODEL_NAME +from hitl_agent.custom_types import CustomLlmAgent +from hitl_agent.subagents.autotuning.prompts import autotune_prompt +from hitl_agent.tools.tools import autotune_tool, filesystem_tool_rw + +autotune_agent = CustomLlmAgent( + name="AutotuneAgent", + model=MODEL_NAME, + generate_content_config=model_config, + planner=thinking_planner, + instruction=autotune_prompt.PROMPT, + description="Auto-tunes Pallas kernels by searching over parameter spaces.", + tools=[autotune_tool, filesystem_tool_rw], +) + +__all__ = ["autotune_agent"] diff --git a/MaxKernel/hitl_agent/subagents/autotuning/prompts/autotune_prompt.py b/MaxKernel/hitl_agent/subagents/autotuning/prompts/autotune_prompt.py new file mode 100644 index 0000000..f9bbbf4 --- /dev/null +++ b/MaxKernel/hitl_agent/subagents/autotuning/prompts/autotune_prompt.py @@ -0,0 +1,18 @@ +"""Prompt for AutotuneAgent.""" + +PROMPT = """You are a specialized agent for auto-tuning Pallas kernels. +Your goal is to find the optimal parameters (like block sizes) for a given kernel to minimize execution time. + +You have access to the `autotune_tool` which performs a grid search. + +To use the tool, you must: +1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N). +2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, BLOCK_M should become BLOCK_M enclosed in curly braces). +3. Ensure the template code prints "RESULT_TIME: " to indicate the execution time. You may need to wrap the kernel call in a loop or use `jax.block_until_ready()` to get accurate timing. +4. Define a search space as a dictionary mapping placeholder names to lists of suggested values. +5. Call `autotune_tool` with the kernel name, code template, and search space. + +After the tool returns, report the best configuration found to the user. + +If the user didn't provide a specific kernel or search space, ask them for it or read it from the work directory if a plan or implementation file exists. +""" diff --git a/MaxKernel/hitl_agent/subagents/testing/prompts/summarize_test_results_prompt.py b/MaxKernel/hitl_agent/subagents/testing/prompts/summarize_test_results_prompt.py index 5020b25..26f7145 100644 --- a/MaxKernel/hitl_agent/subagents/testing/prompts/summarize_test_results_prompt.py +++ b/MaxKernel/hitl_agent/subagents/testing/prompts/summarize_test_results_prompt.py @@ -55,10 +55,10 @@ **Recommendation Guidelines:** -- If tests **passed**: Suggest next steps (profiling for bottlenecks, testing with more input sizes, production deployment considerations) +- If tests **passed**: Suggest next steps (auto-tuning parameters like block sizes using `AutotuneAgent`, profiling for bottlenecks, testing with more input sizes, production deployment considerations) - If **compilation failed**: Provide specific fixes based on the error (API signature issues, import problems, syntax errors) - If **correctness failed**: Suggest debugging approaches (check block boundaries, verify reduction operations, inspect memory access patterns, adjust tolerances) -- If **performance is poor**: Suggest optimization opportunities (block size tuning, memory layout optimization, pipelining, prefetching) +- If **performance is poor**: Suggest optimization opportunities (block size tuning using `AutotuneAgent`, memory layout optimization, pipelining, prefetching) **Important**: - Research the specific error types using your tools before making recommendations diff --git a/MaxKernel/hitl_agent/tools/autotune_tool.py b/MaxKernel/hitl_agent/tools/autotune_tool.py new file mode 100644 index 0000000..aac0841 --- /dev/null +++ b/MaxKernel/hitl_agent/tools/autotune_tool.py @@ -0,0 +1,103 @@ +"""Standalone tool for auto-tuning Pallas kernels using grid search on remote servers.""" + +import json +import logging +from typing import Any +from google.adk import tools +import requests + +from hitl_agent.constants import CPU_SERVER_PORT, TPU_SERVER_PORT + + +def autotune_kernel( + kernel_name: str, + code_template: str, + search_space: dict[str, list[Any]], + backend: str = "tpu", + server_addr: str = "http://localhost", +) -> dict: + """Runs a grid search to auto-tune a Pallas kernel on a remote server. + + Args: + kernel_name: Name of the kernel. + code_template: Python code containing placeholders for parameters to be + tuned. It should produce a line like "RESULT_TIME: " in its + output to indicate performance. + search_space: A dictionary mapping placeholder names to lists of feasible + values. + backend: 'tpu' or 'cpu'. + server_addr: Address of the server (default: http://localhost). + + Returns: + A dictionary containing the status, optimal parameters, and a summary of + results. + """ + logging.info( + f"Starting remote autotuning for kernel: {kernel_name} on {backend}" + ) + + if backend == "tpu": + port = TPU_SERVER_PORT + elif backend == "cpu": + port = CPU_SERVER_PORT + else: + return {"status": "error", "message": f"Invalid backend: {backend}"} + + url = f"{server_addr}:{port}/autotune" + + try: + response = requests.post( + url, + json={ + "code_template": code_template, + "search_space": search_space, + "timeout": 300, + }, + timeout=3600, # 1 hour timeout for the whole autotune request + ) + + if response.status_code == 200: + result = response.json() + if result["exit_code"] == 0: + try: + output_data = json.loads(result["output"]) + logging.info( + f"Autotuning completed. Best config: {output_data['best_cfg']} with time {output_data['best_time']} ms" + ) + return { + "status": "success", + "message": "Autotuning completed", + "best_config": output_data["best_cfg"], + "best_time_ms": output_data["best_time"], + "best_output": output_data["best_output"], + } + except json.JSONDecodeError: + logging.warning("Failed to decode JSON from server output.") + return { + "status": "success", + "message": "Autotuning completed (raw output)", + "raw_output": result["output"], + } + else: + return { + "status": "failed", + "message": result["error"] or "Autotune failed on server", + "server_output": result["output"], + } + else: + return { + "status": "error", + "message": f"Server returned status code {response.status_code}: {response.text}", + } + + except requests.exceptions.ConnectionError: + return { + "status": "error", + "message": f"Could not connect to server at {url}. Make sure it is running.", + } + except Exception as e: + return {"status": "error", "message": str(e)} + + +# Wrap the function with FunctionTool for compatibility with ADK agents +autotune_tool = tools.FunctionTool(autotune_kernel) diff --git a/MaxKernel/hitl_agent/tools/tools.py b/MaxKernel/hitl_agent/tools/tools.py index 8aa108a..515f94e 100644 --- a/MaxKernel/hitl_agent/tools/tools.py +++ b/MaxKernel/hitl_agent/tools/tools.py @@ -10,11 +10,11 @@ from google.adk.tools.retrieval.vertex_ai_rag_retrieval import ( VertexAiRagRetrieval, ) -from mcp import StdioServerParameters -from vertexai.preview import rag - from hitl_agent.config import RAG_CORPUS, WORKDIR +from hitl_agent.tools.autotune_tool import autotune_tool from hitl_agent.tools.search_api_tool import search_api_tool +from mcp import StdioServerParameters +from vertexai.preview import rag # Custom VertexAiRagRetrieval that forces function_declarations mode to avoid @@ -91,9 +91,10 @@ async def process_llm_request( ) __all__ = [ - "search_api_tool", - "filesystem_tool_r", - "filesystem_tool_rw", - "vertex_ai_rag_tool", - "CompatibleVertexAiRagRetrieval", + "search_api_tool", + "filesystem_tool_r", + "filesystem_tool_rw", + "vertex_ai_rag_tool", + "CompatibleVertexAiRagRetrieval", + "autotune_tool", ]