Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,21 @@ public enum DefaultDriverOption implements DriverOption {
*/
CONNECTION_POOL_REMOTE_SIZE("advanced.connection.pool.remote.size"),

/**
* The maximum number of connections to create at once when filling the connection pool. Relevant
* during channel pool creation and reconnections.
*
* <p>Value 0 means unlimited - all missing channels will be created at once. Any other value
* means that driver will create connections in batches of at most that size and the batches will
* be handled sequentially one after another. The actual batch size may be smaller, to ensure that
* at least two batches are created.
*
* <p>It is advised to use advanced shard awareness with this feature.
*
* <p>Value-type: int
*/
CONNECTION_POOL_INIT_BATCH_SIZE("advanced.connection.pool.init-batch-size"),

/**
* Whether to schedule reconnection attempts if all contact points are unreachable on the first
* initialization attempt.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ protected static void fillWithDriverDefaults(OptionsMap map) {
map.put(TypedDriverOption.CONNECTION_SET_KEYSPACE_TIMEOUT, initQueryTimeout);
map.put(TypedDriverOption.CONNECTION_POOL_LOCAL_SIZE, 1);
map.put(TypedDriverOption.CONNECTION_POOL_REMOTE_SIZE, 1);
map.put(TypedDriverOption.CONNECTION_POOL_INIT_BATCH_SIZE, 0);
map.put(TypedDriverOption.CONNECTION_MAX_REQUESTS, 1024);
map.put(TypedDriverOption.CONNECTION_MAX_ORPHAN_REQUESTS, 256);
map.put(TypedDriverOption.CONNECTION_WARN_INIT_ERROR, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ public String toString() {
/** The number of connections in the REMOTE pool. */
public static final TypedDriverOption<Integer> CONNECTION_POOL_REMOTE_SIZE =
new TypedDriverOption<>(DefaultDriverOption.CONNECTION_POOL_REMOTE_SIZE, GenericType.INTEGER);
/** The maximum number of connections to create at once when filling the connection pool. */
public static final TypedDriverOption<Integer> CONNECTION_POOL_INIT_BATCH_SIZE =
new TypedDriverOption<>(
DefaultDriverOption.CONNECTION_POOL_INIT_BATCH_SIZE, GenericType.INTEGER);
/**
* Whether to schedule reconnection attempts if all contact points are unreachable on the first
* initialization attempt.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,11 @@ private CompletionStage<Boolean> addMissingChannels() {
channels.length * wantedCount - Arrays.stream(channels).mapToInt(ChannelSet::size).sum();
LOG.debug("[{}] Trying to create {} missing channels", logPrefix, missing);
DriverChannelOptions options = buildDriverOptions();
int batchSize =
config.getDefaultProfile().getInt(DefaultDriverOption.CONNECTION_POOL_INIT_BATCH_SIZE);
batchSize = Integer.min(batchSize, (channels.length * wantedCount + 1) / 2);
List<CompletionStage<DriverChannel>> previousBatch = new ArrayList<>();
List<CompletionStage<DriverChannel>> currentBatch = new ArrayList<>();
for (int shard = 0; shard < channels.length; shard++) {
LOG.trace(
"[{}] Missing {} channels for shard {}",
Expand All @@ -501,11 +506,26 @@ private CompletionStage<Boolean> addMissingChannels() {
if (config
.getDefaultProfile()
.getBoolean(DefaultDriverOption.CONNECTION_ADVANCED_SHARD_AWARENESS_ENABLED)) {
channelFuture = channelFactory.connect(node, shard, options);
int finalShard = shard;
channelFuture =
CompletableFutures.allDone(previousBatch)
.thenComposeAsync(
ignored -> channelFactory.connect(node, finalShard, options),
adminExecutor);
} else {
channelFuture = channelFactory.connect(node, options);
channelFuture =
CompletableFutures.allDone(previousBatch)
.thenComposeAsync(
ignored -> channelFactory.connect(node, options), adminExecutor);
}
pendingChannels.add(channelFuture);
if (batchSize != 0) {
currentBatch.add(channelFuture);
if (currentBatch.size() >= batchSize) {
previousBatch = currentBatch;
currentBatch = new ArrayList<>();
}
}
}
}
return CompletableFutures.allDone(pendingChannels)
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,29 @@ datastax-java-driver {
# and will adjust their size.
# Overridable in a profile: no
remote.size = 1

# The maximum number of connections to create at once when filling the particular connection pool.
# This controls the batching behavior for connection creation to avoid overwhelming the
# server with too many simultaneous connection attempts.
#
# If this value is 0, there is no limit and all required connections will be created
# simultaneously as usual. Setting this to a positive value will create
# connections in batches of at most that specified size.
#
# This should be particularly useful for TLS 1.3 session resumptions when working with ScyllaDB.
# OpenJDK in version 24 improves on simultaneous session resumption by multiple threads.
# By batching the connection creation the driver can take advantage of NewSessionTickets received by previous
# batches when establishing new connections.
# OpenJDK 24 has a hardcoded max queue size of 10 for its sessionHostPortCache.
# Previous versions hold at most 1 cached session per host, so the main benefit of enabling this
# would be only the throttling.
#
# It is advised to use advanced shard awareness with this feature.
#
# Required: yes
# Modifiable at runtime: yes; the new value will be used for reconnection attempts after the change.
# Overridable in a profile: no
init-batch-size = 0
}

# The maximum number of requests that can be executed concurrently on a connection. This must be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.datastax.oss.driver.api.testinfra.requirement.BackendType;
import com.datastax.oss.driver.api.testinfra.session.SessionUtils;
import com.datastax.oss.driver.categories.IsolatedTests;
import com.datastax.oss.driver.internal.core.connection.ConstantReconnectionPolicy;
import com.datastax.oss.driver.shaded.guava.common.util.concurrent.Uninterruptibles;
import com.google.common.collect.ImmutableList;
import java.io.ByteArrayOutputStream;
Expand All @@ -22,6 +23,7 @@
import java.nio.file.Paths;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -183,6 +185,65 @@ public void all_reconnections_should_use_tickets_with_TLSv13() throws Exception
reconnectionPsks);
}

@Test
public void all_reconnections_but_one_should_use_tickets_when_throttled_TLSv13()
throws Exception {
int initialResumptions, reconnectionResumptions;
int initialHellos, reconnectionHellos;
int initialPsks, reconnectionPsks;
try {
SSLContext context = createSslContext("TLSv1.3");
try (DriverConfigLoader configLoader =
SessionUtils.configLoaderBuilder()
.withString(
DefaultDriverOption.PROTOCOL_VERSION, DefaultProtocolVersion.V4.name())
.withClass(
DefaultDriverOption.RECONNECTION_POLICY_CLASS,
ConstantReconnectionPolicy.class)
.withDuration(DefaultDriverOption.RECONNECTION_BASE_DELAY, Duration.ofSeconds(1))
.withInt(DefaultDriverOption.CONNECTION_POOL_INIT_BATCH_SIZE, 1)
.build();
CqlSession session =
(CqlSession)
SessionUtils.baseBuilder()
.addContactEndPoints(CCM_RULE.getContactPointsWithShardAwarePort())
.withSslContext(context)
.withConfigLoader(configLoader)
.build()) {
healthCheck(session);

// Perform a node restart to force all connections to be re-established
initialResumptions = resumptions.get();
initialHellos = serverHellos.get();
initialPsks = pskUses.get();
CCM_RULE.getCcmBridge().stop();
Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
CCM_RULE.getCcmBridge().start();
healthCheck(session);
reconnectionResumptions = resumptions.get() - initialResumptions;
reconnectionHellos = serverHellos.get() - initialHellos;
reconnectionPsks = pskUses.get() - initialPsks;
}
} finally {
handler.flush();
}

Assert.assertEquals(
"Each connection should have negotiated TLSv1.3.",
serverHellos.get(),
negotiatedTls13.get());
Assert.assertTrue("Client should have received some tickets.", ticketsReceived.get() > 0);
Assert.assertTrue(
String.format(
"Each reconnection but first should be a resumption. Meanwhile found %s ServerHellos and %s resumptions.",
reconnectionHellos, reconnectionResumptions),
reconnectionResumptions >= reconnectionHellos - 1);
Assert.assertEquals(
"PSK should have been used for each resumption on reconnection.",
reconnectionResumptions,
reconnectionPsks);
}

private void healthCheck(CqlSession session) {
Awaitility.await()
.atMost(20, TimeUnit.SECONDS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,12 @@ public Set<EndPoint> getContactPoints() {
return Collections.singleton(
new DefaultEndPoint(new InetSocketAddress(ccmBridge.getNodeIpAddress(1), 9042)));
}

public Set<EndPoint> getContactPointsWithShardAwarePort() {
if (!CcmBridge.isDistributionOf(BackendType.SCYLLA)) {
throw new UnsupportedOperationException("Shard aware port is only supported in Scylla");
}
return Collections.singleton(
new DefaultEndPoint(new InetSocketAddress(ccmBridge.getNodeIpAddress(1), 19042)));
}
}
Loading