diff --git a/build.sh b/build.sh index 17c2c6c6e..f1a1c65ca 100755 --- a/build.sh +++ b/build.sh @@ -41,8 +41,12 @@ function test() { if [[ "$(uname -s)" == "Darwin" ]]; then echo "macOS detected. Setting up IP aliases for tests." echo "You may be prompted for your password to run sudo." - sudo ifconfig lo0 alias 127.0.0.2 up - sudo ifconfig lo0 alias 127.0.0.3 up + if ! ifconfig lo0 | grep -q 127.0.0.2 ; then + sudo ifconfig lo0 alias 127.0.0.2 up + fi + if ! ifconfig lo0 | grep -q 127.0.0.3 ; then + sudo ifconfig lo0 alias 127.0.0.3 up + fi fi $mvn_cmd -P coverage test } @@ -93,7 +97,7 @@ function write_e2e_env(){ secret_vars=( MYSQL_CONNECTION_NAME=MYSQL_CONNECTION_NAME MYSQL_USER=MYSQL_USER - MYSQL_USER_IAM=MYSQL_USER_IAM_GO + IMPERSONATED_USER=IMPERSONATED_USER MYSQL_PASS=MYSQL_PASS MYSQL_DB=MYSQL_DB MYSQL_MCP_CONNECTION_NAME=MYSQL_MCP_CONNECTION_NAME @@ -107,8 +111,8 @@ function write_e2e_env(){ POSTGRES_CAS_PASS=POSTGRES_CAS_PASS POSTGRES_CUSTOMER_CAS_CONNECTION_NAME=POSTGRES_CUSTOMER_CAS_CONNECTION_NAME POSTGRES_CUSTOMER_CAS_PASS=POSTGRES_CUSTOMER_CAS_PASS - POSTGRES_CUSTOMER_CAS_DOMAIN_NAME=POSTGRES_CUSTOMER_CAS_DOMAIN_NAME - POSTGRES_CUSTOMER_CAS_INVALID_DOMAIN_NAME=POSTGRES_CUSTOMER_CAS_INVALID_DOMAIN_NAME + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME=POSTGRES_CUSTOMER_CAS_DOMAIN_NAME + POSTGRES_CUSTOMER_CAS_PASS_INVALID_DOMAIN_NAME=POSTGRES_CUSTOMER_CAS_INVALID_DOMAIN_NAME POSTGRES_MCP_CONNECTION_NAME=POSTGRES_MCP_CONNECTION_NAME POSTGRES_MCP_PASS=POSTGRES_MCP_PASS SQLSERVER_CONNECTION_NAME=SQLSERVER_CONNECTION_NAME @@ -133,6 +137,10 @@ function write_e2e_env(){ val=$(gcloud secrets versions access latest --project "$TEST_PROJECT" --secret="$secret_name") echo "export $env_var_name='$val'" done + + echo "export MYSQL_IAM_USER='$(whoami)'" + echo "export POSTGRES_IAM_USER='$(whoami)@google.com'" + } > "$outfile" } diff --git a/core/src/main/java/com/google/cloud/sql/core/Connector.java b/core/src/main/java/com/google/cloud/sql/core/Connector.java index af4b29812..088049bd0 100644 --- a/core/src/main/java/com/google/cloud/sql/core/Connector.java +++ b/core/src/main/java/com/google/cloud/sql/core/Connector.java @@ -24,9 +24,12 @@ import com.google.common.util.concurrent.ListeningScheduledExecutorService; import java.io.File; import java.io.IOException; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; +import java.net.UnknownHostException; import java.security.KeyPair; +import java.util.List; import java.util.Timer; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; @@ -52,6 +55,7 @@ class Connector { private final ConnectorConfig config; private final InstanceConnectionNameResolver instanceNameResolver; + private final DnsResolver dnsResolver; private final Timer instanceNameResolverTimer; private final ProtocolHandler mdxProtocolHandler; @@ -65,9 +69,9 @@ class Connector { long refreshTimeoutMs, int serverProxyPort, InstanceConnectionNameResolver instanceNameResolver, + DnsResolver dnsResolver, ProtocolHandler mdxProtocolHandler) { this.config = config; - this.adminApi = connectionInfoRepositoryFactory.create(instanceCredentialFactory.create(), config); this.instanceCredentialFactory = instanceCredentialFactory; @@ -76,6 +80,7 @@ class Connector { this.minRefreshDelayMs = minRefreshDelayMs; this.serverProxyPort = serverProxyPort; this.instanceNameResolver = instanceNameResolver; + this.dnsResolver = dnsResolver; this.instanceNameResolverTimer = new Timer("InstanceNameResolverTimer", true); this.mdxProtocolHandler = mdxProtocolHandler; } @@ -125,6 +130,40 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException { try { ConnectionMetadata metadata = instance.getConnectionMetadata(timeoutMs); String instanceIp = metadata.getPreferredIpAddress(); + + // If a domain name was used to connect, resolve it to an IP address + if (!Strings.isNullOrEmpty(instance.getConfig().getDomainName())) { + try { + List addrs = dnsResolver.resolveHost(instance.getConfig().getDomainName()); + if (addrs != null && !addrs.isEmpty()) { + logger.debug( + String.format( + "[%s] custom DNS name %s resolved to %s, using it to connect", + instance.getConfig().getCloudSqlInstance(), + instance.getConfig().getDomainName(), + addrs.get(0).getHostAddress())); + instanceIp = addrs.get(0).getHostAddress(); + } else { + logger.debug( + String.format( + "[%s] custom DNS name %s resolved but returned no entries, using %s from" + + " instance metadata", + instance.getConfig().getCloudSqlInstance(), + instance.getConfig().getDomainName(), + instanceIp)); + } + } catch (UnknownHostException e) { + logger.debug( + String.format( + "[%s] custom DNS name %s did not resolve to an IP address: %s, using %s from" + + " instance metadata", + instance.getConfig().getCloudSqlInstance(), + instance.getConfig().getDomainName(), + e.getMessage(), + instanceIp)); + } + } + logger.debug(String.format("[%s] Connecting to instance.", instanceIp)); SSLSocket socket = (SSLSocket) metadata.getSslContext().getSocketFactory().createSocket(); diff --git a/core/src/main/java/com/google/cloud/sql/core/DnsJavaResolver.java b/core/src/main/java/com/google/cloud/sql/core/DnsJavaResolver.java index 044d12f5f..6c961bf9e 100644 --- a/core/src/main/java/com/google/cloud/sql/core/DnsJavaResolver.java +++ b/core/src/main/java/com/google/cloud/sql/core/DnsJavaResolver.java @@ -16,11 +16,15 @@ package com.google.cloud.sql.core; +import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.stream.Collectors; import javax.naming.NameNotFoundException; +import org.xbill.DNS.ARecord; import org.xbill.DNS.Lookup; import org.xbill.DNS.Record; import org.xbill.DNS.SimpleResolver; @@ -105,4 +109,44 @@ public Collection resolveTxt(String domainName) throws NameNotFoundExcep throw new RuntimeException("Invalid domain name format: " + domainName, e); } } + + /** + * Resolve an A record. + * + * @param hostName the hostname to look up + * @return the resolved IP addresses + * @throws UnknownHostException if no records are found. + */ + @Override + public List resolveHost(String hostName) throws UnknownHostException { + try { + Lookup lookup = new Lookup(hostName, Type.A); + if (this.resolver != null) { + lookup.setResolver(this.resolver); + } + lookup.run(); + + int resultCode = lookup.getResult(); + if (resultCode == Lookup.HOST_NOT_FOUND) { + throw new UnknownHostException("DNS record not found for " + hostName); + } + if (resultCode != Lookup.SUCCESSFUL) { + throw new UnknownHostException( + "DNS lookup failed for " + hostName + ": " + lookup.getErrorString()); + } + + Record[] records = lookup.getAnswers(); + if (records == null || records.length == 0) { + return Collections.emptyList(); + } + + return Arrays.stream(records) + .map(r -> (ARecord) r) + .map(ARecord::getAddress) + .collect(Collectors.toList()); + + } catch (TextParseException e) { + throw new UnknownHostException("Invalid domain name format: " + hostName); + } + } } diff --git a/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java b/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java index 65b85c939..eda136138 100644 --- a/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java +++ b/core/src/main/java/com/google/cloud/sql/core/DnsResolver.java @@ -16,10 +16,15 @@ package com.google.cloud.sql.core; +import java.net.InetAddress; +import java.net.UnknownHostException; import java.util.Collection; +import java.util.List; import javax.naming.NameNotFoundException; /** Wraps the Java DNS API. */ interface DnsResolver { Collection resolveTxt(String domainName) throws NameNotFoundException; + + List resolveHost(String hostName) throws UnknownHostException; } diff --git a/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java b/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java index e652f30a7..b2e5d5515 100644 --- a/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java +++ b/core/src/main/java/com/google/cloud/sql/core/InternalConnectorRegistry.java @@ -339,6 +339,7 @@ private Connector createConnector(ConnectorConfig config) { connectTimeoutMs, serverProxyPort, new DnsInstanceConnectionNameResolver(new DnsJavaResolver()), + new DnsJavaResolver(), this.mdxProtocolHandler); } diff --git a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java index 1e8d0ca51..e182d1cf8 100644 --- a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java @@ -31,6 +31,7 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; +import java.net.InetAddress; import java.net.Socket; import java.nio.file.Files; import java.nio.file.Path; @@ -40,7 +41,10 @@ import java.time.Instant; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; +import java.util.List; import java.util.Locale; +import java.util.Map; import javax.naming.NameNotFoundException; import javax.net.ssl.SSLHandshakeException; import org.junit.After; @@ -283,6 +287,7 @@ public void create_successfulPrivateConnectionWhenDomainNameValueChanges() TEST_MAX_REFRESH_MS, port, new DnsInstanceConnectionNameResolver(resolver), + resolver, new ProtocolHandler("test")); // Open socket to initial instance @@ -334,6 +339,7 @@ public void create_refreshConnectorWhenDomainNameValueChanges() TEST_MAX_REFRESH_MS, port, new DnsInstanceConnectionNameResolver(resolver), + resolver, new ProtocolHandler("test")); // Open socket to initial instance @@ -388,6 +394,7 @@ public void create_noChangeWhenDomainNameFailsToResolve() TEST_MAX_REFRESH_MS, port, new DnsInstanceConnectionNameResolver(resolver), + resolver, new ProtocolHandler("test")); // Open socket to initial instance @@ -507,6 +514,8 @@ public void create_successfulPublicCasConnection() throws IOException, Interrupt ConnectionInfoRepositoryFactory factory = new StubConnectionInfoRepositoryFactory(fakeSuccessHttpCasTransport(Duration.ZERO)); + DnsResolver dnsResolver = new MockDnsResolver("example.com", "myProject:myRegion:myInstance"); + Connector connector = new Connector( config.getConnectorConfig(), @@ -517,8 +526,8 @@ public void create_successfulPublicCasConnection() throws IOException, Interrupt 10, TEST_MAX_REFRESH_MS, port, - new DnsInstanceConnectionNameResolver( - new MockDnsResolver("example.com", "myProject:myRegion:myInstance")), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); @@ -587,6 +596,9 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr .withCloudSqlInstance("example.com:myProject:myRegion:myInstance") .withIpTypes("PRIMARY") .build(); + + DnsResolver dnsResolver = new MockDnsResolver(); + Connector c = new Connector( config.getConnectorConfig(), @@ -597,7 +609,8 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr 10, TEST_MAX_REFRESH_MS, port, - new DnsInstanceConnectionNameResolver(new MockDnsResolver()), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -653,6 +666,8 @@ public void create_throwsException_adminApiNotEnabled() throws IOException { .withCloudSqlInstance("NotMyProject:myRegion:myInstance") .withIpTypes("PRIMARY") .build(); + + DnsResolver dnsResolver = new MockDnsResolver(); Connector c = new Connector( config.getConnectorConfig(), @@ -663,7 +678,8 @@ public void create_throwsException_adminApiNotEnabled() throws IOException { 10, TEST_MAX_REFRESH_MS, DEFAULT_SERVER_PROXY_PORT, - new DnsInstanceConnectionNameResolver(new MockDnsResolver()), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); // Use a different project to get Api Not Enabled Error. @@ -687,6 +703,8 @@ public void create_throwsException_adminApiReturnsNotAuthorized() throws IOExcep .withCloudSqlInstance("myProject:myRegion:NotMyInstance") .withIpTypes("PRIMARY") .build(); + + DnsResolver dnsResolver = new MockDnsResolver(); Connector c = new Connector( config.getConnectorConfig(), @@ -697,7 +715,8 @@ public void create_throwsException_adminApiReturnsNotAuthorized() throws IOExcep 10, TEST_MAX_REFRESH_MS, DEFAULT_SERVER_PROXY_PORT, - new DnsInstanceConnectionNameResolver(new MockDnsResolver()), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); // Use a different instance to simulate incorrect permissions. @@ -721,6 +740,9 @@ public void create_throwsException_badGateway() throws IOException { .withCloudSqlInstance("myProject:myRegion:NotMyInstance") .withIpTypes("PRIMARY") .build(); + + DnsResolver dnsResolver = new MockDnsResolver(); + Connector c = new Connector( config.getConnectorConfig(), @@ -731,7 +753,8 @@ public void create_throwsException_badGateway() throws IOException { 10, TEST_MAX_REFRESH_MS, DEFAULT_SERVER_PROXY_PORT, - new DnsInstanceConnectionNameResolver(new MockDnsResolver()), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); // If the gateway is down, then this is a temporary error, not a fatal error. @@ -765,6 +788,8 @@ public void create_successfulPublicConnection_withIntermittentBadGatewayErrors() int port = sslServer.start(PUBLIC_IP); + DnsResolver dnsResolver = new MockDnsResolver(); + Connector c = new Connector( config.getConnectorConfig(), @@ -775,7 +800,8 @@ public void create_successfulPublicConnection_withIntermittentBadGatewayErrors() 10, TEST_MAX_REFRESH_MS, port, - new DnsInstanceConnectionNameResolver(new MockDnsResolver()), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -800,6 +826,9 @@ public void supportsCustomCredentialFactoryWithIAM() throws InterruptedException .withIpTypes("PRIMARY") .withAuthType(AuthType.IAM) .build(); + + DnsResolver dnsResolver = new MockDnsResolver(); + Connector c = new Connector( config.getConnectorConfig(), @@ -810,7 +839,8 @@ public void supportsCustomCredentialFactoryWithIAM() throws InterruptedException 10, TEST_MAX_REFRESH_MS, port, - new DnsInstanceConnectionNameResolver(new MockDnsResolver()), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -834,6 +864,7 @@ public void supportsCustomCredentialFactoryWithNoExpirationTime() .withIpTypes("PRIMARY") .withAuthType(AuthType.IAM) .build(); + DnsResolver dnsResolver = new MockDnsResolver(); Connector c = new Connector( config.getConnectorConfig(), @@ -844,7 +875,8 @@ public void supportsCustomCredentialFactoryWithNoExpirationTime() 10, TEST_MAX_REFRESH_MS, port, - new DnsInstanceConnectionNameResolver(new MockDnsResolver()), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); Socket socket = c.connect(config, TEST_MAX_REFRESH_MS); @@ -874,6 +906,7 @@ public HttpRequestInitializer create() { .withIpTypes("PRIMARY") .withAuthType(AuthType.IAM) .build(); + DnsResolver dnsResolver = new MockDnsResolver(); Connector c = new Connector( config.getConnectorConfig(), @@ -884,7 +917,8 @@ public HttpRequestInitializer create() { 10, TEST_MAX_REFRESH_MS, DEFAULT_SERVER_PROXY_PORT, - new DnsInstanceConnectionNameResolver(new MockDnsResolver()), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); @@ -902,6 +936,7 @@ private Connector newConnectorLegacyDnsField( false, domainName, null)); + DnsResolver dnsResolver = new MockDnsResolver(domainName, instanceName); Connector connector = new Connector( config, @@ -912,7 +947,8 @@ private Connector newConnectorLegacyDnsField( 10, TEST_MAX_REFRESH_MS, port, - new DnsInstanceConnectionNameResolver(new MockDnsResolver(domainName, instanceName)), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); return connector; } @@ -933,6 +969,7 @@ private Connector newConnector( .setName(domainName) .setConnectionType("PRIVATE_SERVICE_CONNECT") .setDnsScope("INSTANCE")))); + DnsResolver dnsResolver = new MockDnsResolver(domainName, instanceName); Connector connector = new Connector( config, @@ -943,7 +980,8 @@ private Connector newConnector( 10, TEST_MAX_REFRESH_MS, port, - new DnsInstanceConnectionNameResolver(new MockDnsResolver(domainName, instanceName)), + new DnsInstanceConnectionNameResolver(dnsResolver), + dnsResolver, new ProtocolHandler("test")); return connector; } @@ -957,6 +995,7 @@ private String readLine(Socket socket) throws IOException { private static class MockDnsResolver implements DnsResolver { private final String domainName; private final String instanceName; + private final Map hosts = new HashMap<>(); private MockDnsResolver() { this.domainName = null; @@ -981,6 +1020,14 @@ public Collection resolveTxt(String domainName) throws NameNotFoundExcep } throw new NameNotFoundException("Not found: " + domainName); } + + @Override + public List resolveHost(String hostName) { + if (hosts.containsKey(hostName)) { + return Collections.singletonList(hosts.get(hostName)); + } + return Collections.emptyList(); + } } private static class MutableDnsResolver implements DnsResolver { @@ -1010,5 +1057,10 @@ public synchronized Collection resolveTxt(String domainName) } throw new NameNotFoundException("Not found: " + domainName); } + + @Override + public List resolveHost(String hostName) { + return Collections.emptyList(); + } } }