Skip to content

Commit 3fbbee0

Browse files
Show how to do per endpoint encryption
1 parent d660166 commit 3fbbee0

5 files changed

Lines changed: 574 additions & 0 deletions

File tree

temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusInternal.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,8 @@ private NexusInternal() {}
88
public static NexusOperationContext getOperationContext() {
99
return CurrentNexusOperationContext.get().getUserFacingContext();
1010
}
11+
12+
public static boolean isInOperationHandler() {
13+
return CurrentNexusOperationContext.isNexusContext();
14+
}
1115
}

temporal-sdk/src/main/java/io/temporal/nexus/Nexus.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ public static NexusOperationContext getOperationContext() {
1313
return NexusInternal.getOperationContext();
1414
}
1515

16+
/**
17+
* Returns true if the current thread is executing inside a Nexus operation handler. Useful for
18+
* context-aware components (such as codecs or interceptors) that need to behave differently
19+
* inside vs outside a Nexus handler.
20+
*/
21+
public static boolean isInOperationHandler() {
22+
return NexusInternal.isInOperationHandler();
23+
}
24+
1625
/**
1726
* Use this to rethrow a checked exception from a Nexus Operation instead of adding the exception
1827
* to a method signature.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package io.temporal.workflow.nexus;
2+
3+
import io.temporal.api.common.v1.Payload;
4+
import io.temporal.common.context.ContextPropagator;
5+
import io.temporal.common.converter.DefaultDataConverter;
6+
import io.temporal.nexus.Nexus;
7+
import java.util.Collections;
8+
import java.util.HashMap;
9+
import java.util.Map;
10+
11+
/**
12+
* A {@link ContextPropagator} that carries the encryption key identifier in workflow headers.
13+
*
14+
* <p>This propagator bridges the async Nexus operation path: when a Nexus handler starts a
15+
* workflow, this propagator auto-detects the endpoint from the Nexus operation context via {@link
16+
* #getCurrentContext()} (called on the handler thread), serializes it into the workflow header, and
17+
* restores it via {@link #setCurrentContext(Object)} (called on the workflow execution thread). The
18+
* codec then reads the thread-local to select the correct encryption key.
19+
*
20+
* <p>Neither the handler nor the workflow code needs any awareness of encryption. The propagator
21+
* automatically discovers the key ID from the Nexus context or the existing thread-local.
22+
*/
23+
public class EncryptionKeyContextPropagator implements ContextPropagator {
24+
25+
static final String HEADER_KEY = "x-encryption-key-id";
26+
27+
@Override
28+
public String getName() {
29+
return "encryption-key-propagator";
30+
}
31+
32+
/**
33+
* Called on the caller/handler thread to capture context for propagation. Auto-detects the key ID
34+
* from: (1) Nexus operation context (on handler threads), (2) existing thread-local (on workflow
35+
* threads).
36+
*/
37+
@Override
38+
public Object getCurrentContext() {
39+
String keyId = null;
40+
41+
// Try Nexus context first (available on handler threads)
42+
try {
43+
keyId = Nexus.getOperationContext().getInfo().getEndpoint();
44+
} catch (Exception e) {
45+
// Not on a Nexus handler thread
46+
}
47+
48+
// Fall back to thread-local (set by setCurrentContext on workflow threads)
49+
if (keyId == null) {
50+
keyId = PerEndpointEncryptionCodec.getCurrentKeyId();
51+
}
52+
53+
if (keyId == null) {
54+
return Collections.emptyMap();
55+
}
56+
Map<String, String> context = new HashMap<>();
57+
context.put(HEADER_KEY, keyId);
58+
return context;
59+
}
60+
61+
/**
62+
* Called on the workflow execution thread before workflow code runs. Sets the encryption key ID
63+
* on the thread-local so the codec can read it during {@code encode()}.
64+
*/
65+
@Override
66+
@SuppressWarnings("unchecked")
67+
public void setCurrentContext(Object context) {
68+
if (context instanceof Map) {
69+
Map<String, String> contextMap = (Map<String, String>) context;
70+
String keyId = contextMap.get(HEADER_KEY);
71+
if (keyId != null) {
72+
PerEndpointEncryptionCodec.setCurrentKeyId(keyId);
73+
}
74+
}
75+
}
76+
77+
@Override
78+
public Map<String, Payload> serializeContext(Object context) {
79+
if (context instanceof Map) {
80+
@SuppressWarnings("unchecked")
81+
Map<String, String> contextMap = (Map<String, String>) context;
82+
String keyId = contextMap.get(HEADER_KEY);
83+
if (keyId != null) {
84+
Map<String, Payload> serialized = new HashMap<>();
85+
serialized.put(
86+
HEADER_KEY, DefaultDataConverter.newDefaultInstance().toPayload(keyId).get());
87+
return serialized;
88+
}
89+
}
90+
return Collections.emptyMap();
91+
}
92+
93+
@Override
94+
public Object deserializeContext(Map<String, Payload> header) {
95+
Map<String, String> context = new HashMap<>();
96+
Payload keyIdPayload = header.get(HEADER_KEY);
97+
if (keyIdPayload != null) {
98+
String keyId =
99+
DefaultDataConverter.newDefaultInstance()
100+
.fromPayload(keyIdPayload, String.class, String.class);
101+
context.put(HEADER_KEY, keyId);
102+
}
103+
return context;
104+
}
105+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package io.temporal.workflow.nexus;
2+
3+
import com.google.protobuf.ByteString;
4+
import io.temporal.api.common.v1.Payload;
5+
import io.temporal.nexus.Nexus;
6+
import io.temporal.payload.codec.PayloadCodec;
7+
import java.util.List;
8+
import java.util.stream.Collectors;
9+
import javax.annotation.Nonnull;
10+
11+
/**
12+
* A simulated per-endpoint encryption codec that demonstrates how to select different encryption
13+
* keys based on the Nexus endpoint being called.
14+
*
15+
* <p>This codec uses a simple prefix-based transformation (not real encryption) to keep focus on
16+
* the Nexus wiring pattern rather than cryptographic details. The design is asymmetric:
17+
*
18+
* <ul>
19+
* <li>{@code encode()}: Auto-detects the key ID by checking (1) the Nexus operation context if on
20+
* a handler thread, (2) the thread-local set by {@link EncryptionKeyContextPropagator} if on
21+
* a workflow thread, or (3) a default key. This means handler and workflow code need zero
22+
* awareness of encryption.
23+
* <li>{@code decode()}: Reads the key ID from the payload's own metadata ({@code
24+
* encryption-key-id} field). This is self-describing and does not depend on any external
25+
* context, which is necessary because handler-side input deserialization occurs before
26+
* handler code runs.
27+
* </ul>
28+
*/
29+
public class PerEndpointEncryptionCodec implements PayloadCodec {
30+
31+
static final String METADATA_ENCRYPTED = "encrypted";
32+
static final ByteString METADATA_ENCRYPTED_VALUE = ByteString.copyFromUtf8("true");
33+
static final String METADATA_KEY_ID = "encryption-key-id";
34+
static final String ENC_PREFIX = "ENC:";
35+
36+
private static final ThreadLocal<String> CURRENT_KEY_ID = new ThreadLocal<>();
37+
38+
private final String defaultKeyId;
39+
40+
/**
41+
* @param defaultKeyId the default key ID used when neither Nexus context nor thread-local is
42+
* available (e.g., on the test thread or internal SDK serialization paths)
43+
*/
44+
public PerEndpointEncryptionCodec(String defaultKeyId) {
45+
this.defaultKeyId = defaultKeyId;
46+
}
47+
48+
/** Sets the current encryption key ID on the calling thread (used by ContextPropagator). */
49+
public static void setCurrentKeyId(String keyId) {
50+
CURRENT_KEY_ID.set(keyId);
51+
}
52+
53+
/** Clears the current encryption key ID on the calling thread. */
54+
public static void clearCurrentKeyId() {
55+
CURRENT_KEY_ID.remove();
56+
}
57+
58+
/** Returns the current key ID for the calling thread, or null if not set. */
59+
static String getCurrentKeyId() {
60+
return CURRENT_KEY_ID.get();
61+
}
62+
63+
@Nonnull
64+
@Override
65+
public List<Payload> encode(@Nonnull List<Payload> payloads) {
66+
return payloads.stream().map(this::encodePayload).collect(Collectors.toList());
67+
}
68+
69+
@Nonnull
70+
@Override
71+
public List<Payload> decode(@Nonnull List<Payload> payloads) {
72+
return payloads.stream().map(this::decodePayload).collect(Collectors.toList());
73+
}
74+
75+
/**
76+
* Resolves the key ID to use for encoding. Checks sources in priority order:
77+
*
78+
* <ol>
79+
* <li>Nexus operation context (available on handler threads during serialization)
80+
* <li>Thread-local (set by {@link EncryptionKeyContextPropagator} on workflow threads)
81+
* <li>Default key (fallback for test threads and internal SDK operations)
82+
* </ol>
83+
*/
84+
private String resolveKeyId() {
85+
// 1. Check if we're on a Nexus handler thread — the endpoint is the key ID
86+
if (Nexus.isInOperationHandler()) {
87+
return Nexus.getOperationContext().getInfo().getEndpoint();
88+
}
89+
90+
// 2. Check thread-local (set by ContextPropagator for workflow threads)
91+
String threadLocalKey = CURRENT_KEY_ID.get();
92+
if (threadLocalKey != null) {
93+
return threadLocalKey;
94+
}
95+
96+
// 3. Default key
97+
return defaultKeyId;
98+
}
99+
100+
private Payload encodePayload(Payload payload) {
101+
String keyId = resolveKeyId();
102+
103+
// Simulated encryption: prefix data with "ENC:<keyId>:" to make it visibly non-plaintext.
104+
// Use toBuilder() to preserve existing metadata (e.g., "encoding: json/plain").
105+
ByteString prefix = ByteString.copyFromUtf8(ENC_PREFIX + keyId + ":");
106+
ByteString encodedData = prefix.concat(payload.getData());
107+
108+
return payload.toBuilder()
109+
.putMetadata(METADATA_ENCRYPTED, METADATA_ENCRYPTED_VALUE)
110+
.putMetadata(METADATA_KEY_ID, ByteString.copyFromUtf8(keyId))
111+
.setData(encodedData)
112+
.build();
113+
}
114+
115+
private Payload decodePayload(Payload payload) {
116+
// Self-describing: check our encryption marker, then read key-id from metadata.
117+
ByteString encryptedMeta = payload.getMetadataOrDefault(METADATA_ENCRYPTED, null);
118+
if (encryptedMeta == null || !encryptedMeta.equals(METADATA_ENCRYPTED_VALUE)) {
119+
return payload;
120+
}
121+
122+
ByteString keyIdBytes = payload.getMetadataOrDefault(METADATA_KEY_ID, null);
123+
if (keyIdBytes == null) {
124+
throw new IllegalStateException("Encrypted payload missing encryption-key-id metadata");
125+
}
126+
String keyId = keyIdBytes.toStringUtf8();
127+
128+
// Reverse the simulated encryption: strip the "ENC:<keyId>:" prefix
129+
String expectedPrefix = ENC_PREFIX + keyId + ":";
130+
ByteString expectedPrefixBytes = ByteString.copyFromUtf8(expectedPrefix);
131+
ByteString data = payload.getData();
132+
if (!data.startsWith(expectedPrefixBytes)) {
133+
throw new IllegalStateException(
134+
"Encrypted payload data does not start with expected prefix: " + expectedPrefix);
135+
}
136+
ByteString decodedData = data.substring(expectedPrefixBytes.size());
137+
138+
return payload.toBuilder()
139+
.removeMetadata(METADATA_ENCRYPTED)
140+
.removeMetadata(METADATA_KEY_ID)
141+
.setData(decodedData)
142+
.build();
143+
}
144+
}

0 commit comments

Comments
 (0)