Skip to content

Commit 959c877

Browse files
committed
fix: Move all logic about socket management o MonitoredCache and add unit test.
1 parent 47a0fc3 commit 959c877

3 files changed

Lines changed: 212 additions & 9 deletions

File tree

core/src/main/java/com/google/cloud/sql/core/Connector.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,7 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {
138138
}
139139

140140
logger.debug(String.format("[%s] Connected to instance successfully.", instanceIp));
141-
// If this connection was opened using a domain name, then store it
142-
// for later in case we need to forcibly close it on failover.
143-
if (!Strings.isNullOrEmpty(config.getDomainName())) {
144-
instance.addSocket(socket);
145-
}
141+
instance.addSocket(socket);
146142

147143
return socket;
148144
} catch (IOException e) {

core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616

1717
package com.google.cloud.sql.core;
1818

19+
import com.google.common.annotations.VisibleForTesting;
1920
import com.google.common.base.Strings;
2021
import java.io.IOException;
2122
import java.net.Socket;
22-
import java.util.ArrayList;
2323
import java.util.Collections;
2424
import java.util.Iterator;
25-
import java.util.List;
25+
import java.util.Set;
2626
import java.util.Timer;
2727
import java.util.TimerTask;
28+
import java.util.WeakHashMap;
2829
import java.util.function.Function;
2930
import javax.net.ssl.SSLSocket;
3031
import org.slf4j.Logger;
@@ -38,7 +39,11 @@
3839
class MonitoredCache implements ConnectionInfoCache {
3940
private static final Logger logger = LoggerFactory.getLogger(Connector.class);
4041
private final ConnectionInfoCache cache;
41-
private final List<Socket> sockets = Collections.synchronizedList(new ArrayList<>());
42+
// Use weak references to hold the open sockets. If a socket is no longer in
43+
// use by the application, the garabage collector will automatically remove
44+
// it from this set.
45+
private final Set<Socket> sockets =
46+
Collections.synchronizedSet(Collections.newSetFromMap(new WeakHashMap<>()));
4247
private final Function<ConnectionConfig, CloudSqlInstanceName> resolve;
4348
private final TimerTask task;
4449

@@ -49,6 +54,8 @@ class MonitoredCache implements ConnectionInfoCache {
4954
this.cache = cache;
5055
this.resolve = resolve;
5156

57+
// If this was configured with a domain name, start the domain name check
58+
// and socket cleanup periodic task.
5259
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
5360
long failoverPeriod = cache.getConfig().getConnectorConfig().getFailoverPeriod().toMillis();
5461
this.task =
@@ -64,6 +71,11 @@ public void run() {
6471
}
6572
}
6673

74+
@VisibleForTesting
75+
int getOpenSocketCount() {
76+
return sockets.size();
77+
}
78+
6779
private void checkDomainName() {
6880
// Resolve the domain name again. If it changed, close the sockets
6981
try {
@@ -149,6 +161,10 @@ public synchronized boolean isClosed() {
149161
}
150162

151163
synchronized void addSocket(SSLSocket socket) {
152-
sockets.add(socket);
164+
// Only add the socket if this was configured using a domain name,
165+
// and therefore the background socket cleanup task is running.
166+
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
167+
sockets.add(socket);
168+
}
153169
}
154170
}
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package com.google.cloud.sql.core;
2+
3+
import com.google.cloud.sql.ConnectorConfig;
4+
import java.io.IOException;
5+
import java.time.Duration;
6+
import java.util.Timer;
7+
import javax.net.ssl.HandshakeCompletedListener;
8+
import javax.net.ssl.SSLSession;
9+
import javax.net.ssl.SSLSocket;
10+
import org.junit.AfterClass;
11+
import org.junit.Assert;
12+
import org.junit.Test;
13+
14+
public class MonitoredCacheTest {
15+
private static final Timer timer = new Timer(true);
16+
17+
@AfterClass
18+
public static void afterClass() {
19+
timer.cancel();
20+
}
21+
22+
@Test
23+
public void testMonitoredCacheHoldsSocketsWithDomainName() {
24+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com");
25+
ConnectionConfig config =
26+
new ConnectionConfig.Builder()
27+
.withCloudSqlInstance("proj:reg:inst")
28+
.withDomainName("db.example.com")
29+
.build();
30+
MockCache mockCache = new MockCache(config);
31+
32+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
33+
MockSslSocket socket = new MockSslSocket();
34+
cache.addSocket(socket);
35+
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
36+
cache.close();
37+
Assert.assertTrue("socket closed", socket.closed);
38+
}
39+
40+
@Test
41+
public void testMonitoredCachePurgesClosedSockets() throws InterruptedException {
42+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com");
43+
// Purge sockets every 10ms.
44+
ConnectionConfig config =
45+
new ConnectionConfig.Builder()
46+
.withCloudSqlInstance("proj:reg:inst")
47+
.withDomainName("db.example.com")
48+
.withConnectorConfig(
49+
new ConnectorConfig.Builder().withFailoverPeriod(Duration.ofMillis(10)).build())
50+
.build();
51+
MockCache mockCache = new MockCache(config);
52+
53+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
54+
MockSslSocket socket = new MockSslSocket();
55+
cache.addSocket(socket);
56+
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
57+
socket.close();
58+
Thread.sleep(20);
59+
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
60+
}
61+
62+
@Test
63+
public void testMonitoredCacheWithoutDomainNameIgnoresSockets() {
64+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst");
65+
ConnectionConfig config =
66+
new ConnectionConfig.Builder().withCloudSqlInstance("proj:reg:inst").build();
67+
MockCache mockCache = new MockCache(config);
68+
69+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
70+
MockSslSocket socket = new MockSslSocket();
71+
cache.addSocket(socket);
72+
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
73+
}
74+
75+
private static class MockSslSocket extends SSLSocket {
76+
boolean closed;
77+
78+
@Override
79+
public boolean isClosed() {
80+
return closed;
81+
}
82+
83+
public void close() {
84+
this.closed = true;
85+
}
86+
87+
@Override
88+
public String[] getSupportedCipherSuites() {
89+
return new String[0];
90+
}
91+
92+
@Override
93+
public String[] getEnabledCipherSuites() {
94+
return new String[0];
95+
}
96+
97+
@Override
98+
public void setEnabledCipherSuites(String[] suites) {}
99+
100+
@Override
101+
public String[] getSupportedProtocols() {
102+
return new String[0];
103+
}
104+
105+
@Override
106+
public String[] getEnabledProtocols() {
107+
return new String[0];
108+
}
109+
110+
@Override
111+
public void setEnabledProtocols(String[] protocols) {}
112+
113+
@Override
114+
public SSLSession getSession() {
115+
return null;
116+
}
117+
118+
@Override
119+
public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {}
120+
121+
@Override
122+
public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {}
123+
124+
@Override
125+
public void startHandshake() throws IOException {}
126+
127+
@Override
128+
public void setUseClientMode(boolean mode) {}
129+
130+
@Override
131+
public boolean getUseClientMode() {
132+
return false;
133+
}
134+
135+
@Override
136+
public void setNeedClientAuth(boolean need) {}
137+
138+
@Override
139+
public boolean getNeedClientAuth() {
140+
return false;
141+
}
142+
143+
@Override
144+
public void setWantClientAuth(boolean want) {}
145+
146+
@Override
147+
public boolean getWantClientAuth() {
148+
return false;
149+
}
150+
151+
@Override
152+
public void setEnableSessionCreation(boolean flag) {}
153+
154+
@Override
155+
public boolean getEnableSessionCreation() {
156+
return false;
157+
}
158+
}
159+
160+
class MockCache implements ConnectionInfoCache {
161+
private final ConnectionConfig config;
162+
163+
MockCache(ConnectionConfig config) {
164+
this.config = config;
165+
}
166+
167+
@Override
168+
public ConnectionMetadata getConnectionMetadata(long timeoutMs) {
169+
return null;
170+
}
171+
172+
@Override
173+
public void forceRefresh() {}
174+
175+
@Override
176+
public void refreshIfExpired() {}
177+
178+
@Override
179+
public void close() {}
180+
181+
@Override
182+
public boolean isClosed() {
183+
return false;
184+
}
185+
186+
@Override
187+
public ConnectionConfig getConfig() {
188+
return config;
189+
}
190+
}
191+
}

0 commit comments

Comments
 (0)