Skip to content

Commit 15ade96

Browse files
fix: auto-add missing standard library imports in AI-generated Java tests
AI-generated test code sometimes uses standard library classes like Arrays, List, HashMap etc. without the corresponding import statement, causing compilation failures. Add ensure_common_java_imports() that detects usage of common classes and adds missing imports automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 29dd7a3 commit 15ade96

1 file changed

Lines changed: 30 additions & 0 deletions

File tree

codeflash/languages/java/instrumentation.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,35 @@ def remove_instrumentation(source: str) -> str:
12601260
return source
12611261

12621262

1263+
_COMMON_JAVA_IMPORTS = {
1264+
"Arrays": "import java.util.Arrays;",
1265+
"List": "import java.util.List;",
1266+
"ArrayList": "import java.util.ArrayList;",
1267+
"Map": "import java.util.Map;",
1268+
"HashMap": "import java.util.HashMap;",
1269+
"Set": "import java.util.Set;",
1270+
"HashSet": "import java.util.HashSet;",
1271+
"Collections": "import java.util.Collections;",
1272+
"Collectors": "import java.util.stream.Collectors;",
1273+
"Random": "import java.util.Random;",
1274+
"BigDecimal": "import java.math.BigDecimal;",
1275+
"BigInteger": "import java.math.BigInteger;",
1276+
}
1277+
1278+
1279+
def ensure_common_java_imports(test_code: str) -> str:
1280+
for class_name, import_stmt in _COMMON_JAVA_IMPORTS.items():
1281+
if not re.search(rf"\b{class_name}\b", test_code):
1282+
continue
1283+
if import_stmt in test_code:
1284+
continue
1285+
package = import_stmt.split()[1].rsplit(".", 1)[0]
1286+
if f"import {package}.*;" in test_code:
1287+
continue
1288+
test_code = _add_import(test_code, import_stmt)
1289+
return test_code
1290+
1291+
12631292
def instrument_generated_java_test(
12641293
test_code: str,
12651294
function_name: str,
@@ -1290,6 +1319,7 @@ def instrument_generated_java_test(
12901319
from codeflash.languages.java.remove_asserts import transform_java_assertions
12911320

12921321
test_code = transform_java_assertions(test_code, function_name, qualified_name)
1322+
test_code = ensure_common_java_imports(test_code)
12931323

12941324
# Extract class name from the test code
12951325
# Use pattern that starts at beginning of line to avoid matching words in comments

0 commit comments

Comments
 (0)