Skip to content

Commit 3d1da80

Browse files
Support Spring specific service/vobj/workflow annotations (#575)
1 parent 1247003 commit 3d1da80

File tree

5 files changed

+84
-43
lines changed

5 files changed

+84
-43
lines changed

client/src/main/java/dev/restate/client/Client.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
99
package dev.restate.client;
1010

11-
import static dev.restate.common.reflections.ReflectionUtils.mustHaveAnnotation;
12-
1311
import dev.restate.common.Output;
1412
import dev.restate.common.Request;
1513
import dev.restate.common.Target;
@@ -556,7 +554,7 @@ default Response<Output<Res>> getOutput() throws IngressException {
556554
*/
557555
@org.jetbrains.annotations.ApiStatus.Experimental
558556
default <SVC> SVC service(Class<SVC> clazz) {
559-
mustHaveAnnotation(clazz, Service.class);
557+
ReflectionUtils.mustHaveServiceAnnotation(clazz);
560558
var serviceName = ReflectionUtils.extractServiceName(clazz);
561559
return ProxySupport.createProxy(
562560
clazz,
@@ -607,7 +605,7 @@ default <SVC> SVC service(Class<SVC> clazz) {
607605
*/
608606
@org.jetbrains.annotations.ApiStatus.Experimental
609607
default <SVC> ClientServiceHandle<SVC> serviceHandle(Class<SVC> clazz) {
610-
mustHaveAnnotation(clazz, Service.class);
608+
ReflectionUtils.mustHaveServiceAnnotation(clazz);
611609
return new ClientServiceHandleImpl<>(this, clazz, null);
612610
}
613611

@@ -635,7 +633,7 @@ default <SVC> ClientServiceHandle<SVC> serviceHandle(Class<SVC> clazz) {
635633
*/
636634
@org.jetbrains.annotations.ApiStatus.Experimental
637635
default <SVC> SVC virtualObject(Class<SVC> clazz, String key) {
638-
mustHaveAnnotation(clazz, VirtualObject.class);
636+
ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz);
639637
var serviceName = ReflectionUtils.extractServiceName(clazz);
640638
return ProxySupport.createProxy(
641639
clazz,
@@ -687,7 +685,7 @@ default <SVC> SVC virtualObject(Class<SVC> clazz, String key) {
687685
*/
688686
@org.jetbrains.annotations.ApiStatus.Experimental
689687
default <SVC> ClientServiceHandle<SVC> virtualObjectHandle(Class<SVC> clazz, String key) {
690-
mustHaveAnnotation(clazz, VirtualObject.class);
688+
ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz);
691689
return new ClientServiceHandleImpl<>(this, clazz, key);
692690
}
693691

@@ -715,7 +713,7 @@ default <SVC> ClientServiceHandle<SVC> virtualObjectHandle(Class<SVC> clazz, Str
715713
*/
716714
@org.jetbrains.annotations.ApiStatus.Experimental
717715
default <SVC> SVC workflow(Class<SVC> clazz, String key) {
718-
mustHaveAnnotation(clazz, Workflow.class);
716+
ReflectionUtils.mustHaveWorkflowAnnotation(clazz);
719717
var serviceName = ReflectionUtils.extractServiceName(clazz);
720718
return ProxySupport.createProxy(
721719
clazz,
@@ -767,7 +765,7 @@ default <SVC> SVC workflow(Class<SVC> clazz, String key) {
767765
*/
768766
@org.jetbrains.annotations.ApiStatus.Experimental
769767
default <SVC> ClientServiceHandle<SVC> workflowHandle(Class<SVC> clazz, String key) {
770-
mustHaveAnnotation(clazz, Workflow.class);
768+
ReflectionUtils.mustHaveWorkflowAnnotation(clazz);
771769
return new ClientServiceHandleImpl<>(this, clazz, key);
772770
}
773771

common/src/main/java/dev/restate/common/reflections/ReflectionUtils.java

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020

2121
public class ReflectionUtils {
2222

23+
private static final @Nullable Class<? extends Annotation> RESTATE_SPRING_SERVICE_ANNOTATION =
24+
tryLoadClass("dev.restate.sdk.springboot.RestateService");
25+
private static final @Nullable Class<? extends Annotation>
26+
RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION =
27+
tryLoadClass("dev.restate.sdk.springboot.RestateVirtualObject");
28+
private static final @Nullable Class<? extends Annotation> RESTATE_SPRING_WORKFLOW_ANNOTATION =
29+
tryLoadClass("dev.restate.sdk.springboot.RestateWorkflow");
30+
2331
/** Record containing handler information extracted from annotations. */
2432
public record HandlerInfo(String name, boolean shared) {}
2533

@@ -163,16 +171,17 @@ private static String inferRestateNameFromHierarchy(Class<?> type) {
163171
}
164172

165173
// Check if the type has any of the Restate component annotations
166-
var restateServiceAnnotation = type.getAnnotation(Service.class);
167-
if (restateServiceAnnotation != null) {
168-
return extractNameFromAnnotations(type);
169-
}
170-
var restateVirtualObjectAnnotation = type.getAnnotation(VirtualObject.class);
171-
if (restateVirtualObjectAnnotation != null) {
172-
return extractNameFromAnnotations(type);
173-
}
174-
var restateWorkflowAnnotation = type.getAnnotation(Workflow.class);
175-
if (restateWorkflowAnnotation != null) {
174+
var isRestateAnnotated =
175+
type.getAnnotation(Service.class) != null
176+
|| type.getAnnotation(VirtualObject.class) != null
177+
|| type.getAnnotation(Workflow.class) != null
178+
|| (RESTATE_SPRING_SERVICE_ANNOTATION != null
179+
&& type.getAnnotation(RESTATE_SPRING_SERVICE_ANNOTATION) != null)
180+
|| (RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION != null
181+
&& type.getAnnotation(RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION) != null)
182+
|| (RESTATE_SPRING_WORKFLOW_ANNOTATION != null
183+
&& type.getAnnotation(RESTATE_SPRING_WORKFLOW_ANNOTATION) != null);
184+
if (isRestateAnnotated) {
176185
return extractNameFromAnnotations(type);
177186
}
178187

@@ -200,17 +209,49 @@ private static String extractNameFromAnnotations(Class<?> type) {
200209
return type.getSimpleName();
201210
}
202211

203-
public static <A extends Annotation> A mustHaveAnnotation(
204-
Class<?> clazz, Class<A> annotationClazz) {
205-
A annotation = findAnnotation(clazz, annotationClazz);
206-
if (annotation == null) {
212+
public static boolean hasServiceAnnotation(Class<?> clazz) {
213+
return findAnnotation(clazz, Service.class) != null
214+
|| (RESTATE_SPRING_SERVICE_ANNOTATION != null
215+
&& findAnnotation(clazz, RESTATE_SPRING_SERVICE_ANNOTATION) != null);
216+
}
217+
218+
public static void mustHaveServiceAnnotation(Class<?> clazz) {
219+
if (!hasServiceAnnotation(clazz)) {
220+
throw new IllegalArgumentException(
221+
"The given class "
222+
+ clazz.getName()
223+
+ " is not annotated with the Restate service annotation");
224+
}
225+
}
226+
227+
public static boolean hasVirtualObjectAnnotation(Class<?> clazz) {
228+
return findAnnotation(clazz, VirtualObject.class) != null
229+
|| (RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION != null
230+
&& findAnnotation(clazz, RESTATE_SPRING_VIRTUAL_OBJECT_ANNOTATION) != null);
231+
}
232+
233+
public static void mustHaveVirtualObjectAnnotation(Class<?> clazz) {
234+
if (!hasVirtualObjectAnnotation(clazz)) {
235+
throw new IllegalArgumentException(
236+
"The given class "
237+
+ clazz.getName()
238+
+ " is not annotated with the Restate virtualObject annotation");
239+
}
240+
}
241+
242+
public static boolean hasWorkflowAnnotation(Class<?> clazz) {
243+
return findAnnotation(clazz, Workflow.class) != null
244+
|| (RESTATE_SPRING_WORKFLOW_ANNOTATION != null
245+
&& findAnnotation(clazz, RESTATE_SPRING_WORKFLOW_ANNOTATION) != null);
246+
}
247+
248+
public static void mustHaveWorkflowAnnotation(Class<?> clazz) {
249+
if (!hasWorkflowAnnotation(clazz)) {
207250
throw new IllegalArgumentException(
208251
"The given class "
209252
+ clazz.getName()
210-
+ " is not annotated with @"
211-
+ annotationClazz.getSimpleName());
253+
+ " is not annotated with the Restate workflow annotation");
212254
}
213-
return annotation;
214255
}
215256

216257
public static HandlerInfo mustHaveHandlerAnnotation(@NonNull Method method) {
@@ -308,6 +349,15 @@ public static boolean isKotlinClass(Class<?> clazz) {
308349
.anyMatch(annotation -> annotation.annotationType().getName().equals("kotlin.Metadata"));
309350
}
310351

352+
@SuppressWarnings("unchecked")
353+
private static @Nullable <T> Class<T> tryLoadClass(String className) {
354+
try {
355+
return (Class<T>) Class.forName(className);
356+
} catch (ClassNotFoundException e) {
357+
return null;
358+
}
359+
}
360+
311361
// From Spring's ReflectionUtils
312362
// License Apache 2.0
313363

sdk-api/src/main/java/dev/restate/sdk/Restate.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
99
package dev.restate.sdk;
1010

11-
import static dev.restate.common.reflections.ReflectionUtils.mustHaveAnnotation;
12-
1311
import dev.restate.common.Request;
1412
import dev.restate.common.Slice;
1513
import dev.restate.common.Target;
@@ -430,7 +428,7 @@ public static AwakeableHandle awakeableHandle(String id) {
430428
*/
431429
@org.jetbrains.annotations.ApiStatus.Experimental
432430
public static <SVC> SVC service(Class<SVC> clazz) {
433-
mustHaveAnnotation(clazz, Service.class);
431+
ReflectionUtils.mustHaveServiceAnnotation(clazz);
434432
String serviceName = ReflectionUtils.extractServiceName(clazz);
435433
return ProxySupport.createProxy(
436434
clazz,
@@ -481,7 +479,7 @@ public static <SVC> SVC service(Class<SVC> clazz) {
481479
*/
482480
@org.jetbrains.annotations.ApiStatus.Experimental
483481
public static <SVC> ServiceHandle<SVC> serviceHandle(Class<SVC> clazz) {
484-
mustHaveAnnotation(clazz, Service.class);
482+
ReflectionUtils.mustHaveServiceAnnotation(clazz);
485483
return new ServiceHandleImpl<>(clazz, null);
486484
}
487485

@@ -506,7 +504,7 @@ public static <SVC> ServiceHandle<SVC> serviceHandle(Class<SVC> clazz) {
506504
*/
507505
@org.jetbrains.annotations.ApiStatus.Experimental
508506
public static <SVC> SVC virtualObject(Class<SVC> clazz, String key) {
509-
mustHaveAnnotation(clazz, VirtualObject.class);
507+
ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz);
510508
String serviceName = ReflectionUtils.extractServiceName(clazz);
511509
return ProxySupport.createProxy(
512510
clazz,
@@ -558,7 +556,7 @@ public static <SVC> SVC virtualObject(Class<SVC> clazz, String key) {
558556
*/
559557
@org.jetbrains.annotations.ApiStatus.Experimental
560558
public static <SVC> ServiceHandle<SVC> virtualObjectHandle(Class<SVC> clazz, String key) {
561-
mustHaveAnnotation(clazz, VirtualObject.class);
559+
ReflectionUtils.mustHaveVirtualObjectAnnotation(clazz);
562560
return new ServiceHandleImpl<>(clazz, key);
563561
}
564562

@@ -583,7 +581,7 @@ public static <SVC> ServiceHandle<SVC> virtualObjectHandle(Class<SVC> clazz, Str
583581
*/
584582
@org.jetbrains.annotations.ApiStatus.Experimental
585583
public static <SVC> SVC workflow(Class<SVC> clazz, String key) {
586-
mustHaveAnnotation(clazz, Workflow.class);
584+
ReflectionUtils.mustHaveWorkflowAnnotation(clazz);
587585
String serviceName = ReflectionUtils.extractServiceName(clazz);
588586
return ProxySupport.createProxy(
589587
clazz,
@@ -635,7 +633,7 @@ public static <SVC> SVC workflow(Class<SVC> clazz, String key) {
635633
*/
636634
@org.jetbrains.annotations.ApiStatus.Experimental
637635
public static <SVC> ServiceHandle<SVC> workflowHandle(Class<SVC> clazz, String key) {
638-
mustHaveAnnotation(clazz, Workflow.class);
636+
ReflectionUtils.mustHaveWorkflowAnnotation(clazz);
639637
return new ServiceHandleImpl<>(clazz, key);
640638
}
641639

sdk-api/src/main/java/dev/restate/sdk/internal/ReflectionServiceDefinitionFactory.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,9 @@ public ServiceDefinition create(
4747

4848
Class<?> serviceClazz = serviceInstance.getClass();
4949

50-
boolean hasServiceAnnotation =
51-
ReflectionUtils.findAnnotation(serviceClazz, Service.class) != null;
52-
boolean hasVirtualObjectAnnotation =
53-
ReflectionUtils.findAnnotation(serviceClazz, VirtualObject.class) != null;
54-
boolean hasWorkflowAnnotation =
55-
ReflectionUtils.findAnnotation(serviceClazz, Workflow.class) != null;
50+
boolean hasServiceAnnotation = ReflectionUtils.hasServiceAnnotation(serviceClazz);
51+
boolean hasVirtualObjectAnnotation = ReflectionUtils.hasVirtualObjectAnnotation(serviceClazz);
52+
boolean hasWorkflowAnnotation = ReflectionUtils.hasWorkflowAnnotation(serviceClazz);
5653

5754
boolean hasAnyAnnotation =
5855
hasServiceAnnotation || hasVirtualObjectAnnotation || hasWorkflowAnnotation;

sdk-spring-boot-starter/src/test/java/dev/restate/sdk/springboot/java/GreeterNewApi.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@
1010

1111
import dev.restate.sdk.annotation.Handler;
1212
import dev.restate.sdk.annotation.Name;
13-
import dev.restate.sdk.annotation.Service;
14-
import dev.restate.sdk.springboot.RestateComponent;
13+
import dev.restate.sdk.springboot.RestateService;
1514
import org.springframework.beans.factory.annotation.Value;
1615

17-
@Service
18-
@RestateComponent
16+
@RestateService
1917
@Name("greeterNewApi")
2018
public class GreeterNewApi {
2119

0 commit comments

Comments
 (0)