diff --git a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java index 31769782d2..7631e3af10 100644 --- a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java +++ b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSource.java @@ -30,6 +30,7 @@ import org.slf4j.LoggerFactory; import java.util.Collections; +import java.util.List; import java.util.concurrent.ExecutionException; @DataPrepperPlugin(name = "http", pluginType = Source.class, pluginConfigurationType = HTTPSourceConfig.class) @@ -50,6 +51,8 @@ public class HTTPSource implements Source> { private static final String HTTP_HEALTH_CHECK_PATH = "/health"; private ByteDecoder byteDecoder; private final InputCodec codec; + private final List metadataHeaders; + private final HttpHeaderExtractor httpHeaderExtractor; @DataPrepperPluginConstructor public HTTPSource(final HTTPSourceConfig sourceConfig, final PluginMetrics pluginMetrics, final PluginFactory pluginFactory, @@ -59,6 +62,7 @@ public HTTPSource(final HTTPSourceConfig sourceConfig, final PluginMetrics plugi this.pipelineName = pipelineDescription.getPipelineName(); this.byteDecoder = new JsonDecoder(); this.certificateProviderFactory = new CertificateProviderFactory(sourceConfig); + this.metadataHeaders = sourceConfig.getMetadataHeaders(); final PluginModel authenticationConfiguration = sourceConfig.getAuthentication(); final PluginSetting authenticationPluginSetting; @@ -84,6 +88,7 @@ public HTTPSource(final HTTPSourceConfig sourceConfig, final PluginMetrics plugi final PluginSetting codecPluginSettings = new PluginSetting(codecConfiguration.getPluginName(), codecConfiguration.getPluginSettings()); codec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSettings); } + httpHeaderExtractor = new HttpHeaderExtractor(metadataHeaders); } @Override @@ -94,7 +99,7 @@ public void start(final Buffer> buffer) { if (server == null) { ServerConfiguration serverConfiguration = ConvertConfiguration.convertConfiguration(sourceConfig); CreateServer createServer = new CreateServer(serverConfiguration, LOG, pluginMetrics, PLUGIN_NAME, pipelineName); - final LogHTTPService logHTTPService = new LogHTTPService(serverConfiguration.getBufferTimeoutInMillis(), buffer, pluginMetrics, codec); + final LogHTTPService logHTTPService = new LogHTTPService(serverConfiguration.getBufferTimeoutInMillis(), buffer, pluginMetrics, codec, httpHeaderExtractor); server = createServer.createHTTPServer(buffer, certificateProviderFactory, authenticationProvider, httpRequestExceptionHandler, logHTTPService); pluginMetrics.gauge(SERVER_CONNECTIONS, server, Server::numConnections); } diff --git a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceConfig.java b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceConfig.java index fc83d59a2b..a0a9074f6a 100644 --- a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceConfig.java +++ b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceConfig.java @@ -9,6 +9,9 @@ import org.opensearch.dataprepper.http.BaseHttpServerConfig; import org.opensearch.dataprepper.model.configuration.PluginModel; +import java.util.Collections; +import java.util.List; + public class HTTPSourceConfig extends BaseHttpServerConfig { static final String DEFAULT_LOG_INGEST_URI = "/log/ingest"; @@ -27,7 +30,15 @@ public String getDefaultPath() { @JsonProperty("codec") private PluginModel codec; + @JsonProperty("metadata_headers") + private List metadataHeaders = Collections.emptyList(); + public PluginModel getCodec() { return codec; } + + public List getMetadataHeaders() { + return metadataHeaders; + } + } diff --git a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HttpHeaderExtractor.java b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HttpHeaderExtractor.java new file mode 100644 index 0000000000..78b21977fd --- /dev/null +++ b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/HttpHeaderExtractor.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.loghttp; + +import com.linecorp.armeria.common.AggregatedHttpRequest; + +import javax.annotation.Nonnull; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class HttpHeaderExtractor { + + static final Set SENSITIVE_HEADERS = Set.of( + "authorization", + "proxy-authorization", + "cookie", + "set-cookie", + "www-authenticate", + "proxy-authenticate", + "x-api-key", + "x-csrf-token", + "x-xsrf-token", + "x-auth-token", + "x-amz-security-token", + "x-amz-credential" + ); + + private final Collection metadataHeaders; + + public HttpHeaderExtractor(@Nonnull final Collection metadataHeaders) { + this.metadataHeaders = metadataHeaders; + } + + public Map extractHeaders(final AggregatedHttpRequest aggregatedHttpRequest) { + if (metadataHeaders.isEmpty()) { + return Collections.emptyMap(); + } + + final Set headerNames = metadataHeaders.stream() + .map(String::toLowerCase) + .collect(Collectors.toCollection(LinkedHashSet::new)); + + final Map headers = new HashMap<>(); + for (String headerName : headerNames) { + if (isSensitiveHeader(headerName)) { + continue; + } + List values = aggregatedHttpRequest.headers().getAll(headerName); + if (!values.isEmpty()) { + headers.put(headerName, values.size() == 1 ? values.get(0) : Collections.unmodifiableList(values)); + } + } + + return headers; + } + + static boolean isSensitiveHeader(final String headerName) { + return SENSITIVE_HEADERS.contains(headerName.toLowerCase()); + } +} diff --git a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPService.java b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPService.java index 4125bdb1da..c309ac9e9a 100644 --- a/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPService.java +++ b/data-prepper-plugins/http-source/src/main/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPService.java @@ -28,13 +28,15 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.Collections; +import java.util.Map; import java.util.UUID; import java.util.stream.Collectors; - /* * A HTTP service for log ingestion to be executed by BlockingTaskExecutor. */ + @Blocking public class LogHTTPService { private static final int SERIALIZATION_OVERHEAD = 1024; @@ -60,16 +62,19 @@ public class LogHTTPService { private final Timer requestProcessDuration; private Integer bufferMaxRequestLength; private Integer bufferOptimalRequestLength; + private final HttpHeaderExtractor httpHeaderExtractor; public LogHTTPService(final int bufferWriteTimeoutInMillis, final Buffer> buffer, final PluginMetrics pluginMetrics, - final InputCodec codec) { + final InputCodec codec, + final HttpHeaderExtractor httpHeaderExtractor) { this.buffer = buffer; this.bufferWriteTimeoutInMillis = bufferWriteTimeoutInMillis; this.bufferMaxRequestLength = buffer.getMaxRequestSize().isPresent() ? buffer.getMaxRequestSize().get(): null; this.bufferOptimalRequestLength = buffer.getOptimalRequestSize().isPresent() ? buffer.getOptimalRequestSize().get(): null; this.codec = codec; + this.httpHeaderExtractor = httpHeaderExtractor; requestsReceivedCounter = pluginMetrics.counter(REQUESTS_RECEIVED); successRequestsCounter = pluginMetrics.counter(SUCCESS_REQUESTS); requestsOverOptimalSizeCounter = pluginMetrics.counter(REQUESTS_OVER_OPTIMAL_SIZE); @@ -78,6 +83,13 @@ public LogHTTPService(final int bufferWriteTimeoutInMillis, requestProcessDuration = pluginMetrics.timer(REQUEST_PROCESS_DURATION); } + public LogHTTPService(final int bufferWriteTimeoutInMillis, + final Buffer> buffer, + final PluginMetrics pluginMetrics, + final InputCodec codec) { + this(bufferWriteTimeoutInMillis, buffer, pluginMetrics, codec, new HttpHeaderExtractor(Collections.emptySet())); + } + @Post public HttpResponse doPost(final ServiceRequestContext serviceRequestContext, final AggregatedHttpRequest aggregatedHttpRequest) throws Exception { requestsReceivedCounter.increment(); @@ -92,6 +104,7 @@ public HttpResponse doPost(final ServiceRequestContext serviceRequestContext, fi HttpResponse processRequest(final AggregatedHttpRequest aggregatedHttpRequest) throws Exception { final HttpData content = aggregatedHttpRequest.content(); + final Map extractedHeaders = Collections.unmodifiableMap(httpHeaderExtractor.extractHeaders(aggregatedHttpRequest)); if (buffer.isByteBuffer()) { if (bufferMaxRequestLength != null && bufferOptimalRequestLength != null && content.array().length > bufferOptimalRequestLength) { @@ -140,6 +153,12 @@ HttpResponse processRequest(final AggregatedHttpRequest aggregatedHttpRequest) t ); } + if (!extractedHeaders.isEmpty()) { + for (Record record : records) { + record.getData().getMetadata().setAttribute("headers", extractedHeaders); + } + } + try { buffer.writeAll(records, bufferWriteTimeoutInMillis); } catch (Exception e) { @@ -171,13 +190,10 @@ private void writeChunkedBody(final String chunk) { } } - private Record buildRecordLog(String json) { - - final JacksonLog log = JacksonLog.builder() + private Record buildRecordLog(final String json) { + final JacksonLog.Builder builder = JacksonLog.builder() .withData(json) - .getThis() - .build(); - - return new Record<>(log); + .getThis(); + return new Record<>(builder.build()); } } diff --git a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceConfigTest.java b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceConfigTest.java index 70051bff18..127f05023a 100644 --- a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceConfigTest.java +++ b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceConfigTest.java @@ -5,12 +5,18 @@ package org.opensearch.dataprepper.plugins.source.loghttp; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.jupiter.api.Test; +import java.util.Collections; +import java.util.List; + import static org.junit.jupiter.api.Assertions.assertEquals; public class HTTPSourceConfigTest { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + @Test void testDefault() { // Prepare @@ -21,5 +27,14 @@ void testDefault() { assertEquals(HTTPSourceConfig.DEFAULT_LOG_INGEST_URI, sourceConfig.getPath()); assertEquals(HTTPSourceConfig.DEFAULT_PORT, sourceConfig.getDefaultPort()); assertEquals(HTTPSourceConfig.DEFAULT_LOG_INGEST_URI, sourceConfig.getDefaultPath()); + assertEquals(sourceConfig.getMetadataHeaders(), Collections.emptyList()); + } + + @Test + void testSetMetadataHeaders() throws Exception { + final String json = "{\"metadata_headers\": [\"X-Tenant-Id\", \"X-Region\"]}"; + final HTTPSourceConfig sourceConfig = OBJECT_MAPPER.readValue(json, HTTPSourceConfig.class); + + assertEquals(List.of("X-Tenant-Id", "X-Region"), sourceConfig.getMetadataHeaders()); } } diff --git a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceTest.java b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceTest.java index 7f073d2311..8f398fa605 100644 --- a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceTest.java +++ b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HTTPSourceTest.java @@ -75,6 +75,7 @@ import java.util.Map; import java.util.Random; import java.util.StringJoiner; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; @@ -92,6 +93,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -1020,6 +1022,58 @@ public void testHTTPJsonCodec() throws IOException { assertEquals(testPayloadSize, payloadSizeMax.getValue()); } + @Test + public void testHTTPJsonResponse200WithMetadataHeaders() throws JsonProcessingException { + final String tenantId = UUID.randomUUID().toString(); + final String testData = "[{\"log\": \"somelog\"}]"; + + when(sourceConfig.getMetadataHeaders()).thenReturn(List.of("X-Tenant-Id")); + HTTPSourceUnderTest = new HTTPSource(sourceConfig, pluginMetrics, pluginFactory, pipelineDescription); + testBuffer = getBuffer(1, 1); + HTTPSourceUnderTest.start(testBuffer); + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:2021") + .method(HttpMethod.POST) + .path("/log/ingest") + .contentType(MediaType.JSON_UTF_8) + .add("X-Tenant-Id", tenantId) + .build(), + HttpData.ofUtf8(testData)) + .aggregate() + .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.OK)).join(); + + final Map.Entry>, CheckpointState> result = testBuffer.read(100); + List> records = new ArrayList<>(result.getKey()); + assertEquals(1, records.size()); + assertEquals(tenantId, records.get(0).getData().getMetadata().getAttribute("headers/x-tenant-id")); + } + + @Test + public void testHTTPJsonResponse200WithNoMetadataHeaders() { + final String testData = "[{\"log\": \"somelog\"}]"; + + HTTPSourceUnderTest.start(testBuffer); + + WebClient.of().execute(RequestHeaders.builder() + .scheme(SessionProtocol.HTTP) + .authority("127.0.0.1:2021") + .method(HttpMethod.POST) + .path("/log/ingest") + .contentType(MediaType.JSON_UTF_8) + .add("X-Tenant-Id", UUID.randomUUID().toString()) + .build(), + HttpData.ofUtf8(testData)) + .aggregate() + .whenComplete((i, ex) -> assertSecureResponseWithStatusCode(i, HttpStatus.OK)).join(); + + final Map.Entry>, CheckpointState> result = testBuffer.read(100); + List> records = new ArrayList<>(result.getKey()); + assertEquals(1, records.size()); + assertNull(records.get(0).getData().getMetadata().getAttribute("headers/x-tenant-id")); + } + private void assertCommonFields(Record record) { assertEquals("111111111111", record.getData().get("owner", String.class)); assertEquals("CloudTrail/logs", record.getData().get("logGroup", String.class)); diff --git a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HttpHeaderExtractorTest.java b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HttpHeaderExtractorTest.java new file mode 100644 index 0000000000..efa49830dc --- /dev/null +++ b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/HttpHeaderExtractorTest.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.loghttp; + +import com.linecorp.armeria.common.AggregatedHttpRequest; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.MediaType; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestHeadersBuilder; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ExecutionException; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anEmptyMap; + +class HttpHeaderExtractorTest { + + @ParameterizedTest + @CsvSource({ + "AUTHORIZATION, true", + "proxy-authorization, true", + "X-AMZ-SECURITY-TOKEN, true", + "cookie, true", + "set-cookie, true", + "x-api-key, true", + "x-csrf-token, true", + "x-auth-token, true", + "X-Tenant-Id, false", + "Content-Type, false", + "X-Request-Id, false" + }) + void testIsSensitiveHeader(String headerName, boolean expected) { + assertThat(HttpHeaderExtractor.isSensitiveHeader(headerName), equalTo(expected)); + } + + @Test + void extractHeaders_returnsEmptyMap_whenMetadataHeadersIsEmpty() throws Exception { + final HttpHeaderExtractor extractor = new HttpHeaderExtractor(Collections.emptyList()); + final AggregatedHttpRequest request = buildRequest(1, Map.of("X-Tenant-Id", "test")); + + assertThat(extractor.extractHeaders(request), anEmptyMap()); + } + + @Test + void extractHeaders_extractsConfiguredHeaders() throws Exception { + final String tenantId = UUID.randomUUID().toString(); + final String region = UUID.randomUUID().toString(); + final HttpHeaderExtractor extractor = new HttpHeaderExtractor(List.of("X-Tenant-Id", "X-Region")); + final AggregatedHttpRequest request = buildRequest(1, + Map.of("X-Tenant-Id", tenantId, "X-Region", region)); + + final Map headers = extractor.extractHeaders(request); + + assertThat(headers.get("x-tenant-id"), equalTo(tenantId)); + assertThat(headers.get("x-region"), equalTo(region)); + } + + @Test + void extractHeaders_ignoresNonConfiguredHeaders() throws Exception { + final HttpHeaderExtractor extractor = new HttpHeaderExtractor(List.of("X-Tenant-Id")); + final AggregatedHttpRequest request = buildRequest(1, + Map.of("X-Tenant-Id", "val1", "X-Other", "val2")); + + final Map headers = extractor.extractHeaders(request); + + assertThat(headers.size(), equalTo(1)); + assertThat(headers.containsKey("x-other"), equalTo(false)); + } + + @Test + void extractHeaders_ignoresMissingConfiguredHeaders() throws Exception { + final HttpHeaderExtractor extractor = new HttpHeaderExtractor(List.of("X-Tenant-Id", "X-Missing")); + final AggregatedHttpRequest request = buildRequest(1, Map.of("X-Tenant-Id", "val1")); + + final Map headers = extractor.extractHeaders(request); + + assertThat(headers.size(), equalTo(1)); + assertThat(headers.containsKey("x-missing"), equalTo(false)); + } + + @Test + void extractHeaders_filtersSensitiveHeaders() throws Exception { + final String tenantId = UUID.randomUUID().toString(); + final String authValue = UUID.randomUUID().toString(); + final HttpHeaderExtractor extractor = new HttpHeaderExtractor(List.of("X-Tenant-Id", "authorization")); + final AggregatedHttpRequest request = buildRequest(1, + Map.of("X-Tenant-Id", tenantId, "authorization", authValue)); + + final Map headers = extractor.extractHeaders(request); + + assertThat(headers.get("x-tenant-id"), equalTo(tenantId)); + assertThat(headers.containsKey("authorization"), equalTo(false)); + } + + @Test + void extractHeaders_normalizesHeaderKeysToLowercase() throws Exception { + final String value = UUID.randomUUID().toString(); + final HttpHeaderExtractor extractor = new HttpHeaderExtractor(List.of("X-Tenant-Id")); + final AggregatedHttpRequest request = buildRequest(1, Map.of("X-Tenant-Id", value)); + + final Map headers = extractor.extractHeaders(request); + + assertThat(headers.containsKey("x-tenant-id"), equalTo(true)); + assertThat(headers.get("x-tenant-id"), equalTo(value)); + } + + @Test + void extractHeaders_storesMultiValueHeaderAsList() throws Exception { + final String ip1 = UUID.randomUUID().toString(); + final String ip2 = UUID.randomUUID().toString(); + final HttpHeaderExtractor extractor = new HttpHeaderExtractor(List.of("X-Forwarded-For")); + + RequestHeadersBuilder headersBuilder = RequestHeaders.builder() + .contentType(MediaType.JSON) + .method(HttpMethod.POST) + .path("/log/ingest") + .add("X-Forwarded-For", ip1) + .add("X-Forwarded-For", ip2); + AggregatedHttpRequest request = HttpRequest.of(headersBuilder.build(), + HttpData.ofUtf8("[{\"log\":\"test\"}]")).aggregate().get(); + + final Map headers = extractor.extractHeaders(request); + + assertThat(headers.get("x-forwarded-for"), equalTo(List.of(ip1, ip2))); + } + + @Test + void extractHeaders_storesSingleValueHeaderAsString() throws Exception { + final String value = UUID.randomUUID().toString(); + final HttpHeaderExtractor extractor = new HttpHeaderExtractor(List.of("X-Tenant-Id")); + final AggregatedHttpRequest request = buildRequest(1, Map.of("X-Tenant-Id", value)); + + final Map headers = extractor.extractHeaders(request); + + assertThat(headers.get("x-tenant-id") instanceof String, equalTo(true)); + } + + private AggregatedHttpRequest buildRequest(int numJson, Map customHeaders) + throws ExecutionException, InterruptedException { + RequestHeadersBuilder headersBuilder = RequestHeaders.builder() + .contentType(MediaType.JSON) + .method(HttpMethod.POST) + .path("/log/ingest"); + for (Map.Entry entry : customHeaders.entrySet()) { + headersBuilder.add(entry.getKey(), entry.getValue()); + } + StringBuilder sb = new StringBuilder("["); + for (int i = 0; i < numJson; i++) { + if (i > 0) sb.append(","); + sb.append("{\"log\":\"").append(UUID.randomUUID()).append("\"}"); + } + sb.append("]"); + return HttpRequest.of(headersBuilder.build(), HttpData.ofUtf8(sb.toString())).aggregate().get(); + } +} diff --git a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceTest.java b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceTest.java index 4c9e9691c8..7ee286a058 100644 --- a/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceTest.java +++ b/data-prepper-plugins/http-source/src/test/java/org/opensearch/dataprepper/plugins/source/loghttp/LogHTTPServiceTest.java @@ -16,6 +16,7 @@ import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.MediaType; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestHeadersBuilder; import com.linecorp.armeria.server.ServiceRequestContext; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.DistributionSummary; @@ -33,6 +34,8 @@ import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.buffer.SizeOverflowException; import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.log.JacksonLog; import org.opensearch.dataprepper.model.log.Log; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.plugins.buffer.blockingbuffer.BlockingBuffer; @@ -50,6 +53,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; +import java.util.function.Consumer; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; @@ -60,6 +64,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -99,6 +104,12 @@ class LogHTTPServiceTest { @Mock private InputCodec codec; + @Mock + private Buffer> blockingBuffer; + + @Mock + private HttpHeaderExtractor httpHeaderExtractor; + @BeforeEach public void setUp() throws Exception { when(pluginMetrics.counter(LogHTTPService.REQUESTS_RECEIVED)).thenReturn(requestsReceivedCounter); @@ -390,6 +401,59 @@ void chunking_with_4mb() throws Exception { } } + + @Test + public void processRequestAttachesHeadersToEventMetadata() throws Exception { + when(blockingBuffer.getMaxRequestSize()).thenReturn(Optional.empty()); + when(blockingBuffer.getOptimalRequestSize()).thenReturn(Optional.empty()); + Map headers = Map.of("x-tenant-id", "tenant-abc", "x-region", "us-west-2"); + when(httpHeaderExtractor.extractHeaders(any(AggregatedHttpRequest.class))).thenReturn(headers); + logHTTPService = new LogHTTPService(TEST_TIMEOUT_IN_MILLIS, blockingBuffer, pluginMetrics, null, httpHeaderExtractor); + AggregatedHttpRequest testRequest = generateRequestWithHeaders(2, Map.of("X-Tenant-Id", "tenant-abc", "X-Region", "us-west-2")); + + logHTTPService.processRequest(testRequest); + + verify(httpHeaderExtractor).extractHeaders(testRequest); + ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); + verify(blockingBuffer).writeAll(captor.capture(), eq(TEST_TIMEOUT_IN_MILLIS)); + List> records = captor.getValue(); + assertThat(records.size(), equalTo(2)); + for (Record record : records) { + assertThat(record.getData().getMetadata().getAttribute("headers/x-tenant-id"), equalTo("tenant-abc")); + assertThat(record.getData().getMetadata().getAttribute("headers/x-region"), equalTo("us-west-2")); + } + } + + @Test + public void processRequestWithCodecAttachesHeadersToEventMetadata() throws Exception { + when(blockingBuffer.getMaxRequestSize()).thenReturn(Optional.empty()); + when(blockingBuffer.getOptimalRequestSize()).thenReturn(Optional.empty()); + Map headers = Map.of("x-tenant-id", "tenant-xyz"); + when(httpHeaderExtractor.extractHeaders(any(AggregatedHttpRequest.class))).thenReturn(headers); + logHTTPService = new LogHTTPService(TEST_TIMEOUT_IN_MILLIS, blockingBuffer, pluginMetrics, codec, httpHeaderExtractor); + + doAnswer(invocation -> { + Consumer> consumer = invocation.getArgument(1); + Log log1 = JacksonLog.builder().withData(Map.of("msg", "log1")).getThis().build(); + Log log2 = JacksonLog.builder().withData(Map.of("msg", "log2")).getThis().build(); + consumer.accept(new Record<>(log1)); + consumer.accept(new Record<>(log2)); + return null; + }).when(codec).parse(any(InputStream.class), any()); + + AggregatedHttpRequest testRequest = generateRequestWithHeaders(1, Map.of("X-Tenant-Id", "tenant-xyz")); + + logHTTPService.processRequest(testRequest); + + verify(httpHeaderExtractor).extractHeaders(testRequest); + ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); + verify(blockingBuffer).writeAll(captor.capture(), eq(TEST_TIMEOUT_IN_MILLIS)); + List> records = captor.getValue(); + assertThat(records.size(), equalTo(2)); + for (Record record : records) { + assertThat(record.getData().getMetadata().getAttribute("headers/x-tenant-id"), equalTo("tenant-xyz")); + } + } private AggregatedHttpRequest generateRandomValidHTTPRequest(int numJson) throws JsonProcessingException, ExecutionException, InterruptedException { @@ -416,4 +480,22 @@ private AggregatedHttpRequest generateBadHTTPRequest() throws ExecutionException HttpData httpData = HttpData.ofUtf8("{"); return HttpRequest.of(requestHeaders, httpData).aggregate().get(); } + + private AggregatedHttpRequest generateRequestWithHeaders(int numJson, Map customHeaders) + throws JsonProcessingException, ExecutionException, InterruptedException { + RequestHeadersBuilder headersBuilder = RequestHeaders.builder() + .contentType(MediaType.JSON) + .method(HttpMethod.POST) + .path("/log/ingest"); + for (Map.Entry entry : customHeaders.entrySet()) { + headersBuilder.add(entry.getKey(), entry.getValue()); + } + List> jsonList = new ArrayList<>(); + for (int i = 0; i < numJson; i++) { + jsonList.add(Collections.singletonMap("log", UUID.randomUUID().toString())); + } + String content = mapper.writeValueAsString(jsonList); + HttpData httpData = HttpData.ofUtf8(content); + return HttpRequest.of(headersBuilder.build(), httpData).aggregate().get(); + } }