Skip to content

Commit c58fd1e

Browse files
edburnsCopilot
andcommitted
Reject mismatched numeric defaults for integral params
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent ebd8b6d commit c58fd1e

2 files changed

Lines changed: 120 additions & 0 deletions

File tree

java/src/main/java/com/github/copilot/tool/CopilotToolProcessor.java

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment
7474
processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR,
7575
"@Param cannot have both required=true and a non-empty defaultValue", param);
7676
}
77+
if (paramAnnotation != null && !paramAnnotation.defaultValue().isEmpty()) {
78+
String defaultValidationError = validateDefaultValueCompatibility(param.asType(),
79+
paramAnnotation.defaultValue());
80+
if (defaultValidationError != null) {
81+
processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR, defaultValidationError, param);
82+
}
83+
}
7784
if (paramAnnotation != null
7885
&& !paramAnnotation.required()
7986
&& paramAnnotation.defaultValue().isEmpty()
@@ -409,6 +416,9 @@ private String generateArgExtractionFromMap(String paramName, TypeMirror type) {
409416
if (type.getKind().isPrimitive()) {
410417
return generatePrimitiveExtraction("args.get(\"" + paramName + "\")", type);
411418
}
419+
if (type.getKind() == TypeKind.ARRAY) {
420+
return generateGenericTypeReferenceConversion("args.get(\"" + paramName + "\")", type);
421+
}
412422
if (type.getKind() == TypeKind.DECLARED) {
413423
TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement();
414424
String qualifiedName = typeElement.getQualifiedName().toString();
@@ -434,6 +444,9 @@ private String generateArgExtraction(String varExpr, TypeMirror type) {
434444
if (type.getKind().isPrimitive()) {
435445
return generatePrimitiveExtraction(varExpr, type);
436446
}
447+
if (type.getKind() == TypeKind.ARRAY) {
448+
return generateGenericTypeReferenceConversion(varExpr, type);
449+
}
437450
if (type.getKind() == TypeKind.DECLARED) {
438451
TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement();
439452
String qualifiedName = typeElement.getQualifiedName().toString();
@@ -594,6 +607,94 @@ private String generateDefaultLiteral(TypeMirror type, String defaultValue) {
594607
return "\"" + escapeJava(defaultValue) + "\"";
595608
}
596609

610+
private String validateDefaultValueCompatibility(TypeMirror type, String defaultValue) {
611+
if (type.getKind().isPrimitive()) {
612+
return validatePrimitiveDefault(type.getKind(), defaultValue);
613+
}
614+
if (type.getKind() == TypeKind.DECLARED) {
615+
TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement();
616+
String qualifiedName = typeElement.getQualifiedName().toString();
617+
if ("java.lang.String".equals(qualifiedName)) {
618+
return null;
619+
}
620+
if ("java.lang.Boolean".equals(qualifiedName)) {
621+
return validateBooleanDefault(defaultValue);
622+
}
623+
if ("java.lang.Character".equals(qualifiedName)) {
624+
return validateCharacterDefault(defaultValue);
625+
}
626+
if (isBoxedNumeric(qualifiedName)) {
627+
return validatePrimitiveDefault(boxedTypeKind(qualifiedName), defaultValue);
628+
}
629+
}
630+
return null;
631+
}
632+
633+
private String validatePrimitiveDefault(TypeKind kind, String defaultValue) {
634+
try {
635+
switch (kind) {
636+
case INT :
637+
Integer.parseInt(defaultValue);
638+
return null;
639+
case LONG :
640+
Long.parseLong(defaultValue);
641+
return null;
642+
case SHORT :
643+
Short.parseShort(defaultValue);
644+
return null;
645+
case BYTE :
646+
Byte.parseByte(defaultValue);
647+
return null;
648+
case DOUBLE :
649+
Double.parseDouble(defaultValue);
650+
return null;
651+
case FLOAT :
652+
Float.parseFloat(defaultValue);
653+
return null;
654+
case BOOLEAN :
655+
return validateBooleanDefault(defaultValue);
656+
case CHAR :
657+
return validateCharacterDefault(defaultValue);
658+
default :
659+
return null;
660+
}
661+
} catch (NumberFormatException ex) {
662+
return "@Param defaultValue '" + defaultValue + "' is not valid for " + kind.name().toLowerCase()
663+
+ " parameters";
664+
}
665+
}
666+
667+
private String validateBooleanDefault(String defaultValue) {
668+
if ("true".equalsIgnoreCase(defaultValue) || "false".equalsIgnoreCase(defaultValue)) {
669+
return null;
670+
}
671+
return "@Param defaultValue '" + defaultValue + "' is not valid for boolean parameters";
672+
}
673+
674+
private String validateCharacterDefault(String defaultValue) {
675+
return defaultValue != null && defaultValue.length() == 1 ? null
676+
: "@Param defaultValue '" + defaultValue + "' is not valid for char parameters";
677+
}
678+
679+
private TypeKind boxedTypeKind(String qualifiedName) {
680+
switch (qualifiedName) {
681+
case "java.lang.Integer" :
682+
return TypeKind.INT;
683+
case "java.lang.Long" :
684+
return TypeKind.LONG;
685+
case "java.lang.Double" :
686+
return TypeKind.DOUBLE;
687+
case "java.lang.Float" :
688+
return TypeKind.FLOAT;
689+
case "java.lang.Short" :
690+
return TypeKind.SHORT;
691+
case "java.lang.Byte" :
692+
return TypeKind.BYTE;
693+
default :
694+
return TypeKind.NONE;
695+
}
696+
}
697+
597698
private String getParamName(VariableElement param) {
598699
Param paramAnnotation = param.getAnnotation(Param.class);
599700
if (paramAnnotation != null && !paramAnnotation.name().isEmpty()) {

java/src/test/java/com/github/copilot/tool/CopilotToolProcessorTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,25 @@ public String doWork(
509509
"Expected string default \"hello\" as quoted string. Generated:\n" + generated);
510510
}
511511

512+
@Test
513+
void rejectsMismatchedNumericDefaultForIntegralParameters() {
514+
String source = """
515+
package test;
516+
import com.github.copilot.tool.CopilotTool;
517+
import com.github.copilot.tool.Param;
518+
public class MismatchedDefaults {
519+
@CopilotTool("Tool with bad default")
520+
public String doWork(@Param(value = "Limit", required = false, defaultValue = "1.5") int limit) {
521+
return String.valueOf(limit);
522+
}
523+
}
524+
""";
525+
526+
CompilationResult result = compileWithProcessor(List.of(inMemorySource("test.MismatchedDefaults", source)));
527+
assertTrue(hasErrorContaining(result, "not valid for int parameters"),
528+
"Expected compile error for mismatched int defaultValue, got: " + result.diagnostics);
529+
}
530+
512531
// ── Test: package-private methods are allowed ───────────────────────────────
513532

514533
@Test

0 commit comments

Comments
 (0)