Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ public interface KnownAddresses {
Address<Long> REQUEST_COMBINED_FILE_SIZE =
new Address<>("server.request.body.combined_file_size");

/**
* Contains the content of each uploaded file in a multipart/form-data request. Each entry in the
* list corresponds positionally to {@link #REQUEST_FILES_FILENAMES}. Content is truncated to a
* maximum size to avoid excessive memory usage. Available only on inspected multipart/form-data
* requests.
*/
Address<List<String>> REQUEST_FILES_CONTENT = new Address<>("server.request.body.files_content");

/**
* The parsed query string.
*
Expand Down Expand Up @@ -205,6 +213,8 @@ static Address<?> forName(String name) {
return REQUEST_FILES_FILENAMES;
case "server.request.body.combined_file_size":
return REQUEST_COMBINED_FILE_SIZE;
case "server.request.body.files_content":
return REQUEST_FILES_CONTENT;
case "server.request.query":
return REQUEST_QUERY;
case "server.request.headers.no_cookies":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ public class GatewayBridge {
private volatile DataSubscriberInfo execCmdSubInfo;
private volatile DataSubscriberInfo shellCmdSubInfo;
private volatile DataSubscriberInfo requestFilesFilenamesSubInfo;
private volatile DataSubscriberInfo requestFilesContentSubInfo;

public GatewayBridge(
SubscriptionService subscriptionService,
Expand Down Expand Up @@ -208,6 +209,10 @@ public void init() {
subscriptionService.registerCallback(
EVENTS.requestFilesFilenames(), this::onRequestFilesFilenames);
}
if (additionalIGEvents.contains(EVENTS.requestFilesContent())) {
subscriptionService.registerCallback(
EVENTS.requestFilesContent(), this::onRequestFilesContent);
}
}

/**
Expand Down Expand Up @@ -235,6 +240,7 @@ public void reset() {
execCmdSubInfo = null;
shellCmdSubInfo = null;
requestFilesFilenamesSubInfo = null;
requestFilesContentSubInfo = null;
}

private Flow<Void> onUser(final RequestContext ctx_, final String user) {
Expand Down Expand Up @@ -605,6 +611,31 @@ private Flow<Void> onRequestFilesFilenames(RequestContext ctx_, List<String> fil
}
}

private Flow<Void> onRequestFilesContent(RequestContext ctx_, List<String> filesContent) {
AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC);
if (ctx == null || filesContent == null || filesContent.isEmpty()) {
return NoopFlow.INSTANCE;
}
while (true) {
DataSubscriberInfo subInfo = requestFilesContentSubInfo;
if (subInfo == null) {
subInfo = producerService.getDataSubscribers(KnownAddresses.REQUEST_FILES_CONTENT);
requestFilesContentSubInfo = subInfo;
}
if (subInfo == null || subInfo.isEmpty()) {
return NoopFlow.INSTANCE;
}
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.REQUEST_FILES_CONTENT, filesContent);
try {
GatewayContext gwCtx = new GatewayContext(false);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
requestFilesContentSubInfo = null;
}
}
}

private Flow<Void> onDatabaseSqlQuery(RequestContext ctx_, String sql) {
AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC);
if (ctx == null) {
Expand Down Expand Up @@ -1464,6 +1495,7 @@ private static class IGAppSecEventDependencies {
DATA_DEPENDENCIES.put(KnownAddresses.REQUEST_BODY_OBJECT, l(EVENTS.requestBodyProcessed()));
DATA_DEPENDENCIES.put(
KnownAddresses.REQUEST_FILES_FILENAMES, l(EVENTS.requestFilesFilenames()));
DATA_DEPENDENCIES.put(KnownAddresses.REQUEST_FILES_CONTENT, l(EVENTS.requestFilesContent()));
}

private static Collection<datadog.trace.api.gateway.EventType<?>> l(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class KnownAddressesSpecificationForkedTest extends Specification {
'server.request.body.files_field_names',
'server.request.body.filenames',
'server.request.body.combined_file_size',
'server.request.body.files_content',
'server.request.query',
'server.request.headers.no_cookies',
'grpc.server.method',
Expand Down Expand Up @@ -58,7 +59,7 @@ class KnownAddressesSpecificationForkedTest extends Specification {

void 'number of known addresses is expected number'() {
expect:
Address.instanceCount() == 46
Address.instanceCount() == 47
KnownAddresses.WAF_CONTEXT_PROCESSOR.serial == Address.instanceCount() - 1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,15 @@ class GatewayBridgeIGRegistrationSpecification extends DDSpecification {
then:
1 * ig.registerCallback(Events.REQUEST_BODY_DONE, _)
}

void 'requestFilesContent is registered via data address'() {
given:
1 * eventDispatcher.allSubscribedDataAddresses() >> [KnownAddresses.REQUEST_FILES_CONTENT]

when:
bridge.init()

then:
1 * ig.registerCallback(Events.get().requestFilesContent(), _)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class GatewayBridgeSpecification extends DDSpecification {
BiFunction<RequestContext, String, Flow<Void>> fileLoadedCB
BiFunction<RequestContext, String, Flow<Void>> fileWrittenCB
BiFunction<RequestContext, List<String>, Flow<Void>> requestFilesFilenamesCB
BiFunction<RequestContext, List<String>, Flow<Void>> requestFilesContentCB
BiFunction<RequestContext, String, Flow<Void>> requestSessionCB
BiFunction<RequestContext, String[], Flow<Void>> execCmdCB
BiFunction<RequestContext, String, Flow<Void>> shellCmdCB
Expand Down Expand Up @@ -463,7 +464,7 @@ class GatewayBridgeSpecification extends DDSpecification {

void callInitAndCaptureCBs() {
// force all callbacks to be registered
_ * eventDispatcher.allSubscribedDataAddresses() >> [KnownAddresses.REQUEST_PATH_PARAMS, KnownAddresses.REQUEST_BODY_OBJECT, KnownAddresses.REQUEST_FILES_FILENAMES]
_ * eventDispatcher.allSubscribedDataAddresses() >> [KnownAddresses.REQUEST_PATH_PARAMS, KnownAddresses.REQUEST_BODY_OBJECT, KnownAddresses.REQUEST_FILES_FILENAMES, KnownAddresses.REQUEST_FILES_CONTENT]

1 * ig.registerCallback(EVENTS.requestStarted(), _) >> {
requestStartedCB = it[1]; null
Expand Down Expand Up @@ -561,6 +562,9 @@ class GatewayBridgeSpecification extends DDSpecification {
1 * ig.registerCallback(EVENTS.requestFilesFilenames(), _) >> {
requestFilesFilenamesCB = it[1]; null
}
1 * ig.registerCallback(EVENTS.requestFilesContent(), _) >> {
requestFilesContentCB = it[1]; null
}
0 * ig.registerCallback(_, _)

bridge.init()
Expand Down Expand Up @@ -1142,6 +1146,38 @@ class GatewayBridgeSpecification extends DDSpecification {
0 * eventDispatcher.publishDataEvent(*_)
}

void 'process request files content'() {
setup:
final filesContent = ['%PDF-1.4 malicious content', '#!/bin/bash\nrm -rf /']
eventDispatcher.getDataSubscribers({
KnownAddresses.REQUEST_FILES_CONTENT in it
}) >> nonEmptyDsInfo
DataBundle bundle
GatewayContext gatewayContext

when:
Flow<?> flow = requestFilesContentCB.apply(ctx, filesContent)

then:
1 * eventDispatcher.publishDataEvent(nonEmptyDsInfo, ctx.data, _ as DataBundle, _ as GatewayContext) >> {
a, b, db, gw -> bundle = db; gatewayContext = gw; NoopFlow.INSTANCE
}
bundle.get(KnownAddresses.REQUEST_FILES_CONTENT) == filesContent
flow.result == null
flow.action == Flow.Action.Noop.INSTANCE
gatewayContext.isTransient == false
gatewayContext.isRasp == false
}

void 'process request files content with empty list returns noop'() {
when:
Flow<?> flow = requestFilesContentCB.apply(ctx, [])

then:
flow == NoopFlow.INSTANCE
0 * eventDispatcher.publishDataEvent(*_)
}

void 'process exec cmd'() {
setup:
final cmd = ['/bin/../usr/bin/reboot', '-f'] as String[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ public String instrumentedType() {
return "org.apache.commons.fileupload.servlet.ServletFileUpload";
}

@Override
public String[] helperClassNames() {
return new String[] {
"datadog.trace.instrumentation.commons.fileupload.FileItemContentReader",
};
}

@Override
public void methodAdvice(MethodTransformer transformer) {
transformer.applyAdvice(
Expand All @@ -47,6 +54,7 @@ public void methodAdvice(MethodTransformer transformer) {

@RequiresRequestContext(RequestContextSlot.APPSEC)
public static class ParseRequestAdvice {

@Advice.OnMethodExit(suppress = Throwable.class, onThrowable = Throwable.class)
static void after(
@Advice.Return final List<FileItem> fileItems,
Expand All @@ -57,9 +65,11 @@ static void after(
}

CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC);
BiFunction<RequestContext, List<String>, Flow<Void>> callback =
BiFunction<RequestContext, List<String>, Flow<Void>> filenamesCallback =
cbp.getCallback(EVENTS.requestFilesFilenames());
if (callback == null) {
BiFunction<RequestContext, List<String>, Flow<Void>> contentCallback =
cbp.getCallback(EVENTS.requestFilesContent());
if (filenamesCallback == null && contentCallback == null) {
return;
}

Expand All @@ -77,14 +87,48 @@ static void after(
return;
}

Flow<Void> flow = callback.apply(reqCtx, filenames);
Flow.Action action = flow.getAction();
if (action instanceof Flow.Action.RequestBlockingAction) {
Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) action;
// Fire filenames event
if (filenamesCallback != null) {
Flow<Void> flow = filenamesCallback.apply(reqCtx, filenames);
Flow.Action action = flow.getAction();
if (action instanceof Flow.Action.RequestBlockingAction) {
Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) action;
BlockResponseFunction brf = reqCtx.getBlockResponseFunction();
if (brf != null) {
brf.tryCommitBlockingResponse(reqCtx.getTraceSegment(), rba);
t = new BlockingException("Blocked request (multipart file upload)");
reqCtx.getTraceSegment().effectivelyBlocked();
return;
}
}
}

// Fire content event only if not blocked
if (contentCallback == null) {
return;
}
List<String> filesContent = new ArrayList<>();
for (FileItem fileItem : fileItems) {
if (fileItem.isFormField()) {
continue;
}
String name = fileItem.getName();
if (name == null || name.isEmpty()) {
continue;
}
filesContent.add(FileItemContentReader.readContent(fileItem));
}
if (filesContent.isEmpty()) {
return;
}
Flow<Void> contentFlow = contentCallback.apply(reqCtx, filesContent);
Flow.Action contentAction = contentFlow.getAction();
if (contentAction instanceof Flow.Action.RequestBlockingAction) {
Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) contentAction;
BlockResponseFunction brf = reqCtx.getBlockResponseFunction();
if (brf != null) {
brf.tryCommitBlockingResponse(reqCtx.getTraceSegment(), rba);
t = new BlockingException("Blocked request (multipart file upload)");
t = new BlockingException("Blocked request (multipart file upload content)");
reqCtx.getTraceSegment().effectivelyBlocked();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package datadog.trace.instrumentation.commons.fileupload;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import org.apache.commons.fileupload.FileItem;

/** Helper class injected into the application classloader by the AppSec instrumentation. */
public final class FileItemContentReader {
public static final int MAX_CONTENT_BYTES = 4096;

public static String readContent(FileItem fileItem) {
try (InputStream is = fileItem.getInputStream()) {
byte[] buf = new byte[MAX_CONTENT_BYTES];
int total = 0;
int n;
while (total < MAX_CONTENT_BYTES
&& (n = is.read(buf, total, MAX_CONTENT_BYTES - total)) != -1) {
total += n;
}
return new String(buf, 0, total, StandardCharsets.ISO_8859_1);
} catch (IOException ignored) {
return "";
}
}

private FileItemContentReader() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import datadog.trace.instrumentation.commons.fileupload.FileItemContentReader
import org.apache.commons.fileupload.FileItem
import spock.lang.Specification

class CommonsFileUploadAppSecModuleTest extends Specification {

def "readContent returns full content when smaller than limit"() {
given:
def content = 'Hello, World!'
def item = fileItem(content)

expect:
FileItemContentReader.readContent(item) == content
}

def "readContent truncates content to MAX_CONTENT_BYTES"() {
given:
def largeContent = 'X' * (FileItemContentReader.MAX_CONTENT_BYTES + 500)
def item = fileItem(largeContent)

when:
def result = FileItemContentReader.readContent(item)

then:
result.length() == FileItemContentReader.MAX_CONTENT_BYTES
result == 'X' * FileItemContentReader.MAX_CONTENT_BYTES
}

def "readContent returns empty string when getInputStream throws"() {
given:
FileItem item = Stub(FileItem)
item.getInputStream() >> { throw new IOException('simulated error') }

expect:
FileItemContentReader.readContent(item) == ''
}

def "readContent returns empty string for empty content"() {
given:
def item = fileItem('')

expect:
FileItemContentReader.readContent(item) == ''
}

def "readContent reads exactly MAX_CONTENT_BYTES when content equals the limit"() {
given:
def content = 'A' * FileItemContentReader.MAX_CONTENT_BYTES
def item = fileItem(content)

when:
def result = FileItemContentReader.readContent(item)

then:
result.length() == FileItemContentReader.MAX_CONTENT_BYTES
result == content
}

private FileItem fileItem(String content) {
FileItem item = Stub(FileItem)
item.getInputStream() >> new ByteArrayInputStream(content.getBytes('ISO-8859-1'))
return item
}
}
Loading
Loading