From 536b74017f432aae8f5e1f6e559e13e4bb346b8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20B=C4=85czkowski?= Date: Fri, 12 Sep 2025 23:12:43 +0200 Subject: [PATCH 1/3] Expand utility method in SSLTestBase Adds new parameter allowing to specify protocol for SSLContext when using `getSSLOptions`. --- .../com/datastax/driver/core/SSLTestBase.java | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java b/driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java index e4b91b3abbf..0557de2f25b 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java +++ b/driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java @@ -82,11 +82,15 @@ enum SslImplementation { * @param sslImplementation the SSL implementation to use * @param clientAuth whether the client should authenticate * @param trustingServer whether the client should trust the server's certificate + * @param protocol SSLContext protocol to use, e.g. TLSv1.2 * @return {@link com.datastax.driver.core.SSLOptions} with the given configuration for server * certificate validation and client certificate authentication. */ public SSLOptions getSSLOptions( - SslImplementation sslImplementation, boolean clientAuth, boolean trustingServer) + SslImplementation sslImplementation, + boolean clientAuth, + boolean trustingServer, + String protocol) throws Exception { TrustManagerFactory tmf = null; @@ -113,7 +117,7 @@ public SSLOptions getSSLOptions( kmf.init(ks, CCMBridge.DEFAULT_CLIENT_KEYSTORE_PASSWORD.toCharArray()); } - SSLContext sslContext = SSLContext.getInstance("TLS"); + SSLContext sslContext = SSLContext.getInstance(protocol); sslContext.init( kmf != null ? kmf.getKeyManagers() : null, tmf != null ? tmf.getTrustManagers() : null, @@ -125,6 +129,14 @@ public SSLOptions getSSLOptions( SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(OPENSSL).trustManager(tmf); + if (protocol.equals("TLS") || protocol.isEmpty()) { + // There is no netty constant for "TLS". Use defaults. + // see + // https://netty.io/4.1/api/constant-values.html#io.netty.handler.ssl.SslProtocols.SSL_v2 + } else { + builder.protocols(protocol); + } + if (clientAuth) { builder.keyManager( CCMBridge.DEFAULT_CLIENT_CERT_CHAIN_FILE, CCMBridge.DEFAULT_CLIENT_PRIVATE_KEY_FILE); @@ -136,4 +148,15 @@ public SSLOptions getSSLOptions( return null; } } + + /** + * Legacy method using "TLS" as the protocol. + * + * @see SSLTestBase#getSSLOptions(SslImplementation, boolean, boolean, String) + */ + public SSLOptions getSSLOptions( + SslImplementation sslImplementation, boolean clientAuth, boolean trustingServer) + throws Exception { + return getSSLOptions(sslImplementation, clientAuth, trustingServer, "TLS"); + } } From 1efcdb5a0e7770584ae5ceac7df1dc9311f501d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20B=C4=85czkowski?= Date: Sat, 13 Sep 2025 00:04:42 +0200 Subject: [PATCH 2/3] Add SSLSessionTicketsTest Adds new test class that verifies the behavior of session tickets. Relevant only for Scylla clusters and TLSv1.3. There are two ssl implementations being tested: JDK and Netty. JDK implementation is tested by tracking `javax.net.ssl` logs. The specifics of TLS handshakes are read from them and custom metrics are collected. It is expected that the client will receive session tickets and use them when possible. With JDK implementation driver is not expected to be able to reconnect using solely session resumptions after node restart. The cache used in Java internal classes (before JDK 24) can hold only 1 ticket for this purpose. This ticket cannot be reused for simultaneous reconnection to multiple shards. Netty implementation is tested by extending `RemoteEndpointAwareNettySSLOptions`. The extension called `TestableNettySSLOptions` should behave nearly identically. The difference comes from additional listeners and handlers that are used for collecting statistics about completed handshakes and ClientHellos sent. In this implementation the cache stores enough sessions for reconnections, so the test method for this implementation expects all reconnections to use the session resumption. The driver also does not attempt to send ClientHello's into the void before the node gets back up, which would waste the session information from received tickets. --- .../driver/core/SSLSessionTicketsTest.java | 323 ++++++++++ .../com/datastax/driver/core/SSLTestBase.java | 10 +- .../driver/core/TestableNettySSLOptions.java | 579 ++++++++++++++++++ 3 files changed, 910 insertions(+), 2 deletions(-) create mode 100644 driver-core/src/test/java/com/datastax/driver/core/SSLSessionTicketsTest.java create mode 100644 driver-core/src/test/java/com/datastax/driver/core/TestableNettySSLOptions.java diff --git a/driver-core/src/test/java/com/datastax/driver/core/SSLSessionTicketsTest.java b/driver-core/src/test/java/com/datastax/driver/core/SSLSessionTicketsTest.java new file mode 100644 index 00000000000..131c8cec172 --- /dev/null +++ b/driver-core/src/test/java/com/datastax/driver/core/SSLSessionTicketsTest.java @@ -0,0 +1,323 @@ +package com.datastax.driver.core; + +import static com.datastax.driver.core.CreateCCM.TestMode.PER_METHOD; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import com.datastax.driver.core.policies.ConstantReconnectionPolicy; +import com.datastax.driver.core.utils.ScyllaVersion; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Uninterruptibles; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.awaitility.Awaitility; +import org.testng.annotations.Test; + +@CreateCCM(PER_METHOD) +@CCMConfig( + auth = false, + config = "client_encryption_options.enable_session_tickets:true", + jvmArgs = {"--smp", "5"}, + dirtiesContext = true) +public class SSLSessionTicketsTest extends SSLTestBase { + + private static final int NUM_SHARDS = 5; // Has to match the smp value above + + private Logger sslLogger; + private Level originalLevel; + private TlsDebugLogHandler handler; + + private final OccurrenceCounter serverHellos = + new OccurrenceCounter("Consuming ServerHello handshake message"); + private final OccurrenceCounter negotiatedTls13 = + new OccurrenceCounter("Negotiated protocol version: TLSv1.3"); + private final OccurrenceCounter resumptions = new OccurrenceCounter("Resuming session:"); + private final OccurrenceCounter pskUses = + new OccurrenceCounter("Using PSK to derive early secret"); + private final OccurrenceCounter ticketsReceived = + new OccurrenceCounter("Consuming NewSessionTicket"); + private final List counters = + ImmutableList.of(serverHellos, resumptions, pskUses, ticketsReceived, negotiatedTls13); + + /** + * @test_category connection:ssl + * @expected_result Connection can be established. + */ + @Test(groups = "isolated") + @ScyllaVersion( + minEnterprise = "2025.2.0", + maxOSS = "0.0.0", + description = "Requires certain options to be enabled server side. Since scylladb/pull/22928") + public void should_receive_tickets_TLSv13_JDK() throws Exception { + try { + setupJavaSslLogTracking(); + SSLOptions sslOptions = getSSLOptions(SslImplementation.JDK, false, true, "TLSv1.3"); + Cluster cluster = register(createClusterBuilder().withSSL(sslOptions).build()); + Session session = cluster.connect(); + ResultSet rs = session.execute("SELECT * FROM system.local"); + healthCheck(session); + assertEquals( + negotiatedTls13.get(), serverHellos.get(), "Every negotiated TLS version should be 1.3"); + assertTrue(ticketsReceived.get() > 0, "Client should have received some tickets"); + // If server ever starts sending less (or more) tickets this check below will alert us + assertEquals( + ticketsReceived.get(), serverHellos.get() * 2, "We expect 2 tickets per connection"); + assertTrue(resumptions.get() > 0, "Client should have resumed at least one session"); + assertTrue(pskUses.get() > 0, "Client should have used PSK at least once for the resumption"); + } finally { + cleanUpJavaSslLogTracking(); + } + } + + @Test(groups = "isolated") + @ScyllaVersion( + minEnterprise = "2025.2.0", + maxOSS = "0.0.0", + description = "Requires certain options to be enabled server side. Since scylladb/pull/22928") + public void all_reconnections_should_use_tickets_TLSv13_netty() throws Exception { + TestableNettySSLOptions testableSSLOptions = + (TestableNettySSLOptions) + getSSLOptions(SslImplementation.NETTY_OPENSSL_DEBUG, false, true, "TLSv1.3"); + + testableSSLOptions.resetCounters(); + Cluster cluster = + register( + createClusterBuilder() + .withSSL(testableSSLOptions) + .withReconnectionPolicy(new ConstantReconnectionPolicy(200)) + .build()); + Session session = cluster.connect(); + ResultSet rs = session.execute("SELECT * FROM system.local"); + healthCheck(session); + + ccm().stop(1); + Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS); + ccm().start(1); + healthCheck(session); + + // Assert that every connection negotiated TLS 1.3 + assertEquals( + testableSSLOptions.getTls13Negotiations(), + testableSSLOptions.getHandshakeCompletions(), + "Every " + "negotiated TLS version should always be 1.3"); + + // Assert that last of ClientHellos contained unique PSK identities + int expectedConnections = + getExpectedNumberOfConnectionsPerHost(session) + 1; // +1 for the control connection + List clientHelloHistory = + testableSSLOptions.getClientHelloHistory(); + List lastClientHellos = + clientHelloHistory.subList( + clientHelloHistory.size() - expectedConnections, clientHelloHistory.size()); + // Assert that every element in this list has a psk identity list of 1 + long pskIdentityListsOfSize1 = + lastClientHellos.stream().filter(c -> c.getPreSharedKeys().size() == 1).count(); + // Technically the client could send more than 1 PSK identity. It would be unexpected here + // though. + assertEquals( + pskIdentityListsOfSize1, + expectedConnections, + "All final ClientHellos should have a PSK identity list of size 1"); + // Assert that every element in this list has a unique PSK identity + long uniquePskIdentities = + lastClientHellos.stream() + .map(c -> c.getPreSharedKeys().get(0).getIdentity()) + .distinct() + .count(); + assertEquals( + uniquePskIdentities, + expectedConnections, + "Every final connection should have utilized PSK to resume the session"); + } + + @Test( + groups = "isolated", + expectedExceptions = AssertionError.class, + expectedExceptionsMessageRegExp = ".*Every reconnection should be a resumption.*") + @ScyllaVersion( + minEnterprise = "2025.2.0", + maxOSS = "0.0.0", + description = "Requires certain options to be enabled server side. Since scylladb/pull/22928") + public void all_reconnections_should_use_tickets_TLSv13_JDK() throws Exception { + // Unfortunately the OpenJDK's cache in older versions cannot hold more than 1 ticket + // making the reconnection scenario with all reconnections using tickets impossible. + // For additional context see https://github.com/scylladb/java-driver/issues/444 + // The insights on what's happening on JDK side should be still relevant despite + // different driver version + int initialResumptions, reconnectionResumptions; + int initialHellos, reconnectionHellos; + int initialPsks, reconnectionPsks; + try { + setupJavaSslLogTracking(); + SSLOptions sslOptions = getSSLOptions(SslImplementation.JDK, false, true, "TLSv1.3"); + Cluster cluster = register(createClusterBuilder().withSSL(sslOptions).build()); + Session session = cluster.connect(); + ResultSet rs = session.execute("SELECT * FROM system.local"); + healthCheck(session); + initialResumptions = resumptions.get(); + initialHellos = serverHellos.get(); + initialPsks = pskUses.get(); + ccm().stop(1); + Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS); + ccm().start(1); + healthCheck(session); + reconnectionResumptions = resumptions.get() - initialResumptions; + reconnectionHellos = serverHellos.get() - initialHellos; + reconnectionPsks = pskUses.get() - initialPsks; + assertEquals( + negotiatedTls13.get(), serverHellos.get(), "Every negotiated TLS version should be 1.3"); + assertTrue(ticketsReceived.get() > 0, "Client should have received some tickets"); + assertEquals( + reconnectionResumptions, reconnectionHellos, "Every reconnection should be a resumption"); + assertEquals( + reconnectionPsks, reconnectionHellos, "Every reconnection resumption should use PSK"); + } finally { + cleanUpJavaSslLogTracking(); + } + } + + public void setupJavaSslLogTracking() { + System.setProperty("javax.net.debug", ""); + sslLogger = Logger.getLogger("javax.net.ssl"); + originalLevel = sslLogger.getLevel(); + sslLogger.setLevel(Level.ALL); + + for (OccurrenceCounter counter : counters) { + counter.reset(); + } + + // Custom handler to capture log messages + ByteArrayOutputStream logCapture = new ByteArrayOutputStream(); + handler = new TlsDebugLogHandler(logCapture, counters); + sslLogger.setUseParentHandlers(false); + sslLogger.addHandler(handler); + } + + public void cleanUpJavaSslLogTracking() { + sslLogger.removeHandler(handler); + sslLogger.setLevel(originalLevel); + } + + private void healthCheck(Session session) { + Awaitility.await() + .atMost(20, TimeUnit.SECONDS) + .pollInterval(1, TimeUnit.SECONDS) + .until( + () -> { + try { + for (Host host : session.getCluster().getMetadata().getAllHosts()) { + int expectedConnections = getExpectedNumberOfConnectionsPerHost(session); + if (session.getState().getOpenConnections(host) != expectedConnections) { + return false; + } + } + for (int i = 0; i < 3; i++) { + session.execute("select * from system.local where key='local'"); + } + return true; + } catch (Exception e) { + return false; + } + }); + } + + private int getExpectedNumberOfConnectionsPerHost(Session session) { + // In this test we care only about LOCAL connections. There should be no remote connections. + int expectedConnections = + session + .getCluster() + .getConfiguration() + .getPoolingOptions() + .getCoreConnectionsPerHost(HostDistance.LOCAL); + if (expectedConnections % NUM_SHARDS > 0) { + expectedConnections += NUM_SHARDS - (expectedConnections % NUM_SHARDS); + } + return expectedConnections; + } + + static class TlsDebugLogHandler extends Handler { + private final ByteArrayOutputStream outputStream; + private final List counters; + + TlsDebugLogHandler(ByteArrayOutputStream outputStream, List counters) { + this.outputStream = outputStream; + this.counters = counters; + } + + @Override + public void publish(LogRecord record) { + try { + for (OccurrenceCounter counter : counters) { + counter.incrementIfFound(record.getMessage()); + } + outputStream.write((record.getMessage() + "\n").getBytes(UTF_8)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void flush() { + try { + outputStream.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() throws SecurityException { + try { + outputStream.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + static class OccurrenceCounter { + private final AtomicInteger count = new AtomicInteger(0); + private final String substring; // Exact substring to look for + + public OccurrenceCounter(String substring) { + this.substring = substring; + } + + /** + * Increment the counter if the substring is found in the log line. Multiple occurrences count + * as one. + * + * @param logLine log line to check + */ + public void incrementIfFound(String logLine) { + if (logLine.contains(substring)) { + count.incrementAndGet(); + } + } + + public int get() { + return count.get(); + } + + public String getSubstring() { + return substring; + } + + public void reset() { + count.set(0); + } + + @Override + public String toString() { + return "OccurrenceCounter{substring='" + substring + "', count=" + count.get() + "}"; + } + } +} diff --git a/driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java b/driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java index 0557de2f25b..9df57d5f7f9 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java +++ b/driver-core/src/test/java/com/datastax/driver/core/SSLTestBase.java @@ -75,7 +75,8 @@ protected void connectWithSSL() throws Exception { enum SslImplementation { JDK, - NETTY_OPENSSL + NETTY_OPENSSL, + NETTY_OPENSSL_DEBUG } /** @@ -126,6 +127,7 @@ public SSLOptions getSSLOptions( return RemoteEndpointAwareJdkSSLOptions.builder().withSSLContext(sslContext).build(); case NETTY_OPENSSL: + case NETTY_OPENSSL_DEBUG: SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(OPENSSL).trustManager(tmf); @@ -142,7 +144,11 @@ public SSLOptions getSSLOptions( CCMBridge.DEFAULT_CLIENT_CERT_CHAIN_FILE, CCMBridge.DEFAULT_CLIENT_PRIVATE_KEY_FILE); } - return new RemoteEndpointAwareNettySSLOptions(builder.build()); + if (sslImplementation.equals(NETTY_OPENSSL)) { + return new RemoteEndpointAwareNettySSLOptions(builder.build()); + } else { + return new TestableNettySSLOptions(builder.build()); + } default: fail("Unsupported SSL implementation: " + sslImplementation); return null; diff --git a/driver-core/src/test/java/com/datastax/driver/core/TestableNettySSLOptions.java b/driver-core/src/test/java/com/datastax/driver/core/TestableNettySSLOptions.java new file mode 100644 index 00000000000..08ce3658976 --- /dev/null +++ b/driver-core/src/test/java/com/datastax/driver/core/TestableNettySSLOptions.java @@ -0,0 +1,579 @@ +package com.datastax.driver.core; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import javax.net.ssl.SSLSession; + +/** + * A testable version of RemoteEndpointAwareNettySSLOptions that tracks SSL events for verification + * in tests. + */ +public class TestableNettySSLOptions extends RemoteEndpointAwareNettySSLOptions { + + private static final boolean DEBUG = false; + + private final AtomicInteger handshakeCompletions = new AtomicInteger(0); + private final AtomicInteger tls13Negotiations = new AtomicInteger(0); + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final List sessionHistory = Collections.synchronizedList(new ArrayList<>()); + + private final List clientHelloHistory = + Collections.synchronizedList(new ArrayList<>()); + + public TestableNettySSLOptions(SslContext context) { + super(context); + } + + @Override + public SslHandler newSSLHandler(SocketChannel channel, EndPoint remoteEndpoint) { + SslHandler sslHandler = super.newSSLHandler(channel, remoteEndpoint); + setupSslEventTracking(channel, sslHandler); + return sslHandler; + } + + private void setupSslEventTracking(SocketChannel channel, SslHandler sslHandler) { + channel.pipeline().addFirst(new ClientHelloInspector()); + + // Track handshake completion events + sslHandler + .handshakeFuture() + .addListener( + (GenericFutureListener>) + future -> { + if (future.isSuccess()) { + handshakeCompletions.incrementAndGet(); + + SSLSession session = sslHandler.engine().getSession(); + String protocol = session.getProtocol(); + byte[] sessionId = session.getId(); + String sessionIdHex = bytesToHex(sessionId); + long sessionCreationTime = session.getCreationTime(); + long currentTime = System.currentTimeMillis(); + + if ("TLSv1.3".equals(protocol)) { + tls13Negotiations.incrementAndGet(); + } + + // Create session info + SessionInfo sessionInfo = + new SessionInfo( + sessionIdHex, + sessionCreationTime, + currentTime, + protocol, + session.getCipherSuite(), + channel.remoteAddress().toString()); + + if (!sessions.containsKey(sessionIdHex)) { + sessions.put(sessionIdHex, sessionInfo); + } + + sessionHistory.add(sessionInfo); + } + }); + } + + private String bytesToHex(byte[] bytes) { + if (bytes == null || bytes.length == 0) { + return "empty"; + } + StringBuilder result = new StringBuilder(); + for (byte b : bytes) { + result.append(String.format("%02x", b)); + } + return result.toString(); + } + + public int getHandshakeCompletions() { + return handshakeCompletions.get(); + } + + public int getTls13Negotiations() { + return tls13Negotiations.get(); + } + + public int getUniqueSessionsCount() { + return sessions.size(); + } + + public List getSessionHistory() { + return new ArrayList<>(sessionHistory); + } + + public List getClientHelloHistory() { + return new ArrayList<>(clientHelloHistory); + } + + // Reset counters for test setup + public void resetCounters() { + handshakeCompletions.set(0); + tls13Negotiations.set(0); + sessions.clear(); + sessionHistory.clear(); + clientHelloHistory.clear(); + } + + // Print session information to standard output + public void printSessionInfo() { + System.out.println("=== SSL Session Information ==="); + System.out.println("Total handshakes: " + getHandshakeCompletions()); + System.out.println("TLS 1.3 negotiations: " + getTls13Negotiations()); + System.out.println("Unique sessions: " + getUniqueSessionsCount()); + System.out.println(); + + System.out.println("=== Session History ==="); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); + + for (int i = 0; i < sessionHistory.size(); i++) { + SessionInfo info = sessionHistory.get(i); + System.out.println("Handshake #" + (i + 1) + ":"); + System.out.println(" Session ID: " + info.getSessionId()); + System.out.println(" Creation Time: " + dateFormat.format(new Date(info.getCreationTime()))); + System.out.println( + " Handshake Time: " + dateFormat.format(new Date(info.getHandshakeTime()))); + System.out.println(" Protocol: " + info.getProtocol()); + System.out.println(" Cipher Suite: " + info.getCipherSuite()); + System.out.println(" Remote Address: " + info.getRemoteAddress()); + System.out.println( + " Age at handshake: " + (info.getHandshakeTime() - info.getCreationTime()) + "ms"); + System.out.println(); + } + + System.out.println("=== Unique Sessions ==="); + for (SessionInfo info : sessions.values()) { + System.out.println( + "Session ID: " + + info.getSessionId() + + " | Created: " + + dateFormat.format(new Date(info.getCreationTime())) + + " | Protocol: " + + info.getProtocol() + + " | Cipher: " + + info.getCipherSuite()); + } + + System.out.println("=== ClientHello History ==="); + for (ClientHelloInfo helloInfo : clientHelloHistory) { + System.out.println( + "ClientHello ID: " + + helloInfo.getClientHelloId() + + " | Created: " + + dateFormat.format(new Date(helloInfo.getCreationTime())) + + " | Session ID: " + + helloInfo.getSessionId() + + " | Has PSK Extension: " + + helloInfo.hasPreSharedKeyExtension()); + + if (helloInfo.hasPreSharedKeyExtension()) { + System.out.println( + " Pre-shared Keys (" + helloInfo.getPreSharedKeys().size() + " total):"); + for (int i = 0; i < helloInfo.getPreSharedKeys().size(); i++) { + PreSharedKeyInfo psk = helloInfo.getPreSharedKeys().get(i); + System.out.println(" PSK[" + i + "] Identity: " + psk.getIdentity()); + System.out.println( + " PSK[" + i + "] Obfuscated Ticket Age: " + psk.getObfuscatedTicketAge()); + } + } + } + System.out.println("=============================="); + } + + public static class SessionInfo { + private final String sessionId; + private final long creationTime; + private final long handshakeTime; + private final String protocol; + private final String cipherSuite; + private final String remoteAddress; + + public SessionInfo( + String sessionId, + long creationTime, + long handshakeTime, + String protocol, + String cipherSuite, + String remoteAddress) { + this.sessionId = sessionId; + this.creationTime = creationTime; + this.handshakeTime = handshakeTime; + this.protocol = protocol; + this.cipherSuite = cipherSuite; + this.remoteAddress = remoteAddress; + } + + public String getSessionId() { + return sessionId; + } + + public long getCreationTime() { + return creationTime; + } + + public long getHandshakeTime() { + return handshakeTime; + } + + public String getProtocol() { + return protocol; + } + + public String getCipherSuite() { + return cipherSuite; + } + + public String getRemoteAddress() { + return remoteAddress; + } + + @Override + public String toString() { + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); + return String.format( + "SessionInfo{id=%s, created=%s, handshake=%s, protocol=%s, cipher=%s, remote=%s}", + sessionId, + dateFormat.format(new Date(creationTime)), + dateFormat.format(new Date(handshakeTime)), + protocol, + cipherSuite, + remoteAddress); + } + } + + // Inner class to hold new session ticket information + public static class NewSessionTicketInfo { + private final String ticketId; + private final long creationTime; + private final String sessionId; + private boolean resumed; + + public NewSessionTicketInfo(String ticketId, long creationTime, String sessionId) { + this.ticketId = ticketId; + this.creationTime = creationTime; + this.sessionId = sessionId; + this.resumed = false; + } + + public String getTicketId() { + return ticketId; + } + + public long getCreationTime() { + return creationTime; + } + + public String getSessionId() { + return sessionId; + } + + public boolean isResumed() { + return resumed; + } + + public void setResumed(boolean resumed) { + this.resumed = resumed; + } + + @Override + public String toString() { + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); + return String.format( + "NewSessionTicketInfo{ticketId=%s, created=%s, sessionId=%s, resumed=%s}", + ticketId, dateFormat.format(new Date(creationTime)), sessionId, resumed); + } + } + + public static class ClientHelloInfo { + private final String clientHelloId; + private final long creationTime; + private final String sessionId; + private final List preSharedKeys; + private final boolean hasPreSharedKeyExtension; + + public ClientHelloInfo(String clientHelloId, long creationTime, String sessionId) { + this.clientHelloId = clientHelloId; + this.creationTime = creationTime; + this.sessionId = sessionId; + this.preSharedKeys = new ArrayList<>(); + this.hasPreSharedKeyExtension = false; + } + + public ClientHelloInfo( + String clientHelloId, + long creationTime, + String sessionId, + List preSharedKeys) { + this.clientHelloId = clientHelloId; + this.creationTime = creationTime; + this.sessionId = sessionId; + this.preSharedKeys = + preSharedKeys != null ? new ArrayList<>(preSharedKeys) : new ArrayList<>(); + this.hasPreSharedKeyExtension = !this.preSharedKeys.isEmpty(); + } + + public String getClientHelloId() { + return clientHelloId; + } + + public long getCreationTime() { + return creationTime; + } + + public String getSessionId() { + return sessionId; + } + + public List getPreSharedKeys() { + return new ArrayList<>(preSharedKeys); + } + + public boolean hasPreSharedKeyExtension() { + return hasPreSharedKeyExtension; + } + + @Override + public String toString() { + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); + return String.format( + "ClientHelloInfo{clientHelloId=%s, created=%s, sessionId=%s, hasPreSharedKeys=%s, preSharedKeyCount=%d}", + clientHelloId, + dateFormat.format(new Date(creationTime)), + sessionId, + hasPreSharedKeyExtension, + preSharedKeys.size()); + } + } + + // Inner class to hold pre-shared key information + public static class PreSharedKeyInfo { + private final String identity; + private final int obfuscatedTicketAge; + + public PreSharedKeyInfo(String identity, int obfuscatedTicketAge) { + this.identity = identity; + this.obfuscatedTicketAge = obfuscatedTicketAge; + } + + public String getIdentity() { + return identity; + } + + public int getObfuscatedTicketAge() { + return obfuscatedTicketAge; + } + + @Override + public String toString() { + return String.format( + "PreSharedKeyInfo{identity=%s, obfuscatedAge=%d}", identity, obfuscatedTicketAge); + } + } + + private class ClientHelloInspector extends ChannelOutboundHandlerAdapter { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (msg instanceof ByteBuf) { + ByteBuf buffer = (ByteBuf) msg; + + // Check if this looks like a TLS handshake message + if (buffer.readableBytes() >= 6) { + // Make a copy to inspect without affecting the original buffer + ByteBuf copy = buffer.duplicate(); + + // TLS record header: type (1 byte) + version (2 bytes) + length (2 bytes) + byte contentType = copy.readByte(); + short version = copy.readShort(); + short length = copy.readShort(); + + // Check if this is a handshake record (content type 22) + if (contentType == 22 && copy.readableBytes() >= 4) { + // Handshake message header: type (1 byte) + length (3 bytes) + byte handshakeType = copy.readByte(); + + // Check if this is a ClientHello (handshake type 1) + if (handshakeType == 1) { + // Read the handshake message length (3 bytes, big-endian) + int messageLength = + (copy.readByte() & 0xFF) << 16 + | (copy.readByte() & 0xFF) << 8 + | (copy.readByte() & 0xFF); + + if (copy.readableBytes() >= Math.min(messageLength, 34)) { + ClientHelloInfo clientHelloInfo = parseClientHello(copy, messageLength); + if (clientHelloInfo != null) { + clientHelloHistory.add(clientHelloInfo); + + if (DEBUG) { + System.out.println("=== ClientHello Detected ==="); + System.out.println("ClientHello ID: " + clientHelloInfo.getClientHelloId()); + System.out.println("Session ID: " + clientHelloInfo.getSessionId()); + System.out.println( + "Timestamp: " + + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS") + .format(new Date(clientHelloInfo.getCreationTime()))); + System.out.println("Raw message length: " + messageLength + " bytes"); + System.out.println( + "Has pre_shared_key extension: " + + clientHelloInfo.hasPreSharedKeyExtension()); + if (clientHelloInfo.hasPreSharedKeyExtension()) { + System.out.println( + "Pre-shared keys count: " + clientHelloInfo.getPreSharedKeys().size()); + for (int i = 0; i < clientHelloInfo.getPreSharedKeys().size(); i++) { + PreSharedKeyInfo psk = clientHelloInfo.getPreSharedKeys().get(i); + System.out.println(" PSK[" + i + "] Identity: " + psk.getIdentity()); + System.out.println( + " PSK[" + + i + + "] Obfuscated Ticket Age: " + + psk.getObfuscatedTicketAge()); + } + } + System.out.println("============================"); + } + } + } + } + } + } + } + + // Pass the message along unchanged + super.write(ctx, msg, promise); + } + + private ClientHelloInfo parseClientHello(ByteBuf buffer, int messageLength) { + try { + // Skip protocol version (2 bytes) + buffer.skipBytes(2); + + // Skip client random (32 bytes) + buffer.skipBytes(32); + + // Read session ID length (1 byte) + if (buffer.readableBytes() < 1) return null; + int sessionIdLength = buffer.readUnsignedByte(); + + // Read session ID + String sessionId = "empty"; + if (sessionIdLength > 0 && buffer.readableBytes() >= sessionIdLength) { + byte[] sessionIdBytes = new byte[sessionIdLength]; + buffer.readBytes(sessionIdBytes); + sessionId = bytesToHex(sessionIdBytes); + } + + // Skip cipher suites length (2 bytes) and cipher suites + if (buffer.readableBytes() < 2) return null; + int cipherSuitesLength = buffer.readUnsignedShort(); + if (buffer.readableBytes() < cipherSuitesLength) return null; + buffer.skipBytes(cipherSuitesLength); + + // Skip compression methods length (1 byte) and compression methods + if (buffer.readableBytes() < 1) return null; + int compressionMethodsLength = buffer.readUnsignedByte(); + if (buffer.readableBytes() < compressionMethodsLength) return null; + buffer.skipBytes(compressionMethodsLength); + + // Parse extensions if present + List preSharedKeys = new ArrayList<>(); + if (buffer.readableBytes() >= 2) { + // Read extensions length (2 bytes) + int extensionsLength = buffer.readUnsignedShort(); + int extensionsStart = buffer.readerIndex(); + int extensionsEnd = extensionsStart + extensionsLength; + + // Parse each extension + while (buffer.readerIndex() < extensionsEnd && buffer.readableBytes() >= 4) { + int extType = buffer.readUnsignedShort(); + int extLength = buffer.readUnsignedShort(); + + if (buffer.readableBytes() < extLength) { + // Not enough data for this extension + break; + } + + // Check for pre_shared_key extension (type 41) + if (extType == 41) { + preSharedKeys = parsePreSharedKeyExtension(buffer, extLength); + } else { + // Skip this extension + buffer.skipBytes(extLength); + } + } + } + + // Generate a unique ID for this ClientHello + String clientHelloId = + "ch_" + System.currentTimeMillis() + "_" + Math.abs(buffer.hashCode() % 1000); + + return new ClientHelloInfo( + clientHelloId, System.currentTimeMillis(), sessionId, preSharedKeys); + + } catch (Exception e) { + // If parsing fails, return null + System.err.println("Failed to parse ClientHello: " + e.getMessage()); + return null; + } + } + + private List parsePreSharedKeyExtension(ByteBuf buffer, int extLength) { + List preSharedKeys = new ArrayList<>(); + try { + int startIndex = buffer.readerIndex(); + int endIndex = startIndex + extLength; + + // Read identities length (2 bytes) + if (buffer.readableBytes() < 2) return preSharedKeys; + int identitiesLength = buffer.readUnsignedShort(); + + int identitiesStart = buffer.readerIndex(); + int identitiesEnd = identitiesStart + identitiesLength; + + // Parse each identity + while (buffer.readerIndex() < identitiesEnd && buffer.readableBytes() >= 2) { + // Read identity length (2 bytes) + int identityLength = buffer.readUnsignedShort(); + + if (buffer.readableBytes() < identityLength + 4) { + // Not enough data for identity + obfuscated_ticket_age + break; + } + + // Read identity data + byte[] identityBytes = new byte[identityLength]; + buffer.readBytes(identityBytes); + String identity = bytesToHex(identityBytes); + + // Read obfuscated_ticket_age (4 bytes) + int obfuscatedTicketAge = buffer.readInt(); + + preSharedKeys.add(new PreSharedKeyInfo(identity, obfuscatedTicketAge)); + } + + // Skip any remaining bytes in the extension (like PSK binders) + int remainingBytes = endIndex - buffer.readerIndex(); + if (remainingBytes > 0 && buffer.readableBytes() >= remainingBytes) { + buffer.skipBytes(remainingBytes); + } + + } catch (Exception e) { + throw new RuntimeException("Failed to parse pre_shared_key extension", e); + } + return preSharedKeys; + } + } +} From 191d29a6d62bc44de82bb85a7b41e101b3f8b8ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20B=C4=85czkowski?= Date: Tue, 23 Sep 2025 01:43:24 +0200 Subject: [PATCH 3/3] Add copyright headers to modified parts --- .../java/com/datastax/driver/core/SSLSessionTicketsTest.java | 5 +++++ .../com/datastax/driver/core/TestableNettySSLOptions.java | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/driver-core/src/test/java/com/datastax/driver/core/SSLSessionTicketsTest.java b/driver-core/src/test/java/com/datastax/driver/core/SSLSessionTicketsTest.java index 131c8cec172..7b3aac272ff 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/SSLSessionTicketsTest.java +++ b/driver-core/src/test/java/com/datastax/driver/core/SSLSessionTicketsTest.java @@ -1,3 +1,8 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + * Modified by ScyllaDB + */ package com.datastax.driver.core; import static com.datastax.driver.core.CreateCCM.TestMode.PER_METHOD; diff --git a/driver-core/src/test/java/com/datastax/driver/core/TestableNettySSLOptions.java b/driver-core/src/test/java/com/datastax/driver/core/TestableNettySSLOptions.java index 08ce3658976..18c00d7e74f 100644 --- a/driver-core/src/test/java/com/datastax/driver/core/TestableNettySSLOptions.java +++ b/driver-core/src/test/java/com/datastax/driver/core/TestableNettySSLOptions.java @@ -1,3 +1,8 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + * Modified by ScyllaDB + */ package com.datastax.driver.core; import io.netty.buffer.ByteBuf;