Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ public class AwsSecretManagerConfiguration {
@JsonProperty("disable_refresh")
private boolean disableRefresh = false;

@JsonProperty("skip_validation_on_start")
private boolean skipValidationOnStart = false; // Default: false (validate by default)

public String getAwsSecretId() {
return awsSecretId;
}
Expand All @@ -68,6 +71,10 @@ public boolean isDisableRefresh() {
return disableRefresh;
}

public boolean isSkipValidationOnStart() {
return skipValidationOnStart;
}

public SecretsManagerClient createSecretManagerClient(final AwsCredentialsSupplier awsCredentialsSupplier) {
final AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder()
.withRegion(this.awsRegion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public class AwsSecretsSupplier implements SecretsSupplier {
static final TypeReference<Map<String, String>> MAP_TYPE_REFERENCE = new TypeReference<>() {
};
private static final Logger LOG = LoggerFactory.getLogger(AwsSecretsSupplier.class);
private static final Object NOT_LOADED_SENTINEL = new Object(); // Sentinel to indicate secret not loaded yet

private final SecretValueDecoder secretValueDecoder;
private final ObjectMapper objectMapper;
private final Map<String, AwsSecretManagerConfiguration> awsSecretManagerConfigurationMap;
Expand Down Expand Up @@ -58,6 +60,14 @@ private ConcurrentMap<String, Object> toSecretMap(
final AwsSecretManagerConfiguration awsSecretManagerConfiguration =
awsSecretManagerConfigurationMap.get(secretConfigurationId);
final SecretsManagerClient secretsManagerClient = entry.getValue();

// Check if validation on start is skipped for this secret
if (awsSecretManagerConfiguration.isSkipValidationOnStart()) {
LOG.info("Skipping secret retrieval on start for secret: {} (skip_validation_on_start=true)",
awsSecretManagerConfiguration.getAwsSecretId());
return NOT_LOADED_SENTINEL; // Mark as not loaded, will be loaded on first access
}

return retrieveSecretsFromSecretManager(awsSecretManagerConfiguration, secretsManagerClient);
}));
}
Expand All @@ -77,7 +87,10 @@ public Object retrieveValue(String secretId, String key) {
if (!secretIdToValue.containsKey(secretId)) {
throw new IllegalArgumentException(String.format("Unable to find secretId: %s", secretId));
}
final Object keyValuePairs = secretIdToValue.get(secretId);

// Load secret if it was skipped on start
final Object keyValuePairs = loadSecretIfNeeded(secretId);

if (!(keyValuePairs instanceof Map)) {
throw new IllegalArgumentException(String.format("The value under secretId: %s is not a valid json.",
secretId));
Expand All @@ -95,8 +108,11 @@ public Object retrieveValue(String secretId) {
if (!secretIdToValue.containsKey(secretId)) {
throw new IllegalArgumentException(String.format("Unable to find secretId: %s", secretId));
}

// Load secret if it was skipped on start
final Object secretValue = loadSecretIfNeeded(secretId);

try {
final Object secretValue = secretIdToValue.get(secretId);
return secretValue instanceof Map ? objectMapper.writeValueAsString(secretValue) :
secretValue;
} catch (JsonProcessingException e) {
Expand All @@ -105,6 +121,25 @@ public Object retrieveValue(String secretId) {
}
}

/**
* Loads a secret if it was skipped on start (lazy-loading).
* Uses {@link ConcurrentMap#compute} to ensure atomicity of the sentinel check and refresh.
*
* @param secretId The secret configuration ID
* @return The loaded secret value
*/
private Object loadSecretIfNeeded(String secretId) {
return secretIdToValue.compute(secretId, (key, currentValue) -> {
if (currentValue == NOT_LOADED_SENTINEL) {
LOG.info("Secret {} was not loaded on start, loading now on first access.", key);
final AwsSecretManagerConfiguration config = awsSecretManagerConfigurationMap.get(key);
final SecretsManagerClient client = secretsManagerClientMap.get(key);
return retrieveSecretsFromSecretManager(config, client);
}
return currentValue;
});
}


@Override
public void refresh(String secretConfigId) {
Expand Down Expand Up @@ -152,6 +187,8 @@ public String updateValue(String secretId, Object newValue) {

@Override
public String updateValue(String secretId, String keyToUpdate, Object newValue) {
// Ensure the secret is loaded before attempting to update
loadSecretIfNeeded(secretId);
Object currentSecretStore = secretIdToValue.get(secretId);
if (currentSecretStore instanceof Map) {
if (keyToUpdate == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.dataprepper.plugins.aws;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import org.junit.jupiter.api.Test;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;

class AwsSecretManagerConfigurationValidateAtBootstrapTest {

private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory())
.registerModule(new JavaTimeModule());

@Test
void testDefaultSkipValidationOnStart() {
final AwsSecretManagerConfiguration config = new AwsSecretManagerConfiguration();

// Default should be false (validate by default)
assertThat(config.isSkipValidationOnStart(), equalTo(false));
}

@Test
void testSkipValidationOnStartFromYaml_Enabled() throws Exception {
final String yaml =
"secret_id: my-secret\n" +
"region: us-east-1\n" +
"refresh_interval: PT1H\n" +
"skip_validation_on_start: true\n";

final AwsSecretManagerConfiguration config =
objectMapper.readValue(yaml, AwsSecretManagerConfiguration.class);

assertThat(config.isSkipValidationOnStart(), equalTo(true));
}

@Test
void testSkipValidationOnStartFromYaml_Disabled() throws Exception {
final String yaml =
"secret_id: my-secret\n" +
"region: us-east-1\n" +
"refresh_interval: PT1H\n" +
"skip_validation_on_start: false\n";

final AwsSecretManagerConfiguration config =
objectMapper.readValue(yaml, AwsSecretManagerConfiguration.class);

assertThat(config.isSkipValidationOnStart(), equalTo(false));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ void testInitializationWithNonNullConfig() {
when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn(
Map.of(TEST_SECRET_CONFIG_ID, awsSecretManagerConfiguration));
when(awsSecretManagerConfiguration.getRefreshInterval()).thenReturn(testInterval);
when(awsSecretManagerConfiguration.isSkipValidationOnStart()).thenReturn(false); // Default behavior
when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient);
when(awsSecretManagerConfiguration.createGetSecretValueRequest()).thenReturn(getSecretValueRequest);
when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse);
Expand Down Expand Up @@ -133,6 +134,7 @@ void testInitializationWithDisableRefresh() {
when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn(
Map.of(TEST_SECRET_CONFIG_ID, awsSecretManagerConfiguration));
when(awsSecretManagerConfiguration.isDisableRefresh()).thenReturn(true);
when(awsSecretManagerConfiguration.isSkipValidationOnStart()).thenReturn(false); // Default behavior
when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient);
when(awsSecretManagerConfiguration.createGetSecretValueRequest()).thenReturn(getSecretValueRequest);
when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.dataprepper.plugins.aws;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient;
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueRequest;
import software.amazon.awssdk.services.secretsmanager.model.GetSecretValueResponse;
import software.amazon.awssdk.services.secretsmanager.model.PutSecretValueRequest;
import software.amazon.awssdk.services.secretsmanager.model.PutSecretValueResponse;

import java.util.Map;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
* Tests for lazy-loading behavior when skip_validation_on_start is true.
*/
@ExtendWith(MockitoExtension.class)
class AwsSecretsSupplierLazyLoadTest {

private ObjectMapper objectMapper;
private String testSecretId;
private String testKey;
private String testValue;

@Mock
private SecretValueDecoder secretValueDecoder;

@Mock
private AwsSecretPluginConfig awsSecretPluginConfig;

@Mock
private AwsSecretManagerConfiguration awsSecretManagerConfiguration;

@Mock
private SecretsManagerClient secretsManagerClient;

@Mock
private GetSecretValueRequest getSecretValueRequest;

@Mock
private GetSecretValueResponse getSecretValueResponse;

@Mock
private PutSecretValueRequest putSecretValueRequest;

@Mock
private PutSecretValueResponse putSecretValueResponse;

@Mock
private AwsCredentialsSupplier awsCredentialsSupplier;

@BeforeEach
void setUp() {
objectMapper = new ObjectMapper();
testSecretId = UUID.randomUUID().toString();
testKey = UUID.randomUUID().toString();
testValue = UUID.randomUUID().toString();
}

@Test
void testSecretWithSkipValidationOnStartTrue_LoadsOnFirstAccess() throws JsonProcessingException {
// Given: Secret configured with skip_validation_on_start=true
when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn(
Map.of(testSecretId, awsSecretManagerConfiguration)
);
when(awsSecretManagerConfiguration.isSkipValidationOnStart()).thenReturn(true); // Skip on start
when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient);
when(awsSecretManagerConfiguration.createGetSecretValueRequest()).thenReturn(getSecretValueRequest);
when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(objectMapper.writeValueAsString(
Map.of(testKey, testValue)
));
when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse);

// When: AwsSecretsSupplier is constructed
final AwsSecretsSupplier supplier = new AwsSecretsSupplier(
secretValueDecoder, awsSecretPluginConfig, objectMapper, awsCredentialsSupplier
);

// Then: Secret is NOT retrieved at construction time
verify(secretsManagerClient, never()).getSecretValue(eq(getSecretValueRequest));

// When: Secret is accessed for the first time
final Object value = supplier.retrieveValue(testSecretId, testKey);

// Then: Secret is loaded on-demand
verify(secretsManagerClient, times(1)).getSecretValue(eq(getSecretValueRequest));
assertThat(value, equalTo(testValue));
}

@Test
void testSecretWithSkipValidationOnStartFalse_LoadsAtConstruction() throws JsonProcessingException {
// Given: Secret configured with skip_validation_on_start=false (default)
when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn(
Map.of(testSecretId, awsSecretManagerConfiguration)
);
when(awsSecretManagerConfiguration.isSkipValidationOnStart()).thenReturn(false); // Load on start
when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient);
when(awsSecretManagerConfiguration.createGetSecretValueRequest()).thenReturn(getSecretValueRequest);
when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(objectMapper.writeValueAsString(
Map.of(testKey, testValue)
));
when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse);

// When: AwsSecretsSupplier is constructed
final AwsSecretsSupplier supplier = new AwsSecretsSupplier(
secretValueDecoder, awsSecretPluginConfig, objectMapper, awsCredentialsSupplier
);

// Then: Secret IS retrieved at construction time
verify(secretsManagerClient, times(1)).getSecretValue(eq(getSecretValueRequest));

// When: Secret is accessed
final Object value = supplier.retrieveValue(testSecretId, testKey);

// Then: No additional retrieval (already loaded)
verify(secretsManagerClient, times(1)).getSecretValue(eq(getSecretValueRequest));
assertThat(value, equalTo(testValue));
}

@Test
void testUpdateValue_withSkipValidationOnStart_loadsSecretBeforeUpdate() throws JsonProcessingException {
// Given: Secret configured with skip_validation_on_start=true
when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn(
Map.of(testSecretId, awsSecretManagerConfiguration)
);
when(awsSecretManagerConfiguration.isSkipValidationOnStart()).thenReturn(true);
when(awsSecretManagerConfiguration.createSecretManagerClient(awsCredentialsSupplier)).thenReturn(secretsManagerClient);
when(awsSecretManagerConfiguration.createGetSecretValueRequest()).thenReturn(getSecretValueRequest);
when(secretValueDecoder.decode(eq(getSecretValueResponse))).thenReturn(objectMapper.writeValueAsString(
Map.of(testKey, testValue)
));
when(secretsManagerClient.getSecretValue(eq(getSecretValueRequest))).thenReturn(getSecretValueResponse);
when(awsSecretManagerConfiguration.putSecretValueRequest(any())).thenReturn(putSecretValueRequest);
when(secretsManagerClient.putSecretValue(eq(putSecretValueRequest))).thenReturn(putSecretValueResponse);
final String newVersionId = UUID.randomUUID().toString();
when(putSecretValueResponse.versionId()).thenReturn(newVersionId);

final AwsSecretsSupplier supplier = new AwsSecretsSupplier(
secretValueDecoder, awsSecretPluginConfig, objectMapper, awsCredentialsSupplier
);

// Then: Secret is NOT retrieved at construction time
verify(secretsManagerClient, never()).getSecretValue(eq(getSecretValueRequest));

// When: updateValue is called before any retrieveValue
final String versionId = supplier.updateValue(testSecretId, testKey, "newValue");

// Then: Secret was loaded on-demand and update succeeded
verify(secretsManagerClient, times(1)).getSecretValue(eq(getSecretValueRequest));
assertThat(versionId, equalTo(newVersionId));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class AwsSecretsSupplierTest {
@BeforeEach
void setUp() throws JsonProcessingException {
when(awsSecretManagerConfiguration.createGetSecretValueRequest()).thenReturn(getSecretValueRequest);
when(awsSecretManagerConfiguration.isSkipValidationOnStart()).thenReturn(false); // Default: validate on start
when(awsSecretPluginConfig.getAwsSecretManagerConfigurationMap()).thenReturn(
Map.of(TEST_AWS_SECRET_CONFIGURATION_NAME, awsSecretManagerConfiguration)
);
Expand Down
Loading