diff --git a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java index 53accf672..049c3df97 100644 --- a/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java +++ b/bytebuddy-proxy-support/src/main/java/dev/restate/bytebuddy/proxysupport/ByteBuddyProxyFactory.java @@ -8,6 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.bytebuddy.proxysupport; +import static dev.restate.common.reflections.ReflectionUtils.findRestateAnnotatedClass; import static net.bytebuddy.matcher.ElementMatchers.*; import dev.restate.common.reflections.ProxyFactory; @@ -95,14 +96,8 @@ private Class generateProxyClass(Class clazz) throws NoSuchFieldExcept if (!clazz.isInterface()) { // We perform here some additional validation of the handlers that won't be executed by // bytebuddy and can easily lead to strange behavior - var methods = - ReflectionUtils.getUniqueDeclaredMethods( - clazz, - method -> - ReflectionUtils.findAnnotation(method, Handler.class) != null - || ReflectionUtils.findAnnotation(method, Shared.class) != null - || ReflectionUtils.findAnnotation(method, Workflow.class) != null - || ReflectionUtils.findAnnotation(method, Exclusive.class) != null); + var restateAnnotatedClazz = findRestateAnnotatedClass(clazz); + var methods = ReflectionUtils.findRestateHandlers(restateAnnotatedClazz); for (var method : methods) { validateMethod(method); } diff --git a/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java b/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java index ca55efb1d..15fd0e33f 100644 --- a/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java +++ b/common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java @@ -31,141 +31,29 @@ public class ReflectionUtils { /** Record containing handler information extracted from annotations. */ public record HandlerInfo(String name, boolean shared) {} - /** - * Find a single {@link Annotation} of {@code annotationType} on the supplied {@link Class}, - * traversing its interfaces, annotations, and superclasses if the annotation is not directly - * present on the given class itself. - * - *

This method explicitly handles class-level annotations which are not declared as {@linkplain - * java.lang.annotation.Inherited inherited} as well as meta-annotations and annotations on - * interfaces. - * - *

The algorithm operates as follows: - * - *

    - *
  1. Search for the annotation on the given class and return it if found. - *
  2. Recursively search through all interfaces that the given class declares. - *
  3. Recursively search through the superclass hierarchy of the given class. - *
- * - *

Note: in this context, the term recursively means that the search process continues - * by returning to step #1 with the current interface, annotation, or superclass as the class to - * look for annotations on. - * - * @param clazz the class to look for annotations on - * @param annotationType the type of annotation to look for - * @return the first matching annotation, or {@code null} if not found - */ - @Nullable - public static A findAnnotation( - Class clazz, @Nullable Class annotationType) { - if (annotationType == null) { - return null; - } - return findAnnotation(clazz, annotationType, new java.util.HashSet<>()); + public static Method @NonNull [] findRestateHandlers(Class restateAnnotatedClazz) { + return getUniqueDeclaredMethods( + restateAnnotatedClazz, + method -> + method.getDeclaredAnnotation(Handler.class) != null + || method.getDeclaredAnnotation(Shared.class) != null + || method.getDeclaredAnnotation(Workflow.class) != null + || method.getDeclaredAnnotation(Exclusive.class) != null); } - @Nullable - private static A findAnnotation( - Class clazz, Class annotationType, java.util.Set visited) { - - if (clazz == null || clazz == Object.class) { - return null; - } - - // Check if the annotation is directly present on the class - A annotation = clazz.getDeclaredAnnotation(annotationType); - if (annotation != null) { - return annotation; + /** Find the class where the Restate annotations are declared. */ + public static Class findRestateAnnotatedClass(Class clazz) { + Class restateServiceDefinitionClazz = findRestateAnnotatedClassFromHierarchy(clazz); + if (restateServiceDefinitionClazz != null) { + return restateServiceDefinitionClazz; } - // Search on interfaces - for (Class ifc : clazz.getInterfaces()) { - annotation = findAnnotation(ifc, annotationType, visited); - if (annotation != null) { - return annotation; - } - } - - // Search on superclass - return findAnnotation(clazz.getSuperclass(), annotationType, visited); + throw new IllegalArgumentException( + "Cannot find a Restate annotated class in the type hierarchy starting from " + + clazz.getName()); } - /** - * Find a single {@link Annotation} of {@code annotationType} on the supplied {@link Method}, - * traversing its super methods if the annotation is not directly present on the given - * method itself. - * - *

Annotations on methods are not inherited by default, so we need to handle this explicitly. - * - * @param method the method to look for annotations on - * @param annotationType the type of annotation to look for - * @return the first matching annotation, or {@code null} if not found - */ - @Nullable - public static A findAnnotation( - Method method, @Nullable Class annotationType) { - if (annotationType == null) { - return null; - } - - // Check if the annotation is directly present on the method - A annotation = method.getDeclaredAnnotation(annotationType); - if (annotation != null) { - return annotation; - } - - // Search through the type hierarchy - Class clazz = method.getDeclaringClass(); - return findAnnotationInTypeHierarchy(clazz, method, annotationType, new java.util.HashSet<>()); - } - - @Nullable - private static A findAnnotationInTypeHierarchy( - Class clazz, Method method, Class annotationType, java.util.Set> visited) { - - if (clazz == null || clazz == Object.class || !visited.add(clazz)) { - return null; - } - - // Try to find an equivalent method in this class/interface - Method equivalentMethod = null; - try { - equivalentMethod = clazz.getDeclaredMethod(method.getName(), method.getParameterTypes()); - } catch (NoSuchMethodException ex) { - // No such method in this class, continue searching - } - - if (equivalentMethod != null) { - A annotation = equivalentMethod.getDeclaredAnnotation(annotationType); - if (annotation != null) { - return annotation; - } - } - - // Search in interfaces - for (Class ifc : clazz.getInterfaces()) { - A annotation = findAnnotationInTypeHierarchy(ifc, method, annotationType, visited); - if (annotation != null) { - return annotation; - } - } - - // Search in superclass - return findAnnotationInTypeHierarchy(clazz.getSuperclass(), method, annotationType, visited); - } - - public static String extractServiceName(Class clazz) { - // Fallback: infer from hierarchy against known Restate markers - String inferred = inferRestateNameFromHierarchy(clazz); - if (inferred != null) { - return inferred; - } - - throw new IllegalArgumentException("Cannot infer Restate name from type: " + clazz.getName()); - } - - private static String inferRestateNameFromHierarchy(Class type) { + private static Class findRestateAnnotatedClassFromHierarchy(Class type) { if (type == null || Object.class.equals(type)) { return null; } @@ -182,37 +70,37 @@ private static String inferRestateNameFromHierarchy(Class type) { || (RESTATE_SPRING_WORKFLOW_ANNOTATION != null && type.getAnnotation(RESTATE_SPRING_WORKFLOW_ANNOTATION) != null); if (isRestateAnnotated) { - return extractNameFromAnnotations(type); + return type; } // Check parent interfaces for (Class parent : type.getInterfaces()) { - String res = inferRestateNameFromHierarchy(parent); + Class res = findRestateAnnotatedClassFromHierarchy(parent); if (res != null) { return res; } } // Recurse into superclass - return inferRestateNameFromHierarchy(type.getSuperclass()); + return findRestateAnnotatedClassFromHierarchy(type.getSuperclass()); } - private static String extractNameFromAnnotations(Class type) { + public static String extractServiceName(Class clazz) { // Check for @Name annotation first - var nameAnnotation = type.getAnnotation(Name.class); + var nameAnnotation = clazz.getAnnotation(Name.class); if (nameAnnotation != null && nameAnnotation.value() != null && !nameAnnotation.value().isEmpty()) { return nameAnnotation.value(); } // Default to simple class name - return type.getSimpleName(); + return clazz.getSimpleName(); } public static boolean hasServiceAnnotation(Class clazz) { - return findAnnotation(clazz, Service.class) != null + return clazz.getDeclaredAnnotation(Service.class) != null || (RESTATE_SPRING_SERVICE_ANNOTATION != null - && findAnnotation(clazz, RESTATE_SPRING_SERVICE_ANNOTATION) != null); + && clazz.getDeclaredAnnotation(RESTATE_SPRING_SERVICE_ANNOTATION) != null); } public static void mustHaveServiceAnnotation(Class clazz) { @@ -220,14 +108,14 @@ public static void mustHaveServiceAnnotation(Class clazz) { throw new IllegalArgumentException( "The given class " + clazz.getName() - + " is not annotated with the Restate service annotation"); + + " is not annotated with the Restate service annotation. In case the service annotation is declared on a parent interface, use the interface to execute requests instead of the implementation class."); } } public static boolean hasVirtualObjectAnnotation(Class clazz) { - return findAnnotation(clazz, VirtualObject.class) != null + return clazz.getDeclaredAnnotation(VirtualObject.class) != null || (RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION != null - && findAnnotation(clazz, RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION) != null); + && clazz.getDeclaredAnnotation(RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION) != null); } public static void mustHaveVirtualObjectAnnotation(Class clazz) { @@ -235,14 +123,14 @@ public static void mustHaveVirtualObjectAnnotation(Class clazz) { throw new IllegalArgumentException( "The given class " + clazz.getName() - + " is not annotated with the Restate virtualObject annotation"); + + " is not annotated with the Restate virtualObject annotation. In case the virtual object annotation is declared on a parent interface, use the interface to execute requests instead of the implementation class."); } } public static boolean hasWorkflowAnnotation(Class clazz) { - return findAnnotation(clazz, Workflow.class) != null + return clazz.getDeclaredAnnotation(Workflow.class) != null || (RESTATE_SPRING_WORKFLOW_ANNOTATION != null - && findAnnotation(clazz, RESTATE_SPRING_WORKFLOW_ANNOTATION) != null); + && clazz.getDeclaredAnnotation(RESTATE_SPRING_WORKFLOW_ANNOTATION) != null); } public static void mustHaveWorkflowAnnotation(Class clazz) { @@ -250,16 +138,16 @@ public static void mustHaveWorkflowAnnotation(Class clazz) { throw new IllegalArgumentException( "The given class " + clazz.getName() - + " is not annotated with the Restate workflow annotation"); + + " is not annotated with the Restate workflow annotation. In case the workflow annotation is declared on a parent interface, use the interface to execute requests instead of the implementation class."); } } public static HandlerInfo mustHaveHandlerAnnotation(@NonNull Method method) { // Check for @Handler or @Shared annotation (Shared implies Handler) - var handlerAnnotation = findAnnotation(method, Handler.class); - var sharedAnnotation = findAnnotation(method, Shared.class); - var exclusiveAnnotation = findAnnotation(method, Exclusive.class); - var workflowAnnotation = findAnnotation(method, Workflow.class); + var handlerAnnotation = method.getDeclaredAnnotation(Handler.class); + var sharedAnnotation = method.getDeclaredAnnotation(Shared.class); + var exclusiveAnnotation = method.getDeclaredAnnotation(Exclusive.class); + var workflowAnnotation = method.getDeclaredAnnotation(Workflow.class); if (handlerAnnotation == null && sharedAnnotation == null @@ -279,7 +167,7 @@ public static HandlerInfo mustHaveHandlerAnnotation(@NonNull Method method) { } // Extract the name from @Name annotation, or default to method name - var nameAnnotation = findAnnotation(method, Name.class); + var nameAnnotation = method.getDeclaredAnnotation(Name.class); String handlerName; if (nameAnnotation != null && nameAnnotation.value() != null @@ -295,55 +183,6 @@ public static HandlerInfo mustHaveHandlerAnnotation(@NonNull Method method) { return new HandlerInfo(handlerName, isShared); } - /** - * Walks the type hierarchy to find where the given rawType interface was parameterized. This - * handles inheritance chains and multiple interfaces correctly. - * - * @param concreteClass The concrete class to start searching from - * @param rawType The raw interface type to find (e.g., Function.class) - * @return The ParameterizedType with resolved type arguments, or null if not found - */ - public static ParameterizedType findParameterizedType(Class concreteClass, Class rawType) { - if (concreteClass == null || Object.class.equals(concreteClass)) { - return null; - } - - // Check direct interfaces - for (Type genericInterface : concreteClass.getGenericInterfaces()) { - ParameterizedType result = findParameterizedTypeInType(genericInterface, rawType); - if (result != null) { - return result; - } - } - - // Check superclass - Type genericSuperclass = concreteClass.getGenericSuperclass(); - if (genericSuperclass != null) { - ParameterizedType result = findParameterizedTypeInType(genericSuperclass, rawType); - if (result != null) { - return result; - } - } - - // Recurse up the hierarchy - return findParameterizedType(concreteClass.getSuperclass(), rawType); - } - - private static ParameterizedType findParameterizedTypeInType(Type type, Class rawType) { - if (type instanceof ParameterizedType paramType) { - if (paramType.getRawType().equals(rawType)) { - return paramType; - } - // Check if this parameterized type extends/implements the target - if (paramType.getRawType() instanceof Class clazz) { - return findParameterizedType(clazz, rawType); - } - } else if (type instanceof Class clazz) { - return findParameterizedType(clazz, rawType); - } - return null; - } - public static boolean isKotlinClass(Class clazz) { return Arrays.stream(clazz.getDeclaredAnnotations()) .anyMatch(annotation -> annotation.annotationType().getName().equals("kotlin.Metadata")); @@ -389,8 +228,6 @@ public static boolean isKotlinClass(Class clazz) { private static final Method[] EMPTY_METHOD_ARRAY = new Method[0]; - private static final Field[] EMPTY_FIELD_ARRAY = new Field[0]; - private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0]; /** @@ -400,10 +237,6 @@ public static boolean isKotlinClass(Class clazz) { private static final Map, Method[]> declaredMethodsCache = new ConcurrentReferenceHashMap<>(256); - /** Cache for {@link Class#getDeclaredFields()}, allowing for fast iteration. */ - private static final Map, Field[]> declaredFieldsCache = - new ConcurrentReferenceHashMap<>(256); - // Exception handling /** @@ -903,273 +736,6 @@ public static boolean isCglibRenamedMethod(Method renamedMethod) { return false; } - /** - * Make the given method accessible, explicitly setting it accessible if necessary. The {@code - * setAccessible(true)} method is only called when actually necessary, to avoid unnecessary - * conflicts. - * - * @param method the method to make accessible - * @see Method#setAccessible - */ - @SuppressWarnings("deprecation") - public static void makeAccessible(Method method) { - if ((!Modifier.isPublic(method.getModifiers()) - || !Modifier.isPublic(method.getDeclaringClass().getModifiers())) - && !method.isAccessible()) { - method.setAccessible(true); - } - } - - // Field handling - - /** - * Attempt to find a {@link Field field} on the supplied {@link Class} with the supplied {@code - * name}. Searches all superclasses up to {@link Object}. - * - * @param clazz the class to introspect - * @param name the name of the field - * @return the corresponding Field object, or {@code null} if not found - */ - public static @Nullable Field findField(Class clazz, String name) { - return findField(clazz, name, null); - } - - /** - * Attempt to find a {@link Field field} on the supplied {@link Class} with the supplied {@code - * name} and/or {@link Class type}. Searches all superclasses up to {@link Object}. - * - * @param clazz the class to introspect - * @param name the name of the field (may be {@code null} if type is specified) - * @param type the type of the field (may be {@code null} if name is specified) - * @return the corresponding Field object, or {@code null} if not found - */ - public static @Nullable Field findField( - Class clazz, @Nullable String name, @Nullable Class type) { - Class searchType = clazz; - while (Object.class != searchType && searchType != null) { - Field[] fields = getDeclaredFields(searchType); - for (Field field : fields) { - if ((name == null || name.equals(field.getName())) - && (type == null || type.equals(field.getType()))) { - return field; - } - } - searchType = searchType.getSuperclass(); - } - return null; - } - - /** - * Attempt to find a {@link Field field} on the supplied {@link Class} with the supplied {@code - * name}. Searches all superclasses up to {@link Object}. - * - * @param clazz the class to introspect - * @param name the name of the field (with upper/lower case to be ignored) - * @return the corresponding Field object, or {@code null} if not found - * @since 6.1 - */ - public static @Nullable Field findFieldIgnoreCase(Class clazz, String name) { - Class searchType = clazz; - while (Object.class != searchType && searchType != null) { - Field[] fields = getDeclaredFields(searchType); - for (Field field : fields) { - if (name.equalsIgnoreCase(field.getName())) { - return field; - } - } - searchType = searchType.getSuperclass(); - } - return null; - } - - /** - * Set the field represented by the supplied {@linkplain Field field object} on the specified - * {@linkplain Object target object} to the specified {@code value}. - * - *

In accordance with {@link Field#set(Object, Object)} semantics, the new value is - * automatically unwrapped if the underlying field has a primitive type. - * - *

This method does not support setting {@code static final} fields. - * - *

Thrown exceptions are handled via a call to {@link #handleReflectionException(Exception)}. - * - * @param field the field to set - * @param target the target object on which to set the field (or {@code null} for a static field) - * @param value the value to set (may be {@code null}) - */ - public static void setField(Field field, @Nullable Object target, @Nullable Object value) { - try { - field.set(target, value); - } catch (IllegalAccessException ex) { - handleReflectionException(ex); - } - } - - /** - * Get the field represented by the supplied {@link Field field object} on the specified {@link - * Object target object}. In accordance with {@link Field#get(Object)} semantics, the returned - * value is automatically wrapped if the underlying field has a primitive type. - * - *

Thrown exceptions are handled via a call to {@link #handleReflectionException(Exception)}. - * - * @param field the field to get - * @param target the target object from which to get the field (or {@code null} for a static - * field) - * @return the field's current value - */ - public static @Nullable Object getField(Field field, @Nullable Object target) { - try { - return field.get(target); - } catch (IllegalAccessException ex) { - handleReflectionException(ex); - } - throw new IllegalStateException("Should never get here"); - } - - /** - * Invoke the given callback on all locally declared fields in the given class. - * - * @param clazz the target class to analyze - * @param fc the callback to invoke for each field - * @throws IllegalStateException if introspection fails - * @see #doWithFields - * @since 4.2 - */ - public static void doWithLocalFields(Class clazz, FieldCallback fc) { - for (Field field : getDeclaredFields(clazz)) { - try { - fc.doWith(field); - } catch (IllegalAccessException ex) { - throw new IllegalStateException( - "Not allowed to access field '" + field.getName() + "': " + ex); - } - } - } - - /** - * Invoke the given callback on all fields in the target class, going up the class hierarchy to - * get all declared fields. - * - * @param clazz the target class to analyze - * @param fc the callback to invoke for each field - * @throws IllegalStateException if introspection fails - */ - public static void doWithFields(Class clazz, FieldCallback fc) { - doWithFields(clazz, fc, null); - } - - /** - * Invoke the given callback on all fields in the target class, going up the class hierarchy to - * get all declared fields. - * - * @param clazz the target class to analyze - * @param fc the callback to invoke for each field - * @param ff the filter that determines the fields to apply the callback to - * @throws IllegalStateException if introspection fails - */ - public static void doWithFields(Class clazz, FieldCallback fc, @Nullable FieldFilter ff) { - // Keep backing up the inheritance hierarchy. - Class targetClass = clazz; - do { - for (Field field : getDeclaredFields(targetClass)) { - if (ff != null && !ff.matches(field)) { - continue; - } - try { - fc.doWith(field); - } catch (IllegalAccessException ex) { - throw new IllegalStateException( - "Not allowed to access field '" + field.getName() + "': " + ex); - } - } - targetClass = targetClass.getSuperclass(); - } while (targetClass != null && targetClass != Object.class); - } - - /** - * This variant retrieves {@link Class#getDeclaredFields()} from a local cache in order to avoid - * defensive array copying. - * - * @param clazz the class to introspect - * @return the cached array of fields - * @throws IllegalStateException if introspection fails - * @see Class#getDeclaredFields() - */ - private static Field[] getDeclaredFields(Class clazz) { - Field[] result = declaredFieldsCache.get(clazz); - if (result == null) { - try { - result = clazz.getDeclaredFields(); - declaredFieldsCache.put(clazz, (result.length == 0 ? EMPTY_FIELD_ARRAY : result)); - } catch (Throwable ex) { - throw new IllegalStateException( - "Failed to introspect Class [" - + clazz.getName() - + "] from ClassLoader [" - + clazz.getClassLoader() - + "]", - ex); - } - } - return result; - } - - /** - * Given the source object and the destination, which must be the same class or a subclass, copy - * all fields, including inherited fields. Designed to work on objects with public no-arg - * constructors. - * - * @throws IllegalStateException if introspection fails - */ - public static void shallowCopyFieldState(final Object src, final Object dest) { - if (!src.getClass().isAssignableFrom(dest.getClass())) { - throw new IllegalArgumentException( - "Destination class [" - + dest.getClass().getName() - + "] must be same or subclass as source class [" - + src.getClass().getName() - + "]"); - } - doWithFields( - src.getClass(), - field -> { - makeAccessible(field); - Object srcValue = field.get(src); - field.set(dest, srcValue); - }, - COPYABLE_FIELDS); - } - - /** - * Determine whether the given field is a "public static final" constant. - * - * @param field the field to check - */ - public static boolean isPublicStaticFinal(Field field) { - int modifiers = field.getModifiers(); - return (Modifier.isPublic(modifiers) - && Modifier.isStatic(modifiers) - && Modifier.isFinal(modifiers)); - } - - /** - * Make the given field accessible, explicitly setting it accessible if necessary. The {@code - * setAccessible(true)} method is only called when actually necessary, to avoid unnecessary - * conflicts. - * - * @param field the field to make accessible - * @see Field#setAccessible - */ - @SuppressWarnings("deprecation") - public static void makeAccessible(Field field) { - if ((!Modifier.isPublic(field.getModifiers()) - || !Modifier.isPublic(field.getDeclaringClass().getModifiers()) - || Modifier.isFinal(field.getModifiers())) - && !field.isAccessible()) { - field.setAccessible(true); - } - } - // Cache handling /** @@ -1179,7 +745,6 @@ public static void makeAccessible(Field field) { */ public static void clearCache() { declaredMethodsCache.clear(); - declaredFieldsCache.clear(); } /** Action to take on each method. */ diff --git a/common/src/main/java/dev/restate/serde/Serde.java b/common/src/main/java/dev/restate/serde/Serde.java index 7f4e6af95..7e62e8fd1 100644 --- a/common/src/main/java/dev/restate/serde/Serde.java +++ b/common/src/main/java/dev/restate/serde/Serde.java @@ -153,6 +153,11 @@ public Slice serialize(byte[] value) { public byte[] deserialize(@NonNull Slice value) { return value.toByteArray(); } + + @Override + public String contentType() { + return "application/octet-stream"; + } }; /** Passthrough serializer/deserializer */ diff --git a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java index 67351e55f..e6077b6f6 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java +++ b/sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java @@ -8,6 +8,8 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.internal; +import static dev.restate.common.reflections.ReflectionUtils.findRestateAnnotatedClass; + import dev.restate.common.reflections.ReflectionUtils; import dev.restate.common.reflections.RestateUtils; import dev.restate.sdk.*; @@ -47,9 +49,16 @@ public ServiceDefinition create( Class serviceClazz = serviceInstance.getClass(); - boolean hasServiceAnnotation = ReflectionUtils.hasServiceAnnotation(serviceClazz); - boolean hasVirtualObjectAnnotation = ReflectionUtils.hasVirtualObjectAnnotation(serviceClazz); - boolean hasWorkflowAnnotation = ReflectionUtils.hasWorkflowAnnotation(serviceClazz); + // The behavior of the reflections work as follows: + // * There is one class that has all the restate annotations. That being either the serviceClazz + // itself (concrete class) or some interface in the hierarchy. + // * Then there is the serviceInstance, which is where we call the methods themselves. + Class restateAnnotatedClazz = findRestateAnnotatedClass(serviceClazz); + + boolean hasServiceAnnotation = ReflectionUtils.hasServiceAnnotation(restateAnnotatedClazz); + boolean hasVirtualObjectAnnotation = + ReflectionUtils.hasVirtualObjectAnnotation(restateAnnotatedClazz); + boolean hasWorkflowAnnotation = ReflectionUtils.hasWorkflowAnnotation(restateAnnotatedClazz); boolean hasAnyAnnotation = hasServiceAnnotation || hasVirtualObjectAnnotation || hasWorkflowAnnotation; @@ -71,21 +80,14 @@ public ServiceDefinition create( + "exactly one annotation between @Service/@VirtualObject/@Workflow, more than one annotation found"); } - var serviceName = ReflectionUtils.extractServiceName(serviceClazz); + var serviceName = ReflectionUtils.extractServiceName(restateAnnotatedClazz); var serviceType = hasServiceAnnotation ? ServiceType.SERVICE : hasVirtualObjectAnnotation ? ServiceType.VIRTUAL_OBJECT : ServiceType.WORKFLOW; - var serdeFactory = resolveSerdeFactory(serviceClazz); - - var methods = - ReflectionUtils.getUniqueDeclaredMethods( - serviceClazz, - method -> - ReflectionUtils.findAnnotation(method, Handler.class) != null - || ReflectionUtils.findAnnotation(method, Shared.class) != null - || ReflectionUtils.findAnnotation(method, Workflow.class) != null - || ReflectionUtils.findAnnotation(method, Exclusive.class) != null); + var serdeFactory = resolveSerdeFactory(restateAnnotatedClazz); + + var methods = ReflectionUtils.findRestateHandlers(restateAnnotatedClazz); if (methods.length == 0) { throw new MalformedRestateServiceException(serviceName, "No @Handler method found"); } @@ -332,17 +334,17 @@ private Serde resolveOutputSerde( return serde; } - private SerdeFactory resolveSerdeFactory(Class serviceClazz) { + private SerdeFactory resolveSerdeFactory(Class restateAnnotatedClazz) { // Check for CustomSerdeFactory annotation CustomSerdeFactory customSerdeFactoryAnnotation = - ReflectionUtils.findAnnotation(serviceClazz, CustomSerdeFactory.class); + restateAnnotatedClazz.getDeclaredAnnotation(CustomSerdeFactory.class); if (customSerdeFactoryAnnotation != null) { try { return customSerdeFactoryAnnotation.value().getDeclaredConstructor().newInstance(); } catch (Exception e) { throw new MalformedRestateServiceException( - serviceClazz.getSimpleName(), + restateAnnotatedClazz.getSimpleName(), "Failed to instantiate custom SerdeFactory: " + customSerdeFactoryAnnotation.value().getName(), e); @@ -369,7 +371,7 @@ private SerdeFactory resolveSerdeFactory(Class serviceClazz) { return this.cachedDefaultSerdeFactory; } catch (Exception e) { throw new MalformedRestateServiceException( - serviceClazz.getSimpleName(), + restateAnnotatedClazz.getSimpleName(), "Failed to load JacksonSerdeFactory for Java service. " + "Make sure sdk-serde-jackson is on the classpath.", e); diff --git a/sdk-core/build.gradle.kts b/sdk-core/build.gradle.kts index 7a42dd7fa..3fd1c34ba 100644 --- a/sdk-core/build.gradle.kts +++ b/sdk-core/build.gradle.kts @@ -116,6 +116,7 @@ tasks { "dev.restate.sdk.core.javaapi.reflections.ObjectGreeterImplementedFromInterface", "dev.restate.sdk.core.javaapi.reflections.PrimitiveTypes", "dev.restate.sdk.core.javaapi.reflections.RawInputOutput", + "dev.restate.sdk.core.javaapi.reflections.RawService", "dev.restate.sdk.core.javaapi.reflections.ServiceGreeter", ) diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RawService.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RawService.java new file mode 100644 index 000000000..529873bfb --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RawService.java @@ -0,0 +1,22 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi.reflections; + +import dev.restate.sdk.annotation.Handler; +import dev.restate.sdk.annotation.Name; +import dev.restate.sdk.annotation.Raw; +import dev.restate.sdk.annotation.Service; + +@Service +@Name("RawService") +public interface RawService { + @Handler + @Raw + byte[] echo(@Raw byte[] input); +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RawServiceImpl.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RawServiceImpl.java new file mode 100644 index 000000000..ee9e68528 --- /dev/null +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/RawServiceImpl.java @@ -0,0 +1,16 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.core.javaapi.reflections; + +public class RawServiceImpl implements RawService { + @Override + public byte[] echo(byte[] input) { + return input; + } +} diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionDiscoveryTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionDiscoveryTest.java index d80c566e6..9917863a7 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionDiscoveryTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/javaapi/reflections/ReflectionDiscoveryTest.java @@ -19,6 +19,7 @@ import dev.restate.sdk.core.javaapi.GreeterWithExplicitName; import dev.restate.sdk.core.javaapi.GreeterWithExplicitNameHandlers; import dev.restate.sdk.endpoint.Endpoint; +import dev.restate.serde.Serde; import org.junit.jupiter.api.Test; public class ReflectionDiscoveryTest { @@ -53,6 +54,44 @@ void checkCustomOutputContentType() { .isEqualTo("application/vnd.my.custom"); } + @Test + void checkRawInputContentType() { + assertThatDiscovery(new RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInput") + .extracting(Handler::getInput, type(Input.class)) + .extracting(Input::getContentType) + .isEqualTo(Serde.RAW.contentType()); + } + + @Test + void checkRawOutputContentType() { + assertThatDiscovery(new RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutput") + .extracting(Handler::getOutput, type(Output.class)) + .extracting(Output::getContentType) + .isEqualTo(Serde.RAW.contentType()); + } + + @Test + void checkRawInfoFromInterface() { + var handlerAssert = + assertThatDiscovery(new RawServiceImpl()) + .extractingService("RawService") + .extractingHandler("echo"); + + handlerAssert + .extracting(Handler::getInput, type(Input.class)) + .extracting(Input::getContentType) + .isEqualTo(Serde.RAW.contentType()); + + handlerAssert + .extracting(Handler::getOutput, type(Output.class)) + .extracting(Output::getContentType) + .isEqualTo(Serde.RAW.contentType()); + } + @Test void explicitNames() { assertThatDiscovery((GreeterWithExplicitName) (context, request) -> "")