Skip to content

Commit 321a7ed

Browse files
committed
Add Custom Auth Provider with support for gRPC, plus tests and exception (#5578)
Add Custom Auth Provider with support for gRPC, plus tests and exception handling Signed-off-by: Siqi Ding <109874435+Davidding4718@users.noreply.github.com>
1 parent 68b7be2 commit 321a7ed

8 files changed

Lines changed: 451 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.testcustomauth;
7+
8+
import com.fasterxml.jackson.annotation.JsonCreator;
9+
import com.fasterxml.jackson.annotation.JsonProperty;
10+
11+
public class TestCustomAuthenticationConfig {
12+
private final String customToken;
13+
private final String header;
14+
15+
@JsonCreator
16+
public TestCustomAuthenticationConfig(
17+
@JsonProperty("custom_token") String customToken,
18+
@JsonProperty("header") String header) {
19+
this.customToken = customToken;
20+
this.header = header != null ? header : "authentication";
21+
}
22+
23+
public String customToken() {
24+
return customToken;
25+
}
26+
27+
public String header() {
28+
return header;
29+
}
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.testcustomauth;
7+
8+
import com.linecorp.armeria.server.HttpService;
9+
import io.grpc.ServerInterceptor;
10+
11+
import java.util.Optional;
12+
import java.util.function.Function;
13+
14+
public interface TestCustomAuthenticationProvider {
15+
16+
String UNAUTHENTICATED_PLUGIN_NAME = "unauthenticated";
17+
18+
19+
ServerInterceptor getAuthenticationInterceptor();
20+
21+
default Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() {
22+
return Optional.empty();
23+
}
24+
}
25+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package org.opensearch.dataprepper.plugins.testcustomauth;
2+
3+
import org.junit.jupiter.api.Assertions;
4+
import org.junit.jupiter.api.BeforeEach;
5+
import org.junit.jupiter.api.Test;
6+
import org.junit.jupiter.api.extension.ExtendWith;
7+
import org.mockito.Mock;
8+
import org.mockito.junit.jupiter.MockitoExtension;
9+
10+
import static org.mockito.Mockito.when;
11+
12+
@ExtendWith(MockitoExtension.class)
13+
public class TestCustomAuthenticationProviderTest {
14+
15+
private static final String TOKEN = "test-token";
16+
private static final String HEADER = "authentication";
17+
18+
@Mock
19+
private TestCustomAuthenticationConfig config;
20+
21+
private TestCustomGrpcAuthenticationProvider provider;
22+
23+
@BeforeEach
24+
void setUp() {
25+
when(config.customToken()).thenReturn(TOKEN);
26+
when(config.header()).thenReturn(HEADER);
27+
28+
provider = new TestCustomGrpcAuthenticationProvider(config);
29+
}
30+
31+
@Test
32+
void testGetHttpAuthenticationService_shouldReturnValidOptional() {
33+
var optionalService = provider.getHttpAuthenticationService();
34+
Assertions.assertTrue(optionalService.isPresent());
35+
}
36+
37+
@Test
38+
void testGetAuthenticationInterceptor_shouldReturnNonNull() {
39+
var interceptor = provider.getAuthenticationInterceptor();
40+
Assertions.assertNotNull(interceptor);
41+
}
42+
}
43+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package org.opensearch.dataprepper.plugins.testcustomauth;
2+
3+
import com.linecorp.armeria.client.WebClient;
4+
import com.linecorp.armeria.common.AggregatedHttpResponse;
5+
import com.linecorp.armeria.common.HttpData;
6+
import com.linecorp.armeria.common.HttpMethod;
7+
import com.linecorp.armeria.common.HttpRequest;
8+
import com.linecorp.armeria.common.HttpStatus;
9+
import com.linecorp.armeria.common.MediaType;
10+
import com.linecorp.armeria.common.RequestHeaders;
11+
import com.linecorp.armeria.server.ServerBuilder;
12+
import com.linecorp.armeria.server.grpc.GrpcService;
13+
import com.linecorp.armeria.server.grpc.GrpcServiceBuilder;
14+
import com.linecorp.armeria.testing.junit5.server.ServerExtension;
15+
import io.grpc.ServerInterceptors;
16+
import io.grpc.health.v1.HealthCheckRequest;
17+
import io.grpc.health.v1.HealthCheckResponse;
18+
import io.grpc.health.v1.HealthGrpc;
19+
import io.grpc.stub.StreamObserver;
20+
import org.junit.jupiter.api.BeforeEach;
21+
import org.junit.jupiter.api.Nested;
22+
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.api.extension.RegisterExtension;
24+
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider;
25+
26+
import java.nio.charset.Charset;
27+
import java.util.Collections;
28+
import java.util.UUID;
29+
30+
import static org.hamcrest.MatcherAssert.assertThat;
31+
import static org.hamcrest.Matchers.equalTo;
32+
import static org.junit.jupiter.api.Assertions.assertThrows;
33+
import static org.mockito.Mockito.mock;
34+
import static org.mockito.Mockito.when;
35+
36+
public class TestCustomBasicAuthenticationProviderTest {
37+
private static final String TOKEN = UUID.randomUUID().toString();
38+
private static final String HEADER_NAME = "x-" + UUID.randomUUID();
39+
private static GrpcAuthenticationProvider grpcAuthenticationProvider;
40+
41+
@RegisterExtension
42+
static ServerExtension server = new ServerExtension() {
43+
@Override
44+
protected void configure(ServerBuilder sb) {
45+
TestCustomAuthenticationConfig config = mock(TestCustomAuthenticationConfig.class);
46+
when(config.customToken()).thenReturn(TOKEN);
47+
when(config.header()).thenReturn(HEADER_NAME);
48+
49+
grpcAuthenticationProvider = new TestCustomGrpcAuthenticationProvider(config);
50+
51+
GrpcServiceBuilder grpcServiceBuilder = GrpcService.builder()
52+
.enableUnframedRequests(true)
53+
.addService(ServerInterceptors.intercept(
54+
new SampleHealthGrpcService(),
55+
Collections.singletonList(grpcAuthenticationProvider.getAuthenticationInterceptor())));
56+
57+
sb.service(grpcServiceBuilder.build());
58+
}
59+
};
60+
61+
private static class SampleHealthGrpcService extends HealthGrpc.HealthImplBase {
62+
@Override
63+
public void check(HealthCheckRequest request, StreamObserver<HealthCheckResponse> responseObserver) {
64+
responseObserver.onNext(
65+
HealthCheckResponse.newBuilder().setStatus(HealthCheckResponse.ServingStatus.SERVING).build());
66+
responseObserver.onCompleted();
67+
}
68+
}
69+
70+
@Nested
71+
class ConstructorTests {
72+
TestCustomAuthenticationConfig config;
73+
74+
@BeforeEach
75+
void setUp() {
76+
config = mock(TestCustomAuthenticationConfig.class);
77+
}
78+
79+
@Test
80+
void constructor_with_null_config_throws() {
81+
assertThrows(NullPointerException.class, () -> new TestCustomGrpcAuthenticationProvider(null));
82+
}
83+
}
84+
85+
@Nested
86+
class WithServer {
87+
@Test
88+
void request_without_token_responds_Unauthorized() {
89+
WebClient client = WebClient.of(server.httpUri());
90+
HttpRequest request = HttpRequest.of(RequestHeaders.builder()
91+
.method(HttpMethod.POST)
92+
.path("/grpc.health.v1.Health/Check")
93+
.contentType(MediaType.JSON_UTF_8)
94+
.build());
95+
96+
final AggregatedHttpResponse httpResponse = client.execute(request).aggregate().join();
97+
98+
assertThat(httpResponse.status(), equalTo(HttpStatus.UNAUTHORIZED));
99+
}
100+
101+
@Test
102+
void request_with_invalid_token_responds_Unauthorized() {
103+
WebClient client = WebClient.builder(server.httpUri())
104+
.addHeader(HEADER_NAME, "invalid-token")
105+
.build();
106+
107+
HttpRequest request = HttpRequest.of(RequestHeaders.builder()
108+
.method(HttpMethod.POST)
109+
.path("/grpc.health.v1.Health/Check")
110+
.contentType(MediaType.JSON_UTF_8)
111+
.build());
112+
113+
final AggregatedHttpResponse httpResponse = client.execute(request).aggregate().join();
114+
115+
assertThat(httpResponse.status(), equalTo(HttpStatus.UNAUTHORIZED));
116+
}
117+
118+
@Test
119+
void request_with_valid_token_responds_OK() {
120+
WebClient client = WebClient.builder(server.httpUri())
121+
.addHeader(HEADER_NAME, TOKEN)
122+
.build();
123+
124+
HttpRequest request = HttpRequest.of(RequestHeaders.builder()
125+
.method(HttpMethod.POST)
126+
.path("/grpc.health.v1.Health/Check")
127+
.contentType(MediaType.JSON_UTF_8)
128+
.build(),
129+
HttpData.of(Charset.defaultCharset(), "{\"healthCheckConfig\":{\"serviceName\": \"test\"} }"));
130+
131+
132+
final AggregatedHttpResponse httpResponse = client.execute(request).aggregate().join();
133+
134+
assertThat(httpResponse.status(), equalTo(HttpStatus.OK));
135+
}
136+
}
137+
}
138+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.dataprepper.plugins.testcustomauth;
7+
8+
import com.linecorp.armeria.common.HttpResponse;
9+
import com.linecorp.armeria.common.HttpStatus;
10+
import com.linecorp.armeria.common.MediaType;
11+
import com.linecorp.armeria.server.HttpService;
12+
import io.grpc.Metadata;
13+
import io.grpc.ServerCall;
14+
import io.grpc.ServerCallHandler;
15+
import io.grpc.ServerInterceptor;
16+
import io.grpc.Status;
17+
import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider;
18+
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
19+
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
20+
21+
import java.util.Optional;
22+
import java.util.function.Function;
23+
24+
@DataPrepperPlugin(
25+
name = "test_custom_auth",
26+
pluginType = GrpcAuthenticationProvider.class,
27+
pluginConfigurationType = TestCustomAuthenticationConfig.class
28+
)
29+
public class TestCustomGrpcAuthenticationProvider implements GrpcAuthenticationProvider {
30+
private final String token;
31+
private final String header;
32+
33+
@DataPrepperPluginConstructor
34+
public TestCustomGrpcAuthenticationProvider(final TestCustomAuthenticationConfig config) {
35+
this.token = config.customToken();
36+
this.header = config.header();
37+
}
38+
39+
@Override
40+
public ServerInterceptor getAuthenticationInterceptor() {
41+
return new ServerInterceptor() {
42+
@Override
43+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
44+
ServerCall<ReqT, RespT> call,
45+
Metadata headers,
46+
ServerCallHandler<ReqT, RespT> next) {
47+
48+
String auth = headers.get(Metadata.Key.of(header, Metadata.ASCII_STRING_MARSHALLER));
49+
50+
if (!isValid(auth)) {
51+
call.close(Status.UNAUTHENTICATED.withDescription("Invalid token"), new Metadata());
52+
return new ServerCall.Listener<>() {};
53+
}
54+
55+
return next.startCall(call, headers);
56+
}
57+
};
58+
}
59+
60+
@Override
61+
public Optional<Function<? super HttpService, ? extends HttpService>> getHttpAuthenticationService() {
62+
return Optional.of(delegate -> (ctx, req) -> {
63+
final String auth = req.headers().get(header);
64+
if (!isValid(auth)) {
65+
return HttpResponse.of(
66+
HttpStatus.UNAUTHORIZED,
67+
MediaType.PLAIN_TEXT_UTF_8,
68+
"Unauthorized: Invalid or missing token"
69+
);
70+
}
71+
return delegate.serve(ctx, req);
72+
});
73+
}
74+
75+
/**
76+
* Checks if the provided authentication token is valid.
77+
*
78+
* @param authHeader the value of the authentication header
79+
* @return true if valid, false otherwise
80+
*/
81+
private boolean isValid(final String authHeader) {
82+
return authHeader != null && authHeader.equals(token);
83+
}
84+
}
85+
86+

0 commit comments

Comments
 (0)