Skip to content

Commit 1efcdb5

Browse files
committed
Add SSLSessionTicketsTest
Adds new test class that verifies the behavior of session tickets. Relevant only for Scylla clusters and TLSv1.3. There are two ssl implementations being tested: JDK and Netty. JDK implementation is tested by tracking `javax.net.ssl` logs. The specifics of TLS handshakes are read from them and custom metrics are collected. It is expected that the client will receive session tickets and use them when possible. With JDK implementation driver is not expected to be able to reconnect using solely session resumptions after node restart. The cache used in Java internal classes (before JDK 24) can hold only 1 ticket for this purpose. This ticket cannot be reused for simultaneous reconnection to multiple shards. Netty implementation is tested by extending `RemoteEndpointAwareNettySSLOptions`. The extension called `TestableNettySSLOptions` should behave nearly identically. The difference comes from additional listeners and handlers that are used for collecting statistics about completed handshakes and ClientHellos sent. In this implementation the cache stores enough sessions for reconnections, so the test method for this implementation expects all reconnections to use the session resumption. The driver also does not attempt to send ClientHello's into the void before the node gets back up, which would waste the session information from received tickets.
1 parent 536b740 commit 1efcdb5

3 files changed

Lines changed: 910 additions & 2 deletions

File tree

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
package com.datastax.driver.core;
2+
3+
import static com.datastax.driver.core.CreateCCM.TestMode.PER_METHOD;
4+
import static java.nio.charset.StandardCharsets.UTF_8;
5+
import static org.testng.Assert.assertEquals;
6+
import static org.testng.Assert.assertTrue;
7+
8+
import com.datastax.driver.core.policies.ConstantReconnectionPolicy;
9+
import com.datastax.driver.core.utils.ScyllaVersion;
10+
import com.google.common.collect.ImmutableList;
11+
import com.google.common.util.concurrent.Uninterruptibles;
12+
import java.io.ByteArrayOutputStream;
13+
import java.io.IOException;
14+
import java.util.List;
15+
import java.util.concurrent.TimeUnit;
16+
import java.util.concurrent.atomic.AtomicInteger;
17+
import java.util.logging.Handler;
18+
import java.util.logging.Level;
19+
import java.util.logging.LogRecord;
20+
import java.util.logging.Logger;
21+
import org.awaitility.Awaitility;
22+
import org.testng.annotations.Test;
23+
24+
@CreateCCM(PER_METHOD)
25+
@CCMConfig(
26+
auth = false,
27+
config = "client_encryption_options.enable_session_tickets:true",
28+
jvmArgs = {"--smp", "5"},
29+
dirtiesContext = true)
30+
public class SSLSessionTicketsTest extends SSLTestBase {
31+
32+
private static final int NUM_SHARDS = 5; // Has to match the smp value above
33+
34+
private Logger sslLogger;
35+
private Level originalLevel;
36+
private TlsDebugLogHandler handler;
37+
38+
private final OccurrenceCounter serverHellos =
39+
new OccurrenceCounter("Consuming ServerHello handshake message");
40+
private final OccurrenceCounter negotiatedTls13 =
41+
new OccurrenceCounter("Negotiated protocol version: TLSv1.3");
42+
private final OccurrenceCounter resumptions = new OccurrenceCounter("Resuming session:");
43+
private final OccurrenceCounter pskUses =
44+
new OccurrenceCounter("Using PSK to derive early secret");
45+
private final OccurrenceCounter ticketsReceived =
46+
new OccurrenceCounter("Consuming NewSessionTicket");
47+
private final List<OccurrenceCounter> counters =
48+
ImmutableList.of(serverHellos, resumptions, pskUses, ticketsReceived, negotiatedTls13);
49+
50+
/**
51+
* @test_category connection:ssl
52+
* @expected_result Connection can be established.
53+
*/
54+
@Test(groups = "isolated")
55+
@ScyllaVersion(
56+
minEnterprise = "2025.2.0",
57+
maxOSS = "0.0.0",
58+
description = "Requires certain options to be enabled server side. Since scylladb/pull/22928")
59+
public void should_receive_tickets_TLSv13_JDK() throws Exception {
60+
try {
61+
setupJavaSslLogTracking();
62+
SSLOptions sslOptions = getSSLOptions(SslImplementation.JDK, false, true, "TLSv1.3");
63+
Cluster cluster = register(createClusterBuilder().withSSL(sslOptions).build());
64+
Session session = cluster.connect();
65+
ResultSet rs = session.execute("SELECT * FROM system.local");
66+
healthCheck(session);
67+
assertEquals(
68+
negotiatedTls13.get(), serverHellos.get(), "Every negotiated TLS version should be 1.3");
69+
assertTrue(ticketsReceived.get() > 0, "Client should have received some tickets");
70+
// If server ever starts sending less (or more) tickets this check below will alert us
71+
assertEquals(
72+
ticketsReceived.get(), serverHellos.get() * 2, "We expect 2 tickets per connection");
73+
assertTrue(resumptions.get() > 0, "Client should have resumed at least one session");
74+
assertTrue(pskUses.get() > 0, "Client should have used PSK at least once for the resumption");
75+
} finally {
76+
cleanUpJavaSslLogTracking();
77+
}
78+
}
79+
80+
@Test(groups = "isolated")
81+
@ScyllaVersion(
82+
minEnterprise = "2025.2.0",
83+
maxOSS = "0.0.0",
84+
description = "Requires certain options to be enabled server side. Since scylladb/pull/22928")
85+
public void all_reconnections_should_use_tickets_TLSv13_netty() throws Exception {
86+
TestableNettySSLOptions testableSSLOptions =
87+
(TestableNettySSLOptions)
88+
getSSLOptions(SslImplementation.NETTY_OPENSSL_DEBUG, false, true, "TLSv1.3");
89+
90+
testableSSLOptions.resetCounters();
91+
Cluster cluster =
92+
register(
93+
createClusterBuilder()
94+
.withSSL(testableSSLOptions)
95+
.withReconnectionPolicy(new ConstantReconnectionPolicy(200))
96+
.build());
97+
Session session = cluster.connect();
98+
ResultSet rs = session.execute("SELECT * FROM system.local");
99+
healthCheck(session);
100+
101+
ccm().stop(1);
102+
Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
103+
ccm().start(1);
104+
healthCheck(session);
105+
106+
// Assert that every connection negotiated TLS 1.3
107+
assertEquals(
108+
testableSSLOptions.getTls13Negotiations(),
109+
testableSSLOptions.getHandshakeCompletions(),
110+
"Every " + "negotiated TLS version should always be 1.3");
111+
112+
// Assert that last <expectedConnections> of ClientHellos contained unique PSK identities
113+
int expectedConnections =
114+
getExpectedNumberOfConnectionsPerHost(session) + 1; // +1 for the control connection
115+
List<TestableNettySSLOptions.ClientHelloInfo> clientHelloHistory =
116+
testableSSLOptions.getClientHelloHistory();
117+
List<TestableNettySSLOptions.ClientHelloInfo> lastClientHellos =
118+
clientHelloHistory.subList(
119+
clientHelloHistory.size() - expectedConnections, clientHelloHistory.size());
120+
// Assert that every element in this list has a psk identity list of 1
121+
long pskIdentityListsOfSize1 =
122+
lastClientHellos.stream().filter(c -> c.getPreSharedKeys().size() == 1).count();
123+
// Technically the client could send more than 1 PSK identity. It would be unexpected here
124+
// though.
125+
assertEquals(
126+
pskIdentityListsOfSize1,
127+
expectedConnections,
128+
"All final ClientHellos should have a PSK identity list of size 1");
129+
// Assert that every element in this list has a unique PSK identity
130+
long uniquePskIdentities =
131+
lastClientHellos.stream()
132+
.map(c -> c.getPreSharedKeys().get(0).getIdentity())
133+
.distinct()
134+
.count();
135+
assertEquals(
136+
uniquePskIdentities,
137+
expectedConnections,
138+
"Every final connection should have utilized PSK to resume the session");
139+
}
140+
141+
@Test(
142+
groups = "isolated",
143+
expectedExceptions = AssertionError.class,
144+
expectedExceptionsMessageRegExp = ".*Every reconnection should be a resumption.*")
145+
@ScyllaVersion(
146+
minEnterprise = "2025.2.0",
147+
maxOSS = "0.0.0",
148+
description = "Requires certain options to be enabled server side. Since scylladb/pull/22928")
149+
public void all_reconnections_should_use_tickets_TLSv13_JDK() throws Exception {
150+
// Unfortunately the OpenJDK's cache in older versions cannot hold more than 1 ticket
151+
// making the reconnection scenario with all reconnections using tickets impossible.
152+
// For additional context see https://github.com/scylladb/java-driver/issues/444
153+
// The insights on what's happening on JDK side should be still relevant despite
154+
// different driver version
155+
int initialResumptions, reconnectionResumptions;
156+
int initialHellos, reconnectionHellos;
157+
int initialPsks, reconnectionPsks;
158+
try {
159+
setupJavaSslLogTracking();
160+
SSLOptions sslOptions = getSSLOptions(SslImplementation.JDK, false, true, "TLSv1.3");
161+
Cluster cluster = register(createClusterBuilder().withSSL(sslOptions).build());
162+
Session session = cluster.connect();
163+
ResultSet rs = session.execute("SELECT * FROM system.local");
164+
healthCheck(session);
165+
initialResumptions = resumptions.get();
166+
initialHellos = serverHellos.get();
167+
initialPsks = pskUses.get();
168+
ccm().stop(1);
169+
Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
170+
ccm().start(1);
171+
healthCheck(session);
172+
reconnectionResumptions = resumptions.get() - initialResumptions;
173+
reconnectionHellos = serverHellos.get() - initialHellos;
174+
reconnectionPsks = pskUses.get() - initialPsks;
175+
assertEquals(
176+
negotiatedTls13.get(), serverHellos.get(), "Every negotiated TLS version should be 1.3");
177+
assertTrue(ticketsReceived.get() > 0, "Client should have received some tickets");
178+
assertEquals(
179+
reconnectionResumptions, reconnectionHellos, "Every reconnection should be a resumption");
180+
assertEquals(
181+
reconnectionPsks, reconnectionHellos, "Every reconnection resumption should use PSK");
182+
} finally {
183+
cleanUpJavaSslLogTracking();
184+
}
185+
}
186+
187+
public void setupJavaSslLogTracking() {
188+
System.setProperty("javax.net.debug", "");
189+
sslLogger = Logger.getLogger("javax.net.ssl");
190+
originalLevel = sslLogger.getLevel();
191+
sslLogger.setLevel(Level.ALL);
192+
193+
for (OccurrenceCounter counter : counters) {
194+
counter.reset();
195+
}
196+
197+
// Custom handler to capture log messages
198+
ByteArrayOutputStream logCapture = new ByteArrayOutputStream();
199+
handler = new TlsDebugLogHandler(logCapture, counters);
200+
sslLogger.setUseParentHandlers(false);
201+
sslLogger.addHandler(handler);
202+
}
203+
204+
public void cleanUpJavaSslLogTracking() {
205+
sslLogger.removeHandler(handler);
206+
sslLogger.setLevel(originalLevel);
207+
}
208+
209+
private void healthCheck(Session session) {
210+
Awaitility.await()
211+
.atMost(20, TimeUnit.SECONDS)
212+
.pollInterval(1, TimeUnit.SECONDS)
213+
.until(
214+
() -> {
215+
try {
216+
for (Host host : session.getCluster().getMetadata().getAllHosts()) {
217+
int expectedConnections = getExpectedNumberOfConnectionsPerHost(session);
218+
if (session.getState().getOpenConnections(host) != expectedConnections) {
219+
return false;
220+
}
221+
}
222+
for (int i = 0; i < 3; i++) {
223+
session.execute("select * from system.local where key='local'");
224+
}
225+
return true;
226+
} catch (Exception e) {
227+
return false;
228+
}
229+
});
230+
}
231+
232+
private int getExpectedNumberOfConnectionsPerHost(Session session) {
233+
// In this test we care only about LOCAL connections. There should be no remote connections.
234+
int expectedConnections =
235+
session
236+
.getCluster()
237+
.getConfiguration()
238+
.getPoolingOptions()
239+
.getCoreConnectionsPerHost(HostDistance.LOCAL);
240+
if (expectedConnections % NUM_SHARDS > 0) {
241+
expectedConnections += NUM_SHARDS - (expectedConnections % NUM_SHARDS);
242+
}
243+
return expectedConnections;
244+
}
245+
246+
static class TlsDebugLogHandler extends Handler {
247+
private final ByteArrayOutputStream outputStream;
248+
private final List<OccurrenceCounter> counters;
249+
250+
TlsDebugLogHandler(ByteArrayOutputStream outputStream, List<OccurrenceCounter> counters) {
251+
this.outputStream = outputStream;
252+
this.counters = counters;
253+
}
254+
255+
@Override
256+
public void publish(LogRecord record) {
257+
try {
258+
for (OccurrenceCounter counter : counters) {
259+
counter.incrementIfFound(record.getMessage());
260+
}
261+
outputStream.write((record.getMessage() + "\n").getBytes(UTF_8));
262+
} catch (IOException e) {
263+
throw new RuntimeException(e);
264+
}
265+
}
266+
267+
@Override
268+
public void flush() {
269+
try {
270+
outputStream.flush();
271+
} catch (IOException e) {
272+
throw new RuntimeException(e);
273+
}
274+
}
275+
276+
@Override
277+
public void close() throws SecurityException {
278+
try {
279+
outputStream.close();
280+
} catch (IOException e) {
281+
throw new RuntimeException(e);
282+
}
283+
}
284+
}
285+
286+
static class OccurrenceCounter {
287+
private final AtomicInteger count = new AtomicInteger(0);
288+
private final String substring; // Exact substring to look for
289+
290+
public OccurrenceCounter(String substring) {
291+
this.substring = substring;
292+
}
293+
294+
/**
295+
* Increment the counter if the substring is found in the log line. Multiple occurrences count
296+
* as one.
297+
*
298+
* @param logLine log line to check
299+
*/
300+
public void incrementIfFound(String logLine) {
301+
if (logLine.contains(substring)) {
302+
count.incrementAndGet();
303+
}
304+
}
305+
306+
public int get() {
307+
return count.get();
308+
}
309+
310+
public String getSubstring() {
311+
return substring;
312+
}
313+
314+
public void reset() {
315+
count.set(0);
316+
}
317+
318+
@Override
319+
public String toString() {
320+
return "OccurrenceCounter{substring='" + substring + "', count=" + count.get() + "}";
321+
}
322+
}
323+
}

driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ protected void connectWithSSL() throws Exception {
7575

7676
enum SslImplementation {
7777
JDK,
78-
NETTY_OPENSSL
78+
NETTY_OPENSSL,
79+
NETTY_OPENSSL_DEBUG
7980
}
8081

8182
/**
@@ -126,6 +127,7 @@ public SSLOptions getSSLOptions(
126127
return RemoteEndpointAwareJdkSSLOptions.builder().withSSLContext(sslContext).build();
127128

128129
case NETTY_OPENSSL:
130+
case NETTY_OPENSSL_DEBUG:
129131
SslContextBuilder builder =
130132
SslContextBuilder.forClient().sslProvider(OPENSSL).trustManager(tmf);
131133

@@ -142,7 +144,11 @@ public SSLOptions getSSLOptions(
142144
CCMBridge.DEFAULT_CLIENT_CERT_CHAIN_FILE, CCMBridge.DEFAULT_CLIENT_PRIVATE_KEY_FILE);
143145
}
144146

145-
return new RemoteEndpointAwareNettySSLOptions(builder.build());
147+
if (sslImplementation.equals(NETTY_OPENSSL)) {
148+
return new RemoteEndpointAwareNettySSLOptions(builder.build());
149+
} else {
150+
return new TestableNettySSLOptions(builder.build());
151+
}
146152
default:
147153
fail("Unsupported SSL implementation: " + sslImplementation);
148154
return null;

0 commit comments

Comments
 (0)