Skip to content
Open
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,9 @@ RUN --mount=type=secret,id=confidence_client_secret \
FROM openfeature-provider-java.test AS openfeature-provider-java.test_e2e

RUN --mount=type=secret,id=confidence_client_secret \
--mount=type=secret,id=confidence_client_encryption_key \
CONFIDENCE_CLIENT_SECRET=$(cat /run/secrets/confidence_client_secret) \
CONFIDENCE_CLIENT_ENCRYPTION_KEY=$(cat /run/secrets/confidence_client_encryption_key) \
make test-e2e

# ==============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,22 @@ class FlagsAdminStateFetcher implements AccountStateProvider {
"https://confidence-resolver-state-cdn.spotifycdn.com/";

private final String clientSecret;
private final String encryptionKey;
private final HttpClientFactory httpClientFactory;
// ETag for conditional GETs of resolver state
private final AtomicReference<String> etagHolder = new AtomicReference<>();
private final AtomicReference<byte[]> rawResolverStateHolder =
new AtomicReference<>(
com.spotify.confidence.sdk.flags.admin.v1.ResolverState.newBuilder()
.build()
.toByteArray());
private final AtomicReference<byte[]> rawStateHolder = new AtomicReference<>();
private String accountId = "";

public FlagsAdminStateFetcher(String clientSecret, HttpClientFactory httpClientFactory) {
public FlagsAdminStateFetcher(
String clientSecret, HttpClientFactory httpClientFactory, String encryptionKey) {
this.clientSecret = clientSecret;
this.httpClientFactory = httpClientFactory;
}

public AtomicReference<byte[]> rawStateHolder() {
return rawResolverStateHolder;
this.encryptionKey = encryptionKey;
}

@Override
public byte[] provide() {
return rawResolverStateHolder.get();
return rawStateHolder.get();
}

@Override
Expand All @@ -63,33 +57,37 @@ public void reload() {
}
}

boolean isEncrypted() {
return encryptionKey != null;
}

private void fetchAndUpdateStateIfChanged() {
// Build CDN URL using SHA256 hash of client secret
final var cdnUrl = CDN_BASE_URL + sha256Hex(clientSecret);
final String hash = sha256Hex(clientSecret);
final var cdnUrl = CDN_BASE_URL + hash + (isEncrypted() ? ".enc" : "");
try {
final HttpURLConnection conn = httpClientFactory.create(cdnUrl);
final String previousEtag = etagHolder.get();
if (previousEtag != null) {
conn.setRequestProperty("if-none-match", previousEtag);
}
if (conn.getResponseCode() == 304) {
// Not modified
return;
}
final String etag = conn.getHeaderField("etag");
try (final InputStream stream = conn.getInputStream()) {
final byte[] bytes = stream.readAllBytes();

// Parse SetResolverStateRequest from CDN response
final var stateRequest =
com.spotify.confidence.sdk.wasm.Messages.SetResolverStateRequest.parseFrom(bytes);
this.accountId = stateRequest.getAccountId();

// Store the state bytes (already in bytes format)
rawResolverStateHolder.set(stateRequest.getState().toByteArray());
if (isEncrypted()) {
rawStateHolder.set(bytes);
} else {
final var stateRequest =
com.spotify.confidence.sdk.wasm.Messages.SetResolverStateRequest.parseFrom(bytes);
this.accountId = stateRequest.getAccountId();
rawStateHolder.set(stateRequest.getState().toByteArray());
}
etagHolder.set(etag);
}
logger.info("Loaded resolver state for account={}, etag={}", accountId, etag);
logger.info("Loaded resolver state (encrypted={}, etag={})", isEncrypted(), etag);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public class LocalProviderConfig {
private final HttpClientFactory httpClientFactory;
private final boolean useRemoteMaterializationStore;
private final int resolverPoolSize;
private final String encryptionKey;

public LocalProviderConfig() {
this(null, null);
Expand All @@ -36,11 +37,21 @@ public LocalProviderConfig(
HttpClientFactory httpClientFactory,
boolean useRemoteMaterializationStore,
int resolverPoolSize) {
this(channelFactory, httpClientFactory, useRemoteMaterializationStore, resolverPoolSize, null);
}

private LocalProviderConfig(
ChannelFactory channelFactory,
HttpClientFactory httpClientFactory,
boolean useRemoteMaterializationStore,
int resolverPoolSize,
String encryptionKey) {
this.channelFactory = channelFactory != null ? channelFactory : new DefaultChannelFactory();
this.httpClientFactory =
httpClientFactory != null ? httpClientFactory : new DefaultHttpClientFactory();
this.useRemoteMaterializationStore = useRemoteMaterializationStore;
this.resolverPoolSize = resolverPoolSize > 0 ? resolverPoolSize : DEFAULT_RESOLVER_POOL_SIZE;
this.encryptionKey = encryptionKey;
}

public ChannelFactory getChannelFactory() {
Expand All @@ -63,6 +74,11 @@ public int getResolverPoolSize() {
return resolverPoolSize;
}

/** Returns the hex-encoded AES-256 encryption key, or {@code null} if unset. */
public String getEncryptionKey() {
return encryptionKey;
}

public static Builder builder() {
return new Builder();
}
Expand All @@ -72,6 +88,7 @@ public static class Builder {
private HttpClientFactory httpClientFactory;
private boolean useRemoteMaterializationStore;
private int resolverPoolSize;
private String encryptionKey;

public Builder channelFactory(ChannelFactory channelFactory) {
this.channelFactory = channelFactory;
Expand Down Expand Up @@ -100,9 +117,19 @@ public Builder resolverPoolSize(int resolverPoolSize) {
return this;
}

/** Sets the hex-encoded AES-256 encryption key for decrypting CDN state. */
public Builder encryptionKey(String encryptionKey) {
this.encryptionKey = encryptionKey;
return this;
}

public LocalProviderConfig build() {
return new LocalProviderConfig(
channelFactory, httpClientFactory, useRemoteMaterializationStore, resolverPoolSize);
channelFactory,
httpClientFactory,
useRemoteMaterializationStore,
resolverPoolSize,
encryptionKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ interface LocalResolver {
*/
void setResolverState(byte[] state, String accountId, Sdk sdk);

void setEncryptedResolverState(byte[] encryptedState, byte[] encryptionKey, Sdk sdk);

/**
* Resolves flags. The returned stage completes when all resolution (including any store I/O for
* materializations) has finished.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ public void setResolverState(byte[] state, String accountId, Sdk sdk) {
delegate.setResolverState(state, accountId, sdk);
}

@Override
public void setEncryptedResolverState(byte[] encryptedState, byte[] encryptionKey, Sdk sdk) {
delegate.setEncryptedResolverState(encryptedState, encryptionKey, sdk);
}

@Override
public void registerResolve(RegisterResolveRequest request) {
delegate.registerResolve(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class OpenFeatureLocalResolveProvider implements FeatureProvider {
private final AccountStateProvider stateProvider;
private final AtomicReference<ProviderState> state =
new AtomicReference<>(ProviderState.NOT_READY);
private final String encryptionKey;
private volatile boolean initialized = false;
private volatile byte[] lastStateBytes = null;
@VisibleForTesting boolean forcedFetcherShutdown = false;
Expand Down Expand Up @@ -147,8 +148,11 @@ public OpenFeatureLocalResolveProvider(
public OpenFeatureLocalResolveProvider(
LocalProviderConfig config, String clientSecret, MaterializationStore materializationStore) {
this.clientSecret = clientSecret;
this.encryptionKey = config.getEncryptionKey();
this.materializationStore = materializationStore;
this.stateProvider = new FlagsAdminStateFetcher(clientSecret, config.getHttpClientFactory());
this.stateProvider =
new FlagsAdminStateFetcher(
clientSecret, config.getHttpClientFactory(), config.getEncryptionKey());
final var wasmFlagLogger = new GrpcWasmFlagLogger(clientSecret, config.getChannelFactory());
this.flagLogger = wasmFlagLogger;
final int numInstances = PooledResolver.getNumInstances(config.getResolverPoolSize());
Expand All @@ -174,6 +178,7 @@ public OpenFeatureLocalResolveProvider(
MaterializationStore materializationStore,
WasmFlagLogger wasmFlagLogger) {
this.clientSecret = clientSecret;
this.encryptionKey = null;
this.materializationStore = materializationStore;
this.stateProvider = accountStateProvider;
this.flagLogger = wasmFlagLogger;
Expand All @@ -193,14 +198,13 @@ public ProviderState getState() {

@Override
public void initialize(EvaluationContext evaluationContext) {
if (encryptionKey == null) {
log.warn(
"No encryptionKey provided. Falling back to unencrypted state."
+ " An encryption key will be required in an upcoming version.");
}
stateProvider.reload();
final AtomicReference<byte[]> resolverStateProtobuf =
new AtomicReference<>(stateProvider.provide());
final AtomicReference<String> accountIdRef = new AtomicReference<>(stateProvider.accountId());

// Only initialize WASM and set READY if we got valid state (non-empty accountId)
if (!accountIdRef.get().isEmpty()) {
resolver.setResolverState(resolverStateProtobuf.get(), accountIdRef.get(), SDK);
if (pushStateToResolver()) {
initialized = true;
this.state.set(ProviderState.READY);
} else {
Expand All @@ -210,7 +214,7 @@ public void initialize(EvaluationContext evaluationContext) {
}

final long pollIntervalSeconds = getPollIntervalSeconds();
scheduleStateRefresh(resolverStateProtobuf, accountIdRef, pollIntervalSeconds);
scheduleStateRefresh(pollIntervalSeconds);

assignLogExecutor.scheduleAtFixedRate(
() -> {
Expand All @@ -227,53 +231,62 @@ public void initialize(EvaluationContext evaluationContext) {
TimeUnit.MILLISECONDS);
}

private void scheduleStateRefresh(
AtomicReference<byte[]> resolverStateProtobuf,
AtomicReference<String> accountIdRef,
long pollIntervalSeconds) {
private boolean pushStateToResolver() {
final byte[] stateBytes = stateProvider.provide();
if (stateBytes == null || stateBytes.length == 0) {
return false;
}
if (encryptionKey != null) {
resolver.setEncryptedResolverState(stateBytes, hexToBytes(encryptionKey), SDK);
} else {
final String accountId = stateProvider.accountId();
if (accountId == null || accountId.isEmpty()) {
return false;
}
resolver.setResolverState(stateBytes, accountId, SDK);
}
lastStateBytes = stateBytes;
return true;
}

private void scheduleStateRefresh(long pollIntervalSeconds) {
if (flagsFetcherExecutor.isShutdown()) {
return;
}

// Use short retry interval (1s) when not initialized, normal interval otherwise
long delaySeconds = initialized ? pollIntervalSeconds : 1;

flagsFetcherExecutor.schedule(
() -> {
try {
stateProvider.reload();
resolverStateProtobuf.set(stateProvider.provide());
accountIdRef.set(stateProvider.accountId());

if (!accountIdRef.get().isEmpty()) {
final byte[] newState = stateProvider.provide();
if (newState != null && !java.util.Arrays.equals(newState, lastStateBytes)) {
pushStateToResolver();
if (!initialized) {
resolver.setResolverState(resolverStateProtobuf.get(), accountIdRef.get(), SDK);
lastStateBytes = resolverStateProtobuf.get();
initialized = true;
this.state.set(ProviderState.READY);
log.info("Provider recovered and is now READY");
} else {
// Only push state into the wasm instances when it actually changed — the wasm
// execution inside setResolverState is expensive (runs across all pool slots).
final byte[] newState = resolverStateProtobuf.get();
if (!java.util.Arrays.equals(newState, lastStateBytes)) {
resolver.setResolverState(newState, accountIdRef.get(), SDK);
lastStateBytes = newState;
}
// Always flush logs regardless of state change.
resolver.flushAllLogs();
}
}
if (initialized) {
resolver.flushAllLogs();
}
} catch (RuntimeException e) {
log.error("State refresh failed", e);
} finally {
scheduleStateRefresh(resolverStateProtobuf, accountIdRef, pollIntervalSeconds);
scheduleStateRefresh(pollIntervalSeconds);
}
},
delaySeconds,
TimeUnit.SECONDS);
}

private static byte[] hexToBytes(String hex) {
return java.util.HexFormat.of().parseHex(hex);
}

@Override
public Metadata getMetadata() {
return () -> "confidence-sdk-java-local";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ public void setResolverState(byte[] state, String accountId, Sdk sdk) {
maintenance(lr -> lr.setResolverState(state, accountId, sdk));
}

@Override
public void setEncryptedResolverState(byte[] encryptedState, byte[] encryptionKey, Sdk sdk) {
maintenance(lr -> lr.setEncryptedResolverState(encryptedState, encryptionKey, sdk));
}

@Override
public void flushAllLogs() {
maintenance(LocalResolver::flushAllLogs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
class RecoveringResolver implements LocalResolver {
private static final Logger logger = LoggerFactory.getLogger(RecoveringResolver.class);

private record StateRecord(byte[] state, String accountId, Sdk sdk) {}

private final Supplier<LocalResolver> factory;
private final AtomicReference<LocalResolver> current = new AtomicReference<>();
private final AtomicBoolean broken = new AtomicBoolean(false);
private final AtomicReference<StateRecord> lastState = new AtomicReference<>();
private volatile java.util.function.Consumer<LocalResolver> replayState;

RecoveringResolver(Supplier<LocalResolver> factory) {
this.factory = factory;
Expand All @@ -41,9 +39,8 @@ private void startRecreate() {
try {
final LocalResolver old = current.get();
final LocalResolver newResolver = factory.get();
final StateRecord cached = lastState.get();
if (cached != null) {
newResolver.setResolverState(cached.state(), cached.accountId(), cached.sdk());
if (replayState != null) {
replayState.accept(newResolver);
}
current.set(newResolver);
if (old != null) {
Expand Down Expand Up @@ -83,13 +80,24 @@ private void handleFailure(String opName, ChicoryException e) {
public void setResolverState(byte[] state, String accountId, Sdk sdk) {
try {
current.get().setResolverState(state, accountId, sdk);
lastState.set(new StateRecord(state, accountId, sdk));
replayState = lr -> lr.setResolverState(state, accountId, sdk);
} catch (ChicoryException e) {
handleFailure("setResolverState", e);
throw e;
}
}

@Override
public void setEncryptedResolverState(byte[] encryptedState, byte[] encryptionKey, Sdk sdk) {
try {
current.get().setEncryptedResolverState(encryptedState, encryptionKey, sdk);
replayState = lr -> lr.setEncryptedResolverState(encryptedState, encryptionKey, sdk);
} catch (ChicoryException e) {
handleFailure("setEncryptedResolverState", e);
throw e;
}
}

@Override
public CompletionStage<ResolveProcessResponse> resolveProcess(ResolveProcessRequest request) {
try {
Expand Down
Loading