Skip to content

Commit da48519

Browse files
Make sure Nexus operation input is also encrypted
1 parent 3fbbee0 commit da48519

4 files changed

Lines changed: 144 additions & 110 deletions

File tree

temporal-sdk/src/test/java/io/temporal/workflow/nexus/EncryptionKeyContextPropagator.java

Lines changed: 0 additions & 105 deletions
This file was deleted.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package io.temporal.workflow.nexus;
2+
3+
import io.temporal.api.common.v1.WorkflowExecution;
4+
import io.temporal.client.ActivityCompletionClient;
5+
import io.temporal.client.WorkflowOptions;
6+
import io.temporal.client.WorkflowStub;
7+
import io.temporal.common.converter.DefaultDataConverter;
8+
import io.temporal.common.interceptors.WorkflowClientCallsInterceptor;
9+
import io.temporal.common.interceptors.WorkflowClientCallsInterceptorBase;
10+
import io.temporal.common.interceptors.WorkflowClientInterceptor;
11+
import io.temporal.nexus.Nexus;
12+
import java.util.Optional;
13+
14+
/**
15+
* Client interceptor for per-endpoint Nexus encryption.
16+
*
17+
* <p>When a workflow is started from a Nexus operation handler, injects the endpoint name into the
18+
* workflow's header. The {@link PerEndpointEncryptionWorkerInterceptor} reads it on the workflow
19+
* thread and sets the codec's thread-local key, ensuring the async workflow result is encrypted
20+
* with the correct per-endpoint key.
21+
*/
22+
public class PerEndpointEncryptionClientInterceptor implements WorkflowClientInterceptor {
23+
24+
static final String ENDPOINT_HEADER_KEY = "x-encryption-endpoint";
25+
26+
@Override
27+
public WorkflowClientCallsInterceptor workflowClientCallsInterceptor(
28+
WorkflowClientCallsInterceptor next) {
29+
return new WorkflowClientCallsInterceptorBase(next) {
30+
@Override
31+
public WorkflowStartOutput start(WorkflowStartInput input) {
32+
// If we're on a Nexus handler thread, inject the endpoint into the workflow header.
33+
if (Nexus.isInOperationHandler()) {
34+
String endpoint = Nexus.getOperationContext().getInfo().getEndpoint();
35+
input
36+
.getHeader()
37+
.getValues()
38+
.put(
39+
ENDPOINT_HEADER_KEY,
40+
DefaultDataConverter.newDefaultInstance().toPayload(endpoint).get());
41+
}
42+
return super.start(input);
43+
}
44+
};
45+
}
46+
47+
@Override
48+
@Deprecated
49+
public WorkflowStub newUntypedWorkflowStub(
50+
String workflowType, WorkflowOptions options, WorkflowStub next) {
51+
return next;
52+
}
53+
54+
@Override
55+
@Deprecated
56+
public WorkflowStub newUntypedWorkflowStub(
57+
WorkflowExecution execution, Optional<String> workflowType, WorkflowStub next) {
58+
return next;
59+
}
60+
61+
@Override
62+
public ActivityCompletionClient newActivityCompletionClient(ActivityCompletionClient next) {
63+
return next;
64+
}
65+
}

temporal-sdk/src/test/java/io/temporal/workflow/nexus/PerEndpointEncryptionTest.java

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io.temporal.nexus.Nexus;
2121
import io.temporal.nexus.WorkflowRunOperation;
2222
import io.temporal.testing.internal.SDKTestWorkflowRule;
23+
import io.temporal.worker.WorkerFactoryOptions;
2324
import io.temporal.workflow.*;
2425
import java.time.Duration;
2526
import java.util.Collections;
@@ -37,8 +38,9 @@
3738
*
3839
* <ul>
3940
* <li>The {@link PerEndpointEncryptionCodec} auto-detects the endpoint from Nexus context on
40-
* handler threads, or from the thread-local set by the {@link EncryptionKeyContextPropagator}
41-
* on workflow threads. Falls back to a default key for other contexts.
41+
* handler threads, or from the thread-local set by the {@link
42+
* PerEndpointEncryptionInterceptor} on workflow threads. Falls back to a default key for
43+
* other contexts.
4244
* <li>{@code decode()} reads the key ID from payload metadata (self-describing), so handler-side
4345
* input deserialization works before handler code runs.
4446
* <li>Nexus handler code is pure business logic — no encryption concerns.
@@ -60,8 +62,11 @@ public class PerEndpointEncryptionTest extends BaseNexusTest {
6062
DefaultDataConverter.newDefaultInstance(),
6163
Collections.singletonList(new PerEndpointEncryptionCodec(DEFAULT_KEY_ID)),
6264
true))
63-
.setContextPropagators(
64-
Collections.singletonList(new EncryptionKeyContextPropagator()))
65+
.setInterceptors(new PerEndpointEncryptionClientInterceptor())
66+
.build())
67+
.setWorkerFactoryOptions(
68+
WorkerFactoryOptions.newBuilder()
69+
.setWorkerInterceptors(new PerEndpointEncryptionWorkerInterceptor())
6570
.build())
6671
.build();
6772

@@ -202,6 +207,15 @@ public void testPayloadsEncryptedWithCorrectKeys() {
202207
if (event.hasNexusOperationScheduledEventAttributes()) {
203208
Payload nexusInput = event.getNexusOperationScheduledEventAttributes().getInput();
204209
assertPayloadIsEncrypted(nexusInput, "NexusOperationScheduled input");
210+
// Verify the Nexus operation input uses the endpoint key, not the default key.
211+
String nexusInputKeyId =
212+
nexusInput
213+
.getMetadataOrThrow(PerEndpointEncryptionCodec.METADATA_KEY_ID)
214+
.toStringUtf8();
215+
assertEquals(
216+
"Nexus operation input must use endpoint key, not default",
217+
expectedEndpointKey,
218+
nexusInputKeyId);
205219
foundEncryptedPayload = true;
206220
}
207221
if (event.hasWorkflowExecutionCompletedEventAttributes()) {
@@ -236,7 +250,7 @@ public void testPayloadsEncryptedWithCorrectKeys() {
236250
assertPayloadIsEncrypted(resultPayload, "AsyncWorkflow result");
237251

238252
// This is the critical assertion: the async workflow's result must be encrypted
239-
// with the endpoint-specific key, proving the ContextPropagator carried it through.
253+
// with the endpoint-specific key, proving the interceptor carried it through.
240254
String keyId =
241255
resultPayload
242256
.getMetadataOrThrow(PerEndpointEncryptionCodec.METADATA_KEY_ID)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package io.temporal.workflow.nexus;
2+
3+
import io.temporal.api.common.v1.Payload;
4+
import io.temporal.common.converter.DefaultDataConverter;
5+
import io.temporal.common.interceptors.*;
6+
7+
/**
8+
* Worker interceptor for per-endpoint Nexus encryption.
9+
*
10+
* <p>Two interception points:
11+
*
12+
* <ul>
13+
* <li><b>Caller workflow (outbound)</b>: Before a Nexus operation executes, sets the codec's
14+
* thread-local key from the endpoint name. This ensures the codec encrypts the Nexus
15+
* operation input with the correct per-endpoint key.
16+
* <li><b>Async workflow (inbound)</b>: When a workflow started by a Nexus handler begins, reads
17+
* the endpoint from a header (injected by {@link PerEndpointEncryptionClientInterceptor}) and
18+
* sets the thread-local. This ensures the codec encrypts the workflow result with the
19+
* endpoint key.
20+
* </ul>
21+
*/
22+
public class PerEndpointEncryptionWorkerInterceptor extends WorkerInterceptorBase {
23+
24+
@Override
25+
public WorkflowInboundCallsInterceptor interceptWorkflow(WorkflowInboundCallsInterceptor next) {
26+
return new WorkflowInboundCallsInterceptorBase(next) {
27+
@Override
28+
public void init(WorkflowOutboundCallsInterceptor outboundCalls) {
29+
next.init(
30+
new WorkflowOutboundCallsInterceptorBase(outboundCalls) {
31+
@Override
32+
public <R> ExecuteNexusOperationOutput<R> executeNexusOperation(
33+
ExecuteNexusOperationInput<R> input) {
34+
// Set the key from the endpoint BEFORE the input is serialized.
35+
PerEndpointEncryptionCodec.setCurrentKeyId(input.getEndpoint());
36+
return super.executeNexusOperation(input);
37+
}
38+
});
39+
}
40+
41+
@Override
42+
public WorkflowOutput execute(WorkflowInput input) {
43+
// If this workflow was started by a Nexus handler, the client interceptor injected
44+
// the endpoint into the header. Read it and set the thread-local.
45+
Payload endpointPayload =
46+
input
47+
.getHeader()
48+
.getValues()
49+
.get(PerEndpointEncryptionClientInterceptor.ENDPOINT_HEADER_KEY);
50+
if (endpointPayload != null) {
51+
String endpoint =
52+
DefaultDataConverter.newDefaultInstance()
53+
.fromPayload(endpointPayload, String.class, String.class);
54+
PerEndpointEncryptionCodec.setCurrentKeyId(endpoint);
55+
}
56+
return super.execute(input);
57+
}
58+
};
59+
}
60+
}

0 commit comments

Comments
 (0)