Skip to content

Commit 0d51df3

Browse files
authored
Fix race conditions in aggregate ProfileFileSupplier (aws#6665)
* Fix a race condition in aggregate ProfileFileSupplier * Move aggretate profile supplier logic to a dedicated (internal) class * Make fields private + add javadoc
1 parent d3c1e52 commit 0d51df3

4 files changed

Lines changed: 164 additions & 45 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "bugfix",
3+
"category": "AWS SDK for Java v2",
4+
"contributor": "",
5+
"description": "Fix a race condition in aggregate ProfileFileSupplier that could cause credential resolution failures with shared DefaultCredentialsProvider."
6+
}

core/profiles/src/main/java/software/amazon/awssdk/profiles/ProfileFileSupplier.java

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,10 @@
1717

1818
import java.nio.file.Files;
1919
import java.nio.file.Path;
20-
import java.util.Collections;
21-
import java.util.LinkedHashMap;
22-
import java.util.Map;
23-
import java.util.Objects;
2420
import java.util.Optional;
25-
import java.util.concurrent.atomic.AtomicReference;
2621
import java.util.function.Supplier;
2722
import software.amazon.awssdk.annotations.SdkPublicApi;
23+
import software.amazon.awssdk.profiles.internal.AggregateProfileFileSupplier;
2824
import software.amazon.awssdk.profiles.internal.ProfileFileRefresher;
2925

3026
/**
@@ -125,46 +121,7 @@ static ProfileFileSupplier fixedProfileFile(ProfileFile profileFile) {
125121
*/
126122
static ProfileFileSupplier aggregate(ProfileFileSupplier... suppliers) {
127123

128-
return new ProfileFileSupplier() {
129-
130-
final AtomicReference<ProfileFile> currentAggregateProfileFile = new AtomicReference<>();
131-
final Map<Supplier<ProfileFile>, ProfileFile> currentValuesBySupplier
132-
= Collections.synchronizedMap(new LinkedHashMap<>());
133-
134-
@Override
135-
public ProfileFile get() {
136-
boolean refreshAggregate = false;
137-
for (ProfileFileSupplier supplier : suppliers) {
138-
if (didSuppliedValueChange(supplier)) {
139-
refreshAggregate = true;
140-
}
141-
}
142-
143-
if (refreshAggregate) {
144-
refreshCurrentAggregate();
145-
}
146-
147-
return currentAggregateProfileFile.get();
148-
}
149-
150-
private boolean didSuppliedValueChange(Supplier<ProfileFile> supplier) {
151-
ProfileFile next = supplier.get();
152-
ProfileFile current = currentValuesBySupplier.put(supplier, next);
153-
154-
return !Objects.equals(next, current);
155-
}
156-
157-
private void refreshCurrentAggregate() {
158-
ProfileFile.Aggregator aggregator = ProfileFile.aggregator();
159-
currentValuesBySupplier.values().forEach(aggregator::addFile);
160-
ProfileFile current = currentAggregateProfileFile.get();
161-
ProfileFile next = aggregator.build();
162-
if (!Objects.equals(current, next)) {
163-
currentAggregateProfileFile.compareAndSet(current, next);
164-
}
165-
}
166-
167-
};
124+
return new AggregateProfileFileSupplier(suppliers);
168125
}
169126

170127
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.profiles.internal;
17+
18+
import java.util.Arrays;
19+
import java.util.Collections;
20+
import java.util.LinkedHashMap;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.concurrent.atomic.AtomicReference;
24+
import java.util.function.Supplier;
25+
import software.amazon.awssdk.annotations.SdkInternalApi;
26+
import software.amazon.awssdk.profiles.ProfileFile;
27+
import software.amazon.awssdk.profiles.ProfileFileSupplier;
28+
29+
/**
30+
* A {@link ProfileFileSupplier} that combines the {@link ProfileFile} objects from multiple
31+
* {@code ProfileFileSupplier}s. Objects are passed into {@link ProfileFile.Aggregator}.
32+
*/
33+
@SdkInternalApi
34+
public class AggregateProfileFileSupplier implements ProfileFileSupplier {
35+
private final List<ProfileFileSupplier> suppliers;
36+
37+
// supplier values and the resulting aggregate must always be updated atomically together
38+
private final AtomicReference<SupplierState> state =
39+
new AtomicReference<>(new SupplierState(Collections.emptyMap(), null));
40+
41+
public AggregateProfileFileSupplier(ProfileFileSupplier... suppliers) {
42+
this.suppliers = Collections.unmodifiableList(Arrays.asList(suppliers));
43+
}
44+
45+
@Override
46+
public ProfileFile get() {
47+
SupplierState currentState = state.get();
48+
Map<Supplier<ProfileFile>, ProfileFile> currentValues = currentState.values;
49+
Map<Supplier<ProfileFile>, ProfileFile> changedValues = changedSupplierValues(currentValues);
50+
51+
if (changedValues == null) {
52+
// no suppliers have changed values, return the current aggregate
53+
return currentState.aggregate;
54+
}
55+
56+
// one or more supplier values have changed, we need to update the aggregate (and the state)
57+
// the order of the suppliers matters so we MUST preserve it using LinkedHashMap with insertion ordering
58+
Map<Supplier<ProfileFile>, ProfileFile> nextValues = new LinkedHashMap<>(currentValues);
59+
nextValues.putAll(changedValues);
60+
61+
ProfileFile.Aggregator aggregator = ProfileFile.aggregator();
62+
nextValues.values().forEach(aggregator::addFile);
63+
ProfileFile nextAggregate = aggregator.build();
64+
65+
SupplierState nextState = new SupplierState(nextValues, nextAggregate);
66+
if (state.compareAndSet(currentState, nextState)) {
67+
return nextAggregate;
68+
}
69+
// else: another thread has modified the state in between, assume it is up to date and use the new state
70+
return state.get().aggregate;
71+
}
72+
73+
// return the suppliers with changed values. Returns null if no values have changed
74+
private Map<Supplier<ProfileFile>, ProfileFile> changedSupplierValues(Map<Supplier<ProfileFile>, ProfileFile> currentValues) {
75+
Map<Supplier<ProfileFile>, ProfileFile> changedValues = null;
76+
for (ProfileFileSupplier supplier : suppliers) {
77+
ProfileFile next = supplier.get();
78+
ProfileFile prev = currentValues.get(supplier);
79+
// we ONLY care about if the reference has changed, we don't care about object equality here
80+
if (prev != next) {
81+
if (changedValues == null) {
82+
// changed values must also preserve supplier order
83+
changedValues = new LinkedHashMap<>();
84+
}
85+
changedValues.put(supplier, next);
86+
}
87+
}
88+
return changedValues;
89+
}
90+
91+
/**
92+
* Supplier values and the resulting aggregate must always be updated atomically together.
93+
* This record class tracks all mutable elements of the supplier's state together.
94+
*/
95+
private static final class SupplierState {
96+
private final Map<Supplier<ProfileFile>, ProfileFile> values;
97+
private final ProfileFile aggregate;
98+
99+
private SupplierState(Map<Supplier<ProfileFile>, ProfileFile> values, ProfileFile aggregate) {
100+
this.values = values;
101+
this.aggregate = aggregate;
102+
}
103+
}
104+
}

core/profiles/src/test/java/software/amazon/awssdk/profiles/ProfileFileSupplierTest.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,20 @@
3131
import java.time.ZoneId;
3232
import java.time.ZoneOffset;
3333
import java.time.temporal.TemporalAmount;
34+
import java.util.ArrayList;
3435
import java.util.Arrays;
3536
import java.util.List;
3637
import java.util.Objects;
3738
import java.util.Optional;
3839
import java.util.Set;
40+
import java.util.concurrent.CompletableFuture;
3941
import java.util.concurrent.ConcurrentHashMap;
42+
import java.util.concurrent.CountDownLatch;
43+
import java.util.concurrent.ExecutionException;
44+
import java.util.concurrent.ExecutorService;
45+
import java.util.concurrent.Executors;
46+
import java.util.concurrent.Future;
47+
import java.util.concurrent.TimeUnit;
4048
import java.util.concurrent.atomic.AtomicInteger;
4149
import java.util.function.Predicate;
4250
import java.util.stream.Collectors;
@@ -503,6 +511,50 @@ void aggregate_duplicateOptionsGivenReloadingProfileFirst_preservesPrecedence()
503511
assertThat(accessKeyId).isEqualTo("defaultAccessKey2");
504512
}
505513

514+
@Test
515+
void aggregate_concurrentGetAlwaysReturnsCorrectAggregate() throws ExecutionException, InterruptedException {
516+
ProfileFile credentialFile = credentialProfileFile("test1", "key1", "secret1");
517+
ProfileFile configFile = configProfileFile("profile test",
518+
Pair.of("region", "us-west-2"),
519+
Pair.of("aws_account_id", "012354678922"));
520+
521+
522+
ProfileFile expectedAggregate = ProfileFile.aggregator().addFile(credentialFile).addFile(configFile).build();
523+
524+
ProfileFileSupplier supplier = ProfileFileSupplier.aggregate(() -> credentialFile, () -> configFile);
525+
526+
ExecutorService executor = Executors.newFixedThreadPool(24);
527+
CountDownLatch startLatch = new CountDownLatch(1);
528+
List<Future<Boolean>> tasks = new ArrayList<>();
529+
530+
for(int i = 0; i < 24; i++) {
531+
tasks.add(executor.submit(() -> {
532+
try {
533+
startLatch.await();
534+
ProfileFile resolved = supplier.get();
535+
return Objects.equals(expectedAggregate, resolved);
536+
} catch (InterruptedException e) {
537+
throw new RuntimeException(e);
538+
}
539+
}));
540+
}
541+
// All tasks are now submitted — release them
542+
startLatch.countDown();
543+
executor.shutdown();
544+
try {
545+
assertThat(executor.awaitTermination(10, TimeUnit.SECONDS))
546+
.as("executor did not terminate")
547+
.isTrue();
548+
} finally {
549+
executor.shutdownNow();
550+
}
551+
552+
// assert that all concurrent get's returned the same, expected aggregate
553+
for(Future<Boolean> task : tasks) {
554+
assertThat(task.get()).isTrue();
555+
}
556+
}
557+
506558
@Test
507559
void fixedProfileFile_nullProfileFile_returnsNonNullSupplier() {
508560
ProfileFile file = null;

0 commit comments

Comments
 (0)