Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@

package com.google.cloud.sql.core;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import java.io.IOException;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Timer;
import java.util.TimerTask;
import java.util.WeakHashMap;
import java.util.function.Function;
import javax.net.ssl.SSLSocket;
import org.slf4j.Logger;
Expand All @@ -38,7 +39,11 @@
class MonitoredCache implements ConnectionInfoCache {
private static final Logger logger = LoggerFactory.getLogger(Connector.class);
private final ConnectionInfoCache cache;
private final List<Socket> sockets = Collections.synchronizedList(new ArrayList<>());
// Use weak references to hold the open sockets. If a socket is no longer in
// use by the application, the garabage collector will automatically remove
// it from this set.
private final Set<Socket> sockets =
Collections.synchronizedSet(Collections.newSetFromMap(new WeakHashMap<>()));
private final Function<ConnectionConfig, CloudSqlInstanceName> resolve;
private final TimerTask task;

Expand All @@ -49,6 +54,8 @@ class MonitoredCache implements ConnectionInfoCache {
this.cache = cache;
this.resolve = resolve;

// If this was configured with a domain name, start the domain name check
// and socket cleanup periodic task.
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
long failoverPeriod = cache.getConfig().getConnectorConfig().getFailoverPeriod().toMillis();
this.task =
Expand All @@ -64,6 +71,11 @@ public void run() {
}
}

@VisibleForTesting
int getOpenSocketCount() {
return sockets.size();
}

private void checkDomainName() {
// Resolve the domain name again. If it changed, close the sockets
try {
Expand Down Expand Up @@ -149,6 +161,10 @@ public synchronized boolean isClosed() {
}

synchronized void addSocket(SSLSocket socket) {
sockets.add(socket);
// Only add the socket if this was configured using a domain name,
// and therefore the background socket cleanup task is running.
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
sockets.add(socket);
}
}
}
208 changes: 208 additions & 0 deletions core/src/test/java/com/google/cloud/sql/core/MonitoredCacheTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.sql.core;

import com.google.cloud.sql.ConnectorConfig;
import java.io.IOException;
import java.time.Duration;
import java.util.Timer;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Test;

public class MonitoredCacheTest {
private static final Timer timer = new Timer(true);

@AfterClass
public static void afterClass() {
timer.cancel();
}

@Test
public void testMonitoredCacheHoldsSocketsWithDomainName() {
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com");
ConnectionConfig config =
new ConnectionConfig.Builder()
.withCloudSqlInstance("proj:reg:inst")
.withDomainName("db.example.com")
.build();
MockCache mockCache = new MockCache(config);

MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
MockSslSocket socket = new MockSslSocket();
cache.addSocket(socket);
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
cache.close();
Assert.assertTrue("socket closed", socket.closed);
}

@Test
public void testMonitoredCachePurgesClosedSockets() throws InterruptedException {
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com");
// Purge sockets every 10ms.
ConnectionConfig config =
new ConnectionConfig.Builder()
.withCloudSqlInstance("proj:reg:inst")
.withDomainName("db.example.com")
.withConnectorConfig(
new ConnectorConfig.Builder().withFailoverPeriod(Duration.ofMillis(10)).build())
.build();
MockCache mockCache = new MockCache(config);

MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
MockSslSocket socket = new MockSslSocket();
cache.addSocket(socket);
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
socket.close();
Thread.sleep(20);
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
}

@Test
public void testMonitoredCacheWithoutDomainNameIgnoresSockets() {
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst");
ConnectionConfig config =
new ConnectionConfig.Builder().withCloudSqlInstance("proj:reg:inst").build();
MockCache mockCache = new MockCache(config);

MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
MockSslSocket socket = new MockSslSocket();
cache.addSocket(socket);
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
}

private static class MockSslSocket extends SSLSocket {
boolean closed;

@Override
public synchronized boolean isClosed() {
return closed;
}

@Override
public synchronized void close() {
this.closed = true;
}

@Override
public String[] getSupportedCipherSuites() {
return new String[0];
}

@Override
public String[] getEnabledCipherSuites() {
return new String[0];
}

@Override
public void setEnabledCipherSuites(String[] suites) {}

@Override
public String[] getSupportedProtocols() {
return new String[0];
}

@Override
public String[] getEnabledProtocols() {
return new String[0];
}

@Override
public void setEnabledProtocols(String[] protocols) {}

@Override
public SSLSession getSession() {
return null;
}

@Override
public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {}

@Override
public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {}

@Override
public void startHandshake() throws IOException {}

@Override
public void setUseClientMode(boolean mode) {}

@Override
public boolean getUseClientMode() {
return false;
}

@Override
public void setNeedClientAuth(boolean need) {}

@Override
public boolean getNeedClientAuth() {
return false;
}

@Override
public void setWantClientAuth(boolean want) {}

@Override
public boolean getWantClientAuth() {
return false;
}

@Override
public void setEnableSessionCreation(boolean flag) {}

@Override
public boolean getEnableSessionCreation() {
return false;
}
}

private static class MockCache implements ConnectionInfoCache {
private final ConnectionConfig config;

MockCache(ConnectionConfig config) {
this.config = config;
}

@Override
public ConnectionMetadata getConnectionMetadata(long timeoutMs) {
return null;
}

@Override
public void forceRefresh() {}

@Override
public void refreshIfExpired() {}

@Override
public void close() {}

@Override
public boolean isClosed() {
return false;
}

@Override
public ConnectionConfig getConfig() {
return config;
}
}
}
Loading