1515
1616package software .amazon .awssdk .services .s3 .internal .multipart ;
1717
18+ import java .util .Optional ;
1819import java .util .concurrent .CompletableFuture ;
1920import java .util .concurrent .atomic .AtomicInteger ;
20- import java .util .regex .Matcher ;
21- import java .util .regex .Pattern ;
2221import org .reactivestreams .Subscriber ;
2322import org .reactivestreams .Subscription ;
2423import software .amazon .awssdk .annotations .Immutable ;
2524import software .amazon .awssdk .annotations .SdkInternalApi ;
2625import software .amazon .awssdk .annotations .ThreadSafe ;
2726import software .amazon .awssdk .core .async .AsyncResponseTransformer ;
27+ import software .amazon .awssdk .core .exception .SdkClientException ;
2828import software .amazon .awssdk .services .s3 .S3AsyncClient ;
2929import software .amazon .awssdk .services .s3 .model .GetObjectResponse ;
3030import software .amazon .awssdk .services .s3 .presignedurl .model .PresignedUrlDownloadRequest ;
@@ -49,18 +49,18 @@ public class PresignedUrlMultipartDownloaderSubscriber
4949
5050 private static final Logger log = Logger .loggerFor (PresignedUrlMultipartDownloaderSubscriber .class );
5151 private static final String BYTES_RANGE_PREFIX = "bytes=" ;
52- private static final Pattern CONTENT_RANGE_PATTERN = Pattern .compile ("bytes\\ s+(\\ d+)-(\\ d+)/(\\ d+)" );
5352
5453 private final S3AsyncClient s3AsyncClient ;
5554 private final PresignedUrlDownloadRequest presignedUrlDownloadRequest ;
5655 private final Long configuredPartSizeInBytes ;
5756 private final CompletableFuture <Void > future ;
5857 private final Object lock = new Object ();
5958 private final AtomicInteger completedParts ;
59+ private final AtomicInteger requestsSent ;
6060
61- private Long totalContentLength ;
62- private Integer totalParts ;
63- private String eTag ;
61+ private volatile Long totalContentLength ;
62+ private volatile Integer totalParts ;
63+ private volatile String eTag ;
6464 private Subscription subscription ;
6565
6666 public PresignedUrlMultipartDownloaderSubscriber (
@@ -71,27 +71,26 @@ public PresignedUrlMultipartDownloaderSubscriber(
7171 this .presignedUrlDownloadRequest = presignedUrlDownloadRequest ;
7272 this .configuredPartSizeInBytes = configuredPartSizeInBytes ;
7373 this .completedParts = new AtomicInteger (0 );
74+ this .requestsSent = new AtomicInteger (0 );
7475 this .future = new CompletableFuture <>();
7576 }
7677
7778 @ Override
7879 public void onSubscribe (Subscription s ) {
79- synchronized (lock ) {
80- if (subscription != null ) {
81- s .cancel ();
82- return ;
83- }
84- this .subscription = s ;
85- s .request (1 );
80+ if (subscription != null ) {
81+ s .cancel ();
82+ return ;
8683 }
84+ this .subscription = s ;
85+ s .request (1 );
8786 }
8887
8988 @ Override
9089 public void onNext (AsyncResponseTransformer <GetObjectResponse , GetObjectResponse > asyncResponseTransformer ) {
9190 if (asyncResponseTransformer == null ) {
9291 throw new NullPointerException ("onNext must not be called with null asyncResponseTransformer" );
9392 }
94-
93+
9594 int nextPartIndex ;
9695 synchronized (lock ) {
9796 nextPartIndex = completedParts .get ();
@@ -102,16 +101,16 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
102101 }
103102 completedParts .incrementAndGet ();
104103 }
105-
106104 makeRangeRequest (nextPartIndex , asyncResponseTransformer );
107105 }
108106
109107 private void makeRangeRequest (int partIndex ,
110108 AsyncResponseTransformer <GetObjectResponse ,
111109 GetObjectResponse > asyncResponseTransformer ) {
112- PresignedUrlDownloadRequest partRequest = createPartRequest (partIndex );
110+ PresignedUrlDownloadRequest partRequest = createRangedGetRequest (partIndex );
113111 log .debug (() -> "Sending range request for part " + partIndex + " with range=" + partRequest .range ());
114112
113+ requestsSent .incrementAndGet ();
115114 s3AsyncClient .presignedUrlExtension ()
116115 .getObject (partRequest , asyncResponseTransformer )
117116 .whenComplete ((response , error ) -> {
@@ -120,81 +119,134 @@ private void makeRangeRequest(int partIndex,
120119 handleError (error );
121120 return ;
122121 }
123- requestMoreIfNeeded (response );
122+ requestMoreIfNeeded (response , partIndex );
124123 });
125124 }
126125
127- private void requestMoreIfNeeded (GetObjectResponse response ) {
126+ private void requestMoreIfNeeded (GetObjectResponse response , int partIndex ) {
128127 int totalComplete = completedParts .get ();
129128 log .debug (() -> String .format ("Completed part %d" , totalComplete ));
129+
130+ String responseETag = response .eTag ();
131+ String responseContentRange = response .contentRange ();
132+ if (eTag == null ) {
133+ this .eTag = responseETag ;
134+ log .debug (() -> String .format ("Multipart object ETag: %s" , this .eTag ));
135+ }
136+
137+ Optional <SdkClientException > validationError = validateResponse (response , partIndex );
138+ if (validationError .isPresent ()) {
139+ log .debug (() -> "Response validation failed" , validationError .get ());
140+ handleError (validationError .get ());
141+ return ;
142+ }
130143
131- synchronized ( lock ) {
132- if ( eTag == null ) {
133- this . eTag = response . eTag ();
134- log . debug (() -> String . format ( "Multipart object ETag: %s" , this . eTag ) );
135- } else if ( response . eTag () != null && ! eTag . equals ( response . eTag ())) {
136- handleError (new IllegalStateException ( "ETag mismatch - object may have changed during download" ) );
144+ if ( totalContentLength == null && responseContentRange != null ) {
145+ Optional < Long > parsedContentLength = MultipartDownloadUtils . parseContentRangeForTotalSize ( responseContentRange );
146+ if (! parsedContentLength . isPresent ()) {
147+ SdkClientException error = PresignedUrlDownloadHelper . invalidContentRangeHeader ( responseContentRange );
148+ log . debug (() -> "Failed to parse content range" , error );
149+ handleError (error );
137150 return ;
138151 }
139- if (totalContentLength == null && response .contentRange () != null ) {
140- try {
141- validateResponse (response );
142- this .totalContentLength = parseContentRangeForTotalSize (response .contentRange ());
143- this .totalParts = calculateTotalParts (totalContentLength , configuredPartSizeInBytes );
144- log .debug (() -> String .format ("Total content length: %d, Total parts: %d" , totalContentLength , totalParts ));
145- } catch (Exception e ) {
146- log .debug (() -> "Failed to parse content range" , e );
147- handleError (e );
148- return ;
149- }
150- }
151- if (totalParts != null && totalParts > 1 && totalComplete < totalParts ) {
152+
153+ this .totalContentLength = parsedContentLength .get ();
154+ this .totalParts = calculateTotalParts (totalContentLength , configuredPartSizeInBytes );
155+ log .debug (() -> String .format ("Total content length: %d, Total parts: %d" , totalContentLength , totalParts ));
156+ }
157+
158+ synchronized (lock ) {
159+ if (hasMoreParts (totalComplete )) {
152160 subscription .request (1 );
153161 } else {
162+ if (totalParts != null && requestsSent .get () != totalParts ) {
163+ handleError (new IllegalStateException (
164+ "Request count mismatch. Expected: " + totalParts + ", sent: " + requestsSent .get ()));
165+ return ;
166+ }
154167 log .debug (() -> String .format ("Completing multipart download after a total of %d parts downloaded." , totalParts ));
155168 subscription .cancel ();
156169 }
157170 }
158171 }
159172
160- private void validateResponse (GetObjectResponse response ) {
173+ private Optional < SdkClientException > validateResponse (GetObjectResponse response , int partIndex ) {
161174 if (response == null ) {
162- throw new IllegalStateException ( "Response cannot be null" );
175+ return Optional . of ( SdkClientException . create ( "Response cannot be null" ) );
163176 }
164- if (response .contentRange () == null ) {
165- throw new IllegalStateException ("No Content-Range header in response" );
177+
178+ String contentRange = response .contentRange ();
179+ if (contentRange == null ) {
180+ return Optional .of (PresignedUrlDownloadHelper .missingContentRangeHeader ());
166181 }
182+
167183 Long contentLength = response .contentLength ();
168184 if (contentLength == null || contentLength < 0 ) {
169- throw new IllegalStateException ( "Invalid or missing Content-Length in response" );
185+ return Optional . of ( PresignedUrlDownloadHelper . invalidContentLength () );
170186 }
171- }
172187
173- private long parseContentRangeForTotalSize (String contentRange ) {
174- Matcher matcher = CONTENT_RANGE_PATTERN .matcher (contentRange );
175- if (!matcher .matches ()) {
176- throw new IllegalArgumentException ("Invalid Content-Range header: " + contentRange );
188+ long expectedStartByte = partIndex * configuredPartSizeInBytes ;
189+ long expectedEndByte ;
190+ if (totalContentLength != null ) {
191+ expectedEndByte = Math .min (expectedStartByte + configuredPartSizeInBytes - 1 , totalContentLength - 1 );
192+ } else {
193+ expectedEndByte = expectedStartByte + configuredPartSizeInBytes - 1 ;
194+ }
195+
196+ String expectedRange = "bytes " + expectedStartByte + "-" + expectedEndByte + "/" ;
197+ if (!contentRange .startsWith (expectedRange )) {
198+ return Optional .of (SdkClientException .create (
199+ "Content-Range mismatch. Expected range starting with: " + expectedRange +
200+ ", but got: " + contentRange ));
201+ }
202+
203+ long expectedPartSize ;
204+ if (totalContentLength != null && partIndex == totalParts - 1 ) {
205+ expectedPartSize = totalContentLength - (partIndex * configuredPartSizeInBytes );
206+ } else {
207+ expectedPartSize = configuredPartSizeInBytes ;
208+ }
209+
210+ if (!contentLength .equals (expectedPartSize )) {
211+ return Optional .of (SdkClientException .create (
212+ "Part content length validation failed for part " + partIndex +
213+ ". Expected: " + expectedPartSize + ", but got: " + contentLength ));
177214 }
178- return Long .parseLong (matcher .group (3 ));
179- }
180215
216+ long actualStartByte = MultipartDownloadUtils .parseStartByteFromContentRange (contentRange );
217+ if (actualStartByte != expectedStartByte ) {
218+ return Optional .of (SdkClientException .create (
219+ "Content range offset mismatch for part " + partIndex +
220+ ". Expected start: " + expectedStartByte + ", but got: " + actualStartByte ));
221+ }
222+
223+ return Optional .empty ();
224+ }
225+
181226 private int calculateTotalParts (long contentLength , long partSize ) {
182227 return (int ) Math .ceil ((double ) contentLength / partSize );
183228 }
184229
185- private PresignedUrlDownloadRequest createPartRequest (int partIndex ) {
230+ private boolean hasMoreParts (int completedPartsCount ) {
231+ return totalParts != null && totalParts > 1 && completedPartsCount < totalParts ;
232+ }
233+
234+ private PresignedUrlDownloadRequest createRangedGetRequest (int partIndex ) {
186235 long startByte = partIndex * configuredPartSizeInBytes ;
187236 long endByte ;
188-
189237 if (totalContentLength != null ) {
190238 endByte = Math .min (startByte + configuredPartSizeInBytes - 1 , totalContentLength - 1 );
191239 } else {
192240 endByte = startByte + configuredPartSizeInBytes - 1 ;
193241 }
194242 String rangeHeader = BYTES_RANGE_PREFIX + startByte + "-" + endByte ;
195- return presignedUrlDownloadRequest .toBuilder ()
196- .range (rangeHeader )
197- .build ();
243+ PresignedUrlDownloadRequest .Builder builder = presignedUrlDownloadRequest .toBuilder ()
244+ .range (rangeHeader );
245+ if (partIndex > 0 && eTag != null ) {
246+ builder .ifMatch (eTag );
247+ log .debug (() -> "Setting IfMatch header to: " + eTag + " for part " + partIndex );
248+ }
249+ return builder .build ();
198250 }
199251
200252 private void handleError (Throwable t ) {
@@ -218,6 +270,6 @@ public void onComplete() {
218270 }
219271
220272 public CompletableFuture <Void > future () {
221- return this . future ;
273+ return future ;
222274 }
223275}
0 commit comments