|
| 1 | +package io.temporal.workflow.nexus; |
| 2 | + |
| 3 | +import com.google.protobuf.ByteString; |
| 4 | +import io.temporal.api.common.v1.Payload; |
| 5 | +import io.temporal.nexus.Nexus; |
| 6 | +import io.temporal.payload.codec.PayloadCodec; |
| 7 | +import java.util.List; |
| 8 | +import java.util.stream.Collectors; |
| 9 | +import javax.annotation.Nonnull; |
| 10 | + |
| 11 | +/** |
| 12 | + * A simulated per-endpoint encryption codec that demonstrates how to select different encryption |
| 13 | + * keys based on the Nexus endpoint being called. |
| 14 | + * |
| 15 | + * <p>This codec uses a simple prefix-based transformation (not real encryption) to keep focus on |
| 16 | + * the Nexus wiring pattern rather than cryptographic details. The design is asymmetric: |
| 17 | + * |
| 18 | + * <ul> |
| 19 | + * <li>{@code encode()}: Auto-detects the key ID by checking (1) the Nexus operation context if on |
| 20 | + * a handler thread, (2) the thread-local set by {@link EncryptionKeyContextPropagator} if on |
| 21 | + * a workflow thread, or (3) a default key. This means handler and workflow code need zero |
| 22 | + * awareness of encryption. |
| 23 | + * <li>{@code decode()}: Reads the key ID from the payload's own metadata ({@code |
| 24 | + * encryption-key-id} field). This is self-describing and does not depend on any external |
| 25 | + * context, which is necessary because handler-side input deserialization occurs before |
| 26 | + * handler code runs. |
| 27 | + * </ul> |
| 28 | + */ |
| 29 | +public class PerEndpointEncryptionCodec implements PayloadCodec { |
| 30 | + |
| 31 | + static final String METADATA_ENCRYPTED = "encrypted"; |
| 32 | + static final ByteString METADATA_ENCRYPTED_VALUE = ByteString.copyFromUtf8("true"); |
| 33 | + static final String METADATA_KEY_ID = "encryption-key-id"; |
| 34 | + static final String ENC_PREFIX = "ENC:"; |
| 35 | + |
| 36 | + private static final ThreadLocal<String> CURRENT_KEY_ID = new ThreadLocal<>(); |
| 37 | + |
| 38 | + private final String defaultKeyId; |
| 39 | + |
| 40 | + /** |
| 41 | + * @param defaultKeyId the default key ID used when neither Nexus context nor thread-local is |
| 42 | + * available (e.g., on the test thread or internal SDK serialization paths) |
| 43 | + */ |
| 44 | + public PerEndpointEncryptionCodec(String defaultKeyId) { |
| 45 | + this.defaultKeyId = defaultKeyId; |
| 46 | + } |
| 47 | + |
| 48 | + /** Sets the current encryption key ID on the calling thread (used by ContextPropagator). */ |
| 49 | + public static void setCurrentKeyId(String keyId) { |
| 50 | + CURRENT_KEY_ID.set(keyId); |
| 51 | + } |
| 52 | + |
| 53 | + /** Clears the current encryption key ID on the calling thread. */ |
| 54 | + public static void clearCurrentKeyId() { |
| 55 | + CURRENT_KEY_ID.remove(); |
| 56 | + } |
| 57 | + |
| 58 | + /** Returns the current key ID for the calling thread, or null if not set. */ |
| 59 | + static String getCurrentKeyId() { |
| 60 | + return CURRENT_KEY_ID.get(); |
| 61 | + } |
| 62 | + |
| 63 | + @Nonnull |
| 64 | + @Override |
| 65 | + public List<Payload> encode(@Nonnull List<Payload> payloads) { |
| 66 | + return payloads.stream().map(this::encodePayload).collect(Collectors.toList()); |
| 67 | + } |
| 68 | + |
| 69 | + @Nonnull |
| 70 | + @Override |
| 71 | + public List<Payload> decode(@Nonnull List<Payload> payloads) { |
| 72 | + return payloads.stream().map(this::decodePayload).collect(Collectors.toList()); |
| 73 | + } |
| 74 | + |
| 75 | + /** |
| 76 | + * Resolves the key ID to use for encoding. Checks sources in priority order: |
| 77 | + * |
| 78 | + * <ol> |
| 79 | + * <li>Nexus operation context (available on handler threads during serialization) |
| 80 | + * <li>Thread-local (set by {@link EncryptionKeyContextPropagator} on workflow threads) |
| 81 | + * <li>Default key (fallback for test threads and internal SDK operations) |
| 82 | + * </ol> |
| 83 | + */ |
| 84 | + private String resolveKeyId() { |
| 85 | + // 1. Check if we're on a Nexus handler thread — the endpoint is the key ID |
| 86 | + if (Nexus.isInOperationHandler()) { |
| 87 | + return Nexus.getOperationContext().getInfo().getEndpoint(); |
| 88 | + } |
| 89 | + |
| 90 | + // 2. Check thread-local (set by ContextPropagator for workflow threads) |
| 91 | + String threadLocalKey = CURRENT_KEY_ID.get(); |
| 92 | + if (threadLocalKey != null) { |
| 93 | + return threadLocalKey; |
| 94 | + } |
| 95 | + |
| 96 | + // 3. Default key |
| 97 | + return defaultKeyId; |
| 98 | + } |
| 99 | + |
| 100 | + private Payload encodePayload(Payload payload) { |
| 101 | + String keyId = resolveKeyId(); |
| 102 | + |
| 103 | + // Simulated encryption: prefix data with "ENC:<keyId>:" to make it visibly non-plaintext. |
| 104 | + // Use toBuilder() to preserve existing metadata (e.g., "encoding: json/plain"). |
| 105 | + ByteString prefix = ByteString.copyFromUtf8(ENC_PREFIX + keyId + ":"); |
| 106 | + ByteString encodedData = prefix.concat(payload.getData()); |
| 107 | + |
| 108 | + return payload.toBuilder() |
| 109 | + .putMetadata(METADATA_ENCRYPTED, METADATA_ENCRYPTED_VALUE) |
| 110 | + .putMetadata(METADATA_KEY_ID, ByteString.copyFromUtf8(keyId)) |
| 111 | + .setData(encodedData) |
| 112 | + .build(); |
| 113 | + } |
| 114 | + |
| 115 | + private Payload decodePayload(Payload payload) { |
| 116 | + // Self-describing: check our encryption marker, then read key-id from metadata. |
| 117 | + ByteString encryptedMeta = payload.getMetadataOrDefault(METADATA_ENCRYPTED, null); |
| 118 | + if (encryptedMeta == null || !encryptedMeta.equals(METADATA_ENCRYPTED_VALUE)) { |
| 119 | + return payload; |
| 120 | + } |
| 121 | + |
| 122 | + ByteString keyIdBytes = payload.getMetadataOrDefault(METADATA_KEY_ID, null); |
| 123 | + if (keyIdBytes == null) { |
| 124 | + throw new IllegalStateException("Encrypted payload missing encryption-key-id metadata"); |
| 125 | + } |
| 126 | + String keyId = keyIdBytes.toStringUtf8(); |
| 127 | + |
| 128 | + // Reverse the simulated encryption: strip the "ENC:<keyId>:" prefix |
| 129 | + String expectedPrefix = ENC_PREFIX + keyId + ":"; |
| 130 | + ByteString expectedPrefixBytes = ByteString.copyFromUtf8(expectedPrefix); |
| 131 | + ByteString data = payload.getData(); |
| 132 | + if (!data.startsWith(expectedPrefixBytes)) { |
| 133 | + throw new IllegalStateException( |
| 134 | + "Encrypted payload data does not start with expected prefix: " + expectedPrefix); |
| 135 | + } |
| 136 | + ByteString decodedData = data.substring(expectedPrefixBytes.size()); |
| 137 | + |
| 138 | + return payload.toBuilder() |
| 139 | + .removeMetadata(METADATA_ENCRYPTED) |
| 140 | + .removeMetadata(METADATA_KEY_ID) |
| 141 | + .setData(decodedData) |
| 142 | + .build(); |
| 143 | + } |
| 144 | +} |
0 commit comments