Skip to content

Commit b39440d

Browse files
committed
fix: Move all logic about socket management o MonitoredCache and add unit test.
1 parent 47a0fc3 commit b39440d

File tree

3 files changed

+229
-9
lines changed

3 files changed

+229
-9
lines changed

core/src/main/java/com/google/cloud/sql/core/Connector.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,7 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {
138138
}
139139

140140
logger.debug(String.format("[%s] Connected to instance successfully.", instanceIp));
141-
// If this connection was opened using a domain name, then store it
142-
// for later in case we need to forcibly close it on failover.
143-
if (!Strings.isNullOrEmpty(config.getDomainName())) {
144-
instance.addSocket(socket);
145-
}
141+
instance.addSocket(socket);
146142

147143
return socket;
148144
} catch (IOException e) {

core/src/main/java/com/google/cloud/sql/core/MonitoredCache.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616

1717
package com.google.cloud.sql.core;
1818

19+
import com.google.common.annotations.VisibleForTesting;
1920
import com.google.common.base.Strings;
2021
import java.io.IOException;
2122
import java.net.Socket;
22-
import java.util.ArrayList;
2323
import java.util.Collections;
2424
import java.util.Iterator;
25-
import java.util.List;
25+
import java.util.Set;
2626
import java.util.Timer;
2727
import java.util.TimerTask;
28+
import java.util.WeakHashMap;
2829
import java.util.function.Function;
2930
import javax.net.ssl.SSLSocket;
3031
import org.slf4j.Logger;
@@ -38,7 +39,11 @@
3839
class MonitoredCache implements ConnectionInfoCache {
3940
private static final Logger logger = LoggerFactory.getLogger(Connector.class);
4041
private final ConnectionInfoCache cache;
41-
private final List<Socket> sockets = Collections.synchronizedList(new ArrayList<>());
42+
// Use weak references to hold the open sockets. If a socket is no longer in
43+
// use by the application, the garabage collector will automatically remove
44+
// it from this set.
45+
private final Set<Socket> sockets =
46+
Collections.synchronizedSet(Collections.newSetFromMap(new WeakHashMap<>()));
4247
private final Function<ConnectionConfig, CloudSqlInstanceName> resolve;
4348
private final TimerTask task;
4449

@@ -49,6 +54,8 @@ class MonitoredCache implements ConnectionInfoCache {
4954
this.cache = cache;
5055
this.resolve = resolve;
5156

57+
// If this was configured with a domain name, start the domain name check
58+
// and socket cleanup periodic task.
5259
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
5360
long failoverPeriod = cache.getConfig().getConnectorConfig().getFailoverPeriod().toMillis();
5461
this.task =
@@ -64,6 +71,11 @@ public void run() {
6471
}
6572
}
6673

74+
@VisibleForTesting
75+
int getOpenSocketCount() {
76+
return sockets.size();
77+
}
78+
6779
private void checkDomainName() {
6880
// Resolve the domain name again. If it changed, close the sockets
6981
try {
@@ -149,6 +161,10 @@ public synchronized boolean isClosed() {
149161
}
150162

151163
synchronized void addSocket(SSLSocket socket) {
152-
sockets.add(socket);
164+
// Only add the socket if this was configured using a domain name,
165+
// and therefore the background socket cleanup task is running.
166+
if (!Strings.isNullOrEmpty(cache.getConfig().getDomainName())) {
167+
sockets.add(socket);
168+
}
153169
}
154170
}
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.sql.core;
18+
19+
import com.google.cloud.sql.ConnectorConfig;
20+
import java.io.IOException;
21+
import java.time.Duration;
22+
import java.util.Timer;
23+
import javax.net.ssl.HandshakeCompletedListener;
24+
import javax.net.ssl.SSLSession;
25+
import javax.net.ssl.SSLSocket;
26+
import org.junit.AfterClass;
27+
import org.junit.Assert;
28+
import org.junit.Test;
29+
30+
public class MonitoredCacheTest {
31+
private static final Timer timer = new Timer(true);
32+
33+
@AfterClass
34+
public static void afterClass() {
35+
timer.cancel();
36+
}
37+
38+
@Test
39+
public void testMonitoredCacheHoldsSocketsWithDomainName() {
40+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com");
41+
ConnectionConfig config =
42+
new ConnectionConfig.Builder()
43+
.withCloudSqlInstance("proj:reg:inst")
44+
.withDomainName("db.example.com")
45+
.build();
46+
MockCache mockCache = new MockCache(config);
47+
48+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
49+
MockSslSocket socket = new MockSslSocket();
50+
cache.addSocket(socket);
51+
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
52+
cache.close();
53+
Assert.assertTrue("socket closed", socket.closed);
54+
}
55+
56+
@Test
57+
public void testMonitoredCachePurgesClosedSockets() throws InterruptedException {
58+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst", "db.example.com");
59+
// Purge sockets every 10ms.
60+
ConnectionConfig config =
61+
new ConnectionConfig.Builder()
62+
.withCloudSqlInstance("proj:reg:inst")
63+
.withDomainName("db.example.com")
64+
.withConnectorConfig(
65+
new ConnectorConfig.Builder().withFailoverPeriod(Duration.ofMillis(10)).build())
66+
.build();
67+
MockCache mockCache = new MockCache(config);
68+
69+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
70+
MockSslSocket socket = new MockSslSocket();
71+
cache.addSocket(socket);
72+
Assert.assertEquals("1 socket in cache", 1, cache.getOpenSocketCount());
73+
socket.close();
74+
Thread.sleep(20);
75+
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
76+
}
77+
78+
@Test
79+
public void testMonitoredCacheWithoutDomainNameIgnoresSockets() {
80+
CloudSqlInstanceName name = new CloudSqlInstanceName("proj:reg:inst");
81+
ConnectionConfig config =
82+
new ConnectionConfig.Builder().withCloudSqlInstance("proj:reg:inst").build();
83+
MockCache mockCache = new MockCache(config);
84+
85+
MonitoredCache cache = new MonitoredCache(mockCache, timer, connectionConfig -> name);
86+
MockSslSocket socket = new MockSslSocket();
87+
cache.addSocket(socket);
88+
Assert.assertEquals("0 socket in cache", 0, cache.getOpenSocketCount());
89+
}
90+
91+
private static class MockSslSocket extends SSLSocket {
92+
boolean closed;
93+
94+
@Override
95+
public synchronized boolean isClosed() {
96+
return closed;
97+
}
98+
99+
@Override
100+
public synchronized void close() {
101+
this.closed = true;
102+
}
103+
104+
@Override
105+
public String[] getSupportedCipherSuites() {
106+
return new String[0];
107+
}
108+
109+
@Override
110+
public String[] getEnabledCipherSuites() {
111+
return new String[0];
112+
}
113+
114+
@Override
115+
public void setEnabledCipherSuites(String[] suites) {}
116+
117+
@Override
118+
public String[] getSupportedProtocols() {
119+
return new String[0];
120+
}
121+
122+
@Override
123+
public String[] getEnabledProtocols() {
124+
return new String[0];
125+
}
126+
127+
@Override
128+
public void setEnabledProtocols(String[] protocols) {}
129+
130+
@Override
131+
public SSLSession getSession() {
132+
return null;
133+
}
134+
135+
@Override
136+
public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {}
137+
138+
@Override
139+
public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {}
140+
141+
@Override
142+
public void startHandshake() throws IOException {}
143+
144+
@Override
145+
public void setUseClientMode(boolean mode) {}
146+
147+
@Override
148+
public boolean getUseClientMode() {
149+
return false;
150+
}
151+
152+
@Override
153+
public void setNeedClientAuth(boolean need) {}
154+
155+
@Override
156+
public boolean getNeedClientAuth() {
157+
return false;
158+
}
159+
160+
@Override
161+
public void setWantClientAuth(boolean want) {}
162+
163+
@Override
164+
public boolean getWantClientAuth() {
165+
return false;
166+
}
167+
168+
@Override
169+
public void setEnableSessionCreation(boolean flag) {}
170+
171+
@Override
172+
public boolean getEnableSessionCreation() {
173+
return false;
174+
}
175+
}
176+
177+
private static class MockCache implements ConnectionInfoCache {
178+
private final ConnectionConfig config;
179+
180+
MockCache(ConnectionConfig config) {
181+
this.config = config;
182+
}
183+
184+
@Override
185+
public ConnectionMetadata getConnectionMetadata(long timeoutMs) {
186+
return null;
187+
}
188+
189+
@Override
190+
public void forceRefresh() {}
191+
192+
@Override
193+
public void refreshIfExpired() {}
194+
195+
@Override
196+
public void close() {}
197+
198+
@Override
199+
public boolean isClosed() {
200+
return false;
201+
}
202+
203+
@Override
204+
public ConnectionConfig getConfig() {
205+
return config;
206+
}
207+
}
208+
}

0 commit comments

Comments
 (0)