|
40 | 40 | from codeflash.discovery.functions_to_optimize import FunctionToOptimize |
41 | 41 | from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, TestInfo, TestResult |
42 | 42 | from codeflash.languages.java.concurrency_analyzer import ConcurrencyInfo |
43 | | - from codeflash.models.models import GeneratedTestsList |
| 43 | + from codeflash.models.models import GeneratedTestsList, InvocationId |
44 | 44 |
|
45 | 45 | logger = logging.getLogger(__name__) |
46 | 46 |
|
@@ -199,6 +199,40 @@ def remove_test_functions(self, test_source: str, functions_to_remove: list[str] |
199 | 199 | """Remove specific test functions from test source code.""" |
200 | 200 | return remove_test_functions(test_source, functions_to_remove, self._analyzer) |
201 | 201 |
|
| 202 | + def remove_test_functions_from_generated_tests( |
| 203 | + self, generated_tests: GeneratedTestsList, functions_to_remove: list[str] |
| 204 | + ) -> GeneratedTestsList: |
| 205 | + """Remove specific test functions from generated tests.""" |
| 206 | + from codeflash.models.models import GeneratedTests, GeneratedTestsList |
| 207 | + |
| 208 | + updated_tests: list[GeneratedTests] = [] |
| 209 | + for test in generated_tests.generated_tests: |
| 210 | + updated_tests.append( |
| 211 | + GeneratedTests( |
| 212 | + generated_original_test_source=self.remove_test_functions( |
| 213 | + test.generated_original_test_source, functions_to_remove |
| 214 | + ), |
| 215 | + instrumented_behavior_test_source=test.instrumented_behavior_test_source, |
| 216 | + instrumented_perf_test_source=test.instrumented_perf_test_source, |
| 217 | + behavior_file_path=test.behavior_file_path, |
| 218 | + perf_file_path=test.perf_file_path, |
| 219 | + ) |
| 220 | + ) |
| 221 | + return GeneratedTestsList(generated_tests=updated_tests) |
| 222 | + |
| 223 | + def add_runtime_comments_to_generated_tests( |
| 224 | + self, |
| 225 | + generated_tests: GeneratedTestsList, |
| 226 | + original_runtimes: dict[InvocationId, list[int]], |
| 227 | + optimized_runtimes: dict[InvocationId, list[int]], |
| 228 | + tests_project_rootdir: Path | None = None, |
| 229 | + ) -> GeneratedTestsList: |
| 230 | + """Add runtime comments to generated tests.""" |
| 231 | + _ = tests_project_rootdir |
| 232 | + # For Java, we currently don't add runtime comments to generated tests |
| 233 | + # Return the generated tests unchanged |
| 234 | + return generated_tests |
| 235 | + |
202 | 236 | # === Test Result Comparison === |
203 | 237 |
|
204 | 238 | def compare_test_results( |
|
0 commit comments