-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathHandlerRunner.java
More file actions
227 lines (204 loc) · 8.09 KB
/
HandlerRunner.java
File metadata and controls
227 lines (204 loc) · 8.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk;
import dev.restate.common.Slice;
import dev.restate.common.function.*;
import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.endpoint.definition.HandlerContext;
import dev.restate.sdk.internal.ContextThreadLocal;
import dev.restate.serde.Serde;
import dev.restate.serde.SerdeFactory;
import io.opentelemetry.context.Scope;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jspecify.annotations.Nullable;
/**
* Adapter class for {@link dev.restate.sdk.endpoint.definition.HandlerRunner} to use the Java API.
*/
public class HandlerRunner<REQ, RES>
implements dev.restate.sdk.endpoint.definition.HandlerRunner<REQ, RES> {
private final ThrowingBiFunction<Context, REQ, RES> runner;
private final SerdeFactory contextSerdeFactory;
private final Options options;
private static final Logger LOG = LogManager.getLogger(HandlerRunner.class);
HandlerRunner(
ThrowingBiFunction<? extends Context, REQ, RES> runner,
SerdeFactory contextSerdeFactory,
@Nullable Options options) {
//noinspection unchecked
this.runner = (ThrowingBiFunction<Context, REQ, RES>) runner;
this.contextSerdeFactory = contextSerdeFactory;
this.options = (options != null) ? options : Options.DEFAULT;
}
@Override
public CompletableFuture<Slice> run(
HandlerContext handlerContext,
Serde<REQ> requestSerde,
Serde<RES> responseSerde,
AtomicReference<Runnable> onClosedInvocationStreamHook) {
CompletableFuture<Slice> returnFuture = new CompletableFuture<>();
// Wrap the executor for setting/unsetting the thread local
Executor serviceExecutor =
runnable ->
options.executor.execute(
() -> {
HANDLER_CONTEXT_THREAD_LOCAL.set(handlerContext);
try (Scope ignored =
handlerContext.request().openTelemetryContext().makeCurrent()) {
runnable.run();
} finally {
HANDLER_CONTEXT_THREAD_LOCAL.remove();
}
});
serviceExecutor.execute(
() -> {
// Any context switching, if necessary, will be done by ResolvedEndpointHandler
Context ctx = new ContextImpl(handlerContext, serviceExecutor, contextSerdeFactory);
// Parse input
REQ req;
try {
req = requestSerde.deserialize(handlerContext.request().body());
} catch (Throwable e) {
LOG.warn("Cannot deserialize input", e);
returnFuture.completeExceptionally(
new TerminalException(
TerminalException.BAD_REQUEST_CODE,
"Cannot deserialize input: " + e.getMessage()));
return;
}
// Execute user code
RES res = null;
Throwable error = null;
try {
ContextThreadLocal.setContext(ctx);
res = this.runner.apply(ctx, req);
} catch (Throwable e) {
error = e;
} finally {
ContextThreadLocal.clearContext();
}
// If error, just return now
if (error != null) {
returnFuture.completeExceptionally(error);
return;
}
// Serialize output
Slice serializedResult;
try {
serializedResult = responseSerde.serialize(res);
} catch (Throwable e) {
LOG.warn("Cannot serialize output", e);
returnFuture.completeExceptionally(
new TerminalException(
TerminalException.INTERNAL_SERVER_ERROR_CODE,
"Cannot serialize output: " + e.getMessage()));
return;
}
// Complete callback
returnFuture.complete(serializedResult);
});
return returnFuture;
}
/** Factory method for {@link HandlerRunner}, used by codegen */
public static <CTX extends Context, REQ, RES> HandlerRunner<REQ, RES> of(
ThrowingBiFunction<CTX, REQ, RES> runner,
SerdeFactory contextSerdeFactory,
@Nullable Options options) {
return new HandlerRunner<>(runner, contextSerdeFactory, options);
}
/** Factory method for {@link HandlerRunner}, used by codegen */
@SuppressWarnings("unchecked")
public static <CTX extends Context, RES> HandlerRunner<Void, RES> of(
ThrowingFunction<CTX, RES> runner,
SerdeFactory contextSerdeFactory,
@Nullable Options options) {
return new HandlerRunner<>(
(context, o) -> runner.apply((CTX) context), contextSerdeFactory, options);
}
/** Factory method for {@link HandlerRunner}, used by codegen */
@SuppressWarnings("unchecked")
public static <CTX extends Context, REQ> HandlerRunner<REQ, Void> of(
ThrowingBiConsumer<CTX, REQ> runner,
SerdeFactory contextSerdeFactory,
@Nullable Options options) {
return new HandlerRunner<>(
(context, o) -> {
runner.accept((CTX) context, o);
return null;
},
contextSerdeFactory,
options);
}
/** Factory method for {@link HandlerRunner}, used by codegen */
@SuppressWarnings("unchecked")
public static <CTX extends Context> HandlerRunner<Void, Void> of(
ThrowingConsumer<CTX> runner, SerdeFactory contextSerdeFactory, @Nullable Options options) {
return new HandlerRunner<>(
(ctx, o) -> {
runner.accept((CTX) ctx);
return null;
},
contextSerdeFactory,
options);
}
/**
* {@link HandlerRunner} options. You can override the default options to configure the executor
* where to run the handlers.
*
* <p>You can run on virtual threads by using the executor {@code
* Executors.newVirtualThreadPerTaskExecutor()}.
*/
public static final class Options
implements dev.restate.sdk.endpoint.definition.HandlerRunner.Options {
/**
* Default options will use virtual threads on Java 21+, or fallback to {@link
* Executors#newCachedThreadPool()} for Java < 21. The bounded pool is shared among all
* {@link HandlerRunner} instances, and is used by {@link Restate#run}/{@link Context#run} as
* well.
*/
public static final Options DEFAULT = new Options(createDefaultExecutor());
private final Executor executor;
private Options(Executor executor) {
this.executor = executor;
}
/**
* Create an instance of {@link Options} with the given {@code executor}.
*
* <p>The given executor is used for running the handler code, and {@link Restate#run}/{@link
* Context#run} as well.
*/
public static Options withExecutor(Executor executor) {
return new Options(executor);
}
private static ExecutorService createDefaultExecutor() {
// Try to use virtual threads if available (Java 21+)
try {
return (ExecutorService)
Executors.class.getMethod("newVirtualThreadPerTaskExecutor").invoke(null);
} catch (Exception e) {
LOG.debug(
"Virtual threads not available, using unbounded thread pool. "
+ "If you need to customize the thread pool used by your restate handlers, "
+ "use HandlerRunner.Options.withExecutor() with Endpoint.bind()");
return Executors.newCachedThreadPool();
}
}
}
static HandlerContext getHandlerContext() {
return Objects.requireNonNull(
dev.restate.sdk.endpoint.definition.HandlerRunner.HANDLER_CONTEXT_THREAD_LOCAL.get(),
"Restate methods must be invoked from within a Restate handler");
}
}