Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,20 @@ public static LambdaAsyncClient createAsyncLambdaClient(
awsCredentialsOptions);
final PluginMetrics awsSdkMetrics = PluginMetrics.fromNames("sdk", "aws");

NettyNioAsyncHttpClient.Builder httpClientBuilder = NettyNioAsyncHttpClient.builder()
.maxConcurrency(clientOptions.getMaxConcurrency())
.connectionTimeout(clientOptions.getConnectionTimeout());

if (clientOptions.getReadTimeout() != null) {
httpClientBuilder.readTimeout(clientOptions.getReadTimeout());
}

return LambdaAsyncClient.builder()
.region(awsAuthenticationOptions.getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(
createOverrideConfiguration(clientOptions, awsSdkMetrics))
.httpClient(NettyNioAsyncHttpClient.builder()
.maxConcurrency(clientOptions.getMaxConcurrency())
.connectionTimeout(clientOptions.getConnectionTimeout())
.readTimeout(clientOptions.getReadTimeout()).build())
.httpClient(httpClientBuilder.build())
.build();
}

Expand All @@ -56,11 +61,16 @@ private static ClientOverrideConfiguration createOverrideConfiguration(
.backoffStrategy(backoffStrategy)
.build();

return ClientOverrideConfiguration.builder()
ClientOverrideConfiguration.Builder configBuilder = ClientOverrideConfiguration.builder()
.retryPolicy(customRetryPolicy)
.addMetricPublisher(new MicrometerMetricPublisher(awsSdkMetrics))
.apiCallTimeout(clientOptions.getApiCallTimeout())
.build();
.apiCallTimeout(clientOptions.getApiCallTimeout());

if (clientOptions.getApiCallAttemptTimeout() != null) {
configBuilder.apiCallAttemptTimeout(clientOptions.getApiCallAttemptTimeout());
}

return configBuilder.build();
}

public static AwsCredentialsOptions convertToCredentialsOptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public class ClientOptions {
public static final int DEFAULT_CONNECTION_RETRIES = 3;
public static final int DEFAULT_MAXIMUM_CONCURRENCY = 200;
public static final Duration DEFAULT_CONNECTION_TIMEOUT = Duration.ofSeconds(60);
public static final Duration DEFAULT_READ_TIMEOUT = Duration.ofSeconds(60);

public static final Duration DEFAULT_API_TIMEOUT = Duration.ofSeconds(60);
public static final Duration DEFAULT_BASE_DELAY = Duration.ofMillis(100);
public static final Duration DEFAULT_MAX_BACKOFF = Duration.ofSeconds(20);
Expand All @@ -24,13 +24,17 @@ public class ClientOptions {
@JsonProperty("api_call_timeout")
private Duration apiCallTimeout = DEFAULT_API_TIMEOUT;

@JsonPropertyDescription("api call attempt timeout defines the time sdk waits for a single attempt before timing out")
@JsonProperty("api_call_attempt_timeout")
private Duration apiCallAttemptTimeout;

@JsonPropertyDescription("sdk timeout defines the time sdk maintains the connection to the client before timing out")
@JsonProperty("connection_timeout")
private Duration connectionTimeout = DEFAULT_CONNECTION_TIMEOUT;

@JsonPropertyDescription("read timeout defines the time sdk waits for data to be read from an established connection")
@JsonProperty("read_timeout")
private Duration readTimeout = DEFAULT_READ_TIMEOUT;
Comment thread
ashrao94 marked this conversation as resolved.
private Duration readTimeout;

@JsonPropertyDescription("max concurrency defined from the client side")
@JsonProperty("max_concurrency")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.time.Duration;

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
class LambdaClientFactoryTest {
Expand Down Expand Up @@ -86,6 +88,58 @@ void testCreateAsyncLambdaClientOverrideConfiguration() {
assertNotNull(overrideConfig.retryPolicy());
assertNotNull(overrideConfig.metricPublishers());
assertFalse(overrideConfig.metricPublishers().isEmpty());
// apiCallAttemptTimeout should not be set when null
assertFalse(overrideConfig.apiCallAttemptTimeout().isPresent());
}

@Test
void testCreateAsyncLambdaClientWithApiCallAttemptTimeout() {
// Arrange
ClientOptions clientOptions = mock(ClientOptions.class);
when(clientOptions.getMaxConcurrency()).thenReturn(200);
when(clientOptions.getConnectionTimeout()).thenReturn(Duration.ofSeconds(60));
when(clientOptions.getReadTimeout()).thenReturn(Duration.ofSeconds(60));
when(clientOptions.getApiCallTimeout()).thenReturn(Duration.ofSeconds(60));
when(clientOptions.getApiCallAttemptTimeout()).thenReturn(Duration.ofSeconds(30));
when(clientOptions.getMaxConnectionRetries()).thenReturn(3);
when(clientOptions.getBaseDelay()).thenReturn(Duration.ofMillis(100));
when(clientOptions.getMaxBackoff()).thenReturn(Duration.ofSeconds(20));

// Act
LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
awsCredentialsSupplier,
clientOptions
);

// Assert
assertNotNull(client);
ClientOverrideConfiguration overrideConfig = client.serviceClientConfiguration().overrideConfiguration();
assertEquals(Duration.ofSeconds(30), overrideConfig.apiCallAttemptTimeout().get());
}

@Test
void testCreateAsyncLambdaClientWithoutReadTimeout() {
// Arrange
ClientOptions clientOptions = mock(ClientOptions.class);
when(clientOptions.getMaxConcurrency()).thenReturn(200);
when(clientOptions.getConnectionTimeout()).thenReturn(Duration.ofSeconds(60));
when(clientOptions.getReadTimeout()).thenReturn(null); // No read timeout
when(clientOptions.getApiCallTimeout()).thenReturn(Duration.ofSeconds(60));
when(clientOptions.getApiCallAttemptTimeout()).thenReturn(null); // No attempt timeout
when(clientOptions.getMaxConnectionRetries()).thenReturn(3);
when(clientOptions.getBaseDelay()).thenReturn(Duration.ofMillis(100));
when(clientOptions.getMaxBackoff()).thenReturn(Duration.ofSeconds(20));

// Act
LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
awsCredentialsSupplier,
clientOptions
);

// Assert - should not throw exception when readTimeout is null
assertNotNull(client);
}

@Test
Expand Down Expand Up @@ -184,4 +238,28 @@ void testRetryConditionFirstFailsAndThenSucceeds() {
assertTrue(successReached, "Should have reached successful completion");
}

@Test
void testClientUsesConfiguredReadTimeout() {
ClientOptions clientOptions = new ClientOptions();
Duration customReadTimeout = Duration.ofSeconds(30);

// Use reflection to set the readTimeout since there's no setter
try {
java.lang.reflect.Field readTimeoutField = ClientOptions.class.getDeclaredField("readTimeout");
readTimeoutField.setAccessible(true);
readTimeoutField.set(clientOptions, customReadTimeout);
} catch (Exception e) {
throw new RuntimeException("Failed to set readTimeout", e);
}

LambdaAsyncClient client = LambdaClientFactory.createAsyncLambdaClient(
awsAuthenticationOptions,
awsCredentialsSupplier,
clientOptions
);

assertNotNull(client);
assertEquals(customReadTimeout, clientOptions.getReadTimeout());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/

package org.opensearch.dataprepper.plugins.lambda.common.config;

import org.junit.jupiter.api.Test;


import static org.junit.jupiter.api.Assertions.assertEquals;

class ClientOptionsTest {

@Test
void testDefaultReadTimeout() {
ClientOptions clientOptions = new ClientOptions();
assertEquals(null, clientOptions.getReadTimeout());
}

@Test
void testDefaultApiCallAttemptTimeout() {
ClientOptions clientOptions = new ClientOptions();
assertEquals(null, clientOptions.getApiCallAttemptTimeout());
}
}
Loading