diff --git a/.changes/next-release/feature-AmazonS3-d8d7a87.json b/.changes/next-release/feature-AmazonS3-d8d7a87.json new file mode 100644 index 000000000000..f6a03df1bfe0 --- /dev/null +++ b/.changes/next-release/feature-AmazonS3-d8d7a87.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "Amazon S3", + "contributor": "", + "description": "Add support for maxInFlightParts to multipart upload (PutObject) in MultipartS3AsyncClient." +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java index d86005d85bc4..fddcf1cb843b 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java @@ -48,7 +48,7 @@ import software.amazon.awssdk.utils.Pair; @SdkInternalApi -public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { +public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { private static final Logger log = Logger.loggerFor(KnownContentLengthAsyncRequestBodySubscriber.class); @@ -70,6 +70,8 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber< private final AtomicReferenceArray completedParts; private final Map existingParts; private final PublisherListener progressListener; + private final int maxInFlightParts; + private final Object subscriptionLock = new Object(); private Subscription subscription; private volatile boolean isDone; private volatile boolean isPaused; @@ -80,8 +82,9 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber< private volatile CompletableFuture completeMpuFuture; KnownContentLengthAsyncRequestBodySubscriber(MpuRequestContext mpuRequestContext, - CompletableFuture returnFuture, - MultipartUploadHelper multipartUploadHelper) { + CompletableFuture returnFuture, + MultipartUploadHelper multipartUploadHelper, + int maxInFlightParts) { this.totalSize = mpuRequestContext.contentLength(); this.partSize = mpuRequestContext.partSize(); this.expectedNumParts = mpuRequestContext.expectedNumParts(); @@ -92,8 +95,10 @@ public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber< this.existingNumParts = NumericUtils.saturatedCast(mpuRequestContext.numPartsCompleted()); this.completedParts = new AtomicReferenceArray<>(expectedNumParts); this.multipartUploadHelper = multipartUploadHelper; - this.progressListener = putObjectRequest.overrideConfiguration().map(c -> c.executionAttributes() - .getAttribute(JAVA_PROGRESS_LISTENER)) + this.maxInFlightParts = maxInFlightParts; + this.progressListener = putObjectRequest.overrideConfiguration() + .map(c -> c.executionAttributes() + .getAttribute(JAVA_PROGRESS_LISTENER)) .orElseGet(PublisherListener::noOp); } @@ -133,7 +138,7 @@ public void onSubscribe(Subscription s) { return; } this.subscription = s; - s.request(1); + s.request(maxInFlightParts); returnFuture.whenComplete((r, t) -> { if (t != null) { s.cancel(); @@ -153,23 +158,26 @@ public void onNext(CloseableAsyncRequestBody asyncRequestBody) { int currentPartNum = partNumber.getAndIncrement(); log.debug(() -> String.format("Received asyncRequestBody for part number %d with length %s", currentPartNum, - asyncRequestBody.contentLength())); + asyncRequestBody.contentLength())); if (existingParts.containsKey(currentPartNum)) { asyncRequestBody.subscribe(new CancelledSubscriber<>()); asyncRequestBody.contentLength().ifPresent(progressListener::subscriberOnNext); asyncRequestBody.close(); - subscription.request(1); + + synchronized (subscriptionLock) { + subscription.request(1); + } return; } Optional sdkClientException = validatePart(asyncRequestBody, currentPartNum); if (sdkClientException.isPresent()) { multipartUploadHelper.failRequestsElegantly(futures, - sdkClientException.get(), - uploadId, - returnFuture, - putObjectRequest); + sdkClientException.get(), + uploadId, + returnFuture, + putObjectRequest); subscription.cancel(); return; } @@ -179,8 +187,9 @@ public void onNext(CloseableAsyncRequestBody asyncRequestBody) { currentPartNum, uploadId); - Consumer completedPartConsumer = completedPart -> completedParts.set(completedPart.partNumber() - 1, - completedPart); + Consumer completedPartConsumer = completedPart -> completedParts.set( + completedPart.partNumber() - 1, + completedPart); multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedPartConsumer, futures, Pair.of(uploadRequest, asyncRequestBody), progressListener) .whenComplete((r, t) -> { @@ -192,10 +201,15 @@ public void onNext(CloseableAsyncRequestBody asyncRequestBody) { subscription.cancel(); } } else { - completeMultipartUploadIfFinished(asyncRequestBodyInFlight.decrementAndGet()); + int inFlight = asyncRequestBodyInFlight.decrementAndGet(); + if (!isDone && inFlight < maxInFlightParts) { + synchronized (subscriptionLock) { + subscription.request(1); + } + } + completeMultipartUploadIfFinished(inFlight); } }); - subscription.request(1); } private Optional validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) { @@ -258,10 +272,9 @@ private void completeMultipartUploadIfFinished(int requestsInFlight) { CompletedPart[] parts; if (existingParts.isEmpty()) { - parts = - IntStream.range(0, completedParts.length()) - .mapToObj(completedParts::get) - .toArray(CompletedPart[]::new); + parts = IntStream.range(0, completedParts.length()) + .mapToObj(completedParts::get) + .toArray(CompletedPart[]::new); } else { // List of CompletedParts needs to be in ascending order parts = mergeCompletedParts(); @@ -274,7 +287,8 @@ private void completeMultipartUploadIfFinished(int requestsInFlight) { return; } - completeMpuFuture = multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, + completeMpuFuture = multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, + putObjectRequest, totalSize); } } @@ -283,8 +297,8 @@ private CompletedPart[] mergeCompletedParts() { CompletedPart[] merged = new CompletedPart[expectedNumParts]; int currPart = 1; while (currPart < expectedNumParts + 1) { - CompletedPart completedPart = existingParts.containsKey(currPart) ? existingParts.get(currPart) : - completedParts.get(currPart - 1); + CompletedPart completedPart = existingParts.containsKey(currPart) ? existingParts.get(currPart) + : completedParts.get(currPart - 1); merged[currPart - 1] = completedPart; currPart++; } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UnknownContentLengthAsyncRequestBodySubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UnknownContentLengthAsyncRequestBodySubscriber.java new file mode 100644 index 000000000000..f264307f1984 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UnknownContentLengthAsyncRequestBodySubscriber.java @@ -0,0 +1,294 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMismatchForPart; +import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMissingForPart; +import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.JAVA_PROGRESS_LISTENER; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; +import software.amazon.awssdk.core.async.listener.PublisherListener; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; + +@SdkInternalApi +public class UnknownContentLengthAsyncRequestBodySubscriber implements Subscriber { + private static final Logger log = Logger.loggerFor(UnknownContentLengthAsyncRequestBodySubscriber.class); + + /** + * Indicates whether this is the first async request body or not. + */ + private final AtomicBoolean firstAsyncRequestBodyReceived = new AtomicBoolean(false); + + /** + * Indicates whether CreateMultipartUpload has been initiated or not + */ + private final AtomicBoolean createMultipartUploadInitiated = new AtomicBoolean(false); + + /** + * Indicates whether CompleteMultipart has been initiated or not. + */ + private final AtomicBoolean completedMultipartInitiated = new AtomicBoolean(false); + + /** + * The number of AsyncRequestBody has been received but yet to be processed + */ + private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); + + private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); + + private final AtomicInteger partNumber = new AtomicInteger(0); + private final AtomicLong contentLength = new AtomicLong(0); + + private final Queue completedParts = new ConcurrentLinkedQueue<>(); + private final Collection> futures = new ConcurrentLinkedQueue<>(); + + private final CompletableFuture uploadIdFuture = new CompletableFuture<>(); + + private final long partSizeInBytes; + private final PutObjectRequest putObjectRequest; + private final CompletableFuture returnFuture; + private final PublisherListener progressListener; + private final MultipartUploadHelper multipartUploadHelper; + private final GenericMultipartHelper genericMultipartHelper; + private final int maxInFlightParts; + + private final Object subscriptionLock = new Object(); + private Subscription subscription; + private CloseableAsyncRequestBody firstRequestBody; + private String uploadId; + private volatile boolean isDone; + + UnknownContentLengthAsyncRequestBodySubscriber( + long partSizeInBytes, + PutObjectRequest putObjectRequest, + CompletableFuture returnFuture, + MultipartUploadHelper multipartUploadHelper, + GenericMultipartHelper genericMultipartHelper, + int maxInFlightParts) { + this.partSizeInBytes = partSizeInBytes; + this.putObjectRequest = putObjectRequest; + this.returnFuture = returnFuture; + this.multipartUploadHelper = multipartUploadHelper; + this.genericMultipartHelper = genericMultipartHelper; + this.maxInFlightParts = maxInFlightParts; + this.progressListener = putObjectRequest.overrideConfiguration() + .map(c -> c.executionAttributes().getAttribute(JAVA_PROGRESS_LISTENER)) + .orElseGet(PublisherListener::noOp); + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + log.warn(() -> "The subscriber has already been subscribed. Cancelling the incoming subscription"); + subscription.cancel(); + return; + } + this.subscription = s; + s.request(1); + returnFuture.whenComplete((r, t) -> { + if (t != null) { + s.cancel(); + MultipartUploadHelper.cancelingOtherOngoingRequests(futures, t); + } + }); + } + + @Override + public void onNext(CloseableAsyncRequestBody asyncRequestBody) { + if (asyncRequestBody == null) { + NullPointerException exception = new NullPointerException("asyncRequestBody passed to onNext MUST NOT be null."); + multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest); + throw exception; + } + + if (isDone) { + return; + } + + int currentPartNum = partNumber.incrementAndGet(); + log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); + asyncRequestBodyInFlight.incrementAndGet(); + + Optional sdkClientException = validatePart(asyncRequestBody, currentPartNum); + if (sdkClientException.isPresent()) { + multipartUploadHelper.failRequestsElegantly( + futures, sdkClientException.get(), uploadId, returnFuture, putObjectRequest); + subscription.cancel(); + return; + } + + if (firstAsyncRequestBodyReceived.compareAndSet(false, true)) { + log.trace(() -> "Received first async request body"); + firstRequestBody = asyncRequestBody; + // If this is the first AsyncRequestBody received, request another one because we don't know if there is more + synchronized (subscriptionLock) { + subscription.request(1); + } + return; + } + + // If there are more than 1 AsyncRequestBodies, then we know we need to upload using MPU + if (createMultipartUploadInitiated.compareAndSet(false, true)) { + log.debug(() -> "Starting the upload as multipart upload request"); + CompletableFuture createMultipartUploadFuture = multipartUploadHelper + .createMultipartUpload(putObjectRequest, returnFuture); + + createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { + if (throwable != null) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); + subscription.cancel(); + } else { + uploadId = createMultipartUploadResponse.uploadId(); + log.debug(() -> "Initiated a new multipart upload, uploadId: " + uploadId); + + sendUploadPartRequest(uploadId, firstRequestBody, 1); + sendUploadPartRequest(uploadId, asyncRequestBody, 2); + uploadIdFuture.complete(uploadId); + + // 2 parts already in flight, request the rest of our max in flight + int additionalDemand = maxInFlightParts - 2; + if (additionalDemand > 0) { + synchronized (subscriptionLock) { + subscription.request(additionalDemand); + } + } + } + }); + CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); + } else { + uploadIdFuture.whenComplete((r, t) -> { + sendUploadPartRequest(uploadId, asyncRequestBody, currentPartNum); + }); + } + } + + private Optional validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) { + Optional contentLength = asyncRequestBody.contentLength(); + if (!contentLength.isPresent()) { + return Optional.of(contentLengthMissingForPart(currentPartNum)); + } + + Long contentLengthCurrentPart = contentLength.get(); + if (contentLengthCurrentPart > partSizeInBytes) { + return Optional.of(contentLengthMismatchForPart(partSizeInBytes, contentLengthCurrentPart, currentPartNum)); + } + return Optional.empty(); + } + + private void sendUploadPartRequest(String uploadId, + CloseableAsyncRequestBody asyncRequestBody, + int currentPartNum) { + Long contentLengthCurrentPart = asyncRequestBody.contentLength().get(); + this.contentLength.getAndAdd(contentLengthCurrentPart); + + multipartUploadHelper + .sendIndividualUploadPartRequest(uploadId, completedParts::add, futures, + uploadPart(asyncRequestBody, currentPartNum), progressListener) + .whenComplete((r, t) -> { + asyncRequestBody.close(); + if (t != null) { + if (failureActionInitiated.compareAndSet(false, true)) { + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, + putObjectRequest); + } + } else { + int inFlight = asyncRequestBodyInFlight.decrementAndGet(); + if (!isDone && inFlight < maxInFlightParts) { + synchronized (subscriptionLock) { + subscription.request(1); + } + } + completeMultipartUploadIfFinish(inFlight); + } + }); + } + + private Pair uploadPart(AsyncRequestBody asyncRequestBody, int partNum) { + UploadPartRequest uploadRequest = SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, + partNum, + uploadId); + + return Pair.of(uploadRequest, asyncRequestBody); + } + + @Override + public void onError(Throwable t) { + log.debug(() -> "Received onError() ", t); + if (failureActionInitiated.compareAndSet(false, true)) { + isDone = true; + multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + } + } + + @Override + public void onComplete() { + log.debug(() -> "Received onComplete()"); + // If CreateMultipartUpload has not been initiated at this point, we know this + // is a single object upload, and if no async request body has been received, it's an empty stream + if (createMultipartUploadInitiated.get() == false) { + log.debug(() -> "Starting the upload as a single object upload request"); + AsyncRequestBody entireRequestBody = firstAsyncRequestBodyReceived.get() ? firstRequestBody + : AsyncRequestBody.empty(); + multipartUploadHelper.uploadInOneChunk(putObjectRequest, entireRequestBody, returnFuture); + } else { + isDone = true; + completeMultipartUploadIfFinish(asyncRequestBodyInFlight.get()); + } + } + + private void completeMultipartUploadIfFinish(int requestsInFlight) { + if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { + CompletedPart[] parts = completedParts.stream() + .sorted(Comparator.comparingInt(CompletedPart::partNumber)) + .toArray(CompletedPart[]::new); + + long totalLength = contentLength.get(); + int expectedNumParts = genericMultipartHelper.determinePartCount(totalLength, partSizeInBytes); + if (parts.length != expectedNumParts) { + SdkClientException exception = SdkClientException.create( + String.format( + "The number of UploadParts requests is not equal to the expected number of parts. " + + "Expected: %d, Actual: %d", + expectedNumParts, parts.length)); + multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest); + return; + } + + multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, totalLength); + } + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java index 5f3162ffe8a3..f82905381857 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadObjectHelper.java @@ -48,14 +48,17 @@ public UploadObjectHelper(S3AsyncClient s3AsyncClient, SdkPojoConversionUtils::toPutObjectResponse); this.apiCallBufferSize = resolver.apiCallBufferSize(); this.multipartUploadThresholdInBytes = resolver.thresholdInBytes(); + int maxInFlightParts = resolver.maxInFlightParts(); this.uploadWithKnownContentLength = new UploadWithKnownContentLengthHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, - apiCallBufferSize); + apiCallBufferSize, + maxInFlightParts); this.uploadWithUnknownContentLength = new UploadWithUnknownContentLengthHelper(s3AsyncClient, partSizeInBytes, multipartUploadThresholdInBytes, - apiCallBufferSize); + apiCallBufferSize, + maxInFlightParts); } public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index 0fdeb1674798..9ce0db889a2d 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -50,11 +50,13 @@ public final class UploadWithKnownContentLengthHelper { private final long maxMemoryUsageInBytes; private final long multipartUploadThresholdInBytes; private final MultipartUploadHelper multipartUploadHelper; + private final int maxInFlightParts; public UploadWithKnownContentLengthHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes, long multipartUploadThresholdInBytes, - long maxMemoryUsageInBytes) { + long maxMemoryUsageInBytes, + int maxInFlightParts) { this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, @@ -64,6 +66,7 @@ public UploadWithKnownContentLengthHelper(S3AsyncClient s3AsyncClient, this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); + this.maxInFlightParts = maxInFlightParts; } public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, @@ -181,7 +184,8 @@ private void resumePausedUpload(ResumeRequestContext resumeContext) { private void splitAndSubscribe(MpuRequestContext mpuRequestContext, CompletableFuture returnFuture) { KnownContentLengthAsyncRequestBodySubscriber subscriber = - new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper); + new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper, + maxInFlightParts); attachSubscriberToObservable(subscriber, mpuRequestContext.request().left()); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java index 3239d3ec95ab..aba0c8e63221 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -15,37 +15,15 @@ package software.amazon.awssdk.services.s3.internal.multipart; - -import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMismatchForPart; -import static software.amazon.awssdk.services.s3.internal.multipart.MultipartUploadHelper.contentLengthMissingForPart; -import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.JAVA_PROGRESS_LISTENER; - -import java.util.Collection; -import java.util.Comparator; -import java.util.Optional; -import java.util.Queue; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.core.async.listener.PublisherListener; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; -import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.utils.CompletableFutureUtils; import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.Pair; /** * An internal helper class that uploads streams with unknown content length. @@ -57,16 +35,16 @@ public final class UploadWithUnknownContentLengthHelper { private final S3AsyncClient s3AsyncClient; private final long partSizeInBytes; private final GenericMultipartHelper genericMultipartHelper; - private final long maxMemoryUsageInBytes; private final long multipartUploadThresholdInBytes; - private final MultipartUploadHelper multipartUploadHelper; + private final int maxInFlightParts; public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes, long multipartUploadThresholdInBytes, - long maxMemoryUsageInBytes) { + long maxMemoryUsageInBytes, + int maxInFlightParts) { this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, @@ -76,6 +54,7 @@ public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient, this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); + this.maxInFlightParts = maxInFlightParts; } public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, @@ -86,236 +65,13 @@ public CompletableFuture uploadObject(PutObjectRequest putObj asyncRequestBody.splitCloseable(b -> b.chunkSizeInBytes(partSizeInBytes) .bufferSizeInBytes(maxMemoryUsageInBytes)); - splitAsyncRequestBodyResponse.subscribe(new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, - putObjectRequest, - returnFuture)); + splitAsyncRequestBodyResponse.subscribe( + new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, + putObjectRequest, + returnFuture, + multipartUploadHelper, + genericMultipartHelper, + maxInFlightParts)); return returnFuture; } - - final class UnknownContentLengthAsyncRequestBodySubscriber implements Subscriber { - /** - * Indicates whether this is the first async request body or not. - */ - private final AtomicBoolean firstAsyncRequestBodyReceived = new AtomicBoolean(false); - - /** - * Indicates whether CreateMultipartUpload has been initiated or not - */ - private final AtomicBoolean createMultipartUploadInitiated = new AtomicBoolean(false); - - /** - * Indicates whether CompleteMultipart has been initiated or not. - */ - private final AtomicBoolean completedMultipartInitiated = new AtomicBoolean(false); - - /** - * The number of AsyncRequestBody has been received but yet to be processed - */ - private final AtomicInteger asyncRequestBodyInFlight = new AtomicInteger(0); - - private final AtomicBoolean failureActionInitiated = new AtomicBoolean(false); - - private AtomicInteger partNumber = new AtomicInteger(0); - private AtomicLong contentLength = new AtomicLong(0); - - private final Queue completedParts = new ConcurrentLinkedQueue<>(); - private final Collection> futures = new ConcurrentLinkedQueue<>(); - - private final CompletableFuture uploadIdFuture = new CompletableFuture<>(); - - private final long maximumChunkSizeInByte; - private final PutObjectRequest putObjectRequest; - private final CompletableFuture returnFuture; - private final PublisherListener progressListener; - private Subscription subscription; - private CloseableAsyncRequestBody firstRequestBody; - - private String uploadId; - private volatile boolean isDone; - - UnknownContentLengthAsyncRequestBodySubscriber(long maximumChunkSizeInByte, - PutObjectRequest putObjectRequest, - CompletableFuture returnFuture) { - this.maximumChunkSizeInByte = maximumChunkSizeInByte; - this.putObjectRequest = putObjectRequest; - this.returnFuture = returnFuture; - this.progressListener = putObjectRequest.overrideConfiguration() - .map(c -> c.executionAttributes().getAttribute(JAVA_PROGRESS_LISTENER)) - .orElseGet(PublisherListener::noOp); - } - - @Override - public void onSubscribe(Subscription s) { - if (this.subscription != null) { - log.warn(() -> "The subscriber has already been subscribed. Cancelling the incoming subscription"); - subscription.cancel(); - return; - } - this.subscription = s; - s.request(1); - returnFuture.whenComplete((r, t) -> { - if (t != null) { - s.cancel(); - multipartUploadHelper.cancelingOtherOngoingRequests(futures, t); - } - }); - } - - @Override - public void onNext(CloseableAsyncRequestBody asyncRequestBody) { - if (asyncRequestBody == null) { - NullPointerException exception = new NullPointerException("asyncRequestBody passed to onNext MUST NOT be null."); - multipartUploadHelper.failRequestsElegantly(futures, - exception, uploadId, returnFuture, putObjectRequest); - throw exception; - } - - if (isDone) { - return; - } - - int currentPartNum = partNumber.incrementAndGet(); - log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); - asyncRequestBodyInFlight.incrementAndGet(); - - Optional sdkClientException = validatePart(asyncRequestBody, currentPartNum); - if (sdkClientException.isPresent()) { - multipartUploadHelper.failRequestsElegantly(futures, sdkClientException.get(), uploadId, returnFuture, - putObjectRequest); - subscription.cancel(); - return; - } - - if (firstAsyncRequestBodyReceived.compareAndSet(false, true)) { - log.trace(() -> "Received first async request body"); - // If this is the first AsyncRequestBody received, request another one because we don't know if there is more - firstRequestBody = asyncRequestBody; - subscription.request(1); - return; - } - - // If there are more than 1 AsyncRequestBodies, then we know we need to upload this - // object using MPU - if (createMultipartUploadInitiated.compareAndSet(false, true)) { - log.debug(() -> "Starting the upload as multipart upload request"); - CompletableFuture createMultipartUploadFuture = - multipartUploadHelper.createMultipartUpload(putObjectRequest, returnFuture); - - createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { - if (throwable != null) { - genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", - throwable); - subscription.cancel(); - } else { - uploadId = createMultipartUploadResponse.uploadId(); - log.debug(() -> "Initiated a new multipart upload, uploadId: " + uploadId); - - sendUploadPartRequest(uploadId, firstRequestBody, 1); - sendUploadPartRequest(uploadId, asyncRequestBody, 2); - - // We need to complete the uploadIdFuture *after* the first two requests have been sent - uploadIdFuture.complete(uploadId); - } - }); - CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); - } else { - uploadIdFuture.whenComplete((r, t) -> { - sendUploadPartRequest(uploadId, asyncRequestBody, currentPartNum); - }); - } - } - - private Optional validatePart(AsyncRequestBody asyncRequestBody, int currentPartNum) { - Optional contentLength = asyncRequestBody.contentLength(); - if (!contentLength.isPresent()) { - return Optional.of(contentLengthMissingForPart(currentPartNum)); - } - - Long contentLengthCurrentPart = contentLength.get(); - if (contentLengthCurrentPart > partSizeInBytes) { - return Optional.of(contentLengthMismatchForPart(partSizeInBytes, contentLengthCurrentPart, currentPartNum)); - - } - return Optional.empty(); - } - - private void sendUploadPartRequest(String uploadId, - CloseableAsyncRequestBody asyncRequestBody, - int currentPartNum) { - Long contentLengthCurrentPart = asyncRequestBody.contentLength().get(); - this.contentLength.getAndAdd(contentLengthCurrentPart); - - multipartUploadHelper - .sendIndividualUploadPartRequest(uploadId, completedParts::add, futures, - uploadPart(asyncRequestBody, currentPartNum), progressListener) - .whenComplete((r, t) -> { - asyncRequestBody.close(); - if (t != null) { - if (failureActionInitiated.compareAndSet(false, true)) { - multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); - } - } else { - completeMultipartUploadIfFinish(asyncRequestBodyInFlight.decrementAndGet()); - } - }); - synchronized (this) { - subscription.request(1); - }; - } - - private Pair uploadPart(AsyncRequestBody asyncRequestBody, int partNum) { - UploadPartRequest uploadRequest = - SdkPojoConversionUtils.toUploadPartRequest(putObjectRequest, - partNum, - uploadId); - - return Pair.of(uploadRequest, asyncRequestBody); - } - - @Override - public void onError(Throwable t) { - log.debug(() -> "Received onError() ", t); - if (failureActionInitiated.compareAndSet(false, true)) { - isDone = true; - multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); - } - } - - @Override - public void onComplete() { - log.debug(() -> "Received onComplete()"); - // If CreateMultipartUpload has not been initiated at this point, we know this is a single object upload, and if no - // async request body has been received, it's an empty stream - if (createMultipartUploadInitiated.get() == false) { - log.debug(() -> "Starting the upload as a single object upload request"); - AsyncRequestBody entireRequestBody = firstAsyncRequestBodyReceived.get() ? firstRequestBody : - AsyncRequestBody.empty(); - multipartUploadHelper.uploadInOneChunk(putObjectRequest, entireRequestBody, returnFuture); - } else { - isDone = true; - completeMultipartUploadIfFinish(asyncRequestBodyInFlight.get()); - } - } - - private void completeMultipartUploadIfFinish(int requestsInFlight) { - if (isDone && requestsInFlight == 0 && completedMultipartInitiated.compareAndSet(false, true)) { - CompletedPart[] parts = completedParts.stream() - .sorted(Comparator.comparingInt(CompletedPart::partNumber)) - .toArray(CompletedPart[]::new); - - long totalLength = contentLength.get(); - int expectedNumParts = genericMultipartHelper.determinePartCount(totalLength, partSizeInBytes); - if (parts.length != expectedNumParts) { - SdkClientException exception = SdkClientException.create( - String.format("The number of UploadParts requests is not equal to the expected number of parts. " - + "Expected: %d, Actual: %d", expectedNumParts, parts.length)); - multipartUploadHelper.failRequestsElegantly(futures, exception, uploadId, returnFuture, putObjectRequest); - return; - } - - multipartUploadHelper.completeMultipartUpload(returnFuture, uploadId, parts, putObjectRequest, - totalLength); - } - } - } -} \ No newline at end of file +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/ParallelConfiguration.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/ParallelConfiguration.java index a3816cd97fa2..dea1592cbe6f 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/ParallelConfiguration.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/ParallelConfiguration.java @@ -39,8 +39,13 @@ public static Builder builder() { } /** - * The maximum number of concurrent GetObject the that are allowed for multipart download. - * @return The value for the maximum number of concurrent GetObject the that are allowed for multipart download. + * The maximum number of concurrent part requests that are allowed for multipart operations, including both multipart + * download (GetObject) and multipart upload (PutObject). This limits the number of parts that can be in flight at any + * given time, preventing the client from overwhelming the HTTP connection pool when transferring large objects. For + * getObject it applies only when the {@link AsyncResponseTransformer} supports parallel split. + * Defaults to 50. + * + * @return The value for the maximum number of concurrent part requests. */ public Integer maxInFlightParts() { return maxInFlightParts; diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java index c18f088f1cd9..3c93a38b635b 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java @@ -19,6 +19,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -89,7 +90,7 @@ public void beforeEach() { subscription = mock(Subscription.class); when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(CompletedPart.builder().build())); + .thenReturn(CompletableFuture.completedFuture(CompletedPart.builder().build())); subscriber = createSubscriber(createDefaultMpuRequestContext()); subscriber.onSubscribe(subscription); @@ -112,7 +113,8 @@ void validateLastPartSize_withIncorrectSize_shouldFailRequest() { long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; long incorrectLastPartSize = expectedLastPartSize + 1; - KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext()); + KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber( + createDefaultMpuRequestContext()); lastPartSubscriber.onSubscribe(subscription); for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { @@ -128,18 +130,19 @@ void validateLastPartSize_withIncorrectSize_shouldFailRequest() { void validateTotalPartNum_receivedMoreParts_shouldFail() { long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; - KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext()); + KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber( + createDefaultMpuRequestContext()); lastPartSubscriber.onSubscribe(subscription); for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { CloseableAsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(null)); + .thenReturn(CompletableFuture.completedFuture(null)); lastPartSubscriber.onNext(regularPart); } when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(null)); + .thenReturn(CompletableFuture.completedFuture(null)); lastPartSubscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize)); lastPartSubscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize)); @@ -156,12 +159,12 @@ void validateLastPartSize_withCorrectSize_shouldNotFail() { for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { CloseableAsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(null)); + .thenReturn(CompletableFuture.completedFuture(null)); subscriber.onNext(regularPart); } when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(null)); + .thenReturn(CompletableFuture.completedFuture(null)); subscriber.onNext(createMockAsyncRequestBody(expectedLastPartSize)); subscriber.onComplete(); @@ -181,8 +184,8 @@ void pause_withOngoingCompleteMpuFuture_shouldReturnTokenAndCancelFuture() { @Test void pause_withCompletedCompleteMpuFuture_shouldReturnNullToken() { - CompletableFuture completeMpuFuture = - CompletableFuture.completedFuture(CompleteMultipartUploadResponse.builder().build()); + CompletableFuture completeMpuFuture = CompletableFuture + .completedFuture(CompleteMultipartUploadResponse.builder().build()); int numExistingParts = 2; S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture); @@ -200,14 +203,14 @@ void pause_withUninitiatedCompleteMpuFuture_shouldReturnToken() { } private S3ResumeToken testPauseScenario(int numExistingParts, - CompletableFuture completeMpuFuture) { - KnownContentLengthAsyncRequestBodySubscriber subscriber = - createSubscriber(createMpuRequestContextWithExistingParts(numExistingParts)); + CompletableFuture completeMpuFuture) { + KnownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber( + createMpuRequestContextWithExistingParts(numExistingParts)); when(multipartUploadHelper.completeMultipartUpload(any(CompletableFuture.class), any(String.class), - any(CompletedPart[].class), any(PutObjectRequest.class), - any(Long.class))) - .thenReturn(completeMpuFuture); + any(CompletedPart[].class), any(PutObjectRequest.class), + any(Long.class))) + .thenReturn(completeMpuFuture); simulateOnNextForAllParts(subscriber); subscriber.onComplete(); @@ -215,32 +218,82 @@ private S3ResumeToken testPauseScenario(int numExistingParts, return subscriber.pause(); } + @Test + void maxInFlightPutObjectParts_shouldLimitConcurrentUploads() { + int maxInFlight = 2; + long contentSize = 5 * PART_SIZE; + int totalParts = 5; + + MpuRequestContext context = MpuRequestContext.builder() + .request(Pair.of(putObjectRequest, asyncRequestBody)) + .contentLength(contentSize) + .partSize(PART_SIZE) + .uploadId(UPLOAD_ID) + .numPartsCompleted(0L) + .expectedNumParts(totalParts) + .build(); + + // Use non-completing futures to simulate slow uploads so parts stay in-flight + CompletableFuture pendingFuture1 = new CompletableFuture<>(); + CompletableFuture pendingFuture2 = new CompletableFuture<>(); + CompletableFuture pendingFuture3 = new CompletableFuture<>(); + + when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) + .thenReturn(pendingFuture1) + .thenReturn(pendingFuture2) + .thenReturn(pendingFuture3); + + KnownContentLengthAsyncRequestBodySubscriber sub = createSubscriber(context, maxInFlight); + Subscription mockSubscription = mock(Subscription.class); + sub.onSubscribe(mockSubscription); + + // onSubscribe requests maxInFlightParts(2) upfront + verify(mockSubscription, times(1)).request(maxInFlight); + + // First onNext: in-flight goes to 1. Demand driven by completion callbacks. + sub.onNext(createMockAsyncRequestBody(PART_SIZE)); + + // Second onNext: in-flight goes to 2 + sub.onNext(createMockAsyncRequestBody(PART_SIZE)); + + // Complete the first part — callback decrements to 1, sees 1 < 2, calls request(1) + pendingFuture1.complete(CompletedPart.builder().partNumber(1).build()); + verify(mockSubscription, times(1)).request(1); + } + private MpuRequestContext createDefaultMpuRequestContext() { return MpuRequestContext.builder() - .request(Pair.of(putObjectRequest, AsyncRequestBody.fromFile(testFile))) - .contentLength(MPU_CONTENT_SIZE) - .partSize(PART_SIZE) - .uploadId(UPLOAD_ID) - .numPartsCompleted(0L) - .expectedNumParts(TOTAL_NUM_PARTS) - .build(); + .request(Pair.of(putObjectRequest, AsyncRequestBody.fromFile(testFile))) + .contentLength(MPU_CONTENT_SIZE) + .partSize(PART_SIZE) + .uploadId(UPLOAD_ID) + .numPartsCompleted(0L) + .expectedNumParts(TOTAL_NUM_PARTS) + .build(); } private MpuRequestContext createMpuRequestContextWithExistingParts(int numExistingParts) { Map existingParts = createExistingParts(numExistingParts); return MpuRequestContext.builder() - .request(Pair.of(putObjectRequest, asyncRequestBody)) - .contentLength(MPU_CONTENT_SIZE) - .partSize(PART_SIZE) - .uploadId(UPLOAD_ID) - .existingParts(existingParts) - .expectedNumParts(TOTAL_NUM_PARTS) - .numPartsCompleted((long) existingParts.size()) - .build(); + .request(Pair.of(putObjectRequest, asyncRequestBody)) + .contentLength(MPU_CONTENT_SIZE) + .partSize(PART_SIZE) + .uploadId(UPLOAD_ID) + .existingParts(existingParts) + .expectedNumParts(TOTAL_NUM_PARTS) + .numPartsCompleted((long) existingParts.size()) + .build(); } private KnownContentLengthAsyncRequestBodySubscriber createSubscriber(MpuRequestContext mpuRequestContext) { - return new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper); + return new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper, + 50); + } + + private KnownContentLengthAsyncRequestBodySubscriber createSubscriber(MpuRequestContext mpuRequestContext, + int maxInFlightPutObjectParts) { + return new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper, + maxInFlightPutObjectParts); } private CloseableAsyncRequestBody createMockAsyncRequestBody(long contentLength) { @@ -257,7 +310,8 @@ private CloseableAsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLeng private void verifyFailRequestsElegantly(String expectedErrorMessage) { ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); - verify(multipartUploadHelper).failRequestsElegantly(any(), exceptionCaptor.capture(), eq(UPLOAD_ID), eq(returnFuture), eq(putObjectRequest)); + verify(multipartUploadHelper).failRequestsElegantly(any(), exceptionCaptor.capture(), eq(UPLOAD_ID), + eq(returnFuture), eq(putObjectRequest)); Throwable exception = exceptionCaptor.getValue(); assertThat(exception).isInstanceOf(SdkClientException.class); @@ -266,10 +320,9 @@ private void verifyFailRequestsElegantly(String expectedErrorMessage) { } private Map createExistingParts(int numExistingParts) { - Map existingParts = - IntStream.range(0, numExistingParts) - .boxed().collect(Collectors.toMap(Function.identity(), - i -> CompletedPart.builder().partNumber(i).build(), (a, b) -> b)); + Map existingParts = IntStream.range(0, numExistingParts) + .boxed().collect(Collectors.toMap(Function.identity(), + i -> CompletedPart.builder().partNumber(i).build(), (a, b) -> b)); return existingParts; } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UnknownContentLengthAsyncRequestBodySubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UnknownContentLengthAsyncRequestBodySubscriberTest.java new file mode 100644 index 000000000000..d6cae74afa61 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UnknownContentLengthAsyncRequestBodySubscriberTest.java @@ -0,0 +1,203 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +public class UnknownContentLengthAsyncRequestBodySubscriberTest { + + private static final long PART_SIZE = 8 * 1024; + private static final String UPLOAD_ID = "1234"; + + private MultipartUploadHelper multipartUploadHelper; + private GenericMultipartHelper genericMultipartHelper; + private PutObjectRequest putObjectRequest; + private CompletableFuture returnFuture; + private Subscription subscription; + + @BeforeEach + public void beforeEach() { + multipartUploadHelper = mock(MultipartUploadHelper.class); + genericMultipartHelper = mock(GenericMultipartHelper.class); + putObjectRequest = PutObjectRequest.builder() + .bucket("bucket") + .key("key") + .build(); + returnFuture = new CompletableFuture<>(); + subscription = mock(Subscription.class); + } + + @Test + void validatePart_withMissingContentLength_shouldFailRequest() { + UnknownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(50); + subscriber.onSubscribe(subscription); + + // First onNext with valid body (held as firstRequestBody) + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE)); + + // Second onNext triggers CreateMultipartUpload path + stubSuccessfulCreateMultipartCall(); + when(multipartUploadHelper.sendIndividualUploadPartRequest(any(), any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(CompletedPart.builder().build())); + + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE)); + + // Third onNext with missing content length + subscriber.onNext(createMockAsyncRequestBodyWithEmptyContentLength()); + + verifyFailRequestsElegantly("Content length is missing on the AsyncRequestBody"); + } + + @Test + void validatePart_withPartSizeExceedingLimit_shouldFailRequest() { + UnknownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(50); + subscriber.onSubscribe(subscription); + + // First onNext with valid body + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE)); + + // Second onNext with oversized body triggers failure + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE + 1)); + + verifyFailRequestsElegantly("Content length must not be greater than part size"); + } + + @Test + void onNext_withNullBody_shouldThrowNullPointerException() { + UnknownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(50); + subscriber.onSubscribe(subscription); + + assertThatThrownBy(() -> subscriber.onNext(null)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("MUST NOT be null"); + + verify(multipartUploadHelper).failRequestsElegantly( + any(), any(NullPointerException.class), any(), eq(returnFuture), eq(putObjectRequest)); + } + + @Test + void maxInFlightParts_shouldLimitConcurrentUploads() { + int maxInFlight = 4; + UnknownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(maxInFlight); + Subscription mockSubscription = mock(Subscription.class); + subscriber.onSubscribe(mockSubscription); + + // onSubscribe requests 1 + verify(mockSubscription, times(1)).request(1); + + // First onNext: holds the first body, requests 1 more to decide single vs multipart + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE)); + verify(mockSubscription, times(2)).request(1); + + // Second onNext: triggers CreateMultipartUpload, sends parts 1 and 2, + // then bootstraps pipeline with request(maxInFlight - 2) = request(2) + stubSuccessfulCreateMultipartCall(); + + CompletableFuture pendingFuture1 = new CompletableFuture<>(); + CompletableFuture pendingFuture2 = new CompletableFuture<>(); + when(multipartUploadHelper.sendIndividualUploadPartRequest(any(), any(), any(), any(), any())) + .thenReturn(pendingFuture1) + .thenReturn(pendingFuture2); + + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE)); + + // After sending 2 parts, bootstraps with request(maxInFlight - 2) = request(2) + verify(mockSubscription, times(1)).request(2); + + // Complete part 1 — inFlight drops to 1, which is < 4, so request(1) is called + pendingFuture1.complete(CompletedPart.builder().partNumber(1).build()); + verify(mockSubscription, times(3)).request(1); + } + + @Test + void onComplete_withSinglePart_shouldUploadInOneChunk() { + UnknownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(50); + subscriber.onSubscribe(subscription); + + // Only one onNext — single part, no multipart needed + subscriber.onNext(createMockAsyncRequestBody(PART_SIZE)); + subscriber.onComplete(); + + verify(multipartUploadHelper).uploadInOneChunk(eq(putObjectRequest), any(), eq(returnFuture)); + } + + @Test + void onComplete_withNoParts_shouldUploadEmptyBody() { + UnknownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(50); + subscriber.onSubscribe(subscription); + + // No onNext at all — empty stream + subscriber.onComplete(); + + verify(multipartUploadHelper).uploadInOneChunk(eq(putObjectRequest), any(), eq(returnFuture)); + } + + private UnknownContentLengthAsyncRequestBodySubscriber createSubscriber(int maxInFlightParts) { + return new UnknownContentLengthAsyncRequestBodySubscriber( + PART_SIZE, putObjectRequest, returnFuture, + multipartUploadHelper, genericMultipartHelper, maxInFlightParts); + } + + private void stubSuccessfulCreateMultipartCall() { + when(multipartUploadHelper.createMultipartUpload(any(), any())) + .thenReturn(CompletableFuture.completedFuture( + software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse.builder() + .uploadId(UPLOAD_ID) + .build())); + } + + private CloseableAsyncRequestBody createMockAsyncRequestBody(long contentLength) { + CloseableAsyncRequestBody mockBody = mock(CloseableAsyncRequestBody.class); + when(mockBody.contentLength()).thenReturn(Optional.of(contentLength)); + return mockBody; + } + + private CloseableAsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { + CloseableAsyncRequestBody mockBody = mock(CloseableAsyncRequestBody.class); + when(mockBody.contentLength()).thenReturn(Optional.empty()); + return mockBody; + } + + private void verifyFailRequestsElegantly(String expectedErrorMessage) { + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); + verify(multipartUploadHelper).failRequestsElegantly( + any(), exceptionCaptor.capture(), any(), eq(returnFuture), eq(putObjectRequest)); + + Throwable exception = exceptionCaptor.getValue(); + assertThat(exception).isInstanceOf(SdkClientException.class); + assertThat(exception.getMessage()).contains(expectedErrorMessage); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java index 83eb8f284a72..c32b791d52f0 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java @@ -26,7 +26,6 @@ import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulCreateMultipartCall; import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulPutObjectCall; import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartUploadTestUtils.stubSuccessfulUploadPartCalls; - import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; @@ -57,6 +56,7 @@ import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.testutils.RandomTempFile; import software.amazon.awssdk.utils.StringInputStream; @@ -85,7 +85,7 @@ public static void afterAll() throws Exception { @BeforeEach public void beforeEach() { s3AsyncClient = Mockito.mock(S3AsyncClient.class); - helper = new UploadWithUnknownContentLengthHelper(s3AsyncClient, PART_SIZE, PART_SIZE, PART_SIZE * 4); + helper = new UploadWithUnknownContentLengthHelper(s3AsyncClient, PART_SIZE, PART_SIZE, PART_SIZE * 4, 50); } @Test diff --git a/test/architecture-tests/src/test/java/software/amazon/awssdk/archtests/CodingConventionWithSuppressionTest.java b/test/architecture-tests/src/test/java/software/amazon/awssdk/archtests/CodingConventionWithSuppressionTest.java index 635cfdc834ce..1ceb5e3721a4 100644 --- a/test/architecture-tests/src/test/java/software/amazon/awssdk/archtests/CodingConventionWithSuppressionTest.java +++ b/test/architecture-tests/src/test/java/software/amazon/awssdk/archtests/CodingConventionWithSuppressionTest.java @@ -37,6 +37,8 @@ import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.metrics.publishers.emf.EmfMetricLoggingPublisher; import software.amazon.awssdk.metrics.publishers.emf.internal.MetricEmfConverter; +import software.amazon.awssdk.services.s3.internal.multipart.KnownContentLengthAsyncRequestBodySubscriber; +import software.amazon.awssdk.services.s3.internal.multipart.UnknownContentLengthAsyncRequestBodySubscriber; import software.amazon.awssdk.utils.Logger; /** @@ -54,7 +56,9 @@ public class CodingConventionWithSuppressionTest { ArchUtils.classNameToPattern("software.amazon.awssdk.services.s3.internal.crt.S3CrtResponseHandlerAdapter"), ArchUtils.classNameToPattern( "software.amazon.awssdk.services.s3.internal.crt.CrtResponseFileResponseTransformer"), - ArchUtils.classNameToPattern(RetryableSubAsyncRequestBody.class))); + ArchUtils.classNameToPattern(RetryableSubAsyncRequestBody.class), + ArchUtils.classNameToPattern(KnownContentLengthAsyncRequestBodySubscriber.class), + ArchUtils.classNameToPattern(UnknownContentLengthAsyncRequestBodySubscriber.class))); private static final Set ALLOWED_ERROR_LOG_SUPPRESSION = new HashSet<>( Arrays.asList(