diff --git a/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/GrpcRequestExceptionHandler.java b/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/GrpcRequestExceptionHandler.java index 1b7f591f24..964bb36373 100644 --- a/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/GrpcRequestExceptionHandler.java +++ b/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/GrpcRequestExceptionHandler.java @@ -39,14 +39,14 @@ public class GrpcRequestExceptionHandler implements GoogleGrpcExceptionHandlerFu private final Counter badRequestsCounter; private final Counter requestsTooLargeCounter; private final Counter internalServerErrorCounter; - private final GrpcRetryInfoCalculator retryInfoCalculator; + private final RetryInfoCalculator retryInfoCalculator; public GrpcRequestExceptionHandler(final PluginMetrics pluginMetrics, Duration retryInfoMinDelay, Duration retryInfoMaxDelay) { requestTimeoutsCounter = pluginMetrics.counter(REQUEST_TIMEOUTS); badRequestsCounter = pluginMetrics.counter(BAD_REQUESTS); requestsTooLargeCounter = pluginMetrics.counter(REQUESTS_TOO_LARGE); internalServerErrorCounter = pluginMetrics.counter(INTERNAL_SERVER_ERROR); - retryInfoCalculator = new GrpcRetryInfoCalculator(retryInfoMinDelay, retryInfoMaxDelay); + retryInfoCalculator = new RetryInfoCalculator(retryInfoMinDelay, retryInfoMaxDelay); } @Override diff --git a/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/GrpcRetryInfoCalculator.java b/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/RetryInfoCalculator.java similarity index 92% rename from data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/GrpcRetryInfoCalculator.java rename to data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/RetryInfoCalculator.java index 2b74b0b4bc..33c4af3f07 100644 --- a/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/GrpcRetryInfoCalculator.java +++ b/data-prepper-plugins/armeria-common/src/main/java/org/opensearch/dataprepper/RetryInfoCalculator.java @@ -6,7 +6,7 @@ import java.time.Instant; import java.util.concurrent.atomic.AtomicReference; -class GrpcRetryInfoCalculator { +public class RetryInfoCalculator { private final Duration minimumDelay; private final Duration maximumDelay; @@ -14,7 +14,7 @@ class GrpcRetryInfoCalculator { private final AtomicReference lastTimeCalled; private final AtomicReference nextDelay; - GrpcRetryInfoCalculator(Duration minimumDelay, Duration maximumDelay) { + public RetryInfoCalculator(Duration minimumDelay, Duration maximumDelay) { this.minimumDelay = minimumDelay; this.maximumDelay = maximumDelay; // Create a cushion so that the calculator treats a first quick exception (after prepper startup) as normal request (e.g. does not calculate a backoff) @@ -34,7 +34,7 @@ private static com.google.protobuf.Duration.Builder mapDuration(Duration duratio return com.google.protobuf.Duration.newBuilder().setSeconds(duration.getSeconds()).setNanos(duration.getNano()); } - RetryInfo createRetryInfo() { + public RetryInfo createRetryInfo() { Instant now = Instant.now(); // Is the last time we got called longer ago than the next delay? if (lastTimeCalled.getAndSet(now).isBefore(now.minus(nextDelay.get()))) { diff --git a/data-prepper-plugins/armeria-common/src/test/java/org/opensearch/dataprepper/GrpcRetryInfoCalculatorTest.java b/data-prepper-plugins/armeria-common/src/test/java/org/opensearch/dataprepper/GrpcRetryInfoCalculatorTest.java index 5611826ef7..5cd79a3c1b 100644 --- a/data-prepper-plugins/armeria-common/src/test/java/org/opensearch/dataprepper/GrpcRetryInfoCalculatorTest.java +++ b/data-prepper-plugins/armeria-common/src/test/java/org/opensearch/dataprepper/GrpcRetryInfoCalculatorTest.java @@ -12,7 +12,7 @@ public class GrpcRetryInfoCalculatorTest { @Test public void testMinimumDelayOnFirstCall() { - RetryInfo retryInfo = new GrpcRetryInfoCalculator(Duration.ofMillis(100), Duration.ofSeconds(1)).createRetryInfo(); + RetryInfo retryInfo = new RetryInfoCalculator(Duration.ofMillis(100), Duration.ofSeconds(1)).createRetryInfo(); assertThat(retryInfo.getRetryDelay().getNanos(), equalTo(100_000_000)); assertThat(retryInfo.getRetryDelay().getSeconds(), equalTo(0L)); @@ -20,8 +20,8 @@ public void testMinimumDelayOnFirstCall() { @Test public void testExponentialBackoff() { - GrpcRetryInfoCalculator calculator = - new GrpcRetryInfoCalculator(Duration.ofSeconds(1), Duration.ofSeconds(10)); + RetryInfoCalculator calculator = + new RetryInfoCalculator(Duration.ofSeconds(1), Duration.ofSeconds(10)); RetryInfo retryInfo1 = calculator.createRetryInfo(); RetryInfo retryInfo2 = calculator.createRetryInfo(); RetryInfo retryInfo3 = calculator.createRetryInfo(); @@ -35,8 +35,8 @@ public void testExponentialBackoff() { @Test public void testUsesMaximumAsLongestDelay() { - GrpcRetryInfoCalculator calculator = - new GrpcRetryInfoCalculator(Duration.ofSeconds(1), Duration.ofSeconds(2)); + RetryInfoCalculator calculator = + new RetryInfoCalculator(Duration.ofSeconds(1), Duration.ofSeconds(2)); RetryInfo retryInfo1 = calculator.createRetryInfo(); RetryInfo retryInfo2 = calculator.createRetryInfo(); RetryInfo retryInfo3 = calculator.createRetryInfo(); @@ -49,8 +49,8 @@ public void testUsesMaximumAsLongestDelay() { @Test public void testResetAfterDelayWearsOff() throws InterruptedException { int minDelayNanos = 1_000_000; - GrpcRetryInfoCalculator calculator = - new GrpcRetryInfoCalculator(Duration.ofNanos(minDelayNanos), Duration.ofSeconds(1)); + RetryInfoCalculator calculator = + new RetryInfoCalculator(Duration.ofNanos(minDelayNanos), Duration.ofSeconds(1)); RetryInfo retryInfo1 = calculator.createRetryInfo(); RetryInfo retryInfo2 = calculator.createRetryInfo(); @@ -66,8 +66,8 @@ public void testResetAfterDelayWearsOff() throws InterruptedException { @Test public void testQuickFirstExceptionDoesNotTriggerBackoffCalculationEvenWithLongMinDelay() throws InterruptedException { - GrpcRetryInfoCalculator calculator = - new GrpcRetryInfoCalculator(Duration.ofSeconds(10), Duration.ofSeconds(20)); + RetryInfoCalculator calculator = + new RetryInfoCalculator(Duration.ofSeconds(10), Duration.ofSeconds(20)); RetryInfo retryInfo1 = calculator.createRetryInfo(); RetryInfo retryInfo2 = calculator.createRetryInfo(); diff --git a/data-prepper-plugins/otel-metrics-source/src/test/java/org/opensearch/dataprepper/plugins/source/otelmetrics/OTelMetricsSourceTest.java b/data-prepper-plugins/otel-metrics-source/src/test/java/org/opensearch/dataprepper/plugins/source/otelmetrics/OTelMetricsSourceTest.java index 81214e3c10..8e41255b0e 100644 --- a/data-prepper-plugins/otel-metrics-source/src/test/java/org/opensearch/dataprepper/plugins/source/otelmetrics/OTelMetricsSourceTest.java +++ b/data-prepper-plugins/otel-metrics-source/src/test/java/org/opensearch/dataprepper/plugins/source/otelmetrics/OTelMetricsSourceTest.java @@ -350,7 +350,6 @@ void testHttpFullJsonWithCustomPathAndUnframedRequests() throws InvalidProtocolB .join(); } - @Test void testHttpFullJsonWithCustomPathAndAuthHeader_with_successful_response() throws InvalidProtocolBufferException { when(httpBasicAuthenticationConfig.getUsername()).thenReturn(USERNAME); @@ -420,7 +419,7 @@ void testHttpRequestWithInvalidCredentials_with_unsuccessful_response() throws I when(httpBasicAuthenticationConfig.getUsername()).thenReturn(USERNAME); when(httpBasicAuthenticationConfig.getPassword()).thenReturn(PASSWORD); final GrpcAuthenticationProvider grpcAuthenticationProvider = new GrpcBasicAuthenticationProvider(httpBasicAuthenticationConfig); - + when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))) .thenReturn(grpcAuthenticationProvider); when(oTelMetricsSourceConfig.getAuthentication()).thenReturn(new PluginModel("http_basic", @@ -430,17 +429,17 @@ void testHttpRequestWithInvalidCredentials_with_unsuccessful_response() throws I ))); when(oTelMetricsSourceConfig.enableUnframedRequests()).thenReturn(true); when(oTelMetricsSourceConfig.getPath()).thenReturn(TEST_PATH); - + configureObjectUnderTest(); SOURCE.start(buffer); - + final String invalidUsername = "wrong_user"; final String invalidPassword = "wrong_password"; final String invalidCredentials = Base64.getEncoder() .encodeToString(String.format("%s:%s", invalidUsername, invalidPassword).getBytes(StandardCharsets.UTF_8)); - + final String transformedPath = "/" + TEST_PIPELINE_NAME + "/v1/metrics"; - + WebClient.of().prepare() .post("http://127.0.0.1:21891" + transformedPath) .content(MediaType.JSON_UTF_8, JsonFormat.printer().print(createExportMetricsRequest()).getBytes()) @@ -450,7 +449,7 @@ void testHttpRequestWithInvalidCredentials_with_unsuccessful_response() throws I .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNAUTHORIZED, throwable)) .join(); } - + @Test void testGrpcRequestWithInvalidCredentials_with_unsuccessful_response() throws Exception { when(httpBasicAuthenticationConfig.getUsername()).thenReturn(USERNAME); @@ -489,10 +488,10 @@ void testHttpWithoutSslFailsWhenSslIsEnabled() throws InvalidProtocolBufferExcep when(oTelMetricsSourceConfig.getSslKeyFile()).thenReturn("data/certificate/test_decrypted_key.key"); configureObjectUnderTest(); SOURCE.start(buffer); - + WebClient client = WebClient.builder("http://127.0.0.1:21891") .build(); - + CompletionException exception = assertThrows(CompletionException.class, () -> client.execute(RequestHeaders.builder() .scheme(SessionProtocol.HTTP) .authority("127.0.0.1:21891") @@ -503,10 +502,10 @@ void testHttpWithoutSslFailsWhenSslIsEnabled() throws InvalidProtocolBufferExcep HttpData.copyOf(JsonFormat.printer().print(createExportMetricsRequest()).getBytes())) .aggregate() .join()); - + assertThat(exception.getCause(), instanceOf(ClosedSessionException.class)); } - + @Test void testGrpcFailsIfSslIsEnabledAndNoTls() { when(oTelMetricsSourceConfig.isSsl()).thenReturn(true); @@ -514,17 +513,17 @@ void testGrpcFailsIfSslIsEnabledAndNoTls() { when(oTelMetricsSourceConfig.getSslKeyFile()).thenReturn("data/certificate/test_decrypted_key.key"); configureObjectUnderTest(); SOURCE.start(buffer); - + MetricsServiceGrpc.MetricsServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) .build(MetricsServiceGrpc.MetricsServiceBlockingStub.class); - + StatusRuntimeException actualException = assertThrows(StatusRuntimeException.class, () -> client.export(createExportMetricsRequest())); - + assertThat(actualException.getStatus(), notNullValue()); assertThat(actualException.getStatus().getCode(), equalTo(Status.Code.UNKNOWN)); } - - + + @Test void testServerStartCertFileSuccess() throws IOException { try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { diff --git a/data-prepper-plugins/otel-trace-source/build.gradle b/data-prepper-plugins/otel-trace-source/build.gradle index 474c19ed2f..47e3c5c6e0 100644 --- a/data-prepper-plugins/otel-trace-source/build.gradle +++ b/data-prepper-plugins/otel-trace-source/build.gradle @@ -32,6 +32,7 @@ dependencies { testImplementation 'org.assertj:assertj-core:3.27.3' testImplementation testLibs.slf4j.simple testImplementation libs.commons.io + testImplementation 'com.jayway.jsonpath:json-path-assert:2.6.0' testImplementation 'org.skyscreamer:jsonassert:1.5.3' } diff --git a/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource.java b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource.java index b353f7a6a2..1d8e750492 100644 --- a/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource.java +++ b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource.java @@ -5,46 +5,48 @@ package org.opensearch.dataprepper.plugins.source.oteltrace; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.util.BlockingTaskExecutor; import com.linecorp.armeria.server.Server; -import io.grpc.MethodDescriptor; -import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceRequest; -import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceResponse; -import io.opentelemetry.proto.collector.trace.v1.TraceServiceGrpc; -import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.encoding.DecodingService; +import com.linecorp.armeria.server.healthcheck.HealthCheckService; + import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.codec.ByteDecoder; import org.opensearch.dataprepper.model.configuration.PipelineDescription; -import org.opensearch.dataprepper.model.configuration.PluginModel; -import org.opensearch.dataprepper.model.configuration.PluginSetting; import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.source.Source; import org.opensearch.dataprepper.plugins.certificate.CertificateProvider; -import org.opensearch.dataprepper.plugins.otel.codec.OTelOutputFormat; -import org.opensearch.dataprepper.plugins.otel.codec.OTelProtoOpensearchCodec; -import org.opensearch.dataprepper.plugins.otel.codec.OTelProtoStandardCodec; -import org.opensearch.dataprepper.plugins.otel.codec.OTelTraceDecoder; -import org.opensearch.dataprepper.plugins.server.CreateServer; -import org.opensearch.dataprepper.plugins.server.ServerConfiguration; +import org.opensearch.dataprepper.plugins.certificate.model.Certificate; +import org.opensearch.dataprepper.plugins.codec.CompressionOption; import org.opensearch.dataprepper.plugins.source.oteltrace.certificate.CertificateProviderFactory; +import org.opensearch.dataprepper.plugins.otel.codec.OTelTraceDecoder; +import org.opensearch.dataprepper.plugins.source.oteltrace.grpc.GrpcService; +import org.opensearch.dataprepper.plugins.source.oteltrace.http.HttpService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Collections; +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; import java.util.concurrent.ExecutionException; @DataPrepperPlugin(name = "otel_trace_source", pluginType = Source.class, pluginConfigurationType = OTelTraceSourceConfig.class) public class OTelTraceSource implements Source> { private static final String PLUGIN_NAME = "otel_trace_source"; private static final Logger LOG = LoggerFactory.getLogger(OTelTraceSource.class); + + + private static final String HTTP_HEALTH_CHECK_PATH = "/health"; static final String SERVER_CONNECTIONS = "serverConnections"; private final OTelTraceSourceConfig oTelTraceSourceConfig; private final PluginMetrics pluginMetrics; - private final GrpcAuthenticationProvider authenticationProvider; + private final PluginFactory pluginFactory; private final CertificateProviderFactory certificateProviderFactory; private final String pipelineName; private Server server; @@ -62,9 +64,9 @@ public OTelTraceSource(final OTelTraceSourceConfig oTelTraceSourceConfig, final oTelTraceSourceConfig.validateAndInitializeCertAndKeyFileInS3(); this.oTelTraceSourceConfig = oTelTraceSourceConfig; this.pluginMetrics = pluginMetrics; + this.pluginFactory = pluginFactory; this.certificateProviderFactory = certificateProviderFactory; this.pipelineName = pipelineDescription.getPipelineName(); - this.authenticationProvider = createAuthenticationProvider(pluginFactory); this.byteDecoder = new OTelTraceDecoder(oTelTraceSourceConfig.getOutputFormat()); } @@ -80,39 +82,99 @@ public void start(Buffer> buffer) { } if (server == null) { + ServerBuilder serverBuilder = Server.builder().port(oTelTraceSourceConfig.getPort(), inferProtocolFromConfig()); - final OTelTraceGrpcService oTelTraceGrpcService = new OTelTraceGrpcService( - (int)(oTelTraceSourceConfig.getRequestTimeoutInMillis() * 0.8), - oTelTraceSourceConfig.getOutputFormat() == OTelOutputFormat.OPENSEARCH ? new OTelProtoOpensearchCodec.OTelProtoDecoder() : new OTelProtoStandardCodec.OTelProtoDecoder(), - buffer, - pluginMetrics, - null - ); + configureHeadersAndHealthCheck(serverBuilder); + configureTLS(serverBuilder); + configureTaskExecutor(serverBuilder); - ServerConfiguration serverConfiguration = ConvertConfiguration.convertConfiguration(oTelTraceSourceConfig); - CreateServer createServer = new CreateServer(serverConfiguration, LOG, pluginMetrics, PLUGIN_NAME, pipelineName); - CertificateProvider certificateProvider = null; - if (oTelTraceSourceConfig.isSsl() || oTelTraceSourceConfig.useAcmCertForSSL()) { - certificateProvider = certificateProviderFactory.getCertificateProvider(); + configureGrpcService(serverBuilder, buffer); + + // needed until clarified if unframedRequests should survive + if (!oTelTraceSourceConfig.enableUnframedRequests()) { + configureHttpService(serverBuilder, buffer); } - final MethodDescriptor methodDescriptor = TraceServiceGrpc.getExportMethod(); - server = createServer.createGRPCServer(authenticationProvider, oTelTraceGrpcService, certificateProvider, methodDescriptor); + + server = serverBuilder.build(); pluginMetrics.gauge(SERVER_CONNECTIONS, server, Server::numConnections); } try { server.start().get(); } catch (ExecutionException ex) { - if (ex.getCause() != null && ex.getCause() instanceof RuntimeException) { - throw (RuntimeException) ex.getCause(); - } else { - throw new RuntimeException(ex); - } + handleExecutionException(ex); } catch (InterruptedException ex) { Thread.currentThread().interrupt(); throw new RuntimeException(ex); } - LOG.info("Started otel_trace_source on port " + oTelTraceSourceConfig.getPort() + "..."); + LOG.info("Started otel_trace_source on port {}...", oTelTraceSourceConfig.getPort()); + } + + private SessionProtocol inferProtocolFromConfig() { + if (oTelTraceSourceConfig.isSsl()) { + return SessionProtocol.HTTPS; + } else { + return SessionProtocol.HTTP; + } + } + + private void handleExecutionException(ExecutionException ex) { + if (ex.getCause() != null && ex.getCause() instanceof RuntimeException) { + throw (RuntimeException) ex.getCause(); + } else { + throw new RuntimeException(ex); + } + } + + private void configureGrpcService(ServerBuilder serverBuilder, Buffer> buffer) { + com.linecorp.armeria.server.grpc.GrpcService grpcService = new GrpcService(pluginFactory, oTelTraceSourceConfig, pluginMetrics, pipelineName).create(buffer, serverBuilder); + + if (CompressionOption.NONE.equals(oTelTraceSourceConfig.getCompression())) { + serverBuilder.service(grpcService); + } else { + serverBuilder.service(grpcService, DecodingService.newDecorator()); + } + } + + private void configureHttpService(ServerBuilder serverBuilder, Buffer> buffer) { + new HttpService(pluginMetrics, oTelTraceSourceConfig, pluginFactory).create(serverBuilder, buffer); + } + + private void configureHeadersAndHealthCheck(ServerBuilder serverBuilder) { + serverBuilder.disableServerHeader(); + if (oTelTraceSourceConfig.enableHttpHealthCheck()) { + serverBuilder.service(HTTP_HEALTH_CHECK_PATH, HealthCheckService.builder().longPolling(0).build()); + } + serverBuilder.requestTimeoutMillis(oTelTraceSourceConfig.getRequestTimeoutInMillis()); + if(oTelTraceSourceConfig.getMaxRequestLength() != null) { + serverBuilder.maxRequestLength(oTelTraceSourceConfig.getMaxRequestLength().getBytes()); + } + serverBuilder.maxNumConnections(oTelTraceSourceConfig.getMaxConnectionCount()); + } + + private void configureTLS(ServerBuilder serverBuilder) { + if (oTelTraceSourceConfig.isSsl() || oTelTraceSourceConfig.useAcmCertForSSL()) { + LOG.info("SSL/TLS is enabled."); + final CertificateProvider certificateProvider = certificateProviderFactory.getCertificateProvider(); + final Certificate certificate = certificateProvider.getCertificate(); + serverBuilder.https(oTelTraceSourceConfig.getPort()).tls( + new ByteArrayInputStream(certificate.getCertificate().getBytes(StandardCharsets.UTF_8)), + new ByteArrayInputStream(certificate.getPrivateKey().getBytes(StandardCharsets.UTF_8) + ) + ); + } else { + LOG.warn("Creating otel_trace_source without SSL/TLS. This is not secure."); + LOG.warn("In order to set up TLS for the otel_trace_source, go here: https://github.com/opensearch-project/data-prepper/tree/main/data-prepper-plugins/otel-trace-source#ssl"); + serverBuilder.http(oTelTraceSourceConfig.getPort()); + } + } + + private void configureTaskExecutor(ServerBuilder serverBuilder) { + final BlockingTaskExecutor blockingTaskExecutor = BlockingTaskExecutor.builder() + .numThreads(oTelTraceSourceConfig.getThreadCount()) + .threadNamePrefix(pipelineName + "-otel_trace") + .build(); + serverBuilder.blockingTaskExecutor(blockingTaskExecutor, true); } @Override @@ -133,22 +195,4 @@ public void stop() { } LOG.info("Stopped otel_trace_source."); } - - private GrpcAuthenticationProvider createAuthenticationProvider(final PluginFactory pluginFactory) { - final PluginModel authenticationConfiguration = oTelTraceSourceConfig.getAuthentication(); - - if (authenticationConfiguration == null || authenticationConfiguration.getPluginName().equals(GrpcAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME)) { - LOG.warn("Creating otel-trace-source without authentication. This is not secure."); - LOG.warn("In order to set up Http Basic authentication for the otel-trace-source, go here: https://github.com/opensearch-project/data-prepper/tree/main/data-prepper-plugins/otel-trace-source#authentication-configurations"); - } - - final PluginSetting authenticationPluginSetting; - if (authenticationConfiguration != null) { - authenticationPluginSetting = new PluginSetting(authenticationConfiguration.getPluginName(), authenticationConfiguration.getPluginSettings()); - } else { - authenticationPluginSetting = new PluginSetting(GrpcAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME, Collections.emptyMap()); - } - authenticationPluginSetting.setPipelineName(pipelineName); - return pluginFactory.loadPlugin(GrpcAuthenticationProvider.class, authenticationPluginSetting); - } } diff --git a/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/grpc/GrpcService.java b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/grpc/GrpcService.java new file mode 100644 index 0000000000..9254712a60 --- /dev/null +++ b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/grpc/GrpcService.java @@ -0,0 +1,148 @@ +package org.opensearch.dataprepper.plugins.source.oteltrace.grpc; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +import org.opensearch.dataprepper.GrpcRequestExceptionHandler; +import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.otel.codec.OTelProtoOpensearchCodec; +import org.opensearch.dataprepper.plugins.server.HealthGrpcService; +import org.opensearch.dataprepper.plugins.server.RetryInfoConfig; +import org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceGrpcService; +import org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction; +import com.linecorp.armeria.server.HttpService; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.grpc.GrpcServiceBuilder; + +import io.grpc.MethodDescriptor; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.protobuf.services.ProtoReflectionService; +import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceRequest; +import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceResponse; +import io.opentelemetry.proto.collector.trace.v1.TraceServiceGrpc; + +public class GrpcService { + private static final Logger LOG = LoggerFactory.getLogger(GrpcService.class); + + // Default RetryInfo with minimum 100ms and maximum 2s + private static final RetryInfoConfig DEFAULT_RETRY_INFO = new RetryInfoConfig(Duration.ofMillis(100), Duration.ofMillis(2000)); + + private static final String PIPELINE_NAME_PLACEHOLDER = "${pipelineName}"; + public static final String REGEX_HEALTH = "regex:^/(?!health$).*$"; + + private final OTelTraceSourceConfig oTelTraceSourceConfig; + private final GrpcAuthenticationProvider authenticationProvider; + private final PluginMetrics pluginMetrics; + private final String pipelineName; + + public GrpcService(PluginFactory pluginFactory, OTelTraceSourceConfig oTelTraceSourceConfig, PluginMetrics pluginMetrics, String pipelineName) { + this.oTelTraceSourceConfig = oTelTraceSourceConfig; + this.pluginMetrics = pluginMetrics; + this.pipelineName = pipelineName; + this.authenticationProvider = createAuthenticationProvider(pluginFactory, oTelTraceSourceConfig); + } + + public com.linecorp.armeria.server.grpc.GrpcService create(Buffer> buffer, ServerBuilder serverBuilder) { + + final OTelTraceGrpcService oTelTraceGrpcService = new OTelTraceGrpcService( + (int)(oTelTraceSourceConfig.getRequestTimeoutInMillis() * 0.8), + new OTelProtoOpensearchCodec.OTelProtoDecoder(), + buffer, + pluginMetrics, + null + ); + + final List serverInterceptors = getAuthenticationInterceptor(); + + final GrpcServiceBuilder grpcServiceBuilder = com.linecorp.armeria.server.grpc.GrpcService + .builder() + .useClientTimeoutHeader(false) + .useBlockingTaskExecutor(true) + .exceptionHandler(createGrpExceptionHandler()); + + final MethodDescriptor methodDescriptor = TraceServiceGrpc.getExportMethod(); + final String oTelTraceSourcePath = oTelTraceSourceConfig.getPath(); + if (oTelTraceSourcePath != null) { + final String transformedOTelTraceSourcePath = oTelTraceSourcePath.replace(PIPELINE_NAME_PLACEHOLDER, pipelineName); + grpcServiceBuilder.addService(transformedOTelTraceSourcePath, + ServerInterceptors.intercept(oTelTraceGrpcService, serverInterceptors), methodDescriptor); + } else { + grpcServiceBuilder.addService(ServerInterceptors.intercept(oTelTraceGrpcService, serverInterceptors)); + } + + if (oTelTraceSourceConfig.hasHealthCheck()) { + LOG.info("Health check is enabled"); + grpcServiceBuilder.addService(new HealthGrpcService()); + } + + if (oTelTraceSourceConfig.hasProtoReflectionService()) { + LOG.info("Proto reflection service is enabled"); + grpcServiceBuilder.addService(ProtoReflectionService.newInstance()); + } + + // todo still needed with new http-service? + grpcServiceBuilder.enableUnframedRequests(oTelTraceSourceConfig.enableUnframedRequests()); + + if (oTelTraceSourceConfig.getAuthentication() != null) { + final Optional> optionalHttpAuthenticationService = + authenticationProvider.getHttpAuthenticationService(); + + if (oTelTraceSourceConfig.isUnauthenticatedHealthCheck()) { + optionalHttpAuthenticationService.ifPresent(httpAuthenticationService -> + serverBuilder.decorator(REGEX_HEALTH, httpAuthenticationService)); + } else { + optionalHttpAuthenticationService.ifPresent(serverBuilder::decorator); + } + } + + return grpcServiceBuilder.build(); + } + + private List getAuthenticationInterceptor() { + final ServerInterceptor authenticationInterceptor = authenticationProvider.getAuthenticationInterceptor(); + if (authenticationInterceptor == null) { + return Collections.emptyList(); + } + return Collections.singletonList(authenticationInterceptor); + } + + private GrpcAuthenticationProvider createAuthenticationProvider(final PluginFactory pluginFactory, final OTelTraceSourceConfig oTelTraceSourceConfig) { + final PluginModel authenticationConfiguration = oTelTraceSourceConfig.getAuthentication(); + + if (authenticationConfiguration == null || authenticationConfiguration.getPluginName().equals(GrpcAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME)) { + LOG.warn("Creating otel_trace_source grpc service without authentication. This is not secure."); + LOG.warn("In order to set up Http Basic authentication for the otel-trace-source, go here: https://github.com/opensearch-project/data-prepper/tree/main/data-prepper-plugins/otel-trace-source#authentication-configurations"); + } + + final PluginSetting authenticationPluginSetting; + if (authenticationConfiguration != null) { + authenticationPluginSetting = new PluginSetting(authenticationConfiguration.getPluginName(), authenticationConfiguration.getPluginSettings()); + } else { + authenticationPluginSetting = new PluginSetting(GrpcAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME, Collections.emptyMap()); + } + authenticationPluginSetting.setPipelineName(pipelineName); + return pluginFactory.loadPlugin(GrpcAuthenticationProvider.class, authenticationPluginSetting); + } + + private GrpcExceptionHandlerFunction createGrpExceptionHandler() { + RetryInfoConfig retryInfo = oTelTraceSourceConfig.getRetryInfo() != null + ? oTelTraceSourceConfig.getRetryInfo() + : DEFAULT_RETRY_INFO; + + return new GrpcRequestExceptionHandler(pluginMetrics, retryInfo.getMinDelay(), retryInfo.getMaxDelay()); + } +} diff --git a/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/ArmeriaHttpService.java b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/ArmeriaHttpService.java new file mode 100644 index 0000000000..9dc3949a08 --- /dev/null +++ b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/ArmeriaHttpService.java @@ -0,0 +1,113 @@ +package org.opensearch.dataprepper.plugins.source.oteltrace.http; + +import java.time.Instant; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.dataprepper.exceptions.BadRequestException; +import org.opensearch.dataprepper.exceptions.BufferWriteException; +import org.opensearch.dataprepper.logging.DataPrepperMarkers; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.trace.Span; +import org.opensearch.dataprepper.plugins.otel.codec.OTelProtoCodec; +import org.opensearch.dataprepper.plugins.otel.codec.OTelProtoOpensearchCodec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.annotation.Consumes; +import com.linecorp.armeria.server.annotation.Post; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Timer; +import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceRequest; +import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceResponse; + +public class ArmeriaHttpService { + private static final Logger LOG = LoggerFactory.getLogger(ArmeriaHttpService.class); + + public static final String REQUEST_TIMEOUTS = "requestTimeouts"; + public static final String REQUESTS_RECEIVED = "requestsReceived"; + public static final String BAD_REQUESTS = "badRequests"; + public static final String REQUESTS_TOO_LARGE = "requestsTooLarge"; + public static final String INTERNAL_SERVER_ERROR = "internalServerError"; + public static final String SUCCESS_REQUESTS = "successRequests"; + public static final String PAYLOAD_SIZE = "payloadSize"; + public static final String REQUEST_PROCESS_DURATION = "requestProcessDuration"; + + private final OTelProtoCodec.OTelProtoDecoder oTelProtoDecoder; + private final Buffer> buffer; + + private final int bufferWriteTimeoutInMillis; + + private final Counter requestsReceivedCounter; + private final Counter successRequestsCounter; + private final DistributionSummary payloadSizeSummary; + private final Timer requestProcessDuration; + + public ArmeriaHttpService(Buffer> buffer, final PluginMetrics pluginMetrics, final int bufferWriteTimeoutInMillis) { + this.buffer = buffer; + this.oTelProtoDecoder = new OTelProtoOpensearchCodec.OTelProtoDecoder(); + this.bufferWriteTimeoutInMillis = bufferWriteTimeoutInMillis; + + requestsReceivedCounter = pluginMetrics.counter(REQUESTS_RECEIVED); + successRequestsCounter = pluginMetrics.counter(SUCCESS_REQUESTS); + payloadSizeSummary = pluginMetrics.summary(PAYLOAD_SIZE); + requestProcessDuration = pluginMetrics.timer(REQUEST_PROCESS_DURATION); + } + + // todo make path configurable + @Post("/opentelemetry.proto.collector.trace.v1.TraceService/Export") + @Consumes(value = "application/json") + public ExportTraceServiceResponse exportTrace(ExportTraceServiceRequest request) { + requestsReceivedCounter.increment(); + payloadSizeSummary.record(request.getSerializedSize()); + + requestProcessDuration.record(() -> processRequest(request)); + + return ExportTraceServiceResponse.newBuilder().build(); + } + + private void processRequest(final ExportTraceServiceRequest request) { + final Collection spans; + + try { + spans = oTelProtoDecoder.parseExportTraceServiceRequest(request, Instant.now()); + } catch (final Exception e) { + LOG.warn(DataPrepperMarkers.SENSITIVE, "Failed to parse request with error '{}'. Request body: {}.", e.getMessage(), request); + throw new BadRequestException(e.getMessage(), e); + } + + try { + if (buffer.isByteBuffer()) { + Map requestsMap = oTelProtoDecoder.splitExportTraceServiceRequestByTraceId(request); + for (Map.Entry entry: requestsMap.entrySet()) { + buffer.writeBytes(entry.getValue().toByteArray(), entry.getKey(), bufferWriteTimeoutInMillis); + } + } else { + final List> records = spans.stream().map(span -> new Record(span)).collect(Collectors.toList()); + buffer.writeAll(records, bufferWriteTimeoutInMillis); + } + } catch (final Exception e) { + if (ServiceRequestContext.current().isTimedOut()) { + LOG.warn("Exception writing to buffer but request already timed out.", e); + return; + } + + LOG.error("Failed to write the request of size {} due to:", request.toString().length(), e); + throw new BufferWriteException(e.getMessage(), e); + } + + if (ServiceRequestContext.current().isTimedOut()) { + LOG.warn("Buffer write completed successfully but request already timed out."); + return; + } + + successRequestsCounter.increment(); + } +} diff --git a/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpExceptionHandler.java b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpExceptionHandler.java new file mode 100644 index 0000000000..c06691d673 --- /dev/null +++ b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpExceptionHandler.java @@ -0,0 +1,138 @@ +package org.opensearch.dataprepper.plugins.source.oteltrace.http; + + +import java.time.Duration; +import java.util.concurrent.TimeoutException; + +import org.opensearch.dataprepper.RetryInfoCalculator; +import org.opensearch.dataprepper.exceptions.BadRequestException; +import org.opensearch.dataprepper.exceptions.BufferWriteException; +import org.opensearch.dataprepper.exceptions.RequestCancelledException; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.SizeOverflowException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import com.google.rpc.RetryInfo; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.server.RequestTimeoutException; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.annotation.ExceptionHandlerFunction; + +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.micrometer.core.instrument.Counter; + +public class HttpExceptionHandler implements ExceptionHandlerFunction { + private static final Logger LOG = LoggerFactory.getLogger(HttpExceptionHandler.class); + + static final String ARMERIA_REQUEST_TIMEOUT_MESSAGE = "Timeout waiting for request to be served. This is usually due to the buffer being full."; + public static final String REQUEST_TIMEOUTS = "requestTimeouts"; + public static final String BAD_REQUESTS = "badRequests"; + public static final String REQUESTS_TOO_LARGE = "requestsTooLarge"; + public static final String INTERNAL_SERVER_ERROR = "internalServerError"; + + private final Counter requestTimeoutsCounter; + private final Counter badRequestsCounter; + private final Counter requestsTooLargeCounter; + private final Counter internalServerErrorCounter; + private final RetryInfoCalculator retryInfoCalculator; + + public HttpExceptionHandler(final PluginMetrics pluginMetrics, Duration retryInfoMinDelay, Duration retryInfoMaxDelay) { + requestTimeoutsCounter = pluginMetrics.counter(REQUEST_TIMEOUTS); + badRequestsCounter = pluginMetrics.counter(BAD_REQUESTS); + requestsTooLargeCounter = pluginMetrics.counter(REQUESTS_TOO_LARGE); + internalServerErrorCounter = pluginMetrics.counter(INTERNAL_SERVER_ERROR); + this.retryInfoCalculator = new RetryInfoCalculator(retryInfoMinDelay, retryInfoMaxDelay); + } + + @Override + public HttpResponse handleException(final ServiceRequestContext ctx, + final HttpRequest req, + final Throwable e) { + final Throwable exceptionCause = e instanceof BufferWriteException ? e.getCause() : e; + StatusHolder statusHolder = createStatus(exceptionCause); + + try { + JsonFormat.TypeRegistry typeRegistry = JsonFormat.TypeRegistry.newBuilder() + .add(RetryInfo.getDescriptor()) + .build(); + + JsonFormat.Printer printer = JsonFormat.printer().usingTypeRegistry(typeRegistry); + return HttpResponse.of(statusHolder.getHttpStatus(), MediaType.JSON, printer.print(statusHolder.getStatus())); + } catch (InvalidProtocolBufferException ipbe) { + throw new RuntimeException(ipbe); + } + } + + private StatusHolder createStatus(Throwable e) { + if (e instanceof RequestTimeoutException || e instanceof TimeoutException) { + requestTimeoutsCounter.increment(); + return new StatusHolder(createStatus(e, Status.Code.RESOURCE_EXHAUSTED), createHttpStatusFromProtoBufStatus(Status.Code.RESOURCE_EXHAUSTED)); + } else if (e instanceof SizeOverflowException) { + requestsTooLargeCounter.increment(); + return new StatusHolder(createStatus(e, Status.Code.RESOURCE_EXHAUSTED), createHttpStatusFromProtoBufStatus(Status.Code.RESOURCE_EXHAUSTED)); + } else if (e instanceof BadRequestException) { + badRequestsCounter.increment(); + return new StatusHolder(createStatus(e, Status.Code.INVALID_ARGUMENT), createHttpStatusFromProtoBufStatus(Status.Code.INVALID_ARGUMENT)); + } else if ((e instanceof StatusRuntimeException) && (e.getMessage().contains("Invalid protobuf byte sequence") || e.getMessage().contains("Can't decode compressed frame"))) { + badRequestsCounter.increment(); + return new StatusHolder(createStatus(e, Status.Code.INVALID_ARGUMENT), createHttpStatusFromProtoBufStatus(Status.Code.INVALID_ARGUMENT)); + } else if (e instanceof RequestCancelledException) { + requestTimeoutsCounter.increment(); + return new StatusHolder(createStatus(e, Status.Code.CANCELLED), createHttpStatusFromProtoBufStatus(Status.Code.CANCELLED)); + } else { + LOG.error("Unexpected exception handling http request", e); + internalServerErrorCounter.increment(); + return new StatusHolder(createStatus(e, Status.Code.INTERNAL), createHttpStatusFromProtoBufStatus(Status.Code.INTERNAL)); + } + } + + private HttpStatus createHttpStatusFromProtoBufStatus(Status.Code status) { + if (status == Status.Code.RESOURCE_EXHAUSTED) { + return HttpStatus.INSUFFICIENT_STORAGE; + } else if (status == Status.Code.INVALID_ARGUMENT) { + return HttpStatus.BAD_REQUEST; + } else { + return HttpStatus.INTERNAL_SERVER_ERROR; + } + } + + private com.google.rpc.Status createStatus(final Throwable e, final Status.Code code) { + com.google.rpc.Status.Builder builder = com.google.rpc.Status.newBuilder().setCode(code.value()); + if (e instanceof RequestTimeoutException) { + builder.setMessage(ARMERIA_REQUEST_TIMEOUT_MESSAGE); + } else { + builder.setMessage(e.getMessage() == null ? code.name() :e.getMessage()); + } + if (code == Status.Code.RESOURCE_EXHAUSTED) { + builder.addDetails(Any.pack(retryInfoCalculator.createRetryInfo())); + } + return builder.build(); + } + + private static class StatusHolder { + private final HttpStatus httpStatus; + private final com.google.rpc.Status status; + + public StatusHolder(com.google.rpc.Status status, HttpStatus httpStatus) { + this.httpStatus = httpStatus; + this.status = status; + } + + public HttpStatus getHttpStatus() { + return httpStatus; + } + + public com.google.rpc.Status getStatus() { + return status; + } + } + +} diff --git a/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpService.java b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpService.java new file mode 100644 index 0000000000..107010a8b5 --- /dev/null +++ b/data-prepper-plugins/otel-trace-source/src/main/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpService.java @@ -0,0 +1,70 @@ +package org.opensearch.dataprepper.plugins.source.oteltrace.http; + +import static org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME; + +import java.time.Duration; +import java.util.Map; + +import org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.codec.CompressionOption; +import org.opensearch.dataprepper.plugins.server.RetryInfoConfig; +import org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.encoding.DecodingService; + +public class HttpService { + private static final Logger LOG = LoggerFactory.getLogger(HttpService.class); + private static final RetryInfoConfig DEFAULT_RETRY_INFO = new RetryInfoConfig(Duration.ofMillis(100), Duration.ofMillis(2000)); + + private final PluginMetrics pluginMetrics; + private final OTelTraceSourceConfig oTelTraceSourceConfig; + private final PluginFactory pluginFactory; + + public HttpService(PluginMetrics pluginMetrics, OTelTraceSourceConfig oTelTraceSourceConfig, PluginFactory pluginFactory) { + this.pluginMetrics = pluginMetrics; + this.oTelTraceSourceConfig = oTelTraceSourceConfig; + this.pluginFactory = pluginFactory; + } + + public ArmeriaHttpService create(ServerBuilder serverBuilder, Buffer> buffer) { + RetryInfoConfig retryInfo = oTelTraceSourceConfig.getRetryInfo() != null + ? oTelTraceSourceConfig.getRetryInfo() + : DEFAULT_RETRY_INFO; + ArmeriaHttpService httpService = new ArmeriaHttpService(buffer, pluginMetrics, oTelTraceSourceConfig.getRequestTimeoutInMillis()); + HttpExceptionHandler httpExceptionHandler = new HttpExceptionHandler(pluginMetrics, retryInfo.getMinDelay(), retryInfo.getMaxDelay()); + + configureAuthentication(serverBuilder); + + if (CompressionOption.NONE.equals(oTelTraceSourceConfig.getCompression())) { + serverBuilder.annotatedService(httpService, httpExceptionHandler); + } else { + serverBuilder.annotatedService(httpService, DecodingService.newDecorator(), httpExceptionHandler); + } + + return httpService; + } + + private void configureAuthentication(ServerBuilder serverBuilder) { + if (oTelTraceSourceConfig.getAuthentication() == null || oTelTraceSourceConfig.getAuthentication().getPluginName().equals(UNAUTHENTICATED_PLUGIN_NAME)) { + LOG.warn("Creating otel_trace_source http service without authentication. This is not secure."); + LOG.warn("In order to set up Http Basic authentication for the otel-trace-source, go here: https://github.com/opensearch-project/data-prepper/tree/main/data-prepper-plugins/otel-trace-source#authentication-configurations"); + } else { + ArmeriaHttpAuthenticationProvider authenticationProvider = createAuthenticationProvider(oTelTraceSourceConfig.getAuthentication()); + authenticationProvider.getAuthenticationDecorator().ifPresent(serverBuilder::decorator); + } + } + + private ArmeriaHttpAuthenticationProvider createAuthenticationProvider(final PluginModel authenticationConfiguration) { + Map pluginSettings = authenticationConfiguration.getPluginSettings(); + return pluginFactory.loadPlugin(ArmeriaHttpAuthenticationProvider.class, new PluginSetting(authenticationConfiguration.getPluginName(), pluginSettings)); + } +} diff --git a/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSourceTest.java b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSourceTest.java index 39ef2a9b42..7a923ff0dd 100644 --- a/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSourceTest.java +++ b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSourceTest.java @@ -10,19 +10,15 @@ import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; -import com.linecorp.armeria.client.ClientFactory; import com.linecorp.armeria.client.Clients; import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.common.AggregatedHttpResponse; -import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.HttpData; -import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.SessionProtocol; -import com.linecorp.armeria.server.HttpService; import com.linecorp.armeria.server.Server; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.server.grpc.GrpcService; @@ -36,7 +32,6 @@ import io.micrometer.core.instrument.Statistic; import io.netty.util.AsciiString; import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceRequest; -import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceResponse; import io.opentelemetry.proto.collector.trace.v1.TraceServiceGrpc; import io.opentelemetry.proto.trace.v1.ResourceSpans; import io.opentelemetry.proto.trace.v1.ScopeSpans; @@ -46,24 +41,20 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.api.extension.ExtensionContext; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.ArgumentsProvider; -import org.junit.jupiter.params.provider.ArgumentsSource; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.GrpcRequestExceptionHandler; +import org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider; import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider; import org.opensearch.dataprepper.armeria.authentication.HttpBasicAuthenticationConfig; import org.opensearch.dataprepper.metrics.MetricNames; import org.opensearch.dataprepper.metrics.MetricsTestUtil; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.buffer.Buffer; -import org.opensearch.dataprepper.model.buffer.SizeOverflowException; import org.opensearch.dataprepper.model.configuration.PipelineDescription; import org.opensearch.dataprepper.model.configuration.PluginModel; import org.opensearch.dataprepper.model.configuration.PluginSetting; @@ -71,6 +62,7 @@ import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.types.ByteCount; import org.opensearch.dataprepper.plugins.GrpcBasicAuthenticationProvider; +import org.opensearch.dataprepper.plugins.HttpBasicArmeriaHttpAuthenticationProvider; import org.opensearch.dataprepper.plugins.certificate.CertificateProvider; import org.opensearch.dataprepper.plugins.certificate.model.Certificate; import org.opensearch.dataprepper.plugins.codec.CompressionOption; @@ -78,7 +70,6 @@ import org.opensearch.dataprepper.plugins.server.RetryInfoConfig; import org.opensearch.dataprepper.plugins.source.oteltrace.certificate.CertificateProviderFactory; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; @@ -86,28 +77,20 @@ import java.nio.file.Path; import java.time.Duration; import java.util.Base64; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.StringJoiner; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; import java.util.function.Function; import java.util.stream.Collectors; -import java.util.stream.Stream; -import java.util.zip.GZIPOutputStream; -import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; @@ -115,20 +98,16 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.DEFAULT_PORT; import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.DEFAULT_REQUEST_TIMEOUT_MS; @@ -143,14 +122,6 @@ class OTelTraceSourceTest { private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().registerModule(new JavaTimeModule()); private static final String TEST_PIPELINE_NAME = "test_pipeline"; private static final RetryInfoConfig TEST_RETRY_INFO = new RetryInfoConfig(Duration.ofMillis(50), Duration.ofMillis(2000)); - private static final ExportTraceServiceRequest SUCCESS_REQUEST = ExportTraceServiceRequest.newBuilder() - .addResourceSpans(ResourceSpans.newBuilder() - .addScopeSpans(ScopeSpans.newBuilder() - .addSpans(io.opentelemetry.proto.trace.v1.Span.newBuilder().setTraceState("SUCCESS").build())).build()).build(); - private static final ExportTraceServiceRequest FAILURE_REQUEST = ExportTraceServiceRequest.newBuilder() - .addResourceSpans(ResourceSpans.newBuilder() - .addScopeSpans(ScopeSpans.newBuilder() - .addSpans(io.opentelemetry.proto.trace.v1.Span.newBuilder().setTraceState("FAILURE").build())).build()).build(); @Mock private ServerBuilder serverBuilder; @@ -182,6 +153,9 @@ class OTelTraceSourceTest { @Mock private GrpcBasicAuthenticationProvider authenticationProvider; + @Mock + private HttpBasicArmeriaHttpAuthenticationProvider armeriaHttpAuthenticationProvider; + @Mock(lenient = true) private OTelTraceSourceConfig oTelTraceSourceConfig; @@ -191,7 +165,6 @@ class OTelTraceSourceTest { @Mock private HttpBasicAuthenticationConfig httpBasicAuthenticationConfig; - private PluginSetting pluginSetting; private PluginSetting testPluginSetting; private PluginMetrics pluginMetrics; private PipelineDescription pipelineDescription; @@ -199,11 +172,13 @@ class OTelTraceSourceTest { @BeforeEach void beforeEach() { + lenient().when(serverBuilder.port(anyInt(), ArgumentMatchers.any())).thenReturn(serverBuilder); lenient().when(serverBuilder.service(any(GrpcService.class))).thenReturn(serverBuilder); lenient().when(serverBuilder.service(any(GrpcService.class), any(Function.class))).thenReturn(serverBuilder); lenient().when(serverBuilder.http(anyInt())).thenReturn(serverBuilder); lenient().when(serverBuilder.https(anyInt())).thenReturn(serverBuilder); lenient().when(serverBuilder.build()).thenReturn(server); + lenient().when(server.start()).thenReturn(completableFuture); lenient().when(grpcServiceBuilder.addService(any(BindableService.class))).thenReturn(grpcServiceBuilder); @@ -223,8 +198,10 @@ void beforeEach() { when(oTelTraceSourceConfig.getCompression()).thenReturn(CompressionOption.NONE); when(oTelTraceSourceConfig.getRetryInfo()).thenReturn(TEST_RETRY_INFO); - when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))) + lenient().when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))) .thenReturn(authenticationProvider); + lenient().when(pluginFactory.loadPlugin(eq(ArmeriaHttpAuthenticationProvider.class), any(PluginSetting.class))) + .thenReturn(armeriaHttpAuthenticationProvider); configureObjectUnderTest(); pipelineDescription = mock(PipelineDescription.class); lenient().when(pipelineDescription.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); @@ -244,164 +221,6 @@ private void configureObjectUnderTest() { SOURCE = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, pipelineDescription); } - @Test - void testHttpFullJsonWithNonUnframedRequests() throws InvalidProtocolBufferException { - configureObjectUnderTest(); - SOURCE.start(buffer); - WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.copyOf(JsonFormat.printer().print(SUCCESS_REQUEST).getBytes())) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) - .join(); - WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.copyOf(JsonFormat.printer().print(FAILURE_REQUEST).getBytes())) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) - .join(); - } - - @Test - void testHttpsFullJsonWithNonUnframedRequests() throws InvalidProtocolBufferException { - - final Map settingsMap = new HashMap<>(); - settingsMap.put("request_timeout", 5); - settingsMap.put(SSL, true); - settingsMap.put("useAcmCertForSSL", false); - settingsMap.put("sslKeyCertChainFile", "data/certificate/test_cert.crt"); - settingsMap.put("sslKeyFile", "data/certificate/test_decrypted_key.key"); - pluginSetting = new PluginSetting("otel_trace", settingsMap); - pluginSetting.setPipelineName("pipeline"); - - oTelTraceSourceConfig = OBJECT_MAPPER.convertValue(pluginSetting.getSettings(), OTelTraceSourceConfig.class); - SOURCE = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, pipelineDescription); - - SOURCE.start(buffer); - - WebClient.builder().factory(ClientFactory.insecure()).build().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTPS) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.copyOf(JsonFormat.printer().print(SUCCESS_REQUEST).getBytes())) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) - .join(); - WebClient.builder().factory(ClientFactory.insecure()).build().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTPS) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.copyOf(JsonFormat.printer().print(FAILURE_REQUEST).getBytes())) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) - .join(); - } - - @Test - void testHttpFullBytesWithNonUnframedRequests() { - configureObjectUnderTest(); - SOURCE.start(buffer); - WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.PROTOBUF) - .build(), - HttpData.copyOf(SUCCESS_REQUEST.toByteArray())) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) - .join(); - WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.PROTOBUF) - .build(), - HttpData.copyOf(FAILURE_REQUEST.toByteArray())) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) - .join(); - } - - @Test - void testHttpFullJsonWithUnframedRequests() throws InvalidProtocolBufferException { - when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); - configureObjectUnderTest(); - SOURCE.start(buffer); - - WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.copyOf(JsonFormat.printer().print(createExportTraceRequest()).getBytes())) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.OK, throwable)) - .join(); - } - - @Test - void testHttpCompressionWithUnframedRequests() throws IOException { - when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); - when(oTelTraceSourceConfig.getCompression()).thenReturn(CompressionOption.GZIP); - configureObjectUnderTest(); - SOURCE.start(buffer); - - WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.JSON_UTF_8) - .add(HttpHeaderNames.CONTENT_ENCODING, "gzip") - .build(), - createGZipCompressedPayload(JsonFormat.printer().print(createExportTraceRequest()))) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.OK, throwable)) - .join(); - } - - @Test - void testHttpFullJsonWithCustomPathAndUnframedRequests() throws InvalidProtocolBufferException { - when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); - when(oTelTraceSourceConfig.getPath()).thenReturn(TEST_PATH); - configureObjectUnderTest(); - SOURCE.start(buffer); - - final String transformedPath = "/" + TEST_PIPELINE_NAME + "/v1/traces"; - WebClient.of().execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path(transformedPath) - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.copyOf(JsonFormat.printer().print(createExportTraceRequest()).getBytes())) - .aggregate() - .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.OK, throwable)) - .join(); - } - @Test void testHttpFullJsonWithCustomPathAndAuthHeader_with_successful_response() throws InvalidProtocolBufferException { when(httpBasicAuthenticationConfig.getUsername()).thenReturn(USERNAME); @@ -526,31 +345,7 @@ void testGrpcRequestWithoutAuthentication_with_unsuccessful_response() throws Ex assertThat(actualException.getStatus().getCode(), equalTo(Status.Code.UNAUTHENTICATED)); } - @Test - void testHttpWithoutSslFailsWhenSslIsEnabled() throws InvalidProtocolBufferException { - when(oTelTraceSourceConfig.isSsl()).thenReturn(true); - when(oTelTraceSourceConfig.getSslKeyCertChainFile()).thenReturn("data/certificate/test_cert.crt"); - when(oTelTraceSourceConfig.getSslKeyFile()).thenReturn("data/certificate/test_decrypted_key.key"); - configureObjectUnderTest(); - SOURCE.start(buffer); - - WebClient client = WebClient.builder("http://127.0.0.1:21890") - .build(); - - CompletionException exception = assertThrows(CompletionException.class, () -> client.execute(RequestHeaders.builder() - .scheme(SessionProtocol.HTTP) - .authority("127.0.0.1:21890") - .method(HttpMethod.POST) - .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") - .contentType(MediaType.JSON_UTF_8) - .build(), - HttpData.copyOf(JsonFormat.printer().print(createExportTraceRequest()).getBytes())) - .aggregate() - .join()); - - assertThat(exception.getCause(), instanceOf(ClosedSessionException.class)); - } - + @Test void testGrpcFailsIfSslIsEnabledAndNoTls() { when(oTelTraceSourceConfig.isSsl()).thenReturn(true); @@ -682,49 +477,6 @@ void start_with_Health_configured_includes_HealthCheck_service() throws IOExcept verify(serverBuilder, never()).service(eq("/health"),isA(HealthCheckService.class)); } - @Test - void start_with_Health_configured_unframed_requests_includes_HTTPHealthCheck_service() throws IOException { - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class); - MockedStatic grpcServerMock = Mockito.mockStatic(GrpcService.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - grpcServerMock.when(GrpcService::builder).thenReturn(grpcServiceBuilder); - when(grpcServiceBuilder.addService(any(ServerServiceDefinition.class))).thenReturn(grpcServiceBuilder); - when(grpcServiceBuilder.useClientTimeoutHeader(anyBoolean())).thenReturn(grpcServiceBuilder); - - when(server.stop()).thenReturn(completableFuture); - final Path certFilePath = Path.of("data/certificate/test_cert.crt"); - final Path keyFilePath = Path.of("data/certificate/test_decrypted_key.key"); - final String certAsString = Files.readString(certFilePath); - final String keyAsString = Files.readString(keyFilePath); - when(certificate.getCertificate()).thenReturn(certAsString); - when(certificate.getPrivateKey()).thenReturn(keyAsString); - when(certificateProvider.getCertificate()).thenReturn(certificate); - when(certificateProviderFactory.getCertificateProvider()).thenReturn(certificateProvider); - final Map settingsMap = new HashMap<>(); - settingsMap.put(SSL, true); - settingsMap.put("useAcmCertForSSL", true); - settingsMap.put("awsRegion", "us-east-1"); - settingsMap.put("acmCertificateArn", "arn:aws:acm:us-east-1:account:certificate/1234-567-856456"); - settingsMap.put("sslKeyCertChainFile", "data/certificate/test_cert.crt"); - settingsMap.put("sslKeyFile", "data/certificate/test_decrypted_key.key"); - settingsMap.put("health_check_service", "true"); - settingsMap.put("unframed_requests", "true"); - - testPluginSetting = new PluginSetting(null, settingsMap); - testPluginSetting.setPipelineName("pipeline"); - - oTelTraceSourceConfig = OBJECT_MAPPER.convertValue(testPluginSetting.getSettings(), OTelTraceSourceConfig.class); - final OTelTraceSource source = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, certificateProviderFactory, pipelineDescription); - source.start(buffer); - source.stop(); - } - - verify(grpcServiceBuilder, times(1)).useClientTimeoutHeader(false); - verify(grpcServiceBuilder, times(1)).useBlockingTaskExecutor(true); - verify(grpcServiceBuilder).addService(isA(HealthGrpcService.class)); - verify(serverBuilder).service(eq("/health"), isA(HealthCheckService.class)); - } - @Test void start_without_Health_configured_does_not_include_HealthCheck_service() throws IOException { try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class); @@ -766,48 +518,6 @@ void start_without_Health_configured_does_not_include_HealthCheck_service() thro verify(serverBuilder, never()).service(eq("/health"),isA(HealthCheckService.class)); } - @Test - void start_without_Health_configured_unframed_requests_does_not_include_HealthCheck_service() throws IOException { - try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class); - MockedStatic grpcServerMock = Mockito.mockStatic(GrpcService.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - grpcServerMock.when(GrpcService::builder).thenReturn(grpcServiceBuilder); - when(grpcServiceBuilder.addService(any(ServerServiceDefinition.class))).thenReturn(grpcServiceBuilder); - when(grpcServiceBuilder.useClientTimeoutHeader(anyBoolean())).thenReturn(grpcServiceBuilder); - - when(server.stop()).thenReturn(completableFuture); - final Path certFilePath = Path.of("data/certificate/test_cert.crt"); - final Path keyFilePath = Path.of("data/certificate/test_decrypted_key.key"); - final String certAsString = Files.readString(certFilePath); - final String keyAsString = Files.readString(keyFilePath); - when(certificate.getCertificate()).thenReturn(certAsString); - when(certificate.getPrivateKey()).thenReturn(keyAsString); - when(certificateProvider.getCertificate()).thenReturn(certificate); - when(certificateProviderFactory.getCertificateProvider()).thenReturn(certificateProvider); - final Map settingsMap = new HashMap<>(); - settingsMap.put(SSL, true); - settingsMap.put("useAcmCertForSSL", true); - settingsMap.put("awsRegion", "us-east-1"); - settingsMap.put("acmCertificateArn", "arn:aws:acm:us-east-1:account:certificate/1234-567-856456"); - settingsMap.put("sslKeyCertChainFile", "data/certificate/test_cert.crt"); - settingsMap.put("sslKeyFile", "data/certificate/test_decrypted_key.key"); - settingsMap.put("health_check_service", "false"); - settingsMap.put("unframed_requests", "true"); - - testPluginSetting = new PluginSetting(null, settingsMap); - testPluginSetting.setPipelineName("pipeline"); - oTelTraceSourceConfig = OBJECT_MAPPER.convertValue(testPluginSetting.getSettings(), OTelTraceSourceConfig.class); - final OTelTraceSource source = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, certificateProviderFactory, pipelineDescription); - source.start(buffer); - source.stop(); - } - - verify(grpcServiceBuilder, times(1)).useClientTimeoutHeader(false); - verify(grpcServiceBuilder, times(1)).useBlockingTaskExecutor(true); - verify(grpcServiceBuilder, never()).addService(isA(HealthGrpcService.class)); - verify(serverBuilder, never()).service(eq("/health"),isA(HealthCheckService.class)); - } - @Test void testHealthCheckUnauthNotAllowed() { // Prepare @@ -894,60 +604,6 @@ void testOptionalHttpAuthServiceNotInPlace() { verify(serverBuilder, never()).decorator(isA(Function.class)); } - @Test - void testOptionalHttpAuthServiceInPlace() { - final Optional> function = Optional.of(httpService -> httpService); - - final Map settingsMap = new HashMap<>(); - settingsMap.put("authentication", new PluginModel("test", null)); - settingsMap.put("unauthenticated_health_check", true); - - settingsMap.put(SSL, false); - - testPluginSetting = new PluginSetting(null, settingsMap); - testPluginSetting.setPipelineName("pipeline"); - oTelTraceSourceConfig = OBJECT_MAPPER.convertValue(testPluginSetting.getSettings(), OTelTraceSourceConfig.class); - - when(authenticationProvider.getHttpAuthenticationService()).thenReturn(function); - - final OTelTraceSource source = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, certificateProviderFactory, pipelineDescription); - - try (final MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - source.start(buffer); - } - - verify(serverBuilder).service(isA(GrpcService.class)); - verify(serverBuilder).decorator(isA(String.class), isA(Function.class)); - } - - @Test - void testOptionalHttpAuthServiceInPlaceWithUnauthenticatedDisabled() { - final Optional> function = Optional.of(httpService -> httpService); - - final Map settingsMap = new HashMap<>(); - settingsMap.put("authentication", new PluginModel("test", null)); - settingsMap.put("unauthenticated_health_check", false); - - settingsMap.put(SSL, false); - - testPluginSetting = new PluginSetting(null, settingsMap); - testPluginSetting.setPipelineName("pipeline"); - oTelTraceSourceConfig = OBJECT_MAPPER.convertValue(testPluginSetting.getSettings(), OTelTraceSourceConfig.class); - - when(authenticationProvider.getHttpAuthenticationService()).thenReturn(function); - - final OTelTraceSource source = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, certificateProviderFactory, pipelineDescription); - - try (final MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class)) { - armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); - source.start(buffer); - } - - verify(serverBuilder).service(isA(GrpcService.class)); - verify(serverBuilder).decorator(isA(Function.class)); - } - @Test void testDoubleStart() { // starting server @@ -1070,112 +726,6 @@ void testStopWithInterruptedException() throws ExecutionException, InterruptedEx } } - @Test - void gRPC_request_writes_to_buffer_with_successful_response() throws Exception { - configureObjectUnderTest(); - SOURCE.start(buffer); - - final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) - .build(TraceServiceGrpc.TraceServiceBlockingStub.class); - final ExportTraceServiceResponse exportResponse = client.export(createExportTraceRequest()); - assertThat(exportResponse, notNullValue()); - - final ArgumentCaptor>> bufferWriteArgumentCaptor = ArgumentCaptor.forClass(Collection.class); - verify(buffer).writeAll(bufferWriteArgumentCaptor.capture(), anyInt()); - - final Collection> actualBufferWrites = bufferWriteArgumentCaptor.getValue(); - assertThat(actualBufferWrites, notNullValue()); - assertThat(actualBufferWrites, hasSize(1)); - } - - @Test - void gRPC_with_auth_request_writes_to_buffer_with_successful_response() throws Exception { - when(httpBasicAuthenticationConfig.getUsername()).thenReturn(USERNAME); - when(httpBasicAuthenticationConfig.getPassword()).thenReturn(PASSWORD); - final GrpcAuthenticationProvider grpcAuthenticationProvider = new GrpcBasicAuthenticationProvider(httpBasicAuthenticationConfig); - - when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))) - .thenReturn(grpcAuthenticationProvider); - when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); - when(oTelTraceSourceConfig.getAuthentication()).thenReturn(new PluginModel("http_basic", - Map.of( - "username", USERNAME, - "password", PASSWORD - ))); - configureObjectUnderTest(); - SOURCE.start(buffer); - - final String encodeToString = Base64.getEncoder() - .encodeToString(String.format("%s:%s", USERNAME, PASSWORD).getBytes(StandardCharsets.UTF_8)); - - final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) - .addHeader("Authorization", "Basic " + encodeToString) - .build(TraceServiceGrpc.TraceServiceBlockingStub.class); - final ExportTraceServiceResponse exportResponse = client.export(createExportTraceRequest()); - assertThat(exportResponse, notNullValue()); - - final ArgumentCaptor>> bufferWriteArgumentCaptor = ArgumentCaptor.forClass(Collection.class); - verify(buffer).writeAll(bufferWriteArgumentCaptor.capture(), anyInt()); - - final Collection> actualBufferWrites = bufferWriteArgumentCaptor.getValue(); - assertThat(actualBufferWrites, notNullValue()); - assertThat(actualBufferWrites, hasSize(1)); - } - - @Test - void gRPC_request_with_custom_path_throws_when_written_to_default_path() { - when(oTelTraceSourceConfig.getPath()).thenReturn(TEST_PATH); - when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); - - configureObjectUnderTest(); - SOURCE.start(buffer); - - final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) - .build(TraceServiceGrpc.TraceServiceBlockingStub.class); - - final StatusRuntimeException actualException = assertThrows(StatusRuntimeException.class, () -> client.export(createExportTraceRequest())); - assertThat(actualException.getStatus(), notNullValue()); - assertThat(actualException.getStatus().getCode(), equalTo(Status.UNIMPLEMENTED.getCode())); - } - - @ParameterizedTest - @ArgumentsSource(BufferExceptionToStatusArgumentsProvider.class) - void gRPC_request_returns_expected_status_for_exceptions_from_buffer( - final Class bufferExceptionClass, - final Status.Code expectedStatusCode) throws Exception { - configureObjectUnderTest(); - SOURCE.start(buffer); - - final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) - .build(TraceServiceGrpc.TraceServiceBlockingStub.class); - - doThrow(bufferExceptionClass) - .when(buffer) - .writeAll(anyCollection(), anyInt()); - final ExportTraceServiceRequest exportTraceRequest = createExportTraceRequest(); - final StatusRuntimeException actualException = assertThrows(StatusRuntimeException.class, () -> client.export(exportTraceRequest)); - - assertThat(actualException.getStatus(), notNullValue()); - assertThat(actualException.getStatus().getCode(), equalTo(expectedStatusCode)); - } - - @Test - void gRPC_request_throws_InvalidArgument_for_malformed_trace_data() { - configureObjectUnderTest(); - SOURCE.start(buffer); - - final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) - .build(TraceServiceGrpc.TraceServiceBlockingStub.class); - - final ExportTraceServiceRequest exportTraceRequest = createInvalidExportTraceRequest(); - final StatusRuntimeException actualException = assertThrows(StatusRuntimeException.class, () -> client.export(exportTraceRequest)); - - assertThat(actualException.getStatus(), notNullValue()); - assertThat(actualException.getStatus().getCode(), equalTo(Status.Code.INVALID_ARGUMENT)); - - verifyNoInteractions(buffer); - } - @Test void request_that_exceeds_maxRequestLength_returns_413() throws InvalidProtocolBufferException { when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); @@ -1232,29 +782,6 @@ void testServerConnectionsMetric() throws InvalidProtocolBufferException { assertEquals(1.0, serverConnectionsMeasurement.getValue()); } - static class BufferExceptionToStatusArgumentsProvider implements ArgumentsProvider { - @Override - public Stream provideArguments(final ExtensionContext context) { - return Stream.of( - arguments(TimeoutException.class, Status.Code.RESOURCE_EXHAUSTED), - arguments(SizeOverflowException.class, Status.Code.RESOURCE_EXHAUSTED), - arguments(Exception.class, Status.Code.INTERNAL), - arguments(RuntimeException.class, Status.Code.INTERNAL) - ); - } - } - - private ExportTraceServiceRequest createInvalidExportTraceRequest() { - final io.opentelemetry.proto.trace.v1.Span testSpan = Span.newBuilder() - .setTraceState("SUCCESS").build(); - final ExportTraceServiceRequest successRequest = ExportTraceServiceRequest.newBuilder() - .addResourceSpans(ResourceSpans.newBuilder() - .addScopeSpans(ScopeSpans.newBuilder().addSpans(testSpan)).build()) - .build(); - - return successRequest; - } - private ExportTraceServiceRequest createExportTraceRequest() { final io.opentelemetry.proto.trace.v1.Span testSpan = Span.newBuilder() .setTraceId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) @@ -1284,14 +811,4 @@ private void assertSecureResponseWithStatusCode(final AggregatedHttpResponse res .collect(Collectors.toList()); assertThat("Response Header Keys", headerKeys, not(contains("server"))); } - - private byte[] createGZipCompressedPayload(final String payload) throws IOException { - // Create a GZip compressed request body - final ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); - try (final GZIPOutputStream gzipStream = new GZIPOutputStream(byteStream)) { - gzipStream.write(payload.getBytes(StandardCharsets.UTF_8)); - } - return byteStream.toByteArray(); - } - } diff --git a/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_GrpcRequestTest.java b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_GrpcRequestTest.java new file mode 100644 index 0000000000..876e68bb53 --- /dev/null +++ b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_GrpcRequestTest.java @@ -0,0 +1,388 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.oteltrace; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.DEFAULT_PORT; +import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.DEFAULT_REQUEST_TIMEOUT_MS; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Base64; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeoutException; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.GrpcRequestExceptionHandler; +import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider; +import org.opensearch.dataprepper.armeria.authentication.HttpBasicAuthenticationConfig; +import org.opensearch.dataprepper.metrics.MetricsTestUtil; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.buffer.SizeOverflowException; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.types.ByteCount; +import org.opensearch.dataprepper.plugins.GrpcBasicAuthenticationProvider; +import org.opensearch.dataprepper.plugins.certificate.CertificateProvider; +import org.opensearch.dataprepper.plugins.certificate.model.Certificate; +import org.opensearch.dataprepper.plugins.codec.CompressionOption; +import org.opensearch.dataprepper.plugins.server.RetryInfoConfig; +import org.opensearch.dataprepper.plugins.source.oteltrace.certificate.CertificateProviderFactory; + +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import com.linecorp.armeria.client.Clients; +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.grpc.GrpcService; +import com.linecorp.armeria.server.grpc.GrpcServiceBuilder; + +import io.grpc.BindableService; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.netty.util.AsciiString; +import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceRequest; +import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceResponse; +import io.opentelemetry.proto.collector.trace.v1.TraceServiceGrpc; +import io.opentelemetry.proto.trace.v1.ResourceSpans; +import io.opentelemetry.proto.trace.v1.ScopeSpans; +import io.opentelemetry.proto.trace.v1.Span; + +@ExtendWith(MockitoExtension.class) +class OTelTraceSource_GrpcRequestTest { + private static final String GRPC_ENDPOINT = "gproto+http://127.0.0.1:21890/"; + private static final String USERNAME = "test_user"; + private static final String PASSWORD = "test_password"; + private static final String TEST_PATH = "${pipelineName}/v1/traces"; + private static final String TEST_PIPELINE_NAME = "test_pipeline"; + private static final RetryInfoConfig TEST_RETRY_INFO = new RetryInfoConfig(Duration.ofMillis(50), Duration.ofMillis(2000)); + + + @Mock + private ServerBuilder serverBuilder; + + @Mock + private Server server; + + @Mock + private GrpcServiceBuilder grpcServiceBuilder; + + @Mock + private GrpcService grpcService; + + @Mock + private CertificateProviderFactory certificateProviderFactory; + + @Mock + private CertificateProvider certificateProvider; + + @Mock + private Certificate certificate; + + @Mock + private CompletableFuture completableFuture; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private GrpcBasicAuthenticationProvider authenticationProvider; + + @Mock(lenient = true) + private OTelTraceSourceConfig oTelTraceSourceConfig; + + @Mock + private Buffer> buffer; + + @Mock + private HttpBasicAuthenticationConfig httpBasicAuthenticationConfig; + + + private PluginMetrics pluginMetrics; + private PipelineDescription pipelineDescription; + private OTelTraceSource SOURCE; + + @BeforeEach + void beforeEach() { + lenient().when(serverBuilder.port(anyInt(), ArgumentMatchers.any())).thenReturn(serverBuilder); + lenient().when(serverBuilder.service(any(GrpcService.class))).thenReturn(serverBuilder); + lenient().when(serverBuilder.service(any(GrpcService.class), any(Function.class))).thenReturn(serverBuilder); + lenient().when(serverBuilder.http(anyInt())).thenReturn(serverBuilder); + lenient().when(serverBuilder.https(anyInt())).thenReturn(serverBuilder); + lenient().when(serverBuilder.build()).thenReturn(server); + + lenient().when(server.start()).thenReturn(completableFuture); + + lenient().when(grpcServiceBuilder.addService(any(BindableService.class))).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.useClientTimeoutHeader(anyBoolean())).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.useBlockingTaskExecutor(anyBoolean())).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.exceptionHandler(any(GrpcRequestExceptionHandler.class))).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.build()).thenReturn(grpcService); + + lenient().when(authenticationProvider.getHttpAuthenticationService()).thenCallRealMethod(); + + when(oTelTraceSourceConfig.getPort()).thenReturn(DEFAULT_PORT); + when(oTelTraceSourceConfig.isSsl()).thenReturn(false); + when(oTelTraceSourceConfig.getRequestTimeoutInMillis()).thenReturn(DEFAULT_REQUEST_TIMEOUT_MS); + when(oTelTraceSourceConfig.getMaxConnectionCount()).thenReturn(10); + when(oTelTraceSourceConfig.getThreadCount()).thenReturn(5); + when(oTelTraceSourceConfig.getCompression()).thenReturn(CompressionOption.NONE); + when(oTelTraceSourceConfig.getRetryInfo()).thenReturn(TEST_RETRY_INFO); + + lenient().when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))) + .thenReturn(authenticationProvider); + configureObjectUnderTest(); + pipelineDescription = mock(PipelineDescription.class); + lenient().when(pipelineDescription.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); + } + + @AfterEach + void afterEach() { + SOURCE.stop(); + } + + private void configureObjectUnderTest() { + MetricsTestUtil.initMetrics(); + pluginMetrics = PluginMetrics.fromNames("otel_trace", "pipeline"); + + pipelineDescription = mock(PipelineDescription.class); + when(pipelineDescription.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); + SOURCE = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + } + + + @Test + void gRPC_request_writes_to_buffer_with_successful_response() throws Exception { + configureObjectUnderTest(); + SOURCE.start(buffer); + + final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) + .build(TraceServiceGrpc.TraceServiceBlockingStub.class); + final ExportTraceServiceResponse exportResponse = client.export(createExportTraceRequest()); + assertThat(exportResponse, notNullValue()); + + final ArgumentCaptor>> bufferWriteArgumentCaptor = ArgumentCaptor.forClass(Collection.class); + verify(buffer).writeAll(bufferWriteArgumentCaptor.capture(), anyInt()); + + final Collection> actualBufferWrites = bufferWriteArgumentCaptor.getValue(); + assertThat(actualBufferWrites, notNullValue()); + assertThat(actualBufferWrites, hasSize(1)); + } + + @Test + void gRPC_with_auth_request_writes_to_buffer_with_successful_response() throws Exception { + when(httpBasicAuthenticationConfig.getUsername()).thenReturn(USERNAME); + when(httpBasicAuthenticationConfig.getPassword()).thenReturn(PASSWORD); + final GrpcAuthenticationProvider grpcAuthenticationProvider = new GrpcBasicAuthenticationProvider(httpBasicAuthenticationConfig); + + when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))) + .thenReturn(grpcAuthenticationProvider); + when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); + when(oTelTraceSourceConfig.getAuthentication()).thenReturn(new PluginModel("http_basic", + Map.of( + "username", USERNAME, + "password", PASSWORD + ))); + configureObjectUnderTest(); + SOURCE.start(buffer); + + final String encodeToString = Base64.getEncoder() + .encodeToString(String.format("%s:%s", USERNAME, PASSWORD).getBytes(StandardCharsets.UTF_8)); + + final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) + .addHeader("Authorization", "Basic " + encodeToString) + .build(TraceServiceGrpc.TraceServiceBlockingStub.class); + final ExportTraceServiceResponse exportResponse = client.export(createExportTraceRequest()); + assertThat(exportResponse, notNullValue()); + + final ArgumentCaptor>> bufferWriteArgumentCaptor = ArgumentCaptor.forClass(Collection.class); + verify(buffer).writeAll(bufferWriteArgumentCaptor.capture(), anyInt()); + + final Collection> actualBufferWrites = bufferWriteArgumentCaptor.getValue(); + assertThat(actualBufferWrites, notNullValue()); + assertThat(actualBufferWrites, hasSize(1)); + } + + @Test + void gRPC_request_with_custom_path_throws_when_written_to_default_path() { + when(oTelTraceSourceConfig.getPath()).thenReturn(TEST_PATH); + when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); + + configureObjectUnderTest(); + SOURCE.start(buffer); + + final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) + .build(TraceServiceGrpc.TraceServiceBlockingStub.class); + + final StatusRuntimeException actualException = assertThrows(StatusRuntimeException.class, () -> client.export(createExportTraceRequest())); + assertThat(actualException.getStatus(), notNullValue()); + assertThat(actualException.getMessage(), actualException.getStatus().getCode(), equalTo(Status.UNIMPLEMENTED.getCode())); + } + + @ParameterizedTest + @ArgumentsSource(BufferExceptionToStatusArgumentsProvider.class) + void gRPC_request_returns_expected_status_for_exceptions_from_buffer( + final Class bufferExceptionClass, + final Status.Code expectedStatusCode) throws Exception { + configureObjectUnderTest(); + SOURCE.start(buffer); + + final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) + .build(TraceServiceGrpc.TraceServiceBlockingStub.class); + + doThrow(bufferExceptionClass) + .when(buffer) + .writeAll(anyCollection(), anyInt()); + final ExportTraceServiceRequest exportTraceRequest = createExportTraceRequest(); + final StatusRuntimeException actualException = assertThrows(StatusRuntimeException.class, () -> client.export(exportTraceRequest)); + + assertThat(actualException.getStatus(), notNullValue()); + assertThat(actualException.getMessage(), actualException.getStatus().getCode(), equalTo(expectedStatusCode)); + } + + @Test + void gRPC_request_throws_InvalidArgument_for_malformed_trace_data() { + configureObjectUnderTest(); + SOURCE.start(buffer); + + final TraceServiceGrpc.TraceServiceBlockingStub client = Clients.builder(GRPC_ENDPOINT) + .build(TraceServiceGrpc.TraceServiceBlockingStub.class); + + final ExportTraceServiceRequest exportTraceRequest = createInvalidExportTraceRequest(); + final StatusRuntimeException actualException = assertThrows(StatusRuntimeException.class, () -> client.export(exportTraceRequest)); + + assertThat(actualException.getStatus(), notNullValue()); + assertThat(actualException.getMessage(), actualException.getStatus().getCode(), equalTo(Status.Code.INVALID_ARGUMENT)); + + verifyNoInteractions(buffer); + } + + @Test + void request_that_exceeds_maxRequestLength_returns_413() throws InvalidProtocolBufferException { + when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); + when(oTelTraceSourceConfig.getMaxRequestLength()).thenReturn(ByteCount.ofBytes(4)); + configureObjectUnderTest(); + SOURCE.start(buffer); + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(createExportTraceRequest()).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.REQUEST_ENTITY_TOO_LARGE, throwable)) + .join(); + } + + + static class BufferExceptionToStatusArgumentsProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(final ExtensionContext context) { + return Stream.of( + arguments(TimeoutException.class, Status.Code.RESOURCE_EXHAUSTED), + arguments(SizeOverflowException.class, Status.Code.RESOURCE_EXHAUSTED), + arguments(Exception.class, Status.Code.INTERNAL), + arguments(RuntimeException.class, Status.Code.INTERNAL) + ); + } + } + + private ExportTraceServiceRequest createInvalidExportTraceRequest() { + final Span testSpan = Span.newBuilder() + .setTraceState("SUCCESS").build(); + final ExportTraceServiceRequest successRequest = ExportTraceServiceRequest.newBuilder() + .addResourceSpans(ResourceSpans.newBuilder() + .addScopeSpans(ScopeSpans.newBuilder().addSpans(testSpan)).build()) + .build(); + + return successRequest; + } + + private ExportTraceServiceRequest createExportTraceRequest() { + final Span testSpan = Span.newBuilder() + .setTraceId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .setSpanId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .setName(UUID.randomUUID().toString()) + .setKind(Span.SpanKind.SPAN_KIND_SERVER) + .setStartTimeUnixNano(100) + .setEndTimeUnixNano(101) + .setTraceState("SUCCESS").build(); + + return ExportTraceServiceRequest.newBuilder() + .addResourceSpans(ResourceSpans.newBuilder() + .addScopeSpans(ScopeSpans.newBuilder().addSpans(testSpan)).build()) + .build(); + } + + private void assertSecureResponseWithStatusCode(final AggregatedHttpResponse response, + final HttpStatus expectedStatus, + final Throwable throwable) { + assertThat("Http Status", response.status(), equalTo(expectedStatus)); + assertThat("Http Response Throwable", throwable, is(nullValue())); + + final List headerKeys = response.headers() + .stream() + .map(Map.Entry::getKey) + .map(AsciiString::toString) + .collect(Collectors.toList()); + assertThat("Response Header Keys", headerKeys, not(contains("server"))); + } +} diff --git a/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_HttpServiceTest.java b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_HttpServiceTest.java new file mode 100644 index 0000000000..639117cac6 --- /dev/null +++ b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_HttpServiceTest.java @@ -0,0 +1,385 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.oteltrace; + +import static com.jayway.jsonpath.matchers.JsonPathMatchers.hasJsonPath; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.DEFAULT_PORT; +import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.DEFAULT_REQUEST_TIMEOUT_MS; + +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.stream.Stream; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.GrpcRequestExceptionHandler; +import org.opensearch.dataprepper.armeria.authentication.ArmeriaHttpAuthenticationProvider; +import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider; +import org.opensearch.dataprepper.armeria.authentication.HttpBasicAuthenticationConfig; +import org.opensearch.dataprepper.metrics.MetricsTestUtil; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.buffer.SizeOverflowException; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.HttpBasicArmeriaHttpAuthenticationProvider; +import org.opensearch.dataprepper.plugins.codec.CompressionOption; + +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.ClosedSessionException; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.internal.shaded.bouncycastle.util.encoders.Base64; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.grpc.GrpcService; +import com.linecorp.armeria.server.grpc.GrpcServiceBuilder; + +import io.grpc.BindableService; +import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceRequest; +import io.opentelemetry.proto.trace.v1.ResourceSpans; +import io.opentelemetry.proto.trace.v1.ScopeSpans; +import io.opentelemetry.proto.trace.v1.Span; + +@ExtendWith(MockitoExtension.class) +class OTelTraceSource_HttpServiceTest { + private static final String TEST_PIPELINE_NAME = "test_pipeline"; + + @Mock + private ServerBuilder serverBuilder; + + @Mock + private Server server; + + @Mock + private GrpcServiceBuilder grpcServiceBuilder; + + @Mock + private GrpcService grpcService; + + @Mock + private CompletableFuture completableFuture; + + @Mock + private PluginFactory pluginFactory; + + @Mock(lenient = true) + private OTelTraceSourceConfig oTelTraceSourceConfig; + + @Mock + private Buffer> buffer; + + @Mock + private GrpcAuthenticationProvider grpcAuthProvider; + + @Captor + ArgumentCaptor bytesCaptor; + + private PluginMetrics pluginMetrics; + private PipelineDescription pipelineDescription; + private OTelTraceSource SOURCE; + + private static final HttpBasicAuthenticationConfig PROVIDED_CONFIG = new HttpBasicAuthenticationConfig("username", "password"); + + + @BeforeEach + void beforeEach() { + lenient().when(serverBuilder.service(any(GrpcService.class))).thenReturn(serverBuilder); + lenient().when(serverBuilder.service(any(GrpcService.class), any(Function.class))).thenReturn(serverBuilder); + lenient().when(serverBuilder.http(anyInt())).thenReturn(serverBuilder); + lenient().when(serverBuilder.https(anyInt())).thenReturn(serverBuilder); + lenient().when(serverBuilder.build()).thenReturn(server); + lenient().when(server.start()).thenReturn(completableFuture); + + lenient().when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))).thenReturn(grpcAuthProvider); + + lenient().when(grpcServiceBuilder.addService(any(BindableService.class))).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.useClientTimeoutHeader(anyBoolean())).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.useBlockingTaskExecutor(anyBoolean())).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.exceptionHandler(any(GrpcRequestExceptionHandler.class))).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.build()).thenReturn(grpcService); + + when(oTelTraceSourceConfig.getPort()).thenReturn(DEFAULT_PORT); + when(oTelTraceSourceConfig.isSsl()).thenReturn(false); + when(oTelTraceSourceConfig.getRequestTimeoutInMillis()).thenReturn(DEFAULT_REQUEST_TIMEOUT_MS); + when(oTelTraceSourceConfig.getMaxConnectionCount()).thenReturn(10); + when(oTelTraceSourceConfig.getThreadCount()).thenReturn(5); + when(oTelTraceSourceConfig.getCompression()).thenReturn(CompressionOption.NONE); + + // default: we don't want authentication + when(oTelTraceSourceConfig.getAuthentication()).thenReturn(null); + + configureObjectUnderTest(); + pipelineDescription = mock(PipelineDescription.class); + lenient().when(pipelineDescription.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); + } + + @AfterEach + void afterEach() { + SOURCE.stop(); + } + + private void configureObjectUnderTest() { + MetricsTestUtil.initMetrics(); + pluginMetrics = PluginMetrics.fromNames("otel_trace", "pipeline"); + + pipelineDescription = mock(PipelineDescription.class); + when(pipelineDescription.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); + SOURCE = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + } + + @Test + void healthcheck_is_up() { + when(oTelTraceSourceConfig.enableHttpHealthCheck()).thenReturn(true); + SOURCE.start(buffer); + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.HEAD) + .path("/health") + .contentType(MediaType.JSON_UTF_8) + .build()) + .aggregate() + .whenComplete((response, throwable) -> assertThat(response.status(), is(HttpStatus.OK))) + .join(); + + verifyNoInteractions(buffer); + } + + @Test + void request_fails_because_of_invalid_payload() throws Exception { + ExportTraceServiceRequest request = createInvalidExportTraceRequest(); + SOURCE.start(buffer); + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") + .contentType(MediaType.JSON_UTF_8) + .build(), HttpData.copyOf(JsonFormat.printer().print(request).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertThat(response.status(), is(HttpStatus.BAD_REQUEST))) + .join(); + + verifyNoInteractions(buffer); + } + + @Test + void request_that_is_successful() throws Exception { + when(buffer.isByteBuffer()).thenReturn(true); + ExportTraceServiceRequest request = createExportTraceRequest(); + SOURCE.start(buffer); + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") + .contentType(MediaType.JSON_UTF_8) + .build(), HttpData.copyOf(JsonFormat.printer().print(request).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertThat(response.status(), is(HttpStatus.OK))) + .join(); + + verify(buffer, times(1)).writeBytes(bytesCaptor.capture(), anyString(), anyInt()); + } + + @Test + void providing_unauthenticated_via_config_does_not_add_the_auth_decorator() { + when(oTelTraceSourceConfig.getAuthentication()).thenReturn(new PluginModel(ArmeriaHttpAuthenticationProvider.UNAUTHENTICATED_PLUGIN_NAME, Map.of())); + SOURCE.start(buffer); + + verify(serverBuilder, times(0)).decorator(any(Function.class)); + } + + @Test + void request_that_causes_overflow_exception_should_not_be_written_to_buffer_and_return_retry_information() throws Exception { + Mockito.lenient().doThrow(SizeOverflowException.class).when(buffer).writeAll(any(), anyInt()); + SOURCE.start(buffer); + + makeRequestAndAssertResponse("/opentelemetry.proto.collector.trace.v1.TraceService/Export", createExportTraceRequest(), (response, throwable) -> { + assertThat(response.status(), is(HttpStatus.INSUFFICIENT_STORAGE)); + assertResponseBodyForRetryInformation(response, "0.100s"); + }); + } + + @Test + void request_over_http_with_ssl_enabled_fails() { + when(oTelTraceSourceConfig.isSsl()).thenReturn(true); + when(oTelTraceSourceConfig.getSslKeyCertChainFile()).thenReturn("data/certificate/test_cert.crt"); + when(oTelTraceSourceConfig.getSslKeyFile()).thenReturn("data/certificate/test_decrypted_key.key"); + configureObjectUnderTest(); + SOURCE.start(buffer); + + WebClient client = WebClient.builder("http://127.0.0.1:21890") + .build(); + + CompletionException exception = assertThrows(CompletionException.class, () -> client.execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path("/opentelemetry.proto.collector.trace.v1.TraceService/Export") + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(createExportTraceRequest()).getBytes())) + .aggregate() + .join()); + + assertThat(exception.getCause(), instanceOf(ClosedSessionException.class)); + } + + @ParameterizedTest + @MethodSource("generateCredentials") + void request_with_credentials_returns_expected_status_code(AuthTestDataHolder testData) throws InvalidProtocolBufferException { + when(oTelTraceSourceConfig.getAuthentication()).thenReturn(new PluginModel("http_basic", Map.of("username", PROVIDED_CONFIG.getUsername(), "password",PROVIDED_CONFIG.getPassword()))); + HttpBasicArmeriaHttpAuthenticationProvider authProvider = new HttpBasicArmeriaHttpAuthenticationProvider(new HttpBasicAuthenticationConfig(PROVIDED_CONFIG.getUsername(), PROVIDED_CONFIG.getPassword())); + lenient().when(pluginFactory.loadPlugin(eq(ArmeriaHttpAuthenticationProvider.class), any(PluginSetting.class))).thenReturn(authProvider); + SOURCE.start(buffer); + + makeRequestWithCredentialsAndAssertResponse("/opentelemetry.proto.collector.trace.v1.TraceService/Export", + createExportTraceRequest(), + testData.providedCredentials.getOrDefault("username", null), + testData.providedCredentials.getOrDefault("password", null), + (response, throwable) -> assertThat(response.status(), is(testData.expectedStatus)) + ); + } + + private static Stream generateCredentials() { + return Stream.of( + arguments(named("valid credentials", new AuthTestDataHolder(Map.of("username", "username", "password","password"), HttpStatus.OK))), + arguments(named("wrong credentials", new AuthTestDataHolder(Map.of("username", "wrong-username", "password","wrong-password"), HttpStatus.UNAUTHORIZED))), + arguments(named("no credentials provided", new AuthTestDataHolder(Map.of(), HttpStatus.UNAUTHORIZED))) + ); + } + + static class AuthTestDataHolder { + Map providedCredentials; + HttpStatus expectedStatus; + + public AuthTestDataHolder(Map providedCredentials, HttpStatus expectedStatus) { + this.providedCredentials = providedCredentials; + this.expectedStatus = expectedStatus; + } + } + + void makeRequestWithCredentialsAndAssertResponse( + String path, + ExportTraceServiceRequest request, + String username, + String password, + BiConsumer assertionFunction) throws InvalidProtocolBufferException { + + WebClient.of().execute(RequestHeaders.builder().add("Authorization", "Basic " + new String(Base64.encode(String.format("%s:%s", username, password).getBytes()))) + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(path) + .contentType(MediaType.JSON_UTF_8) + .build(), HttpData.copyOf(JsonFormat.printer().print(request).getBytes())) + .aggregate() + .whenComplete(assertionFunction) + .join(); + } + + private void makeRequestAndAssertResponse(String path, ExportTraceServiceRequest request, BiConsumer assertionFunction) throws InvalidProtocolBufferException { + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(path) + .contentType(MediaType.JSON_UTF_8) + .build(), HttpData.copyOf(JsonFormat.printer().print(request).getBytes())) + .aggregate() + .whenComplete(assertionFunction) + .join(); + } + + private ExportTraceServiceRequest createInvalidExportTraceRequest() { + final io.opentelemetry.proto.trace.v1.Span testSpan = Span.newBuilder() +// .setTraceId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) +// .setSpanId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .setName(UUID.randomUUID().toString()) + .setKind(Span.SpanKind.SPAN_KIND_SERVER) + .setStartTimeUnixNano(100) + .setEndTimeUnixNano(101) + .setTraceState("SUCCESS").build(); + + return ExportTraceServiceRequest.newBuilder() + .addResourceSpans(ResourceSpans.newBuilder() + .addScopeSpans(ScopeSpans.newBuilder().addSpans(testSpan)).build()) + .build(); + } + + private ExportTraceServiceRequest createExportTraceRequest() { + final io.opentelemetry.proto.trace.v1.Span testSpan = Span.newBuilder() + .setTraceId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .setSpanId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .setName(UUID.randomUUID().toString()) + .setKind(Span.SpanKind.SPAN_KIND_SERVER) + .setStartTimeUnixNano(100) + .setEndTimeUnixNano(101) + .setTraceState("SUCCESS").build(); + + return ExportTraceServiceRequest.newBuilder() + .addResourceSpans(ResourceSpans.newBuilder() + .addScopeSpans(ScopeSpans.newBuilder().addSpans(testSpan)).build()) + .build(); + } + + private void assertResponseBodyForRetryInformation(final AggregatedHttpResponse response, String expectedDelay) { + String body = response.content(StandardCharsets.UTF_8); + + // todo map to numeric value when creating status in exception handler + assertThat(body, hasJsonPath("$.details[0].retryDelay", equalTo(expectedDelay))); + } +} diff --git a/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_RetryInfoTest.java b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_RetryInfoTest.java index e2613d5dc8..4d59e53800 100644 --- a/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_RetryInfoTest.java +++ b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_RetryInfoTest.java @@ -155,8 +155,9 @@ private ExportTraceServiceRequest createExportTraceRequest() { .setEndTimeUnixNano(101) .setTraceState("SUCCESS").build(); - ScopeSpans scopeSpan = ScopeSpans.newBuilder().addSpans(testSpan).build(); return ExportTraceServiceRequest.newBuilder() - .addResourceSpans(ResourceSpans.newBuilder().addScopeSpans(scopeSpan)).build(); + .addResourceSpans(ResourceSpans.newBuilder() + .addScopeSpans(ScopeSpans.newBuilder().addSpans(testSpan)).build()) + .build(); } } diff --git a/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_UnframedRequestsTest.java b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_UnframedRequestsTest.java new file mode 100644 index 0000000000..88a4a7e550 --- /dev/null +++ b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/OTelTraceSource_UnframedRequestsTest.java @@ -0,0 +1,496 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.source.oteltrace; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.DEFAULT_PORT; +import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.DEFAULT_REQUEST_TIMEOUT_MS; +import static org.opensearch.dataprepper.plugins.source.oteltrace.OTelTraceSourceConfig.SSL; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.zip.GZIPOutputStream; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.GrpcRequestExceptionHandler; +import org.opensearch.dataprepper.armeria.authentication.GrpcAuthenticationProvider; +import org.opensearch.dataprepper.metrics.MetricsTestUtil; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.GrpcBasicAuthenticationProvider; +import org.opensearch.dataprepper.plugins.certificate.CertificateProvider; +import org.opensearch.dataprepper.plugins.certificate.model.Certificate; +import org.opensearch.dataprepper.plugins.codec.CompressionOption; +import org.opensearch.dataprepper.plugins.server.HealthGrpcService; +import org.opensearch.dataprepper.plugins.server.RetryInfoConfig; +import org.opensearch.dataprepper.plugins.source.oteltrace.certificate.CertificateProviderFactory; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import com.linecorp.armeria.client.ClientFactory; +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.server.Server; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.grpc.GrpcService; +import com.linecorp.armeria.server.grpc.GrpcServiceBuilder; +import com.linecorp.armeria.server.healthcheck.HealthCheckService; + +import io.grpc.BindableService; +import io.grpc.ServerServiceDefinition; +import io.netty.util.AsciiString; +import io.opentelemetry.proto.collector.trace.v1.ExportTraceServiceRequest; +import io.opentelemetry.proto.trace.v1.ResourceSpans; +import io.opentelemetry.proto.trace.v1.ScopeSpans; +import io.opentelemetry.proto.trace.v1.Span; + + +// todo check if unframed requests are still needed. If not, remove this whole test class +@ExtendWith(MockitoExtension.class) +class OTelTraceSource_UnframedRequestsTest { + // used to configure the path for unframed requests and make sure not to use the same path + // as the http service + private static final String UNFRAMED_REQUESTS_PATH = "/unframed"; + private static final String TEST_PATH = "${pipelineName}/v1/traces"; + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().registerModule(new JavaTimeModule()); + private static final String TEST_PIPELINE_NAME = "test_pipeline"; + private static final RetryInfoConfig TEST_RETRY_INFO = new RetryInfoConfig(Duration.ofMillis(50), Duration.ofMillis(2000)); + private static final ExportTraceServiceRequest SUCCESS_REQUEST = ExportTraceServiceRequest.newBuilder() + .addResourceSpans(ResourceSpans.newBuilder() + .addScopeSpans(ScopeSpans.newBuilder() + .addSpans(Span.newBuilder().setTraceState("SUCCESS").build())).build()).build(); + private static final ExportTraceServiceRequest FAILURE_REQUEST = ExportTraceServiceRequest.newBuilder() + .addResourceSpans(ResourceSpans.newBuilder() + .addScopeSpans(ScopeSpans.newBuilder() + .addSpans(Span.newBuilder().setTraceState("FAILURE").build())).build()).build(); + + @Mock + private ServerBuilder serverBuilder; + + @Mock + private Server server; + + @Mock + private GrpcServiceBuilder grpcServiceBuilder; + + @Mock + private GrpcService grpcService; + + @Mock + private CertificateProviderFactory certificateProviderFactory; + + @Mock + private CertificateProvider certificateProvider; + + @Mock + private Certificate certificate; + + @Mock + private CompletableFuture completableFuture; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private GrpcBasicAuthenticationProvider authenticationProvider; + + @Mock(lenient = true) + private OTelTraceSourceConfig oTelTraceSourceConfig; + + @Mock + private Buffer> buffer; + + private PluginSetting pluginSetting; + private PluginSetting testPluginSetting; + private PluginMetrics pluginMetrics; + private PipelineDescription pipelineDescription; + private OTelTraceSource SOURCE; + + @BeforeEach + void beforeEach() { + lenient().when(serverBuilder.port(anyInt(), ArgumentMatchers.any())).thenReturn(serverBuilder); + lenient().when(serverBuilder.service(any(GrpcService.class))).thenReturn(serverBuilder); + lenient().when(serverBuilder.service(any(GrpcService.class), any(Function.class))).thenReturn(serverBuilder); + lenient().when(serverBuilder.http(anyInt())).thenReturn(serverBuilder); + lenient().when(serverBuilder.https(anyInt())).thenReturn(serverBuilder); + lenient().when(serverBuilder.build()).thenReturn(server); + + lenient().when(server.start()).thenReturn(completableFuture); + + lenient().when(grpcServiceBuilder.addService(any(BindableService.class))).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.useClientTimeoutHeader(anyBoolean())).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.useBlockingTaskExecutor(anyBoolean())).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.exceptionHandler(any( + GrpcRequestExceptionHandler.class))).thenReturn(grpcServiceBuilder); + lenient().when(grpcServiceBuilder.build()).thenReturn(grpcService); + + lenient().when(authenticationProvider.getHttpAuthenticationService()).thenCallRealMethod(); + + when(oTelTraceSourceConfig.getPath()).thenReturn(UNFRAMED_REQUESTS_PATH); + when(oTelTraceSourceConfig.getPort()).thenReturn(DEFAULT_PORT); + when(oTelTraceSourceConfig.isSsl()).thenReturn(false); + when(oTelTraceSourceConfig.getRequestTimeoutInMillis()).thenReturn(DEFAULT_REQUEST_TIMEOUT_MS); + when(oTelTraceSourceConfig.getMaxConnectionCount()).thenReturn(10); + when(oTelTraceSourceConfig.getThreadCount()).thenReturn(5); + when(oTelTraceSourceConfig.getCompression()).thenReturn(CompressionOption.NONE); + when(oTelTraceSourceConfig.getRetryInfo()).thenReturn(TEST_RETRY_INFO); + + lenient().when(pluginFactory.loadPlugin(eq(GrpcAuthenticationProvider.class), any(PluginSetting.class))) + .thenReturn(authenticationProvider); + configureObjectUnderTest(); + pipelineDescription = mock(PipelineDescription.class); + lenient().when(pipelineDescription.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); + } + + @AfterEach + void afterEach() { + SOURCE.stop(); + } + + private void configureObjectUnderTest() { + MetricsTestUtil.initMetrics(); + pluginMetrics = PluginMetrics.fromNames("otel_trace", "pipeline"); + + pipelineDescription = mock(PipelineDescription.class); + when(pipelineDescription.getPipelineName()).thenReturn(TEST_PIPELINE_NAME); + SOURCE = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + } + + @Test + void testHttpFullJsonWithNonUnframedRequests() throws InvalidProtocolBufferException { + configureObjectUnderTest(); + SOURCE.start(buffer); + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(UNFRAMED_REQUESTS_PATH) + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(SUCCESS_REQUEST).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) + .join(); + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(UNFRAMED_REQUESTS_PATH) + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(FAILURE_REQUEST).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) + .join(); + } + + @Test + void testHttpsFullJsonWithNonUnframedRequests() throws InvalidProtocolBufferException { + + final Map settingsMap = new HashMap<>(); + settingsMap.put("request_timeout", 5); + settingsMap.put(SSL, true); + settingsMap.put("useAcmCertForSSL", false); + settingsMap.put("sslKeyCertChainFile", "data/certificate/test_cert.crt"); + settingsMap.put("sslKeyFile", "data/certificate/test_decrypted_key.key"); + settingsMap.put("path", UNFRAMED_REQUESTS_PATH); + pluginSetting = new PluginSetting("otel_trace", settingsMap); + pluginSetting.setPipelineName("pipeline"); + + oTelTraceSourceConfig = OBJECT_MAPPER.convertValue(pluginSetting.getSettings(), OTelTraceSourceConfig.class); + SOURCE = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + + SOURCE.start(buffer); + + WebClient.builder().factory(ClientFactory.insecure()).build().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTPS) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(UNFRAMED_REQUESTS_PATH) + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(SUCCESS_REQUEST).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) + .join(); + WebClient.builder().factory(ClientFactory.insecure()).build().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTPS) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(UNFRAMED_REQUESTS_PATH) + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(FAILURE_REQUEST).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) + .join(); + } + + @Test + void testHttpFullBytesWithNonUnframedRequests() { + configureObjectUnderTest(); + SOURCE.start(buffer); + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(UNFRAMED_REQUESTS_PATH) + .contentType(MediaType.PROTOBUF) + .build(), + HttpData.copyOf(SUCCESS_REQUEST.toByteArray())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) + .join(); + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(UNFRAMED_REQUESTS_PATH) + .contentType(MediaType.PROTOBUF) + .build(), + HttpData.copyOf(FAILURE_REQUEST.toByteArray())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.UNSUPPORTED_MEDIA_TYPE, throwable)) + .join(); + } + + @Test + void testHttpFullJsonWithUnframedRequests() throws InvalidProtocolBufferException { + when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); + configureObjectUnderTest(); + SOURCE.start(buffer); + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(UNFRAMED_REQUESTS_PATH) + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(createExportTraceRequest()).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.OK, throwable)) + .join(); + } + + @Test + void testHttpCompressionWithUnframedRequests() throws IOException { + when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); + when(oTelTraceSourceConfig.getCompression()).thenReturn(CompressionOption.GZIP); + configureObjectUnderTest(); + SOURCE.start(buffer); + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(UNFRAMED_REQUESTS_PATH) + .contentType(MediaType.JSON_UTF_8) + .add(HttpHeaderNames.CONTENT_ENCODING, "gzip") + .build(), + createGZipCompressedPayload(JsonFormat.printer().print(createExportTraceRequest()))) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.OK, throwable)) + .join(); + } + + @Test + void testHttpFullJsonWithCustomPathAndUnframedRequests() throws InvalidProtocolBufferException { + when(oTelTraceSourceConfig.enableUnframedRequests()).thenReturn(true); + when(oTelTraceSourceConfig.getPath()).thenReturn(TEST_PATH); + configureObjectUnderTest(); + SOURCE.start(buffer); + + final String transformedPath = "/" + TEST_PIPELINE_NAME + "/v1/traces"; + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:21890") + .method(HttpMethod.POST) + .path(transformedPath) + .contentType(MediaType.JSON_UTF_8) + .build(), + HttpData.copyOf(JsonFormat.printer().print(createExportTraceRequest()).getBytes())) + .aggregate() + .whenComplete((response, throwable) -> assertSecureResponseWithStatusCode(response, HttpStatus.OK, throwable)) + .join(); + } + + + @Test + void start_with_Health_configured_unframed_requests_includes_HTTPHealthCheck_service() throws IOException { + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class); + MockedStatic grpcServerMock = Mockito.mockStatic(GrpcService.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + grpcServerMock.when(GrpcService::builder).thenReturn(grpcServiceBuilder); + when(grpcServiceBuilder.addService(any(ServerServiceDefinition.class))).thenReturn(grpcServiceBuilder); + when(grpcServiceBuilder.useClientTimeoutHeader(anyBoolean())).thenReturn(grpcServiceBuilder); + + when(server.stop()).thenReturn(completableFuture); + final Path certFilePath = Path.of("data/certificate/test_cert.crt"); + final Path keyFilePath = Path.of("data/certificate/test_decrypted_key.key"); + final String certAsString = Files.readString(certFilePath); + final String keyAsString = Files.readString(keyFilePath); + when(certificate.getCertificate()).thenReturn(certAsString); + when(certificate.getPrivateKey()).thenReturn(keyAsString); + when(certificateProvider.getCertificate()).thenReturn(certificate); + when(certificateProviderFactory.getCertificateProvider()).thenReturn(certificateProvider); + final Map settingsMap = new HashMap<>(); + settingsMap.put(SSL, true); + settingsMap.put("useAcmCertForSSL", true); + settingsMap.put("awsRegion", "us-east-1"); + settingsMap.put("acmCertificateArn", "arn:aws:acm:us-east-1:account:certificate/1234-567-856456"); + settingsMap.put("sslKeyCertChainFile", "data/certificate/test_cert.crt"); + settingsMap.put("sslKeyFile", "data/certificate/test_decrypted_key.key"); + settingsMap.put("health_check_service", "true"); + settingsMap.put("unframed_requests", "true"); + + testPluginSetting = new PluginSetting(null, settingsMap); + testPluginSetting.setPipelineName("pipeline"); + + oTelTraceSourceConfig = OBJECT_MAPPER.convertValue(testPluginSetting.getSettings(), OTelTraceSourceConfig.class); + final OTelTraceSource source = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, certificateProviderFactory, pipelineDescription); + source.start(buffer); + source.stop(); + } + + verify(grpcServiceBuilder, times(1)).useClientTimeoutHeader(false); + verify(grpcServiceBuilder, times(1)).useBlockingTaskExecutor(true); + verify(grpcServiceBuilder).addService(isA(HealthGrpcService.class)); + verify(serverBuilder).service(eq("/health"), isA(HealthCheckService.class)); + } + + + @Test + void start_without_Health_configured_unframed_requests_does_not_include_HealthCheck_service() throws IOException { + try (MockedStatic armeriaServerMock = Mockito.mockStatic(Server.class); + MockedStatic grpcServerMock = Mockito.mockStatic(GrpcService.class)) { + armeriaServerMock.when(Server::builder).thenReturn(serverBuilder); + grpcServerMock.when(GrpcService::builder).thenReturn(grpcServiceBuilder); + when(grpcServiceBuilder.addService(any(ServerServiceDefinition.class))).thenReturn(grpcServiceBuilder); + when(grpcServiceBuilder.useClientTimeoutHeader(anyBoolean())).thenReturn(grpcServiceBuilder); + + when(server.stop()).thenReturn(completableFuture); + final Path certFilePath = Path.of("data/certificate/test_cert.crt"); + final Path keyFilePath = Path.of("data/certificate/test_decrypted_key.key"); + final String certAsString = Files.readString(certFilePath); + final String keyAsString = Files.readString(keyFilePath); + when(certificate.getCertificate()).thenReturn(certAsString); + when(certificate.getPrivateKey()).thenReturn(keyAsString); + when(certificateProvider.getCertificate()).thenReturn(certificate); + when(certificateProviderFactory.getCertificateProvider()).thenReturn(certificateProvider); + final Map settingsMap = new HashMap<>(); + settingsMap.put(SSL, true); + settingsMap.put("useAcmCertForSSL", true); + settingsMap.put("awsRegion", "us-east-1"); + settingsMap.put("acmCertificateArn", "arn:aws:acm:us-east-1:account:certificate/1234-567-856456"); + settingsMap.put("sslKeyCertChainFile", "data/certificate/test_cert.crt"); + settingsMap.put("sslKeyFile", "data/certificate/test_decrypted_key.key"); + settingsMap.put("health_check_service", "false"); + settingsMap.put("unframed_requests", "true"); + + testPluginSetting = new PluginSetting(null, settingsMap); + testPluginSetting.setPipelineName("pipeline"); + oTelTraceSourceConfig = OBJECT_MAPPER.convertValue(testPluginSetting.getSettings(), OTelTraceSourceConfig.class); + final OTelTraceSource source = new OTelTraceSource(oTelTraceSourceConfig, pluginMetrics, pluginFactory, certificateProviderFactory, pipelineDescription); + source.start(buffer); + source.stop(); + } + + verify(grpcServiceBuilder, times(1)).useClientTimeoutHeader(false); + verify(grpcServiceBuilder, times(1)).useBlockingTaskExecutor(true); + verify(grpcServiceBuilder, never()).addService(isA(HealthGrpcService.class)); + verify(serverBuilder, never()).service(eq("/health"),isA(HealthCheckService.class)); + } + + private ExportTraceServiceRequest createExportTraceRequest() { + final Span testSpan = Span.newBuilder() + .setTraceId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .setSpanId(ByteString.copyFromUtf8(UUID.randomUUID().toString())) + .setName(UUID.randomUUID().toString()) + .setKind(Span.SpanKind.SPAN_KIND_SERVER) + .setStartTimeUnixNano(100) + .setEndTimeUnixNano(101) + .setTraceState("SUCCESS").build(); + + return ExportTraceServiceRequest.newBuilder() + .addResourceSpans(ResourceSpans.newBuilder() + .addScopeSpans(ScopeSpans.newBuilder().addSpans(testSpan)).build()) + .build(); + } + + private void assertSecureResponseWithStatusCode(final AggregatedHttpResponse response, + final HttpStatus expectedStatus, + final Throwable throwable) { + assertThat("Http Status", response.status(), equalTo(expectedStatus)); + assertThat("Http Response Throwable", throwable, is(nullValue())); + + final List headerKeys = response.headers() + .stream() + .map(Map.Entry::getKey) + .map(AsciiString::toString) + .collect(Collectors.toList()); + assertThat("Response Header Keys", headerKeys, not(contains("server"))); + } + + private byte[] createGZipCompressedPayload(final String payload) throws IOException { + // Create a GZip compressed request body + final ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); + try (final GZIPOutputStream gzipStream = new GZIPOutputStream(byteStream)) { + gzipStream.write(payload.getBytes(StandardCharsets.UTF_8)); + } + return byteStream.toByteArray(); + } + +} diff --git a/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpExceptionHandlerTest.java b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpExceptionHandlerTest.java new file mode 100644 index 0000000000..7023902e85 --- /dev/null +++ b/data-prepper-plugins/otel-trace-source/src/test/java/org/opensearch/dataprepper/plugins/source/oteltrace/http/HttpExceptionHandlerTest.java @@ -0,0 +1,97 @@ +package org.opensearch.dataprepper.plugins.source.oteltrace.http; + + +import java.time.Duration; +import java.util.concurrent.TimeoutException; + +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.HttpRequestExceptionHandler; +import org.opensearch.dataprepper.exceptions.BadRequestException; +import org.opensearch.dataprepper.exceptions.BufferWriteException; +import org.opensearch.dataprepper.exceptions.RequestCancelledException; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.buffer.SizeOverflowException; + +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.server.RequestTimeoutException; +import com.linecorp.armeria.server.ServiceRequestContext; + +import io.micrometer.core.instrument.Counter; + +@ExtendWith(MockitoExtension.class) +class HttpExceptionHandlerTest { + HttpExceptionHandler httpExceptionHandler; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private ServiceRequestContext requestContext; + + @Mock + private HttpRequest httpRequest; + + @Mock + private Counter requestTimeoutsCounter; + + @Mock + private Counter badRequestsCounter; + + @Mock + private Counter requestsTooLargeCounter; + + @Mock + private Counter internalServerErrorCounter; + @BeforeEach + public void setUp() { + when(pluginMetrics.counter(HttpRequestExceptionHandler.REQUEST_TIMEOUTS)).thenReturn(requestTimeoutsCounter); + when(pluginMetrics.counter(HttpRequestExceptionHandler.BAD_REQUESTS)).thenReturn(badRequestsCounter); + when(pluginMetrics.counter(HttpRequestExceptionHandler.REQUESTS_TOO_LARGE)).thenReturn(requestsTooLargeCounter); + when(pluginMetrics.counter(HttpRequestExceptionHandler.INTERNAL_SERVER_ERROR)).thenReturn(internalServerErrorCounter); + httpExceptionHandler = new HttpExceptionHandler(pluginMetrics, Duration.ofMillis(100), Duration.ofSeconds(2)); + } + + @Test + public void testHandleBadRequestException() { + httpExceptionHandler.handleException(requestContext, httpRequest, new BadRequestException("msg", null)); + verify(badRequestsCounter).increment(); + } + + @Test + public void testHandleTimeoutException() { + httpExceptionHandler.handleException(requestContext, httpRequest, new BufferWriteException(null, new TimeoutException())); + verify(requestTimeoutsCounter, times(1)).increment(); + } + + @Test + public void testHandleArmeriaTimeoutException() { + httpExceptionHandler.handleException(requestContext, httpRequest, RequestTimeoutException.get()); + verify(requestTimeoutsCounter, times(1)).increment(); + } + + @Test + public void testHandleSizeOverflowException() { + httpExceptionHandler.handleException(requestContext, httpRequest, new SizeOverflowException("msg")); + verify(requestsTooLargeCounter).increment(); + } + + @Test + public void testHandleRequestCancelledException() { + httpExceptionHandler.handleException(requestContext, httpRequest, new RequestCancelledException("msg")); + verify(requestTimeoutsCounter, times(1)).increment(); + } + + @Test + public void testHandleInternalServerException() { + httpExceptionHandler.handleException(requestContext, httpRequest, new RuntimeException("msg")); + verify(internalServerErrorCounter, times(1)).increment(); + } +} diff --git a/e2e-test/trace/build.gradle b/e2e-test/trace/build.gradle index b46882be93..c5cf041364 100644 --- a/e2e-test/trace/build.gradle +++ b/e2e-test/trace/build.gradle @@ -207,6 +207,7 @@ dependencies { integrationTestImplementation project(':data-prepper-plugins:aws-plugin-api') integrationTestImplementation project(':data-prepper-plugins:otel-trace-group-processor') integrationTestImplementation testLibs.awaitility + integrationTestImplementation testLibs.assertj integrationTestImplementation "io.opentelemetry.proto:opentelemetry-proto:${targetOpenTelemetryVersion}" integrationTestImplementation libs.protobuf.util integrationTestImplementation libs.armeria.core diff --git a/e2e-test/trace/src/integrationTest/java/org/opensearch/dataprepper/integration/trace/EndToEndRawSpanTest.java b/e2e-test/trace/src/integrationTest/java/org/opensearch/dataprepper/integration/trace/EndToEndRawSpanTest.java index f7c956417a..571f5f45af 100644 --- a/e2e-test/trace/src/integrationTest/java/org/opensearch/dataprepper/integration/trace/EndToEndRawSpanTest.java +++ b/e2e-test/trace/src/integrationTest/java/org/opensearch/dataprepper/integration/trace/EndToEndRawSpanTest.java @@ -43,12 +43,12 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Function; import static org.awaitility.Awaitility.await; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; public class EndToEndRawSpanTest { private static final int DATA_PREPPER_PORT_1 = 21890; @@ -88,7 +88,59 @@ public class EndToEndRawSpanTest { @Test public void testPipelineEndToEnd() { - //Send data to otel trace source + final List> expectedDocuments = sendTracesToOpenSearchAndReturnExpectedDocuments(); + final RestHighLevelClient restHighLevelClient = createRestClientForSearch(); + final SearchSourceBuilder sourceBuilder = createSearchSourceBuilder(); + final SearchRequest searchRequest = new SearchRequest(INDEX_NAME).source(sourceBuilder); + + // Wait for data to flow through pipeline and be indexed by ES + await().atLeast(3, TimeUnit.SECONDS).atMost(30, TimeUnit.SECONDS).untilAsserted( + () -> { + refreshIndices(restHighLevelClient); + final SearchResponse searchResponse = restHighLevelClient.search(searchRequest, RequestOptions.DEFAULT); + final List> foundSources = getSourcesFromSearchHits(searchResponse.getHits()); + + assertThat(foundSources).hasSize(expectedDocuments.size()); + assertThatFoundDocumentsContainAllFieldsFromExpectedDocuments(expectedDocuments, foundSources); + } + ); + } + + private void assertThatFoundDocumentsContainAllFieldsFromExpectedDocuments(List> expectedDocuments, List> foundDocuments) { + /** + * Our raw trace prepper add more fields than the actual sent object. These are defaults from the proto. + * So assertion is done if all the expected fields exists. + * + * TODO: Can we do better? + */ + expectedDocuments.forEach(expectedDoc -> { + Set> foundEntrySet = foundDocuments.stream() + .filter(i -> i.get("spanId").equals(expectedDoc.get("spanId"))) + .findFirst().get() + .entrySet(); + + assertThat(foundEntrySet).containsAll(expectedDoc.entrySet()); + }); + + } + + private SearchSourceBuilder createSearchSourceBuilder() { + return SearchSourceBuilder.searchSource() + .size(100) + .fetchField(TraceGroup.TRACE_GROUP_STATUS_CODE_FIELD) + .fetchField(TraceGroup.TRACE_GROUP_END_TIME_FIELD, "strict_date_time") + .fetchField(TraceGroup.TRACE_GROUP_DURATION_IN_NANOS_FIELD); + } + + private RestHighLevelClient createRestClientForSearch() { + return new ConnectionConfiguration.Builder(Collections.singletonList("https://127.0.0.1:9200")) + .withUsername("admin") + .withPassword("admin") + .withInsecure(true) + .build().createClient(null); + } + + private List> sendTracesToOpenSearchAndReturnExpectedDocuments() { final ExportTraceServiceRequest exportTraceServiceRequestTrace1BatchWithRoot = getExportTraceServiceRequest( getResourceSpansBatch(TEST_SPAN_SET_1_WITH_ROOT_SPAN) ); @@ -108,45 +160,9 @@ public void testPipelineEndToEnd() { sendExportTraceServiceRequestToSource(DATA_PREPPER_PORT_2, exportTraceServiceRequestTrace1BatchNoRoot); //Verify data in OpenSearch backend - final List> expectedDocuments = getExpectedDocuments( + return getExpectedDocuments( exportTraceServiceRequestTrace1BatchWithRoot, exportTraceServiceRequestTrace1BatchNoRoot, exportTraceServiceRequestTrace2BatchWithRoot, exportTraceServiceRequestTrace2BatchNoRoot); - final ConnectionConfiguration.Builder builder = new ConnectionConfiguration.Builder( - Collections.singletonList("https://127.0.0.1:9200")); - builder.withUsername("admin"); - builder.withPassword("admin"); - builder.withInsecure(true); - final RestHighLevelClient restHighLevelClient = builder.build().createClient(null); - // Wait for data to flow through pipeline and be indexed by ES - await().atLeast(3, TimeUnit.SECONDS).atMost(20, TimeUnit.SECONDS).untilAsserted( - () -> { - refreshIndices(restHighLevelClient); - final SearchRequest searchRequest = new SearchRequest(INDEX_NAME); - searchRequest.source( - SearchSourceBuilder.searchSource() - .size(100) - .fetchField(TraceGroup.TRACE_GROUP_STATUS_CODE_FIELD) - .fetchField(TraceGroup.TRACE_GROUP_END_TIME_FIELD, "strict_date_time") - .fetchField(TraceGroup.TRACE_GROUP_DURATION_IN_NANOS_FIELD) - ); - final SearchResponse searchResponse = restHighLevelClient.search(searchRequest, RequestOptions.DEFAULT); - final List> foundSources = getSourcesFromSearchHits(searchResponse.getHits()); - assertEquals(expectedDocuments.size(), foundSources.size()); - /** - * Our raw trace prepper add more fields than the actual sent object. These are defaults from the proto. - * So assertion is done if all the expected fields exists. - * - * TODO: Can we do better? - * - */ - expectedDocuments.forEach(expectedDoc -> { - assertTrue(foundSources.stream() - .filter(i -> i.get("spanId").equals(expectedDoc.get("spanId"))) - .findFirst().get() - .entrySet().containsAll(expectedDoc.entrySet())); - }); - } - ); } private void refreshIndices(final RestHighLevelClient restHighLevelClient) throws IOException { diff --git a/settings.gradle b/settings.gradle index 633ee20214..9fb1d6c80b 100644 --- a/settings.gradle +++ b/settings.gradle @@ -80,6 +80,7 @@ dependencyResolutionManagement { version('awaitility', '4.2.0') version('spring', '5.3.28') version('slf4j', '2.0.6') + version('assertj', '3.27.3') library('junit-core', 'org.junit.jupiter', 'junit-jupiter').versionRef('junit') library('junit-params', 'org.junit.jupiter', 'junit-jupiter-params').versionRef('junit') library('junit-engine', 'org.junit.jupiter', 'junit-jupiter-engine').versionRef('junit') @@ -93,6 +94,7 @@ dependencyResolutionManagement { library('awaitility', 'org.awaitility', 'awaitility').versionRef('awaitility') library('spring-test', 'org.springframework', 'spring-test').versionRef('spring') library('slf4j-simple', 'org.slf4j', 'slf4j-simple').versionRef('slf4j') + library('assertj', 'org.assertj', 'assertj-core').versionRef('assertj') } } }