|
5 | 5 |
|
6 | 6 | import static org.junit.jupiter.api.Assertions.*; |
7 | 7 |
|
8 | | -import com.amazonaws.lambda.durable.exception.CallbackFailedException; |
9 | | -import com.amazonaws.lambda.durable.exception.CallbackTimeoutException; |
| 8 | +import com.amazonaws.lambda.durable.serde.JacksonSerDes; |
| 9 | +import com.amazonaws.lambda.durable.serde.SerDes; |
10 | 10 | import com.amazonaws.lambda.durable.model.ExecutionStatus; |
11 | 11 | import com.amazonaws.lambda.durable.testing.LocalDurableTestRunner; |
12 | 12 | import java.time.Duration; |
| 13 | +import java.util.concurrent.atomic.AtomicInteger; |
13 | 14 | import org.junit.jupiter.api.Test; |
14 | 15 | import software.amazon.awssdk.services.lambda.model.ErrorObject; |
15 | 16 | import software.amazon.awssdk.services.lambda.model.OperationStatus; |
16 | 17 | import software.amazon.awssdk.services.lambda.model.OperationType; |
17 | 18 |
|
18 | 19 | class CallbackIntegrationTest { |
19 | 20 |
|
| 21 | + /** Custom SerDes that tracks deserialization calls for testing. */ |
| 22 | + static class TrackingSerDes implements SerDes { |
| 23 | + private final JacksonSerDes delegate = new JacksonSerDes(); |
| 24 | + private final AtomicInteger deserializeCount = new AtomicInteger(0); |
| 25 | + |
| 26 | + @Override |
| 27 | + public String serialize(Object value) { |
| 28 | + return delegate.serialize(value); |
| 29 | + } |
| 30 | + |
| 31 | + @Override |
| 32 | + public <T> T deserialize(String data, Class<T> type) { |
| 33 | + deserializeCount.incrementAndGet(); |
| 34 | + return delegate.deserialize(data, type); |
| 35 | + } |
| 36 | + |
| 37 | + @Override |
| 38 | + public <T> T deserialize(String data, TypeToken<T> typeToken) { |
| 39 | + deserializeCount.incrementAndGet(); |
| 40 | + return delegate.deserialize(data, typeToken); |
| 41 | + } |
| 42 | + |
| 43 | + public int getDeserializeCount() { |
| 44 | + return deserializeCount.get(); |
| 45 | + } |
| 46 | + } |
| 47 | + |
20 | 48 | @Test |
21 | 49 | void callbackSuccessFlow() { |
22 | 50 | var runner = LocalDurableTestRunner.create(String.class, (input, ctx) -> { |
@@ -146,4 +174,56 @@ void callbackWithSteps() { |
146 | 174 | assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); |
147 | 175 | assertEquals("prepared -> approved -> done", result.getResult(String.class)); |
148 | 176 | } |
| 177 | + |
| 178 | + @Test |
| 179 | + void callbackWithCustomSerDes() { |
| 180 | + var customSerDes = new TrackingSerDes(); |
| 181 | + |
| 182 | + var runner = LocalDurableTestRunner.create(String.class, (input, ctx) -> { |
| 183 | + var cb = ctx.createCallback( |
| 184 | + "approval", |
| 185 | + String.class, |
| 186 | + CallbackConfig.builder().serDes(customSerDes).build()); |
| 187 | + |
| 188 | + return cb.future().get(); |
| 189 | + }); |
| 190 | + |
| 191 | + // First run - creates callback, suspends |
| 192 | + var result = runner.run("test"); |
| 193 | + |
| 194 | + // Complete the callback |
| 195 | + var callbackId = runner.getCallbackId("approval"); |
| 196 | + runner.completeCallback(callbackId, "\"approved\""); |
| 197 | + |
| 198 | + // Second run - callback complete, returns result |
| 199 | + result = runner.run("test"); |
| 200 | + |
| 201 | + assertEquals("approved", result.getResult(String.class)); |
| 202 | + assertTrue(customSerDes.getDeserializeCount() > 0, "Custom SerDes should have been used"); |
| 203 | + } |
| 204 | + |
| 205 | + @Test |
| 206 | + void callbackWithNullSerDesUsesDefault() { |
| 207 | + var runner = LocalDurableTestRunner.create(String.class, (input, ctx) -> { |
| 208 | + // Explicitly pass null SerDes - should use default |
| 209 | + var cb = ctx.createCallback( |
| 210 | + "approval", |
| 211 | + String.class, |
| 212 | + CallbackConfig.builder().serDes(null).build()); |
| 213 | + |
| 214 | + return cb.future().get(); |
| 215 | + }); |
| 216 | + |
| 217 | + // First run - creates callback, suspends |
| 218 | + var result = runner.run("test"); |
| 219 | + |
| 220 | + // Complete the callback |
| 221 | + var callbackId = runner.getCallbackId("approval"); |
| 222 | + runner.completeCallback(callbackId, "\"result\""); |
| 223 | + |
| 224 | + // Second run - callback complete, returns result |
| 225 | + result = runner.run("test"); |
| 226 | + |
| 227 | + assertEquals("result", result.getResult(String.class)); |
| 228 | + } |
149 | 229 | } |
0 commit comments