Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions MaxKernel/hitl_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
plan_kernel_agent,
validate_kernel_compilation_agent,
)
from hitl_agent.subagents.autotuning.agent import autotune_agent
from hitl_agent.subagents.profiling import profile_agent
from hitl_agent.subagents.testing import (
unified_test_agent,
Expand Down Expand Up @@ -52,6 +53,7 @@
unified_test_agent, # Step 5: Run tests and provide summary
profile_agent, # Step 6: Profile for bottlenecks
gpu_to_jax_agent, # GPU-to-JAX conversion
autotune_agent, # Step 7: Auto-tune kernel
],
tools=[
filesystem_tool_r
Expand Down
3 changes: 2 additions & 1 deletion MaxKernel/hitl_agent/dependency/agent_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ google-cloud-aiplatform[adk,agent_engines]
cloudpickle
xprof
tpu-info
pytest-asyncio
pytest-asyncio
requests
11 changes: 11 additions & 0 deletions MaxKernel/hitl_agent/prompts/interactive_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* **GenerateTestFileAgent**: Generates a comprehensive pytest test file with compilation, correctness, and performance tests for kernel files.
* **UnifiedTestAgent**: Executes the generated pytest test file on TPU and provides comprehensive results including full tracebacks. Automatically manages server lifecycle (starts/stops TPU and eval servers as needed).
* **ProfileAgentOrchestrator**: Profiles a kernel to identify performance bottlenecks (DMAs, memory transfers, compute ratios).
* **AutotuneAgent**: Auto-tunes Pallas kernels by searching over parameter spaces (like block sizes) to minimize execution time.
* **GpuToJaxAgent**: Converts/writes GPU code (CUDA/Triton/PyTorch) to JAX/Pallas.

### Your Reasoning Process
Expand Down Expand Up @@ -77,6 +78,9 @@
* **Action**: Delegate to `ProfileAgentOrchestrator` to generate and run profiling scripts.
* **Note**: This identifies performance bottlenecks like memory transfers vs compute ratios.

* **If the request is to AUTO-TUNE a kernel** (like "Autotune kernel.py", "Search for best parameters", "Optimize block sizes"):
* **Action**: Delegate to `AutotuneAgent` to perform grid search.

* **If the request is GPU-to-JAX conversion**:
* **Action**: Delegate to `GpuToJaxAgent` (it handles its own plan-approve-implement workflow).

Expand All @@ -90,6 +94,7 @@
4. **Validation Phase (optional)**: User requests validation → You delegate to `ValidateKernelCompilationAgent` → Compilation validated/fixed → **Return control to user**
5. **Test Generation Phase (optional)**: User requests tests → You delegate to `GenerateTestFileAgent` → Test file generated → **Return control to user**
6. **Test Execution Phase (optional)**: User requests test execution → You delegate to `UnifiedTestAgent` → Tests run → **Return control to user**
7. **Autotune Phase (optional)**: User requests auto-tuning → You delegate to `AutotuneAgent` → Parameters optimized → **Return control to user**

**Remember**: After ANY agent completes (planning, implementation, testing, profiling, etc.), immediately return control. The user decides the next step, not you.

Expand Down Expand Up @@ -152,6 +157,12 @@
You: Delegate to ProfileAgentOrchestrator → Profiling analysis complete → [END TURN, wait for user]
```

**Example 7: Autotuning**
```
User: "Autotune optimized_kernel.py"
You: Delegate to AutotuneAgent → Grid search complete with best config → [END TURN, wait for user]
```

**ANTI-PATTERN - NEVER DO THIS:**
```
❌ WRONG:
Expand Down
101 changes: 101 additions & 0 deletions MaxKernel/hitl_agent/server_utils/cpu_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
import sys
import tempfile
import itertools
import re
from typing import Optional

from fastapi import FastAPI, HTTPException
Expand Down Expand Up @@ -38,6 +40,12 @@ class CodeResponse(BaseModel):
exit_code: int


class AutotuneRequest(BaseModel):
code_template: str
search_space: dict[str, list]
timeout: Optional[int] = 30


class GetBackendVersionResponse(BaseModel):
backend_version: str

Expand Down Expand Up @@ -306,6 +314,99 @@ async def performance_test(request: CodeRequest):
logging.info("Performance test finished")


@app.post("/autotune", response_model=CodeResponse)
async def autotune(request: AutotuneRequest):
logging.info("Starting autotune on CPU backend")
async with performance_semaphore:
try:
# Generate all combinations
keys = list(request.search_space.keys())
values = list(request.search_space.values())
combinations = list(itertools.product(*values))

best_time = float("inf")
best_cfg = None
best_output = ""

for combo in combinations:
cfg = dict(zip(keys, combo))
try:
code_content = request.code_template.format(**cfg)
except KeyError as e:
logging.error(f"KeyError during template formatting: {e}. Config: {cfg}")
continue

# Execute the code
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as temp_file:
temp_file.write(code_content)
temp_file_path = temp_file.name

try:
process = await asyncio.create_subprocess_exec(
sys.executable,
temp_file_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=tempfile.gettempdir(),
env=get_cpu_env(), # Force CPU backend
)

stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=request.timeout
)

output = stdout.decode("utf-8") if stdout else ""
error = stderr.decode("utf-8") if stderr else ""
exit_code = process.returncode

if exit_code == 0:
# Parse RESULT_TIME
match = re.search(r"RESULT_TIME:\s*([0-9.]+)", output)
if match:
time_taken = float(match.group(1))
if time_taken < best_time:
best_time = time_taken
best_cfg = cfg
best_output = output
else:
logging.warning(f"No RESULT_TIME found in output for config {cfg}")
else:
logging.warning(f"Config {cfg} failed with exit code {exit_code}. Stderr: {error}")

except asyncio.TimeoutError:
logging.warning(f"Config {cfg} timed out")
process.kill()
await process.wait()
except Exception as e:
logging.error(f"Error running config {cfg}: {e}")
finally:
try:
os.unlink(temp_file_path)
except OSError:
pass

if best_cfg is None:
return CodeResponse(
output="",
error="No successful configuration found during autotune.",
exit_code=-1,
)

output_data = {
"best_cfg": best_cfg,
"best_time": best_time,
"best_output": best_output,
}
logging.info("Autotune finished on CPU backend")
return CodeResponse(output=json.dumps(output_data), error=None, exit_code=0)

except Exception as e:
logging.error(f"Autotune failed with error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Autotune error: {str(e)}")


@app.post("/profile", response_model=CodeResponse)
async def profile(request: CodeRequest):
logging.info("Starting profile on CPU backend")
Expand Down
99 changes: 99 additions & 0 deletions MaxKernel/hitl_agent/server_utils/tpu_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import sys
import tempfile
import itertools
from typing import Optional

from fastapi import FastAPI, HTTPException
Expand Down Expand Up @@ -40,6 +41,12 @@ class CodeResponse(BaseModel):
exit_code: int


class AutotuneRequest(BaseModel):
code_template: str
search_space: dict[str, list]
timeout: Optional[int] = 30


class GetTpuVersionResponse(BaseModel):
tpu_version: str

Expand Down Expand Up @@ -293,6 +300,98 @@ async def performance_test(request: CodeRequest):
logging.info("Performance test finished")


@app.post("/autotune", response_model=CodeResponse)
async def autotune(request: AutotuneRequest):
logging.info("Starting autotune")
async with performance_semaphore:
try:
# Generate all combinations
keys = list(request.search_space.keys())
values = list(request.search_space.values())
combinations = list(itertools.product(*values))

best_time = float("inf")
best_cfg = None
best_output = ""

for combo in combinations:
cfg = dict(zip(keys, combo))
try:
code_content = request.code_template.format(**cfg)
except KeyError as e:
logging.error(f"KeyError during template formatting: {e}. Config: {cfg}")
continue

# Execute the code
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as temp_file:
temp_file.write(code_content)
temp_file_path = temp_file.name

try:
process = await asyncio.create_subprocess_exec(
sys.executable,
temp_file_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=tempfile.gettempdir(),
)

stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=request.timeout
)

output = stdout.decode("utf-8") if stdout else ""
error = stderr.decode("utf-8") if stderr else ""
exit_code = process.returncode

if exit_code == 0:
# Parse RESULT_TIME
match = re.search(r"RESULT_TIME:\s*([0-9.]+)", output)
if match:
time_taken = float(match.group(1))
if time_taken < best_time:
best_time = time_taken
best_cfg = cfg
best_output = output
else:
logging.warning(f"No RESULT_TIME found in output for config {cfg}")
else:
logging.warning(f"Config {cfg} failed with exit code {exit_code}. Stderr: {error}")

except asyncio.TimeoutError:
logging.warning(f"Config {cfg} timed out")
process.kill()
await process.wait()
except Exception as e:
logging.error(f"Error running config {cfg}: {e}")
finally:
try:
os.unlink(temp_file_path)
except OSError:
pass

if best_cfg is None:
return CodeResponse(
output="",
error="No successful configuration found during autotune.",
exit_code=-1,
)

output_data = {
"best_cfg": best_cfg,
"best_time": best_time,
"best_output": best_output,
}
logging.info("Autotune finished")
return CodeResponse(output=json.dumps(output_data), error=None, exit_code=0)

except Exception as e:
logging.error(f"Autotune failed with error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Autotune error: {str(e)}")


@app.post("/profile", response_model=CodeResponse)
async def profile(request: CodeRequest):
logging.info("Starting profile")
Expand Down
19 changes: 19 additions & 0 deletions MaxKernel/hitl_agent/subagents/autotuning/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Autotuning agent."""

from hitl_agent.config import model_config, thinking_planner
from hitl_agent.constants import MODEL_NAME
from hitl_agent.custom_types import CustomLlmAgent
from hitl_agent.subagents.autotuning.prompts import autotune_prompt
from hitl_agent.tools.tools import autotune_tool, filesystem_tool_rw

autotune_agent = CustomLlmAgent(
name="AutotuneAgent",
model=MODEL_NAME,
generate_content_config=model_config,
planner=thinking_planner,
instruction=autotune_prompt.PROMPT,
description="Auto-tunes Pallas kernels by searching over parameter spaces.",
tools=[autotune_tool, filesystem_tool_rw],
)

__all__ = ["autotune_agent"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Prompt for AutotuneAgent."""

PROMPT = """You are a specialized agent for auto-tuning Pallas kernels.
Your goal is to find the optimal parameters (like block sizes) for a given kernel to minimize execution time.

You have access to the `autotune_tool` which performs a grid search.

To use the tool, you must:
1. Identify the parameters that can be tuned in the kernel (e.g., BLOCK_M, BLOCK_N).
2. Create a code template from the kernel code, replacing the specific parameter values with placeholders enclosed in curly braces (for example, BLOCK_M should become BLOCK_M enclosed in curly braces).
3. Ensure the template code prints "RESULT_TIME: <float>" to indicate the execution time. You may need to wrap the kernel call in a loop or use `jax.block_until_ready()` to get accurate timing.
4. Define a search space as a dictionary mapping placeholder names to lists of suggested values.
5. Call `autotune_tool` with the kernel name, code template, and search space.

After the tool returns, report the best configuration found to the user.

If the user didn't provide a specific kernel or search space, ask them for it or read it from the work directory if a plan or implementation file exists.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@

**Recommendation Guidelines:**

- If tests **passed**: Suggest next steps (profiling for bottlenecks, testing with more input sizes, production deployment considerations)
- If tests **passed**: Suggest next steps (auto-tuning parameters like block sizes using `AutotuneAgent`, profiling for bottlenecks, testing with more input sizes, production deployment considerations)
- If **compilation failed**: Provide specific fixes based on the error (API signature issues, import problems, syntax errors)
- If **correctness failed**: Suggest debugging approaches (check block boundaries, verify reduction operations, inspect memory access patterns, adjust tolerances)
- If **performance is poor**: Suggest optimization opportunities (block size tuning, memory layout optimization, pipelining, prefetching)
- If **performance is poor**: Suggest optimization opportunities (block size tuning using `AutotuneAgent`, memory layout optimization, pipelining, prefetching)

**Important**:
- Research the specific error types using your tools before making recommendations
Expand Down
Loading