Skip to content

Commit e918e31

Browse files
committed
Add Custom Auth Provider with support for gRPC, plus tests and exception handling
Signed-off-by: Siqi Ding <dingdd@amazon.com>
1 parent 68b7be2 commit e918e31

10 files changed

Lines changed: 640 additions & 1 deletion
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package org.opensearch.dataprepper;
2+
3+
import com.google.protobuf.Any;
4+
import com.linecorp.armeria.common.RequestContext;
5+
import com.linecorp.armeria.common.annotation.Nullable;
6+
import com.linecorp.armeria.common.grpc.GoogleGrpcExceptionHandlerFunction;
7+
import com.linecorp.armeria.server.RequestTimeoutException;
8+
import io.grpc.Metadata;
9+
import io.grpc.Status;
10+
import io.grpc.StatusRuntimeException;
11+
import io.micrometer.core.instrument.Counter;
12+
13+
import org.opensearch.dataprepper.exceptions.BadRequestException;
14+
import org.opensearch.dataprepper.exceptions.BufferWriteException;
15+
import org.opensearch.dataprepper.exceptions.RequestCancelledException;
16+
import org.opensearch.dataprepper.metrics.PluginMetrics;
17+
import org.opensearch.dataprepper.model.buffer.SizeOverflowException;
18+
import org.slf4j.Logger;
19+
import org.slf4j.LoggerFactory;
20+
21+
import java.time.Duration;
22+
import java.util.concurrent.TimeoutException;
23+
24+
public class CustomAuthenticationExceptionHandler implements GoogleGrpcExceptionHandlerFunction {
25+
private static final Logger LOG = LoggerFactory.getLogger(CustomAuthenticationExceptionHandler.class);
26+
private static final String TIMEOUT_MESSAGE = "Request timed out. Check buffer availability or processing delays.";
27+
28+
public static final String REQUEST_TIMEOUTS = "customAuthRequestTimeouts";
29+
public static final String BAD_REQUESTS = "customAuthBadRequests";
30+
public static final String REQUESTS_TOO_LARGE = "customAuthRequestsTooLarge";
31+
public static final String INTERNAL_SERVER_ERROR = "customAuthInternalServerError";
32+
33+
private final Counter requestTimeoutsCounter;
34+
private final Counter badRequestsCounter;
35+
private final Counter requestsTooLargeCounter;
36+
private final Counter internalServerErrorCounter;
37+
private final GrpcRetryInfoCalculator retryInfoCalculator;
38+
39+
public CustomAuthenticationExceptionHandler(final PluginMetrics pluginMetrics,
40+
final Duration retryInfoMinDelay,
41+
final Duration retryInfoMaxDelay) {
42+
this.requestTimeoutsCounter = pluginMetrics.counter(REQUEST_TIMEOUTS);
43+
this.badRequestsCounter = pluginMetrics.counter(BAD_REQUESTS);
44+
this.requestsTooLargeCounter = pluginMetrics.counter(REQUESTS_TOO_LARGE);
45+
this.internalServerErrorCounter = pluginMetrics.counter(INTERNAL_SERVER_ERROR);
46+
this.retryInfoCalculator = new GrpcRetryInfoCalculator(retryInfoMinDelay, retryInfoMaxDelay);
47+
}
48+
49+
@Override
50+
public com.google.rpc.@Nullable Status applyStatusProto(RequestContext ctx, Throwable throwable, Metadata metadata) {
51+
final Throwable actualCause = (throwable instanceof BufferWriteException)
52+
? throwable.getCause() : throwable;
53+
return handleException(actualCause);
54+
}
55+
56+
private com.google.rpc.Status handleException(Throwable e) {
57+
final String msg = e.getMessage();
58+
if (e instanceof RequestTimeoutException || e instanceof TimeoutException) {
59+
requestTimeoutsCounter.increment();
60+
return buildStatus(e, Status.Code.RESOURCE_EXHAUSTED);
61+
} else if (e instanceof SizeOverflowException) {
62+
requestsTooLargeCounter.increment();
63+
return buildStatus(e, Status.Code.RESOURCE_EXHAUSTED);
64+
} else if (e instanceof BadRequestException) {
65+
badRequestsCounter.increment();
66+
return buildStatus(e, Status.Code.INVALID_ARGUMENT);
67+
} else if ((e instanceof StatusRuntimeException) &&
68+
(msg.contains("Invalid protobuf byte sequence") || msg.contains("Can't decode compressed frame"))) {
69+
badRequestsCounter.increment();
70+
return buildStatus(e, Status.Code.INVALID_ARGUMENT);
71+
} else if (e instanceof RequestCancelledException) {
72+
requestTimeoutsCounter.increment();
73+
return buildStatus(e, Status.Code.CANCELLED);
74+
}
75+
76+
internalServerErrorCounter.increment();
77+
LOG.error("CustomAuth gRPC handler caught unexpected exception", e);
78+
return buildStatus(e, Status.Code.INTERNAL);
79+
}
80+
81+
private com.google.rpc.Status buildStatus(Throwable e, Status.Code code) {
82+
com.google.rpc.Status.Builder builder = com.google.rpc.Status.newBuilder()
83+
.setCode(code.value());
84+
85+
if (e instanceof RequestTimeoutException) {
86+
builder.setMessage(TIMEOUT_MESSAGE);
87+
} else {
88+
builder.setMessage(e.getMessage() != null ? e.getMessage() : code.name());
89+
}
90+
91+
if (code == Status.Code.RESOURCE_EXHAUSTED) {
92+
builder.addDetails(Any.pack(retryInfoCalculator.createRetryInfo()));
93+
}
94+
95+
return builder.build();
96+
}
97+
}
98+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.armeria.authentication;
7+
8+
import com.fasterxml.jackson.annotation.JsonCreator;
9+
import com.fasterxml.jackson.annotation.JsonProperty;
10+
11+
public class CustomAuthenticationConfig {
12+
private final String customToken;
13+
14+
@JsonCreator
15+
public CustomAuthenticationConfig(
16+
@JsonProperty("custom_token") String customToken) {
17+
this.customToken = customToken;
18+
}
19+
20+
public String customToken() {
21+
return customToken;
22+
}
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.armeria.authentication;
7+
8+
import com.linecorp.armeria.server.HttpService;
9+
import io.grpc.ServerInterceptor;
10+
11+
import java.util.Optional;
12+
import java.util.function.Function;
13+
14+
public interface CustomAuthenticationProvider {
15+
16+
String UNAUTHENTICATED_PLUGIN_NAME = "unauthenticated";
17+
18+
19+
ServerInterceptor getAuthenticationInterceptor();
20+
21+
default Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() {
22+
return Optional.empty();
23+
}
24+
}
25+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins;
7+
8+
import com.linecorp.armeria.common.HttpResponse;
9+
import com.linecorp.armeria.common.HttpStatus;
10+
import com.linecorp.armeria.common.MediaType;
11+
import com.linecorp.armeria.server.HttpService;
12+
import io.grpc.Metadata;
13+
import io.grpc.ServerCall;
14+
import io.grpc.ServerCallHandler;
15+
import io.grpc.ServerInterceptor;
16+
import io.grpc.Status;
17+
import org.opensearch.dataprepper.armeria.authentication.CustomAuthenticationConfig;
18+
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider;
19+
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
20+
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
21+
22+
import java.util.Optional;
23+
import java.util.function.Function;
24+
25+
@DataPrepperPlugin(
26+
name = "custom_auth",
27+
pluginType = GrpcAuthenticationProvider.class,
28+
pluginConfigurationType = CustomAuthenticationConfig.class
29+
)
30+
public class CustomGrpcAuthenticationProvider implements GrpcAuthenticationProvider {
31+
private final String token;
32+
private static final String AUTH_HEADER = "authentication";
33+
34+
35+
@DataPrepperPluginConstructor
36+
public CustomGrpcAuthenticationProvider(final CustomAuthenticationConfig config) {
37+
this.token = config.customToken();
38+
}
39+
40+
@Override
41+
public ServerInterceptor getAuthenticationInterceptor() {
42+
return new ServerInterceptor() {
43+
@Override
44+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
45+
ServerCall<ReqT, RespT> call,
46+
Metadata headers,
47+
ServerCallHandler<ReqT, RespT> next) {
48+
49+
String auth = headers.get(Metadata.Key.of("authentication", Metadata.ASCII_STRING_MARSHALLER));
50+
51+
if (auth == null || !auth.equals(token)) {
52+
call.close(Status.UNAUTHENTICATED.withDescription("Invalid token"), new Metadata());
53+
return new ServerCall.Listener<>() {};
54+
}
55+
56+
return next.startCall(call, headers);
57+
}
58+
};
59+
}
60+
61+
@Override
62+
public Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() {
63+
return Optional.of(delegate -> (ctx, req) -> {
64+
final String auth = req.headers().get(AUTH_HEADER);
65+
if (auth == null || !auth.equals(token)) {
66+
return HttpResponse.of(
67+
HttpStatus.UNAUTHORIZED,
68+
MediaType.PLAIN_TEXT_UTF_8,
69+
"Unauthorized: Invalid or missing token"
70+
);
71+
}
72+
return delegate.serve(ctx, req);
73+
});
74+
}
75+
}
76+
77+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins;
7+
8+
import io.grpc.ServerInterceptor;
9+
import io.grpc.ServerCall;
10+
import io.grpc.ServerCallHandler;
11+
import io.grpc.Metadata;
12+
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider;
13+
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
14+
15+
16+
/**
17+
* Plugin that allows unauthenticated gRPC access.
18+
*/
19+
@DataPrepperPlugin(
20+
name = GrpcAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME,
21+
pluginType = GrpcAuthenticationProvider.class
22+
)
23+
public class UnauthenticatedCustomGrpcAuthenticationProvider implements GrpcAuthenticationProvider {
24+
25+
@Override
26+
public ServerInterceptor getAuthenticationInterceptor() {
27+
return new ServerInterceptor() {
28+
@Override
29+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
30+
ServerCall<ReqT, RespT> call,
31+
Metadata headers,
32+
ServerCallHandler<ReqT, RespT> next) {
33+
// No authentication is performed; allow the request to continue
34+
return next.startCall(call, headers);
35+
}
36+
};
37+
}
38+
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package org.opensearch.dataprepper;
2+
3+
import com.google.protobuf.Any;
4+
import com.google.rpc.RetryInfo;
5+
import com.linecorp.armeria.common.RequestContext;
6+
import io.grpc.Metadata;
7+
import io.grpc.Status;
8+
import io.micrometer.core.instrument.Counter;
9+
import org.junit.jupiter.api.BeforeEach;
10+
import org.junit.jupiter.api.Test;
11+
import org.junit.jupiter.api.extension.ExtendWith;
12+
import org.mockito.Mock;
13+
import org.mockito.junit.jupiter.MockitoExtension;
14+
import org.opensearch.dataprepper.exceptions.BadRequestException;
15+
import org.opensearch.dataprepper.exceptions.BufferWriteException;
16+
import org.opensearch.dataprepper.exceptions.RequestCancelledException;
17+
import org.opensearch.dataprepper.metrics.PluginMetrics;
18+
import org.opensearch.dataprepper.model.buffer.SizeOverflowException;
19+
20+
import java.io.IOException;
21+
import java.time.Duration;
22+
import java.util.Optional;
23+
import java.util.UUID;
24+
import java.util.concurrent.TimeoutException;
25+
26+
import static org.hamcrest.MatcherAssert.assertThat;
27+
import static org.hamcrest.Matchers.equalTo;
28+
import static org.junit.jupiter.api.Assertions.assertTrue;
29+
import static org.mockito.Mockito.verify;
30+
import static org.mockito.Mockito.when;
31+
32+
@ExtendWith(MockitoExtension.class)
33+
public class CustomAuthenticationExceptionHandlerTest {
34+
@Mock
35+
private PluginMetrics pluginMetrics;
36+
37+
@Mock
38+
private Counter requestTimeoutsCounter;
39+
40+
@Mock
41+
private Counter badRequestsCounter;
42+
43+
@Mock
44+
private Counter requestsTooLargeCounter;
45+
46+
@Mock
47+
private Counter internalServerErrorCounter;
48+
49+
@Mock
50+
private RequestContext requestContext;
51+
52+
@Mock
53+
private Metadata metadata;
54+
55+
private CustomAuthenticationExceptionHandler handler;
56+
57+
@BeforeEach
58+
public void setUp() {
59+
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.REQUEST_TIMEOUTS)).thenReturn(requestTimeoutsCounter);
60+
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.BAD_REQUESTS)).thenReturn(badRequestsCounter);
61+
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.REQUESTS_TOO_LARGE)).thenReturn(requestsTooLargeCounter);
62+
when(pluginMetrics.counter(CustomAuthenticationExceptionHandler.INTERNAL_SERVER_ERROR)).thenReturn(internalServerErrorCounter);
63+
64+
handler = new CustomAuthenticationExceptionHandler(pluginMetrics, Duration.ofMillis(100), Duration.ofSeconds(2));
65+
}
66+
67+
@Test
68+
public void testBadRequestExceptionHandling() {
69+
final String message = UUID.randomUUID().toString();
70+
BadRequestException exception = new BadRequestException(message, new IOException());
71+
72+
com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata);
73+
74+
assertThat(status.getCode(), equalTo(Status.Code.INVALID_ARGUMENT.value()));
75+
assertThat(status.getMessage(), equalTo(message));
76+
verify(badRequestsCounter).increment();
77+
}
78+
79+
@Test
80+
public void testTimeoutExceptionHandling() {
81+
TimeoutException timeout = new TimeoutException();
82+
BufferWriteException bufferWriteException = new BufferWriteException("timeout", timeout);
83+
84+
com.google.rpc.Status status = handler.applyStatusProto(requestContext, bufferWriteException, metadata);
85+
86+
assertThat(status.getCode(), equalTo(Status.Code.RESOURCE_EXHAUSTED.value()));
87+
verify(requestTimeoutsCounter).increment();
88+
Optional<Any> retryInfo = status.getDetailsList().stream().filter(d -> d.is(RetryInfo.class)).findFirst();
89+
assertTrue(retryInfo.isPresent());
90+
}
91+
92+
@Test
93+
public void testSizeOverflowExceptionHandling() {
94+
SizeOverflowException overflow = new SizeOverflowException("Overflow");
95+
BufferWriteException bufferWriteException = new BufferWriteException("overflow", overflow);
96+
97+
com.google.rpc.Status status = handler.applyStatusProto(requestContext, bufferWriteException, metadata);
98+
99+
assertThat(status.getCode(), equalTo(Status.Code.RESOURCE_EXHAUSTED.value()));
100+
verify(requestsTooLargeCounter).increment();
101+
}
102+
103+
@Test
104+
public void testCancelledRequestHandling() {
105+
String message = UUID.randomUUID().toString();
106+
RequestCancelledException exception = new RequestCancelledException(message);
107+
108+
com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata);
109+
110+
assertThat(status.getCode(), equalTo(Status.Code.CANCELLED.value()));
111+
assertThat(status.getMessage(), equalTo(message));
112+
verify(requestTimeoutsCounter).increment();
113+
}
114+
115+
@Test
116+
public void testInternalExceptionHandling() {
117+
String message = UUID.randomUUID().toString();
118+
RuntimeException exception = new RuntimeException(message);
119+
120+
com.google.rpc.Status status = handler.applyStatusProto(requestContext, exception, metadata);
121+
122+
assertThat(status.getCode(), equalTo(Status.Code.INTERNAL.value()));
123+
assertThat(status.getMessage(), equalTo(message));
124+
verify(internalServerErrorCounter).increment();
125+
}
126+
}

0 commit comments

Comments
 (0)