Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-dd9f8bf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Avoid extra byte array copying when downloading to memory with AsyncResponseTransformer"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
Expand Down Expand Up @@ -68,6 +69,33 @@ public static <ResponseT> ResponseBytes<ResponseT> fromByteArrayUnsafe(ResponseT
return new ResponseBytes<>(response, bytes);
}

/**
* Creates ResponseBytes from a ByteBuffer without copying the underlying data.
*
* @param response the response object containing metadata
* @param buffer the ByteBuffer containing the response body data
* @return ResponseBytes wrapping the buffer data
*/
public static <ResponseT> ResponseBytes<ResponseT> fromByteBufferUnsafe(ResponseT response, ByteBuffer buffer) {
byte[] array;
if (buffer.hasArray()) {
array = buffer.array();
int offset = buffer.arrayOffset() + buffer.position();
int length = buffer.remaining();
if (offset == 0 && length == array.length) {
// Perfect match - use array directly
} else {
// Create view of the relevant portion
array = Arrays.copyOfRange(array, offset, offset + length);
}
} else {
// Direct buffer - must copy to array
Comment thread
davidh44 marked this conversation as resolved.
array = new byte[buffer.remaining()];
buffer.get(array);
}
return new ResponseBytes<>(response, array);
}

/**
* @return the unmarshalled response object from the service.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@
public final class ByteArrayAsyncResponseTransformer<ResponseT> implements
AsyncResponseTransformer<ResponseT, ResponseBytes<ResponseT>> {

private volatile CompletableFuture<byte[]> cf;
private volatile CompletableFuture<ByteBuffer> cf;
private volatile ResponseT response;

@Override
public CompletableFuture<ResponseBytes<ResponseT>> prepare() {
cf = new CompletableFuture<>();
// Using fromByteArrayUnsafe() to avoid unnecessary extra copying of byte array. The data writing has completed and the
// byte array will not be further modified so this is safe
return cf.thenApply(arr -> ResponseBytes.fromByteArrayUnsafe(response, arr));
// Using fromByteBufferUnsafe() to avoid unnecessary extra copying of byte array. The data writing has completed and the
// byte buffer will not be further modified so this is safe
return cf.thenApply(buffer -> ResponseBytes.fromByteBufferUnsafe(response, buffer));
}

@Override
Expand All @@ -73,13 +73,11 @@ public String name() {
}

static class BaosSubscriber implements Subscriber<ByteBuffer> {
private final CompletableFuture<byte[]> resultFuture;

private ByteArrayOutputStream baos = new ByteArrayOutputStream();

private final CompletableFuture<ByteBuffer> resultFuture;
private DirectAccessByteArrayOutputStream directAccessOutputStream = new DirectAccessByteArrayOutputStream();
private Subscription subscription;

BaosSubscriber(CompletableFuture<byte[]> resultFuture) {
BaosSubscriber(CompletableFuture<ByteBuffer> resultFuture) {
this.resultFuture = resultFuture;
}

Expand All @@ -95,19 +93,38 @@ public void onSubscribe(Subscription s) {

@Override
public void onNext(ByteBuffer byteBuffer) {
invokeSafely(() -> baos.write(BinaryUtils.copyBytesFrom(byteBuffer)));
subscription.request(1);
Comment thread
davidh44 marked this conversation as resolved.
invokeSafely(() -> {
if (byteBuffer.hasArray()) {
directAccessOutputStream.write(byteBuffer.array(), byteBuffer.arrayOffset() + byteBuffer.position(),
byteBuffer.remaining());
} else {
directAccessOutputStream.write(BinaryUtils.copyBytesFrom(byteBuffer));
}
});
}

@Override
public void onError(Throwable throwable) {
baos = null;
directAccessOutputStream = null;
resultFuture.completeExceptionally(throwable);
}

@Override
public void onComplete() {
resultFuture.complete(baos.toByteArray());
resultFuture.complete(directAccessOutputStream.toByteBuffer());
}
}

/**
* Custom ByteArrayOutputStream that exposes internal buffer without copying
*/
static class DirectAccessByteArrayOutputStream extends ByteArrayOutputStream {

/**
* Returns the internal buffer wrapped as ByteBuffer with length set to count.
*/
ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(buf, 0, count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.nio.ByteBuffer;
import org.junit.jupiter.api.Test;

public class ResponseBytesTest {
private static final Object OBJECT = new Object();
@Test
public void fromByteArrayCreatesCopy() {
byte[] input = new byte[] { 'a' };
byte[] input = {'a'};
byte[] output = ResponseBytes.fromByteArray(OBJECT, input).asByteArrayUnsafe();

input[0] = 'b';
Expand All @@ -32,7 +33,7 @@ public void fromByteArrayCreatesCopy() {

@Test
public void asByteArrayCreatesCopy() {
byte[] input = new byte[] { 'a' };
byte[] input = {'a'};
byte[] output = ResponseBytes.fromByteArrayUnsafe(OBJECT, input).asByteArray();

input[0] = 'b';
Expand All @@ -41,9 +42,27 @@ public void asByteArrayCreatesCopy() {

@Test
public void fromByteArrayUnsafeAndAsByteArrayUnsafeDoNotCopy() {
byte[] input = new byte[] { 'a' };
byte[] input = {'a'};
byte[] output = ResponseBytes.fromByteArrayUnsafe(OBJECT, input).asByteArrayUnsafe();

assertThat(output).isSameAs(input);
}

@Test
public void fromByteBufferUnsafe_doNotCopy() {
byte[] inputBytes = {'a'};
ByteBuffer inputBuffer = ByteBuffer.wrap(inputBytes);

ResponseBytes<Object> responseBytes = ResponseBytes.fromByteBufferUnsafe(OBJECT, inputBuffer);

ByteBuffer outputBuffer = responseBytes.asByteBuffer();
byte[] outputBytes = responseBytes.asByteArrayUnsafe();

assertThat(outputBuffer).isEqualTo(inputBuffer);
assertThat(outputBytes).isSameAs(inputBytes);

inputBytes[0] = 'b';
assertThat(outputBuffer).isEqualTo(inputBuffer);
assertThat(outputBytes).isEqualTo(inputBytes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static software.amazon.awssdk.core.internal.async.SplittingPublisherTestUtils.verifyIndividualAsyncRequestBody;
import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;

import java.io.ByteArrayOutputStream;
Expand All @@ -38,12 +37,9 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.testutils.RandomTempFile;
import software.amazon.awssdk.utils.BinaryUtils;

Expand Down Expand Up @@ -236,7 +232,7 @@ public void changingFile_fileGetsDeleted_failsBecauseDeleted() throws Exception

@Test
public void positionNotZero_shouldReadFromPosition() throws Exception {
CompletableFuture<byte[]> future = new CompletableFuture<>();
CompletableFuture<ByteBuffer> future = new CompletableFuture<>();
long position = 20L;
AsyncRequestBody asyncRequestBody = FileAsyncRequestBody.builder()
.path(smallFile)
Expand All @@ -249,7 +245,9 @@ public void positionNotZero_shouldReadFromPosition() throws Exception {
asyncRequestBody.subscribe(baosSubscriber);
assertThat(asyncRequestBody.contentLength()).contains(80L);

byte[] bytes = future.get(1, TimeUnit.SECONDS);
ByteBuffer buffer = future.get(1, TimeUnit.SECONDS);
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);

byte[] expected = new byte[80];
try(FileInputStream inputStream = new FileInputStream(smallFile.toFile())) {
Expand All @@ -262,7 +260,7 @@ public void positionNotZero_shouldReadFromPosition() throws Exception {

@Test
public void bothPositionAndNumBytesToReadConfigured_shouldHonor() throws Exception {
CompletableFuture<byte[]> future = new CompletableFuture<>();
CompletableFuture<ByteBuffer> future = new CompletableFuture<>();
long position = 20L;
long numBytesToRead = 5L;
AsyncRequestBody asyncRequestBody = FileAsyncRequestBody.builder()
Expand All @@ -277,7 +275,9 @@ public void bothPositionAndNumBytesToReadConfigured_shouldHonor() throws Excepti
asyncRequestBody.subscribe(baosSubscriber);
assertThat(asyncRequestBody.contentLength()).contains(numBytesToRead);

byte[] bytes = future.get(1, TimeUnit.SECONDS);
ByteBuffer buffer = future.get(1, TimeUnit.SECONDS);
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);

byte[] expected = new byte[5];
try (FileInputStream inputStream = new FileInputStream(smallFile.toFile())) {
Expand All @@ -290,7 +290,7 @@ public void bothPositionAndNumBytesToReadConfigured_shouldHonor() throws Excepti

@Test
public void numBytesToReadConfigured_shouldHonor() throws Exception {
CompletableFuture<byte[]> future = new CompletableFuture<>();
CompletableFuture<ByteBuffer> future = new CompletableFuture<>();
AsyncRequestBody asyncRequestBody = FileAsyncRequestBody.builder()
.path(smallFile)
.numBytesToRead(5L)
Expand All @@ -302,7 +302,9 @@ public void numBytesToReadConfigured_shouldHonor() throws Exception {
asyncRequestBody.subscribe(baosSubscriber);
assertThat(asyncRequestBody.contentLength()).contains(5L);

byte[] bytes = future.get(1, TimeUnit.SECONDS);
ByteBuffer buffer = future.get(1, TimeUnit.SECONDS);
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);

byte[] expected = new byte[5];
try (FileInputStream inputStream = new FileInputStream(smallFile.toFile())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ void failedStream_completesExceptionally() {
}

private static String drainPublisherToStr(SdkPublisher<ByteBuffer> publisher) throws Exception {
CompletableFuture<byte[]> bodyFuture = new CompletableFuture<>();
CompletableFuture<ByteBuffer> bodyFuture = new CompletableFuture<>();
publisher.subscribe(new BaosSubscriber(bodyFuture));
byte[] body = bodyFuture.get();
return new String(body);

ByteBuffer buffer = bodyFuture.get();
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);

return new String(bytes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,26 @@

package software.amazon.awssdk.core.internal.async;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

import java.io.File;
import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.Assertions;
import org.reactivestreams.Publisher;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer;
import software.amazon.awssdk.core.internal.async.SplittingPublisherTest;

public final class SplittingPublisherTestUtils {

public static void verifyIndividualAsyncRequestBody(SdkPublisher<AsyncRequestBody> publisher,
Path file,
int chunkSize) throws Exception {

List<CompletableFuture<byte[]>> futures = new ArrayList<>();
List<CompletableFuture<ByteBuffer>> futures = new ArrayList<>();
publisher.subscribe(requestBody -> {
CompletableFuture<byte[]> baosFuture = new CompletableFuture<>();
CompletableFuture<ByteBuffer> baosFuture = new CompletableFuture<>();
ByteArrayAsyncResponseTransformer.BaosSubscriber subscriber =
new ByteArrayAsyncResponseTransformer.BaosSubscriber(baosFuture);
requestBody.subscribe(subscriber);
Expand All @@ -62,7 +55,10 @@ public static void verifyIndividualAsyncRequestBody(SdkPublisher<AsyncRequestBod
}
fileInputStream.skip(i * chunkSize);
fileInputStream.read(expected);
byte[] actualBytes = futures.get(i).join();
ByteBuffer actualByteBuffer = futures.get(i).join();
byte[] actualBytes = new byte[actualByteBuffer.remaining()];
actualByteBuffer.get(actualBytes);

Assertions.assertThat(actualBytes).isEqualTo(expected);
}
}
Expand Down
Loading