Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.api.client.googleapis.json.GoogleJsonResponseException;
import com.google.api.services.sqladmin.SQLAdmin;
import com.google.api.services.sqladmin.model.ConnectSettings;
import com.google.api.services.sqladmin.model.DnsNameMapping;
import com.google.api.services.sqladmin.model.GenerateEphemeralCertRequest;
import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse;
import com.google.api.services.sqladmin.model.IpMapping;
Expand Down Expand Up @@ -287,10 +288,31 @@ private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthTy
boolean pscEnabled =
instanceMetadata.getPscEnabled() != null
&& instanceMetadata.getPscEnabled().booleanValue();
if (pscEnabled
&& instanceMetadata.getDnsName() != null
&& !instanceMetadata.getDnsName().isEmpty()) {
ipAddrs.put(IpType.PSC, instanceMetadata.getDnsName());

if (pscEnabled) {
// Search the dns_names field for the PSC DNS Name.
String pscDnsName = null;
if (instanceMetadata.getDnsNames() != null) {
for (DnsNameMapping dnm : instanceMetadata.getDnsNames()) {
if ("PRIVATE_SERVICE_CONNECT".equals(dnm.getConnectionType())
&& "INSTANCE".equals(dnm.getDnsScope())) {
pscDnsName = dnm.getName();
break;
}
}
}

// If the psc dns name was not found, use the legacy dns_name field
if (pscDnsName == null
&& instanceMetadata.getDnsName() != null
&& !instanceMetadata.getDnsName().isEmpty()) {
pscDnsName = instanceMetadata.getDnsName();
}

// If the psc dns name was found, add it to the ipaddrs map.
if (pscDnsName != null) {
ipAddrs.put(IpType.PSC, pscDnsName);
}
}

// Verify the instance has at least one IP type assigned that can be used to connect.
Expand All @@ -301,6 +323,18 @@ private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthTy
+ "IP address.",
instanceName.getConnectionName()));
}

// Find a DNS name to use to validate the certificate from the dns_names field. Any
// name in the list may be used to validate the server TLS certificate.
// Fall back to legacy dns_name field if necessary.
String serverName = null;
if (instanceMetadata.getDnsNames() != null && !instanceMetadata.getDnsNames().isEmpty()) {
serverName = instanceMetadata.getDnsNames().get(0).getName();
}
if (serverName == null) {
serverName = instanceMetadata.getDnsName();
}

// Update the Server CA certificate used to create the SSL connection with the instance.
try {
List<Certificate> instanceCaCertificates =
Expand All @@ -313,7 +347,7 @@ private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthTy
ipAddrs,
instanceCaCertificates,
isCasManagedCertificate(instanceMetadata.getServerCaMode()),
instanceMetadata.getDnsName(),
serverName,
pscEnabled);
} catch (CertificateException ex) {
throw new RuntimeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ private void checkCertificateChain(X509Certificate[] chain) throws CertificateEx
throw new CertificateException("Subject is missing");
}

if (instanceMetadata.isCasManagedCertificate() || instanceMetadata.isPscEnabled()) {
checkSan(chain);
} else {
// If the instance metadata does not contain a domain name, use legacy CN validation
if (Strings.isNullOrEmpty(instanceMetadata.getDnsName())) {
checkCn(chain);
} else {
// If there is a DNS name, check the Subject Alternative Names.
checkSan(chain);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.api.client.testing.http.MockLowLevelHttpRequest;
import com.google.api.client.testing.http.MockLowLevelHttpResponse;
import com.google.api.services.sqladmin.model.ConnectSettings;
import com.google.api.services.sqladmin.model.DnsNameMapping;
import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse;
import com.google.api.services.sqladmin.model.IpMapping;
import com.google.api.services.sqladmin.model.SslCert;
Expand All @@ -42,6 +43,7 @@
import java.time.Duration;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -169,6 +171,18 @@ private String parseCertCnFromUrl(String url) {

MockHttpTransport fakeSuccessHttpTransport(
String serverCert, Duration certDuration, String baseUrl, boolean cas, boolean psc) {
return this.fakeSuccessHttpTransport(
serverCert, certDuration, baseUrl, cas, psc, cas ? "db.example.com" : null, null);
}

MockHttpTransport fakeSuccessHttpTransport(
String serverCert,
Duration certDuration,
String baseUrl,
boolean cas,
boolean psc,
String legacyDnsName,
List<DnsNameMapping> dnsNames) {
final JsonFactory jsonFactory = new GsonFactory();
return new MockHttpTransport() {
@Override
Expand Down Expand Up @@ -205,7 +219,8 @@ public LowLevelHttpResponse execute() throws IOException {
.setDatabaseVersion("POSTGRES14")
.setRegion("myRegion")
.setPscEnabled(psc ? Boolean.TRUE : null)
.setDnsName(cas || psc ? "db.example.com" : null)
.setDnsName(legacyDnsName)
.setDnsNames(dnsNames)
.setServerCaMode(
cas ? "GOOGLE_MANAGED_CAS_CA" : "GOOGLE_MANAGED_INTERNAL_CA");
settings.setFactory(jsonFactory);
Expand Down
73 changes: 71 additions & 2 deletions core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.google.api.client.http.BasicAuthentication;
import com.google.api.client.http.HttpRequestInitializer;
import com.google.api.services.sqladmin.model.DnsNameMapping;
import com.google.cloud.sql.AuthType;
import com.google.cloud.sql.ConnectorConfig;
import com.google.cloud.sql.CredentialFactory;
Expand Down Expand Up @@ -193,6 +194,31 @@ public void create_successfulPrivateConnection_UsesInstanceName_EmptyDomainNameI
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
}

@Test
public void create_successfulPublicConnectionWithDomainNameLegacyDns()
throws IOException, InterruptedException {
FakeSslServer sslServer = new FakeSslServer();
ConnectionConfig config =
new ConnectionConfig.Builder()
.withDomainName("db.example.com")
.withIpTypes("PRIMARY")
.build();

int port = sslServer.start(PUBLIC_IP);

Connector connector =
newConnectorLegacyDnsField(
config.getConnectorConfig(),
port,
"db.example.com",
"myProject:myRegion:myInstance",
false);

Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS);

assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
}

@Test
public void create_successfulPublicConnectionWithDomainName()
throws IOException, InterruptedException {
Expand Down Expand Up @@ -516,7 +542,14 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr
new CredentialFactoryProvider(new StubCredentialFactory("foo", null));
ConnectionInfoRepositoryFactory factory =
new StubConnectionInfoRepositoryFactory(
fakeSuccessHttpTransport(TestKeys.getDomainServerCertPem(), Duration.ofSeconds(60)));
fakeSuccessHttpTransport(
TestKeys.getDomainServerCertPem(),
Duration.ofSeconds(60),
null,
false,
false,
null,
null));

int port = sslServer.start(PUBLIC_IP);
ConnectionConfig config =
Expand Down Expand Up @@ -819,12 +852,48 @@ public HttpRequestInitializer create() {
assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS));
}

private Connector newConnectorLegacyDnsField(
ConnectorConfig config, int port, String domainName, String instanceName, boolean cas) {
ConnectionInfoRepositoryFactory factory =
new StubConnectionInfoRepositoryFactory(
fakeSuccessHttpTransport(
TestKeys.getServerCertPem(),
Duration.ofSeconds(0),
null,
cas,
false,
domainName,
null));
Connector connector =
new Connector(
config,
factory,
stubCredentialFactoryProvider.getInstanceCredentialFactory(config),
defaultExecutor,
clientKeyPair,
10,
TEST_MAX_REFRESH_MS,
port,
new DnsInstanceConnectionNameResolver(new MockDnsResolver(domainName, instanceName)));
return connector;
}

private Connector newConnector(
ConnectorConfig config, int port, String domainName, String instanceName, boolean cas) {
ConnectionInfoRepositoryFactory factory =
new StubConnectionInfoRepositoryFactory(
fakeSuccessHttpTransport(
TestKeys.getServerCertPem(), Duration.ofSeconds(0), null, cas, false));
TestKeys.getServerCertPem(),
Duration.ofSeconds(0),
null,
cas,
false,
null,
Collections.singletonList(
new DnsNameMapping()
.setName(domainName)
.setConnectionType("PRIVATE_SERVICE_CONNECT")
.setDnsScope("INSTANCE"))));
Connector connector =
new Connector(
config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void testFetchInstanceData_returnsIpAddresses()
throws ExecutionException, InterruptedException, GeneralSecurityException,
OperatorCreationException {
MockAdminApi mockAdminApi =
buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, DEFAULT_BASE_URL);
buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, DEFAULT_BASE_URL, false);
ConnectorConfig config = new ConnectorConfig.Builder().build();
ConnectionInfoRepository repo =
new StubConnectionInfoRepositoryFactory(mockAdminApi.getHttpTransport())
Expand Down Expand Up @@ -85,7 +85,45 @@ public void testFetchInstanceData_returnsPscForNonIpDatabase()
null,
DATABASE_VERSION,
SAMPLE_PCS_DNS_NAME,
DEFAULT_BASE_URL);
DEFAULT_BASE_URL,
false);
mockAdminApi.addGenerateEphemeralCertResponse(
INSTANCE_CONNECTION_NAME, Duration.ofHours(1), DEFAULT_BASE_URL);
ConnectorConfig config = new ConnectorConfig.Builder().build();

ConnectionInfoRepository repo =
new StubConnectionInfoRepositoryFactory(mockAdminApi.getHttpTransport())
.create(new StubCredentialFactory().create(), config);

ConnectionInfo connectionInfo =
repo.getConnectionInfo(
new CloudSqlInstanceName(INSTANCE_CONNECTION_NAME),
() -> Optional.empty(),
AuthType.PASSWORD,
newTestExecutor(),
Futures.immediateFuture(mockAdminApi.getClientKeyPair()))
.get();
assertThat(connectionInfo.getSslContext()).isInstanceOf(SSLContext.class);

Map<IpType, String> ipAddrs = connectionInfo.getIpAddrs();
assertThat(ipAddrs.get(IpType.PSC)).isEqualTo(SAMPLE_PCS_DNS_NAME);
assertThat(ipAddrs.size()).isEqualTo(1);
}

@Test
public void testFetchInstanceData_legacyPscDns_returnsPscForNonIpDatabase()
throws ExecutionException, InterruptedException, GeneralSecurityException,
OperatorCreationException {

MockAdminApi mockAdminApi = new MockAdminApi();
mockAdminApi.addConnectSettingsResponse(
INSTANCE_CONNECTION_NAME,
null,
null,
DATABASE_VERSION,
SAMPLE_PCS_DNS_NAME,
DEFAULT_BASE_URL,
true);
mockAdminApi.addGenerateEphemeralCertResponse(
INSTANCE_CONNECTION_NAME, Duration.ofHours(1), DEFAULT_BASE_URL);
ConnectorConfig config = new ConnectorConfig.Builder().build();
Expand Down Expand Up @@ -122,7 +160,8 @@ private ListeningScheduledExecutorService newTestExecutor() {
public void testFetchInstanceData_throwsException_whenIamAuthnIsNotSupported()
throws GeneralSecurityException, OperatorCreationException {
MockAdminApi mockAdminApi =
buildMockAdminApi(INSTANCE_CONNECTION_NAME, "SQLSERVER_2019_STANDARD", DEFAULT_BASE_URL);
buildMockAdminApi(
INSTANCE_CONNECTION_NAME, "SQLSERVER_2019_STANDARD", DEFAULT_BASE_URL, false);
ConnectorConfig config = new ConnectorConfig.Builder().build();
ConnectionInfoRepository repo =
new StubConnectionInfoRepositoryFactory(mockAdminApi.getHttpTransport())
Expand All @@ -149,7 +188,7 @@ public void testFetchInstanceData_throwsException_whenIamAuthnIsNotSupported()
public void testFetchInstanceData_throwsException_whenRequestsTimeout()
throws GeneralSecurityException, OperatorCreationException {
MockAdminApi mockAdminApi =
buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, DEFAULT_BASE_URL);
buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, DEFAULT_BASE_URL, false);
ConnectorConfig config = new ConnectorConfig.Builder().build();
ConnectionInfoRepository repo =
new StubConnectionInfoRepositoryFactory(new BadConnectionFactory())
Expand Down Expand Up @@ -182,7 +221,7 @@ public void testSetAdminUrl_FetchInstanceData_returnsIpAddresses()
String adminServicePath = "sqladmin/";
String baseUrl = adminRootUrl + adminServicePath;
MockAdminApi mockAdminApi =
buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, baseUrl);
buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, baseUrl, false);
ConnectorConfig config =
new ConnectorConfig.Builder()
.withAdminRootUrl(adminRootUrl)
Expand Down Expand Up @@ -210,7 +249,7 @@ public void testSetAdminUrl_FetchInstanceData_returnsIpAddresses()

@SuppressWarnings("SameParameterValue")
private MockAdminApi buildMockAdminApi(
String instanceConnectionName, String databaseVersion, String baseUrl)
String instanceConnectionName, String databaseVersion, String baseUrl, boolean legacyDnsName)
throws GeneralSecurityException, OperatorCreationException {
MockAdminApi mockAdminApi = new MockAdminApi();
mockAdminApi.addConnectSettingsResponse(
Expand All @@ -219,7 +258,8 @@ private MockAdminApi buildMockAdminApi(
SAMPLE_PRIVATE_IP,
databaseVersion,
SAMPLE_PCS_DNS_NAME,
baseUrl);
baseUrl,
legacyDnsName);
mockAdminApi.addGenerateEphemeralCertResponse(
instanceConnectionName, Duration.ofHours(1), baseUrl);
return mockAdminApi;
Expand Down
16 changes: 14 additions & 2 deletions core/src/test/java/com/google/cloud/sql/core/MockAdminApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.api.client.testing.http.MockLowLevelHttpRequest;
import com.google.api.client.testing.http.MockLowLevelHttpResponse;
import com.google.api.services.sqladmin.model.ConnectSettings;
import com.google.api.services.sqladmin.model.DnsNameMapping;
import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse;
import com.google.api.services.sqladmin.model.IpMapping;
import com.google.api.services.sqladmin.model.SslCert;
Expand All @@ -38,6 +39,7 @@
import java.security.spec.InvalidKeySpecException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -83,7 +85,8 @@ public void addConnectSettingsResponse(
String privateIp,
String databaseVersion,
String pscHostname,
String baseUrl) {
String baseUrl,
boolean legacyPscDnsName) {
CloudSqlInstanceName cloudSqlInstanceName = new CloudSqlInstanceName(instanceConnectionName);

ArrayList<IpMapping> ipMappings = new ArrayList<>();
Expand All @@ -103,10 +106,19 @@ public void addConnectSettingsResponse(
.setServerCaCert(new SslCert().setCert(TestKeys.getServerCertPem()))
.setDatabaseVersion(databaseVersion)
.setPscEnabled(pscHostname != null)
.setDnsName(pscHostname)
.setPscEnabled(pscHostname != null)
.setRegion(cloudSqlInstanceName.getRegionId());
settings.setFactory(GsonFactory.getDefaultInstance());
if (legacyPscDnsName) {
settings.setDnsName(pscHostname);
} else {
settings.setDnsNames(
Collections.singletonList(
new DnsNameMapping()
.setDnsScope("INSTANCE")
.setConnectionType("PRIVATE_SERVICE_CONNECT")
.setName(pscHostname)));
}

connectSettingsRequests.add(
new ConnectSettingsRequest(cloudSqlInstanceName, settings, baseUrl));
Expand Down
Loading
Loading