Skip to content

Commit 97bc92c

Browse files
committed
fix: Move all logic about socket management o MonitoredCache and add unit test.
1 parent 47a0fc3 commit 97bc92c

File tree

3 files changed

+238
-9
lines changed

3 files changed

+238
-9
lines changed

core/src/main/java/com/google/cloud/sql/core/Connector.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,7 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {
138138
}
139139

140140
logger.debug(String.format("[%s] Connected to instance successfully.", instanceIp));
141-
// If this connection was opened using a domain name, then store it
142-
// for later in case we need to forcibly close it on failover.
143-
if (!Strings.isNullOrEmpty(config.getDomainName())) {
144-
instance.addSocket(socket);
145-
}
141+
instance.addSocket(socket);
146142

147143
return socket;
148144
} catch (IOException e) {

core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616

1717
package com.google.cloud.sql.core;
1818

19+
import com.google.common.annotations.VisibleForTesting;
1920
import com.google.common.base.Strings;
2021
import java.io.IOException;
2122
import java.net.Socket;
22-
import java.util.ArrayList;
2323
import java.util.Collections;
2424
import java.util.Iterator;
25-
import java.util.List;
25+
import java.util.Set;
2626
import java.util.Timer;
2727
import java.util.TimerTask;
28+
import java.util.WeakHashMap;
2829
import java.util.function.Function;
2930
import javax.net.ssl.SSLSocket;
3031
import org.slf4j.Logger;
@@ -38,7 +39,11 @@
3839
class MonitoredCache implements ConnectionInfoCache {
3940
private static final Logger logger = LoggerFactory.getLogger(Connector.class);
4041
private final ConnectionInfoCache cache;
41-
private final List<Socket> sockets = Collections.synchronizedList(new ArrayList<>());
42+
// Use weak references to hold the open sockets. If a socket is no longer in
43+
// use by the application, the garabage collector will automatically remove
44+
// it from this set.
45+
private final Set<Socket> sockets = Collections.synchronizedSet(
46+
Collections.newSetFromMap(new WeakHashMap<>()));
4247
private final Function<ConnectionConfig, CloudSqlInstanceName> resolve;
4348
private final TimerTask task;
4449

@@ -49,6 +54,8 @@ class MonitoredCache implements ConnectionInfoCache {
4954
this.cache = cache;
5055
this.resolve = resolve;
5156

57+
// If this was configured with a domain name, start the domain name check
58+
// and socket cleanup periodic task.
5259
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
5360
long failoverPeriod = cache.getConfig().getConnectorConfig().getFailoverPeriod().toMillis();
5461
this.task =
@@ -64,6 +71,11 @@ public void run() {
6471
}
6572
}
6673

74+
@VisibleForTesting
75+
int getOpenSocketCount(){
76+
return sockets.size();
77+
}
78+
6779
private void checkDomainName() {
6880
// Resolve the domain name again. If it changed, close the sockets
6981
try {
@@ -149,6 +161,10 @@ public synchronized boolean isClosed() {
149161
}
150162

151163
synchronized void addSocket(SSLSocket socket) {
152-
sockets.add(socket);
164+
// Only add the socket if this was configured using a domain name,
165+
// and therefore the background socket cleanup task is running.
166+
if(!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
167+
sockets.add(socket);
168+
}
153169
}
154170
}
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
package com.google.cloud.sql.core;
2+
3+
import com.google.cloud.sql.ConnectorConfig;
4+
import java.io.IOException;
5+
import java.time.Duration;
6+
import java.util.Timer;
7+
import javax.net.ssl.HandshakeCompletedListener;
8+
import javax.net.ssl.SSLSession;
9+
import javax.net.ssl.SSLSocket;
10+
import org.junit.AfterClass;
11+
import org.junit.Assert;
12+
import org.junit.Test;
13+
14+
public class MonitoredCacheTest {
15+
private static final Timer timer = new Timer(true);
16+
17+
@AfterClass
18+
public static void afterClass(){
19+
timer.cancel();
20+
}
21+
22+
@Test
23+
public void testMonitoredCacheHoldsSocketsWithDomainName() {
24+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst","db.example.com");
25+
ConnectionConfig config =
26+
new ConnectionConfig.Builder()
27+
.withCloudSqlInstance("proj:reg:inst")
28+
.withDomainName("db.example.com")
29+
.build();
30+
MockCache mockCache = new MockCache(config);
31+
32+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
33+
MockSslSocket socket = new MockSslSocket();
34+
cache.addSocket(socket);
35+
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
36+
cache.close();
37+
Assert.assertTrue("socket closed", socket.closed);
38+
}
39+
40+
@Test
41+
public void testMonitoredCachePurgesClosedSockets() throws InterruptedException {
42+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst","db.example.com");
43+
// Purge sockets every 10ms.
44+
ConnectionConfig config =
45+
new ConnectionConfig.Builder()
46+
.withCloudSqlInstance("proj:reg:inst")
47+
.withDomainName("db.example.com")
48+
.withConnectorConfig(
49+
new ConnectorConfig.Builder().withFailoverPeriod(Duration.ofMillis(10)).build())
50+
.build();
51+
MockCache mockCache = new MockCache(config);
52+
53+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
54+
MockSslSocket socket = new MockSslSocket();
55+
cache.addSocket(socket);
56+
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
57+
socket.close();
58+
Thread.sleep(20);
59+
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
60+
}
61+
62+
@Test
63+
public void testMonitoredCacheWithoutDomainNameIgnoresSockets() {
64+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst");
65+
ConnectionConfig config =
66+
new ConnectionConfig.Builder()
67+
.withCloudSqlInstance("proj:reg:inst")
68+
.build();
69+
MockCache mockCache = new MockCache(config);
70+
71+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
72+
MockSslSocket socket = new MockSslSocket();
73+
cache.addSocket(socket);
74+
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
75+
}
76+
77+
private static class MockSslSocket extends SSLSocket {
78+
boolean closed;
79+
80+
@Override
81+
public boolean isClosed() {
82+
return closed;
83+
}
84+
85+
public void close(){
86+
this.closed = true;
87+
}
88+
89+
@Override
90+
public String[] getSupportedCipherSuites() {
91+
return new String[0];
92+
}
93+
94+
@Override
95+
public String[] getEnabledCipherSuites() {
96+
return new String[0];
97+
}
98+
99+
@Override
100+
public void setEnabledCipherSuites(String[] suites) {
101+
102+
}
103+
104+
@Override
105+
public String[] getSupportedProtocols() {
106+
return new String[0];
107+
}
108+
109+
@Override
110+
public String[] getEnabledProtocols() {
111+
return new String[0];
112+
}
113+
114+
@Override
115+
public void setEnabledProtocols(String[] protocols) {
116+
117+
}
118+
119+
@Override
120+
public SSLSession getSession() {
121+
return null;
122+
}
123+
124+
@Override
125+
public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {
126+
127+
}
128+
129+
@Override
130+
public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {
131+
132+
}
133+
134+
@Override
135+
public void startHandshake() throws IOException {
136+
137+
}
138+
139+
@Override
140+
public void setUseClientMode(boolean mode) {
141+
142+
}
143+
144+
@Override
145+
public boolean getUseClientMode() {
146+
return false;
147+
}
148+
149+
@Override
150+
public void setNeedClientAuth(boolean need) {
151+
152+
}
153+
154+
@Override
155+
public boolean getNeedClientAuth() {
156+
return false;
157+
}
158+
159+
@Override
160+
public void setWantClientAuth(boolean want) {
161+
162+
}
163+
164+
@Override
165+
public boolean getWantClientAuth() {
166+
return false;
167+
}
168+
169+
@Override
170+
public void setEnableSessionCreation(boolean flag) {
171+
172+
}
173+
174+
@Override
175+
public boolean getEnableSessionCreation() {
176+
return false;
177+
}
178+
}
179+
180+
class MockCache implements ConnectionInfoCache {
181+
private final ConnectionConfig config;
182+
183+
MockCache(ConnectionConfig config) {
184+
this.config = config;
185+
}
186+
187+
@Override
188+
public ConnectionMetadata getConnectionMetadata(long timeoutMs) {
189+
return null;
190+
}
191+
192+
@Override
193+
public void forceRefresh() {
194+
195+
}
196+
197+
@Override
198+
public void refreshIfExpired() {
199+
200+
}
201+
202+
@Override
203+
public void close() {
204+
205+
}
206+
207+
@Override
208+
public boolean isClosed() {
209+
return false;
210+
}
211+
212+
@Override
213+
public ConnectionConfig getConfig() {
214+
return config;
215+
}
216+
}
217+
}

0 commit comments

Comments
 (0)