From 47a0fc3b1a0afa9e662d7a9b0b662c2850ef015a Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 27 Mar 2025 13:23:13 +0000 Subject: [PATCH 1/2] fix: only keep track of sockets opened using domain name --- core/src/main/java/com/google/cloud/sql/core/Connector.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/google/cloud/sql/core/Connector.java b/core/src/main/java/com/google/cloud/sql/core/Connector.java index ee277b400..4716e8f1e 100644 --- a/core/src/main/java/com/google/cloud/sql/core/Connector.java +++ b/core/src/main/java/com/google/cloud/sql/core/Connector.java @@ -138,7 +138,11 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException { } logger.debug(String.format("[%s] Connected to instance successfully.", instanceIp)); - instance.addSocket(socket); + // If this connection was opened using a domain name, then store it + // for later in case we need to forcibly close it on failover. + if (!Strings.isNullOrEmpty(config.getDomainName())) { + instance.addSocket(socket); + } return socket; } catch (IOException e) { From b39440d28909e81cb55a30fae4eeff161af866f0 Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Thu, 27 Mar 2025 09:30:34 -0600 Subject: [PATCH 2/2] fix: Move all logic about socket management o MonitoredCache and add unit test. --- .../com/google/cloud/sql/core/Connector.java | 6 +- .../google/cloud/sql/core/MonitoredCache.java | 24 +- .../cloud/sql/core/MonitoredCacheTest.java | 208 ++++++++++++++++++ 3 files changed, 229 insertions(+), 9 deletions(-) create mode 100644 core/src/test/java/com/google/cloud/sql/core/MonitoredCacheTest.java diff --git a/core/src/main/java/com/google/cloud/sql/core/Connector.java b/core/src/main/java/com/google/cloud/sql/core/Connector.java index 4716e8f1e..ee277b400 100644 --- a/core/src/main/java/com/google/cloud/sql/core/Connector.java +++ b/core/src/main/java/com/google/cloud/sql/core/Connector.java @@ -138,11 +138,7 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException { } logger.debug(String.format("[%s] Connected to instance successfully.", instanceIp)); - // If this connection was opened using a domain name, then store it - // for later in case we need to forcibly close it on failover. - if (!Strings.isNullOrEmpty(config.getDomainName())) { - instance.addSocket(socket); - } + instance.addSocket(socket); return socket; } catch (IOException e) { diff --git a/core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java b/core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java index 069422ab6..990c709cc 100644 --- a/core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java +++ b/core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java @@ -16,15 +16,16 @@ package com.google.cloud.sql.core; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; import java.io.IOException; import java.net.Socket; -import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; -import java.util.List; +import java.util.Set; import java.util.Timer; import java.util.TimerTask; +import java.util.WeakHashMap; import java.util.function.Function; import javax.net.ssl.SSLSocket; import org.slf4j.Logger; @@ -38,7 +39,11 @@ class MonitoredCache implements ConnectionInfoCache { private static final Logger logger = LoggerFactory.getLogger(Connector.class); private final ConnectionInfoCache cache; - private final List sockets = Collections.synchronizedList(new ArrayList<>()); + // Use weak references to hold the open sockets. If a socket is no longer in + // use by the application, the garabage collector will automatically remove + // it from this set. + private final Set sockets = + Collections.synchronizedSet(Collections.newSetFromMap(new WeakHashMap<>())); private final Function resolve; private final TimerTask task; @@ -49,6 +54,8 @@ class MonitoredCache implements ConnectionInfoCache { this.cache = cache; this.resolve = resolve; + // If this was configured with a domain name, start the domain name check + // and socket cleanup periodic task. if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) { long failoverPeriod = cache.getConfig().getConnectorConfig().getFailoverPeriod().toMillis(); this.task = @@ -64,6 +71,11 @@ public void run() { } } + @VisibleForTesting + int getOpenSocketCount() { + return sockets.size(); + } + private void checkDomainName() { // Resolve the domain name again. If it changed, close the sockets try { @@ -149,6 +161,10 @@ public synchronized boolean isClosed() { } synchronized void addSocket(SSLSocket socket) { - sockets.add(socket); + // Only add the socket if this was configured using a domain name, + // and therefore the background socket cleanup task is running. + if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) { + sockets.add(socket); + } } } diff --git a/core/src/test/java/com/google/cloud/sql/core/MonitoredCacheTest.java b/core/src/test/java/com/google/cloud/sql/core/MonitoredCacheTest.java new file mode 100644 index 000000000..47d78eb14 --- /dev/null +++ b/core/src/test/java/com/google/cloud/sql/core/MonitoredCacheTest.java @@ -0,0 +1,208 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.sql.core; + +import com.google.cloud.sql.ConnectorConfig; +import java.io.IOException; +import java.time.Duration; +import java.util.Timer; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Test; + +public class MonitoredCacheTest { + private static final Timer timer = new Timer(true); + + @AfterClass + public static void afterClass() { + timer.cancel(); + } + + @Test + public void testMonitoredCacheHoldsSocketsWithDomainName() { + CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com"); + ConnectionConfig config = + new ConnectionConfig.Builder() + .withCloudSqlInstance("proj:reg:inst") + .withDomainName("db.example.com") + .build(); + MockCache mockCache = new MockCache(config); + + MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name); + MockSslSocket socket = new MockSslSocket(); + cache.addSocket(socket); + Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount()); + cache.close(); + Assert.assertTrue("socket closed", socket.closed); + } + + @Test + public void testMonitoredCachePurgesClosedSockets() throws InterruptedException { + CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com"); + // Purge sockets every 10ms. + ConnectionConfig config = + new ConnectionConfig.Builder() + .withCloudSqlInstance("proj:reg:inst") + .withDomainName("db.example.com") + .withConnectorConfig( + new ConnectorConfig.Builder().withFailoverPeriod(Duration.ofMillis(10)).build()) + .build(); + MockCache mockCache = new MockCache(config); + + MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name); + MockSslSocket socket = new MockSslSocket(); + cache.addSocket(socket); + Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount()); + socket.close(); + Thread.sleep(20); + Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount()); + } + + @Test + public void testMonitoredCacheWithoutDomainNameIgnoresSockets() { + CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst"); + ConnectionConfig config = + new ConnectionConfig.Builder().withCloudSqlInstance("proj:reg:inst").build(); + MockCache mockCache = new MockCache(config); + + MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name); + MockSslSocket socket = new MockSslSocket(); + cache.addSocket(socket); + Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount()); + } + + private static class MockSslSocket extends SSLSocket { + boolean closed; + + @Override + public synchronized boolean isClosed() { + return closed; + } + + @Override + public synchronized void close() { + this.closed = true; + } + + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) {} + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) {} + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {} + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {} + + @Override + public void startHandshake() throws IOException {} + + @Override + public void setUseClientMode(boolean mode) {} + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) {} + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) {} + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) {} + + @Override + public boolean getEnableSessionCreation() { + return false; + } + } + + private static class MockCache implements ConnectionInfoCache { + private final ConnectionConfig config; + + MockCache(ConnectionConfig config) { + this.config = config; + } + + @Override + public ConnectionMetadata getConnectionMetadata(long timeoutMs) { + return null; + } + + @Override + public void forceRefresh() {} + + @Override + public void refreshIfExpired() {} + + @Override + public void close() {} + + @Override + public boolean isClosed() { + return false; + } + + @Override + public ConnectionConfig getConfig() { + return config; + } + } +}