diff --git a/http-validation/src/main/java/io/micronaut/validation/routes/RouteValidationVisitor.java b/http-validation/src/main/java/io/micronaut/validation/routes/RouteValidationVisitor.java index d716550c80c..8975ab3109f 100644 --- a/http-validation/src/main/java/io/micronaut/validation/routes/RouteValidationVisitor.java +++ b/http-validation/src/main/java/io/micronaut/validation/routes/RouteValidationVisitor.java @@ -29,6 +29,7 @@ import io.micronaut.validation.routes.rules.MissingParameterRule; import io.micronaut.validation.routes.rules.NullableParameterRule; import io.micronaut.validation.routes.rules.RequestBeanParameterRule; +import io.micronaut.validation.routes.rules.SuspendedReactiveReturnTypeRule; import io.micronaut.validation.routes.rules.RouteValidationRule; import org.jspecify.annotations.NullUnmarked; @@ -74,7 +75,13 @@ public void visitMethod(MethodElement element, VisitorContext context) { return; } - AnnotationValue mappingAnnotation = element.getAnnotation(METHOD_MAPPING_ANN); + AnnotationValue mappingAnnotation = element.getDeclaredAnnotation(METHOD_MAPPING_ANN); + if (mappingAnnotation == null && element.getAnnotationMetadata().hasStereotype(METHOD_MAPPING_ANN)) { + mappingAnnotation = element.getAnnotationMetadata().getAnnotationValuesByStereotype(METHOD_MAPPING_ANN) + .stream() + .findFirst() + .orElse(null); + } if (mappingAnnotation != null) { Set uris = CollectionUtils.setOf(mappingAnnotation.stringValues("uris")); mappingAnnotation.stringValue().ifPresent(uris::add); @@ -116,6 +123,7 @@ public void start(VisitorContext visitorContext) { rules.add(new NullableParameterRule()); rules.add(new RequestBeanParameterRule()); rules.add(new ClientTypesRule()); + rules.add(new SuspendedReactiveReturnTypeRule()); } /** diff --git a/http-validation/src/main/java/io/micronaut/validation/routes/rules/SuspendedReactiveReturnTypeRule.java b/http-validation/src/main/java/io/micronaut/validation/routes/rules/SuspendedReactiveReturnTypeRule.java new file mode 100644 index 00000000000..0bc46f257c1 --- /dev/null +++ b/http-validation/src/main/java/io/micronaut/validation/routes/rules/SuspendedReactiveReturnTypeRule.java @@ -0,0 +1,47 @@ +/* + * Copyright 2017-2021 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.validation.routes.rules; + +import io.micronaut.http.uri.UriMatchTemplate; +import io.micronaut.inject.ast.ClassElement; +import io.micronaut.inject.ast.MethodElement; +import io.micronaut.inject.ast.ParameterElement; +import io.micronaut.validation.routes.RouteValidationResult; + +import java.util.List; +import java.util.concurrent.CompletionStage; + +import org.reactivestreams.Publisher; + +/** + * Validates that suspended route methods do not declare async or reactive return types. + */ +public final class SuspendedReactiveReturnTypeRule implements RouteValidationRule { + + private static final String MESSAGE = "Unsupported suspended controller return type [%s]. Suspend functions must not return reactive or async types."; + + @Override + public RouteValidationResult validate(List templates, ParameterElement[] parameters, MethodElement method) { + if (!method.isSuspend()) { + return new RouteValidationResult(new String[0]); + } + ClassElement returnType = method.getReturnType(); + if (returnType.isAssignable(Publisher.class) || returnType.isAssignable(CompletionStage.class) || returnType.getName().equals("kotlinx.coroutines.flow.Flow")) { + return new RouteValidationResult(new String[]{MESSAGE.formatted(returnType.getName())}); + } + return new RouteValidationResult(new String[0]); + } +} diff --git a/inject-kotlin-test/src/main/groovy/io/micronaut/annotation/processing/test/AbstractKotlinCompilerSpec.groovy b/inject-kotlin-test/src/main/groovy/io/micronaut/annotation/processing/test/AbstractKotlinCompilerSpec.groovy index f92bbc0836b..671e04faa44 100644 --- a/inject-kotlin-test/src/main/groovy/io/micronaut/annotation/processing/test/AbstractKotlinCompilerSpec.groovy +++ b/inject-kotlin-test/src/main/groovy/io/micronaut/annotation/processing/test/AbstractKotlinCompilerSpec.groovy @@ -15,6 +15,8 @@ */ package io.micronaut.annotation.processing.test +import com.google.devtools.ksp.processing.SymbolProcessorProvider + import io.micronaut.context.ApplicationContext import io.micronaut.context.Qualifier import io.micronaut.core.annotation.Experimental @@ -174,6 +176,10 @@ class AbstractKotlinCompilerSpec extends Specification { KotlinCompiler.buildBeanDefinition(className, cls) } + protected BeanDefinition buildBeanDefinition(String className, @Language("kotlin") String cls, List extraSymbolProcessorProviders) { + KotlinCompiler.buildBeanDefinition(className, cls, extraSymbolProcessorProviders) + } + /** * Create a rough source signature of the given ClassElement, using {@link io.micronaut.inject.ast.ClassElement#getBoundGenericTypes()}. * Can be used to test that {@link io.micronaut.inject.ast.ClassElement#getBoundGenericTypes()} returns the right types in the right diff --git a/inject-kotlin-test/src/main/groovy/io/micronaut/annotation/processing/test/KotlinCompiler.java b/inject-kotlin-test/src/main/groovy/io/micronaut/annotation/processing/test/KotlinCompiler.java index a760d295578..0b5409cc203 100644 --- a/inject-kotlin-test/src/main/groovy/io/micronaut/annotation/processing/test/KotlinCompiler.java +++ b/inject-kotlin-test/src/main/groovy/io/micronaut/annotation/processing/test/KotlinCompiler.java @@ -17,6 +17,7 @@ import com.google.devtools.ksp.processing.SymbolProcessor; import com.google.devtools.ksp.processing.SymbolProcessorEnvironment; +import com.google.devtools.ksp.processing.SymbolProcessorProvider; import com.google.devtools.ksp.symbol.KSClassDeclaration; import com.tschuchort.compiletesting.JvmCompilationResult; import com.tschuchort.compiletesting.KotlinCompilation; @@ -89,8 +90,12 @@ public class KotlinCompiler { } public static URLClassLoader buildClassLoader(String name, @Language("kotlin") String clazz) { + return buildClassLoader(name, clazz, Collections.emptyList()); + } + + public static URLClassLoader buildClassLoader(String name, @Language("kotlin") String clazz, List extraSymbolProcessorProviders) { Pair, Pair> resultPair = compile(name, clazz, classElement -> { - }); + }, extraSymbolProcessorProviders); return toClassLoader(resultPair); } @@ -130,6 +135,10 @@ private static URLClassLoader toClassLoader(Pair, Pair> compile(String name, @Language("kotlin") String clazz, Consumer classElements) { + return compile(name, clazz, classElements, Collections.emptyList()); + } + + public static Pair, Pair> compile(String name, @Language("kotlin") String clazz, Consumer classElements, List extraSymbolProcessorProviders) { try { Files.deleteIfExists(KOTLIN_COMPILATION.getWorkingDir().toPath()); } catch (IOException e) { @@ -143,7 +152,11 @@ public static Pair, Pair symbolProcessorProviders = new ArrayList<>(); + symbolProcessorProviders.add(classElementTypeElementSymbolProcessorProvider); + symbolProcessorProviders.add(new BeanDefinitionProcessorProvider()); + symbolProcessorProviders.addAll(extraSymbolProcessorProviders); + KspKt.setSymbolProcessorProviders(KSP_COMPILATION, symbolProcessorProviders); JvmCompilationResult kspResult = KSP_COMPILATION.compile(); if (kspResult.getExitCode() != KotlinCompilation.ExitCode.OK) { throw new RuntimeException(kspResult.getMessages()); @@ -153,6 +166,10 @@ public static Pair, Pair, Pair> compileJava(String name, @Language("java") String clazz, Consumer classElements) { + return compileJava(name, clazz, classElements, Collections.emptyList()); + } + + public static Pair, Pair> compileJava(String name, @Language("java") String clazz, Consumer classElements, List extraSymbolProcessorProviders) { try { Files.deleteIfExists(KOTLIN_COMPILATION.getWorkingDir().toPath()); } catch (IOException e) { @@ -166,7 +183,11 @@ public static Pair, Pair symbolProcessorProviders = new ArrayList<>(); + symbolProcessorProviders.add(classElementTypeElementSymbolProcessorProvider); + symbolProcessorProviders.add(new BeanDefinitionProcessorProvider()); + symbolProcessorProviders.addAll(extraSymbolProcessorProviders); + KspKt.setSymbolProcessorProviders(KSP_COMPILATION, symbolProcessorProviders); JvmCompilationResult kspResult = KSP_COMPILATION.compile(); if (kspResult.getExitCode() != KotlinCompilation.ExitCode.OK) { throw new RuntimeException(kspResult.getMessages()); @@ -185,13 +206,22 @@ public static BeanIntrospection buildBeanIntrospection(String name, @Language } public static BeanDefinition buildBeanDefinition(String name, @Language("kotlin") String clazz) throws InstantiationException, NoSuchMethodException, InvocationTargetException, IllegalAccessException { + return buildBeanDefinition(name, clazz, Collections.emptyList()); + } + + public static BeanDefinition buildBeanDefinition(String name, @Language("kotlin") String clazz, List extraSymbolProcessorProviders) throws InstantiationException, NoSuchMethodException, InvocationTargetException, IllegalAccessException { return buildBeanDefinition(NameUtils.getPackageName(name), NameUtils.getSimpleName(name), - clazz); + clazz, + extraSymbolProcessorProviders); } public static BeanDefinition buildBeanDefinition(String packageName, String simpleName, @Language("kotlin") String clazz) throws InstantiationException, NoSuchMethodException, InvocationTargetException, IllegalAccessException { - final URLClassLoader classLoader = buildClassLoader(packageName + "." + simpleName, clazz); + return buildBeanDefinition(packageName, simpleName, clazz, Collections.emptyList()); + } + + public static BeanDefinition buildBeanDefinition(String packageName, String simpleName, @Language("kotlin") String clazz, List extraSymbolProcessorProviders) throws InstantiationException, NoSuchMethodException, InvocationTargetException, IllegalAccessException { + final URLClassLoader classLoader = buildClassLoader(packageName + "." + simpleName, clazz, extraSymbolProcessorProviders); String beanDefName = (simpleName.startsWith("$") ? "" : '$') + simpleName + BeanDefinitionWriter.CLASS_SUFFIX; String beanFullName = packageName + "." + beanDefName; return (BeanDefinition) loadDefinition(classLoader, beanFullName); diff --git a/inject-kotlin/build.gradle.kts b/inject-kotlin/build.gradle.kts index 41413b2855b..445da6c478b 100644 --- a/inject-kotlin/build.gradle.kts +++ b/inject-kotlin/build.gradle.kts @@ -33,6 +33,9 @@ dependencies { testImplementation(projects.micronautContext) testImplementation(projects.micronautJacksonDatabind) testImplementation(projects.micronautInjectKotlinTest) + testImplementation(projects.micronautHttpValidation) + testImplementation(projects.micronautRouter) + testImplementation(projects.micronautHttpServer) testImplementation(libs.managed.kotlin.stdlib) testImplementation(projects.micronautHttpClient) testImplementation(libs.managed.jackson.annotations) diff --git a/inject-kotlin/src/main/kotlin/io/micronaut/kotlin/processing/visitor/TypeElementSymbolProcessor.kt b/inject-kotlin/src/main/kotlin/io/micronaut/kotlin/processing/visitor/TypeElementSymbolProcessor.kt index eae2e61769e..dcc621240ee 100644 --- a/inject-kotlin/src/main/kotlin/io/micronaut/kotlin/processing/visitor/TypeElementSymbolProcessor.kt +++ b/inject-kotlin/src/main/kotlin/io/micronaut/kotlin/processing/visitor/TypeElementSymbolProcessor.kt @@ -201,7 +201,7 @@ internal open class TypeElementSymbolProcessor(private val environment: SymbolPr } } - private fun findTypeElementVisitors(): Collection> { + protected open fun findTypeElementVisitors(): Collection> { val typeElementVisitors: MutableMap> = HashMap(10) for (definition in SERVICE_LOADER) { if (definition.isPresent) { diff --git a/inject-kotlin/src/test/groovy/io/micronaut/kotlin/processing/visitor/SuspendedReactiveRouteValidationSpec.groovy b/inject-kotlin/src/test/groovy/io/micronaut/kotlin/processing/visitor/SuspendedReactiveRouteValidationSpec.groovy new file mode 100644 index 00000000000..b5d52e36696 --- /dev/null +++ b/inject-kotlin/src/test/groovy/io/micronaut/kotlin/processing/visitor/SuspendedReactiveRouteValidationSpec.groovy @@ -0,0 +1,63 @@ +package io.micronaut.kotlin.processing.visitor + +import com.google.devtools.ksp.processing.SymbolProcessor +import com.google.devtools.ksp.processing.SymbolProcessorEnvironment +import com.google.devtools.ksp.processing.SymbolProcessorProvider +import io.micronaut.annotation.processing.test.AbstractKotlinCompilerSpec +import io.micronaut.inject.visitor.TypeElementVisitor +import io.micronaut.validation.routes.RouteValidationVisitor + +class SuspendedReactiveRouteValidationSpec extends AbstractKotlinCompilerSpec { + + private SymbolProcessorProvider routeValidationVisitorProcessorProvider() { + return { SymbolProcessorEnvironment environment -> + new TypeElementSymbolProcessor(environment) { + @Override + protected Collection> findTypeElementVisitors() { + return [new RouteValidationVisitor()] + } + } + } as SymbolProcessorProvider + } + + void "test suspended controller with regular return type compiles"() { + when: + buildBeanDefinition('test.TestController', ''' +package test + +import io.micronaut.http.annotation.Controller +import io.micronaut.http.annotation.Get + +@Controller("/test") +class TestController { + @Get + suspend fun hello(): String = "ok" +} +''', [routeValidationVisitorProcessorProvider()]) + + then: + noExceptionThrown() + } + + void "test suspended controller with reactive return type fails validation"() { + when: + buildBeanDefinition('test.TestController', ''' +package test + +import io.micronaut.http.annotation.Controller +import io.micronaut.http.annotation.Get +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flowOf + +@Controller("/test") +class TestController { + @Get + suspend fun hello(): Flow = flowOf("ok") +} +''', [routeValidationVisitorProcessorProvider()]) + + then: + def e = thrown(RuntimeException) + e.message.contains('Unsupported suspended controller return type [kotlinx.coroutines.flow.Flow]. Suspend functions must not return reactive or async types.') + } +}