|
| 1 | +"""LLM Tool definitions for verification functions. |
| 2 | +
|
| 3 | +This module exposes verification functions as tools that can be called by LLMs. |
| 4 | +Each tool has a JSON schema definition and a simplified wrapper function. |
| 5 | +""" |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +import os |
| 10 | +from pathlib import Path |
| 11 | +from typing import Any |
| 12 | + |
| 13 | +from pydantic import BaseModel, Field |
| 14 | + |
| 15 | +from codeflash.models.models import TestFile, TestFiles, TestType |
| 16 | +from codeflash.verification.parse_test_output import parse_test_xml |
| 17 | +from codeflash.verification.test_runner import run_behavioral_tests |
| 18 | +from codeflash.verification.verification_utils import TestConfig |
| 19 | + |
| 20 | + |
| 21 | +class TestFileInput(BaseModel): |
| 22 | + """Input schema for a single test file.""" |
| 23 | + |
| 24 | + test_file_path: str = Field(description="Absolute path to the test file to run") |
| 25 | + test_type: str = Field( |
| 26 | + default="existing_unit_test", |
| 27 | + description="Type of test: 'existing_unit_test', 'generated_regression', 'replay_test', or 'concolic_coverage_test'", |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +class RunBehavioralTestsInput(BaseModel): |
| 32 | + """Input schema for the run_behavioral_tests tool.""" |
| 33 | + |
| 34 | + test_files: list[TestFileInput] = Field(description="List of test files to run") |
| 35 | + test_framework: str = Field(default="pytest", description="Test framework to use: 'pytest' or 'unittest'") |
| 36 | + project_root: str = Field(description="Absolute path to the project root directory") |
| 37 | + pytest_timeout: int | None = Field(default=30, description="Timeout in seconds for each pytest test") |
| 38 | + verbose: bool = Field(default=False, description="Enable verbose output") |
| 39 | + |
| 40 | + |
| 41 | +class TestResultOutput(BaseModel): |
| 42 | + """Output schema for a single test result.""" |
| 43 | + |
| 44 | + test_id: str = Field(description="Unique identifier for the test") |
| 45 | + test_file: str = Field(description="Path to the test file") |
| 46 | + test_function: str | None = Field(description="Name of the test function") |
| 47 | + passed: bool = Field(description="Whether the test passed") |
| 48 | + runtime_ns: int | None = Field(description="Runtime in nanoseconds, if available") |
| 49 | + timed_out: bool = Field(description="Whether the test timed out") |
| 50 | + |
| 51 | + |
| 52 | +class RunBehavioralTestsOutput(BaseModel): |
| 53 | + """Output schema for the run_behavioral_tests tool.""" |
| 54 | + |
| 55 | + success: bool = Field(description="Whether the test run completed successfully") |
| 56 | + total_tests: int = Field(description="Total number of tests run") |
| 57 | + passed_tests: int = Field(description="Number of tests that passed") |
| 58 | + failed_tests: int = Field(description="Number of tests that failed") |
| 59 | + results: list[TestResultOutput] = Field(description="Detailed results for each test") |
| 60 | + stdout: str = Field(description="Standard output from the test run") |
| 61 | + stderr: str = Field(description="Standard error from the test run") |
| 62 | + error: str | None = Field(default=None, description="Error message if the run failed") |
| 63 | + |
| 64 | + |
| 65 | +# JSON Schema for OpenAI-style function calling |
| 66 | +RUN_BEHAVIORAL_TESTS_TOOL_SCHEMA = { |
| 67 | + "type": "function", |
| 68 | + "function": { |
| 69 | + "name": "run_behavioral_tests", |
| 70 | + "description": ( |
| 71 | + "Run behavioral tests to verify code correctness. " |
| 72 | + "This executes test files using pytest or unittest and returns detailed results " |
| 73 | + "including pass/fail status, runtime information, and any errors encountered." |
| 74 | + ), |
| 75 | + "parameters": { |
| 76 | + "type": "object", |
| 77 | + "properties": { |
| 78 | + "test_files": { |
| 79 | + "type": "array", |
| 80 | + "description": "List of test files to run", |
| 81 | + "items": { |
| 82 | + "type": "object", |
| 83 | + "properties": { |
| 84 | + "test_file_path": { |
| 85 | + "type": "string", |
| 86 | + "description": "Absolute path to the test file to run", |
| 87 | + }, |
| 88 | + "test_type": { |
| 89 | + "type": "string", |
| 90 | + "enum": [ |
| 91 | + "existing_unit_test", |
| 92 | + "generated_regression", |
| 93 | + "replay_test", |
| 94 | + "concolic_coverage_test", |
| 95 | + ], |
| 96 | + "default": "existing_unit_test", |
| 97 | + "description": "Type of test being run", |
| 98 | + }, |
| 99 | + }, |
| 100 | + "required": ["test_file_path"], |
| 101 | + }, |
| 102 | + }, |
| 103 | + "test_framework": { |
| 104 | + "type": "string", |
| 105 | + "enum": ["pytest", "unittest"], |
| 106 | + "default": "pytest", |
| 107 | + "description": "Test framework to use", |
| 108 | + }, |
| 109 | + "project_root": {"type": "string", "description": "Absolute path to the project root directory"}, |
| 110 | + "pytest_timeout": { |
| 111 | + "type": "integer", |
| 112 | + "default": 30, |
| 113 | + "description": "Timeout in seconds for each pytest test", |
| 114 | + }, |
| 115 | + "verbose": {"type": "boolean", "default": False, "description": "Enable verbose output"}, |
| 116 | + }, |
| 117 | + "required": ["test_files", "project_root"], |
| 118 | + }, |
| 119 | + }, |
| 120 | +} |
| 121 | + |
| 122 | + |
| 123 | +def _test_type_from_string(test_type_str: str) -> TestType: |
| 124 | + """Convert a string test type to TestType enum.""" |
| 125 | + mapping = { |
| 126 | + "existing_unit_test": TestType.EXISTING_UNIT_TEST, |
| 127 | + "generated_regression": TestType.GENERATED_REGRESSION, |
| 128 | + "replay_test": TestType.REPLAY_TEST, |
| 129 | + "concolic_test": TestType.CONCOLIC_COVERAGE_TEST, |
| 130 | + "concolic_coverage_test": TestType.CONCOLIC_COVERAGE_TEST, |
| 131 | + } |
| 132 | + return mapping.get(test_type_str.lower(), TestType.EXISTING_UNIT_TEST) |
| 133 | + |
| 134 | + |
| 135 | +def run_behavioral_tests_tool( |
| 136 | + test_files: list[dict[str, Any]], |
| 137 | + project_root: str, |
| 138 | + test_framework: str = "pytest", |
| 139 | + pytest_timeout: int | None = 30, |
| 140 | + verbose: bool = False, # noqa: FBT002, FBT001 |
| 141 | +) -> dict[str, Any]: |
| 142 | + """Run behavioral tests and return results in an LLM-friendly format. |
| 143 | +
|
| 144 | + This is a simplified wrapper around run_behavioral_tests that accepts |
| 145 | + primitive types suitable for LLM tool calling and returns a structured |
| 146 | + dictionary response. |
| 147 | +
|
| 148 | + Args: |
| 149 | + test_files: List of dicts with 'test_file_path' and optional 'test_type' |
| 150 | + project_root: Absolute path to the project root directory |
| 151 | + test_framework: Test framework to use ('pytest' or 'unittest') |
| 152 | + pytest_timeout: Timeout in seconds for each pytest test |
| 153 | + verbose: Enable verbose output |
| 154 | +
|
| 155 | + Returns: |
| 156 | + Dictionary containing test results with success status, counts, and details |
| 157 | +
|
| 158 | + Example: |
| 159 | + >>> result = run_behavioral_tests_tool( |
| 160 | + ... test_files=[{"test_file_path": "/path/to/test_example.py"}], project_root="/path/to/project" |
| 161 | + ... ) |
| 162 | + >>> print(result["passed_tests"], "tests passed") |
| 163 | +
|
| 164 | + """ |
| 165 | + try: |
| 166 | + project_root_path = Path(project_root).resolve() |
| 167 | + |
| 168 | + # Build TestFiles structure |
| 169 | + test_file_objects = [] |
| 170 | + for tf in test_files: |
| 171 | + test_file_path = Path(tf["test_file_path"]).resolve() |
| 172 | + test_type_str = tf.get("test_type", "existing_unit_test") |
| 173 | + test_type = _test_type_from_string(test_type_str) |
| 174 | + |
| 175 | + test_file_objects.append( |
| 176 | + TestFile( |
| 177 | + instrumented_behavior_file_path=test_file_path, |
| 178 | + benchmarking_file_path=test_file_path, |
| 179 | + original_file_path=test_file_path, |
| 180 | + test_type=test_type, |
| 181 | + ) |
| 182 | + ) |
| 183 | + |
| 184 | + test_files_model = TestFiles(test_files=test_file_objects) |
| 185 | + |
| 186 | + # Set up test environment |
| 187 | + test_env = os.environ.copy() |
| 188 | + test_env["CODEFLASH_TEST_ITERATION"] = "0" |
| 189 | + test_env["CODEFLASH_TRACER_DISABLE"] = "1" |
| 190 | + |
| 191 | + # Ensure PYTHONPATH includes project root |
| 192 | + if "PYTHONPATH" not in test_env: |
| 193 | + test_env["PYTHONPATH"] = str(project_root_path) |
| 194 | + else: |
| 195 | + test_env["PYTHONPATH"] += os.pathsep + str(project_root_path) |
| 196 | + |
| 197 | + # Run the tests |
| 198 | + result_file_path, process, _, _ = run_behavioral_tests( |
| 199 | + test_paths=test_files_model, |
| 200 | + test_framework=test_framework, |
| 201 | + test_env=test_env, |
| 202 | + cwd=project_root_path, |
| 203 | + pytest_timeout=pytest_timeout, |
| 204 | + verbose=verbose, |
| 205 | + ) |
| 206 | + |
| 207 | + # Create test config for parsing results |
| 208 | + test_config = TestConfig( |
| 209 | + tests_root=project_root_path, |
| 210 | + project_root_path=project_root_path, |
| 211 | + test_framework=test_framework, |
| 212 | + tests_project_rootdir=project_root_path, |
| 213 | + ) |
| 214 | + |
| 215 | + # Parse test results |
| 216 | + test_results = parse_test_xml( |
| 217 | + test_xml_file_path=result_file_path, |
| 218 | + test_files=test_files_model, |
| 219 | + test_config=test_config, |
| 220 | + run_result=process, |
| 221 | + ) |
| 222 | + |
| 223 | + # Clean up result file |
| 224 | + result_file_path.unlink(missing_ok=True) |
| 225 | + |
| 226 | + # Build response |
| 227 | + results_list = [] |
| 228 | + passed_count = 0 |
| 229 | + failed_count = 0 |
| 230 | + |
| 231 | + for result in test_results: |
| 232 | + passed = result.did_pass |
| 233 | + if passed: |
| 234 | + passed_count += 1 |
| 235 | + else: |
| 236 | + failed_count += 1 |
| 237 | + |
| 238 | + results_list.append( |
| 239 | + { |
| 240 | + "test_id": result.id.id() if result.id else "", |
| 241 | + "test_file": str(result.file_name) if result.file_name else "", |
| 242 | + "test_function": result.id.test_function_name if result.id else None, |
| 243 | + "passed": passed, |
| 244 | + "runtime_ns": result.runtime, |
| 245 | + "timed_out": result.timed_out or False, |
| 246 | + } |
| 247 | + ) |
| 248 | + |
| 249 | + return { |
| 250 | + "success": True, |
| 251 | + "total_tests": len(test_results), |
| 252 | + "passed_tests": passed_count, |
| 253 | + "failed_tests": failed_count, |
| 254 | + "results": results_list, |
| 255 | + "stdout": process.stdout if process.stdout else "", |
| 256 | + "stderr": process.stderr if process.stderr else "", |
| 257 | + "error": None, |
| 258 | + } |
| 259 | + |
| 260 | + except Exception as e: |
| 261 | + return { |
| 262 | + "success": False, |
| 263 | + "total_tests": 0, |
| 264 | + "passed_tests": 0, |
| 265 | + "failed_tests": 0, |
| 266 | + "results": [], |
| 267 | + "stdout": "", |
| 268 | + "stderr": "", |
| 269 | + "error": str(e), |
| 270 | + } |
| 271 | + |
| 272 | + |
| 273 | +# Registry of available tools |
| 274 | +AVAILABLE_TOOLS = { |
| 275 | + "run_behavioral_tests": {"schema": RUN_BEHAVIORAL_TESTS_TOOL_SCHEMA, "function": run_behavioral_tests_tool} |
| 276 | +} |
| 277 | + |
| 278 | + |
| 279 | +def get_tool_schema(tool_name: str) -> dict[str, Any] | None: |
| 280 | + """Get the JSON schema for a tool by name. |
| 281 | +
|
| 282 | + Args: |
| 283 | + tool_name: Name of the tool to get schema for |
| 284 | +
|
| 285 | + Returns: |
| 286 | + JSON schema dict or None if tool not found |
| 287 | +
|
| 288 | + """ |
| 289 | + tool = AVAILABLE_TOOLS.get(tool_name) |
| 290 | + return tool["schema"] if tool else None |
| 291 | + |
| 292 | + |
| 293 | +def get_all_tool_schemas() -> list[dict[str, Any]]: |
| 294 | + """Get JSON schemas for all available tools. |
| 295 | +
|
| 296 | + Returns: |
| 297 | + List of JSON schema dicts for all tools |
| 298 | +
|
| 299 | + """ |
| 300 | + return [tool["schema"] for tool in AVAILABLE_TOOLS.values()] |
| 301 | + |
| 302 | + |
| 303 | +def execute_tool(tool_name: str, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401 |
| 304 | + """Execute a tool by name with the given arguments. |
| 305 | +
|
| 306 | + Args: |
| 307 | + tool_name: Name of the tool to execute |
| 308 | + **kwargs: Arguments to pass to the tool function |
| 309 | +
|
| 310 | + Returns: |
| 311 | + Tool execution result as a dictionary |
| 312 | +
|
| 313 | + Raises: |
| 314 | + ValueError: If tool_name is not found |
| 315 | +
|
| 316 | + """ |
| 317 | + tool = AVAILABLE_TOOLS.get(tool_name) |
| 318 | + if not tool: |
| 319 | + msg = f"Unknown tool: {tool_name}" |
| 320 | + raise ValueError(msg) |
| 321 | + return tool["function"](**kwargs) |
0 commit comments