From 0f35e01c69f5975d5e85284461d564504b7f2b18 Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Tue, 11 Mar 2025 12:39:44 -0600 Subject: [PATCH] refactor: Use new ConnectSettings.DnsNames field to validate the server certificate's server name. --- .../core/DefaultConnectionInfoRepository.java | 44 +++++++++-- .../sql/core/InstanceCheckingTrustManger.java | 8 +- .../sql/core/CloudSqlCoreTestingBase.java | 17 ++++- .../google/cloud/sql/core/ConnectorTest.java | 73 ++++++++++++++++++- .../DefaultConnectionInfoRepositoryTest.java | 54 ++++++++++++-- .../google/cloud/sql/core/MockAdminApi.java | 16 +++- .../sql/core/TestCertificateGenerator.java | 4 +- pom.xml | 2 +- 8 files changed, 195 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/com/google/cloud/sql/core/DefaultConnectionInfoRepository.java b/core/src/main/java/com/google/cloud/sql/core/DefaultConnectionInfoRepository.java index 350d573a4..82f88ff17 100644 --- a/core/src/main/java/com/google/cloud/sql/core/DefaultConnectionInfoRepository.java +++ b/core/src/main/java/com/google/cloud/sql/core/DefaultConnectionInfoRepository.java @@ -19,6 +19,7 @@ import com.google.api.client.googleapis.json.GoogleJsonResponseException; import com.google.api.services.sqladmin.SQLAdmin; import com.google.api.services.sqladmin.model.ConnectSettings; +import com.google.api.services.sqladmin.model.DnsNameMapping; import com.google.api.services.sqladmin.model.GenerateEphemeralCertRequest; import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse; import com.google.api.services.sqladmin.model.IpMapping; @@ -287,10 +288,31 @@ private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthTy boolean pscEnabled = instanceMetadata.getPscEnabled() != null && instanceMetadata.getPscEnabled().booleanValue(); - if (pscEnabled - && instanceMetadata.getDnsName() != null - && !instanceMetadata.getDnsName().isEmpty()) { - ipAddrs.put(IpType.PSC, instanceMetadata.getDnsName()); + + if (pscEnabled) { + // Search the dns_names field for the PSC DNS Name. + String pscDnsName = null; + if (instanceMetadata.getDnsNames() != null) { + for (DnsNameMapping dnm : instanceMetadata.getDnsNames()) { + if ("PRIVATE_SERVICE_CONNECT".equals(dnm.getConnectionType()) + && "INSTANCE".equals(dnm.getDnsScope())) { + pscDnsName = dnm.getName(); + break; + } + } + } + + // If the psc dns name was not found, use the legacy dns_name field + if (pscDnsName == null + && instanceMetadata.getDnsName() != null + && !instanceMetadata.getDnsName().isEmpty()) { + pscDnsName = instanceMetadata.getDnsName(); + } + + // If the psc dns name was found, add it to the ipaddrs map. + if (pscDnsName != null) { + ipAddrs.put(IpType.PSC, pscDnsName); + } } // Verify the instance has at least one IP type assigned that can be used to connect. @@ -301,6 +323,18 @@ private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthTy + "IP address.", instanceName.getConnectionName())); } + + // Find a DNS name to use to validate the certificate from the dns_names field. Any + // name in the list may be used to validate the server TLS certificate. + // Fall back to legacy dns_name field if necessary. + String serverName = null; + if (instanceMetadata.getDnsNames() != null && !instanceMetadata.getDnsNames().isEmpty()) { + serverName = instanceMetadata.getDnsNames().get(0).getName(); + } + if (serverName == null) { + serverName = instanceMetadata.getDnsName(); + } + // Update the Server CA certificate used to create the SSL connection with the instance. try { List instanceCaCertificates = @@ -313,7 +347,7 @@ private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthTy ipAddrs, instanceCaCertificates, isCasManagedCertificate(instanceMetadata.getServerCaMode()), - instanceMetadata.getDnsName(), + serverName, pscEnabled); } catch (CertificateException ex) { throw new RuntimeException( diff --git a/core/src/main/java/com/google/cloud/sql/core/InstanceCheckingTrustManger.java b/core/src/main/java/com/google/cloud/sql/core/InstanceCheckingTrustManger.java index 182e705eb..e39dba995 100644 --- a/core/src/main/java/com/google/cloud/sql/core/InstanceCheckingTrustManger.java +++ b/core/src/main/java/com/google/cloud/sql/core/InstanceCheckingTrustManger.java @@ -96,10 +96,12 @@ private void checkCertificateChain(X509Certificate[] chain) throws CertificateEx throw new CertificateException("Subject is missing"); } - if (instanceMetadata.isCasManagedCertificate() || instanceMetadata.isPscEnabled()) { - checkSan(chain); - } else { + // If the instance metadata does not contain a domain name, use legacy CN validation + if (Strings.isNullOrEmpty(instanceMetadata.getDnsName())) { checkCn(chain); + } else { + // If there is a DNS name, check the Subject Alternative Names. + checkSan(chain); } } diff --git a/core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java b/core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java index 4f5020e5f..3cc597fed 100644 --- a/core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java +++ b/core/src/test/java/com/google/cloud/sql/core/CloudSqlCoreTestingBase.java @@ -30,6 +30,7 @@ import com.google.api.client.testing.http.MockLowLevelHttpRequest; import com.google.api.client.testing.http.MockLowLevelHttpResponse; import com.google.api.services.sqladmin.model.ConnectSettings; +import com.google.api.services.sqladmin.model.DnsNameMapping; import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse; import com.google.api.services.sqladmin.model.IpMapping; import com.google.api.services.sqladmin.model.SslCert; @@ -42,6 +43,7 @@ import java.time.Duration; import java.util.Base64; import java.util.Collections; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -169,6 +171,18 @@ private String parseCertCnFromUrl(String url) { MockHttpTransport fakeSuccessHttpTransport( String serverCert, Duration certDuration, String baseUrl, boolean cas, boolean psc) { + return this.fakeSuccessHttpTransport( + serverCert, certDuration, baseUrl, cas, psc, cas ? "db.example.com" : null, null); + } + + MockHttpTransport fakeSuccessHttpTransport( + String serverCert, + Duration certDuration, + String baseUrl, + boolean cas, + boolean psc, + String legacyDnsName, + List dnsNames) { final JsonFactory jsonFactory = new GsonFactory(); return new MockHttpTransport() { @Override @@ -205,7 +219,8 @@ public LowLevelHttpResponse execute() throws IOException { .setDatabaseVersion("POSTGRES14") .setRegion("myRegion") .setPscEnabled(psc ? Boolean.TRUE : null) - .setDnsName(cas || psc ? "db.example.com" : null) + .setDnsName(legacyDnsName) + .setDnsNames(dnsNames) .setServerCaMode( cas ? "GOOGLE_MANAGED_CAS_CA" : "GOOGLE_MANAGED_INTERNAL_CA"); settings.setFactory(jsonFactory); diff --git a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java index 3ebee5213..b41ae27d5 100644 --- a/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java @@ -22,6 +22,7 @@ import com.google.api.client.http.BasicAuthentication; import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.services.sqladmin.model.DnsNameMapping; import com.google.cloud.sql.AuthType; import com.google.cloud.sql.ConnectorConfig; import com.google.cloud.sql.CredentialFactory; @@ -193,6 +194,31 @@ public void create_successfulPrivateConnection_UsesInstanceName_EmptyDomainNameI assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); } + @Test + public void create_successfulPublicConnectionWithDomainNameLegacyDns() + throws IOException, InterruptedException { + FakeSslServer sslServer = new FakeSslServer(); + ConnectionConfig config = + new ConnectionConfig.Builder() + .withDomainName("db.example.com") + .withIpTypes("PRIMARY") + .build(); + + int port = sslServer.start(PUBLIC_IP); + + Connector connector = + newConnectorLegacyDnsField( + config.getConnectorConfig(), + port, + "db.example.com", + "myProject:myRegion:myInstance", + false); + + Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS); + + assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE); + } + @Test public void create_successfulPublicConnectionWithDomainName() throws IOException, InterruptedException { @@ -516,7 +542,14 @@ public void create_successfulDomainScopedConnection() throws IOException, Interr new CredentialFactoryProvider(new StubCredentialFactory("foo", null)); ConnectionInfoRepositoryFactory factory = new StubConnectionInfoRepositoryFactory( - fakeSuccessHttpTransport(TestKeys.getDomainServerCertPem(), Duration.ofSeconds(60))); + fakeSuccessHttpTransport( + TestKeys.getDomainServerCertPem(), + Duration.ofSeconds(60), + null, + false, + false, + null, + null)); int port = sslServer.start(PUBLIC_IP); ConnectionConfig config = @@ -819,12 +852,48 @@ public HttpRequestInitializer create() { assertThrows(RuntimeException.class, () -> c.connect(config, TEST_MAX_REFRESH_MS)); } + private Connector newConnectorLegacyDnsField( + ConnectorConfig config, int port, String domainName, String instanceName, boolean cas) { + ConnectionInfoRepositoryFactory factory = + new StubConnectionInfoRepositoryFactory( + fakeSuccessHttpTransport( + TestKeys.getServerCertPem(), + Duration.ofSeconds(0), + null, + cas, + false, + domainName, + null)); + Connector connector = + new Connector( + config, + factory, + stubCredentialFactoryProvider.getInstanceCredentialFactory(config), + defaultExecutor, + clientKeyPair, + 10, + TEST_MAX_REFRESH_MS, + port, + new DnsInstanceConnectionNameResolver(new MockDnsResolver(domainName, instanceName))); + return connector; + } + private Connector newConnector( ConnectorConfig config, int port, String domainName, String instanceName, boolean cas) { ConnectionInfoRepositoryFactory factory = new StubConnectionInfoRepositoryFactory( fakeSuccessHttpTransport( - TestKeys.getServerCertPem(), Duration.ofSeconds(0), null, cas, false)); + TestKeys.getServerCertPem(), + Duration.ofSeconds(0), + null, + cas, + false, + null, + Collections.singletonList( + new DnsNameMapping() + .setName(domainName) + .setConnectionType("PRIVATE_SERVICE_CONNECT") + .setDnsScope("INSTANCE")))); Connector connector = new Connector( config, diff --git a/core/src/test/java/com/google/cloud/sql/core/DefaultConnectionInfoRepositoryTest.java b/core/src/test/java/com/google/cloud/sql/core/DefaultConnectionInfoRepositoryTest.java index c74a1c174..59d110b62 100644 --- a/core/src/test/java/com/google/cloud/sql/core/DefaultConnectionInfoRepositoryTest.java +++ b/core/src/test/java/com/google/cloud/sql/core/DefaultConnectionInfoRepositoryTest.java @@ -51,7 +51,7 @@ public void testFetchInstanceData_returnsIpAddresses() throws ExecutionException, InterruptedException, GeneralSecurityException, OperatorCreationException { MockAdminApi mockAdminApi = - buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, DEFAULT_BASE_URL); + buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, DEFAULT_BASE_URL, false); ConnectorConfig config = new ConnectorConfig.Builder().build(); ConnectionInfoRepository repo = new StubConnectionInfoRepositoryFactory(mockAdminApi.getHttpTransport()) @@ -85,7 +85,45 @@ public void testFetchInstanceData_returnsPscForNonIpDatabase() null, DATABASE_VERSION, SAMPLE_PCS_DNS_NAME, - DEFAULT_BASE_URL); + DEFAULT_BASE_URL, + false); + mockAdminApi.addGenerateEphemeralCertResponse( + INSTANCE_CONNECTION_NAME, Duration.ofHours(1), DEFAULT_BASE_URL); + ConnectorConfig config = new ConnectorConfig.Builder().build(); + + ConnectionInfoRepository repo = + new StubConnectionInfoRepositoryFactory(mockAdminApi.getHttpTransport()) + .create(new StubCredentialFactory().create(), config); + + ConnectionInfo connectionInfo = + repo.getConnectionInfo( + new CloudSqlInstanceName(INSTANCE_CONNECTION_NAME), + () -> Optional.empty(), + AuthType.PASSWORD, + newTestExecutor(), + Futures.immediateFuture(mockAdminApi.getClientKeyPair())) + .get(); + assertThat(connectionInfo.getSslContext()).isInstanceOf(SSLContext.class); + + Map ipAddrs = connectionInfo.getIpAddrs(); + assertThat(ipAddrs.get(IpType.PSC)).isEqualTo(SAMPLE_PCS_DNS_NAME); + assertThat(ipAddrs.size()).isEqualTo(1); + } + + @Test + public void testFetchInstanceData_legacyPscDns_returnsPscForNonIpDatabase() + throws ExecutionException, InterruptedException, GeneralSecurityException, + OperatorCreationException { + + MockAdminApi mockAdminApi = new MockAdminApi(); + mockAdminApi.addConnectSettingsResponse( + INSTANCE_CONNECTION_NAME, + null, + null, + DATABASE_VERSION, + SAMPLE_PCS_DNS_NAME, + DEFAULT_BASE_URL, + true); mockAdminApi.addGenerateEphemeralCertResponse( INSTANCE_CONNECTION_NAME, Duration.ofHours(1), DEFAULT_BASE_URL); ConnectorConfig config = new ConnectorConfig.Builder().build(); @@ -122,7 +160,8 @@ private ListeningScheduledExecutorService newTestExecutor() { public void testFetchInstanceData_throwsException_whenIamAuthnIsNotSupported() throws GeneralSecurityException, OperatorCreationException { MockAdminApi mockAdminApi = - buildMockAdminApi(INSTANCE_CONNECTION_NAME, "SQLSERVER_2019_STANDARD", DEFAULT_BASE_URL); + buildMockAdminApi( + INSTANCE_CONNECTION_NAME, "SQLSERVER_2019_STANDARD", DEFAULT_BASE_URL, false); ConnectorConfig config = new ConnectorConfig.Builder().build(); ConnectionInfoRepository repo = new StubConnectionInfoRepositoryFactory(mockAdminApi.getHttpTransport()) @@ -149,7 +188,7 @@ public void testFetchInstanceData_throwsException_whenIamAuthnIsNotSupported() public void testFetchInstanceData_throwsException_whenRequestsTimeout() throws GeneralSecurityException, OperatorCreationException { MockAdminApi mockAdminApi = - buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, DEFAULT_BASE_URL); + buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, DEFAULT_BASE_URL, false); ConnectorConfig config = new ConnectorConfig.Builder().build(); ConnectionInfoRepository repo = new StubConnectionInfoRepositoryFactory(new BadConnectionFactory()) @@ -182,7 +221,7 @@ public void testSetAdminUrl_FetchInstanceData_returnsIpAddresses() String adminServicePath = "sqladmin/"; String baseUrl = adminRootUrl + adminServicePath; MockAdminApi mockAdminApi = - buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, baseUrl); + buildMockAdminApi(INSTANCE_CONNECTION_NAME, DATABASE_VERSION, baseUrl, false); ConnectorConfig config = new ConnectorConfig.Builder() .withAdminRootUrl(adminRootUrl) @@ -210,7 +249,7 @@ public void testSetAdminUrl_FetchInstanceData_returnsIpAddresses() @SuppressWarnings("SameParameterValue") private MockAdminApi buildMockAdminApi( - String instanceConnectionName, String databaseVersion, String baseUrl) + String instanceConnectionName, String databaseVersion, String baseUrl, boolean legacyDnsName) throws GeneralSecurityException, OperatorCreationException { MockAdminApi mockAdminApi = new MockAdminApi(); mockAdminApi.addConnectSettingsResponse( @@ -219,7 +258,8 @@ private MockAdminApi buildMockAdminApi( SAMPLE_PRIVATE_IP, databaseVersion, SAMPLE_PCS_DNS_NAME, - baseUrl); + baseUrl, + legacyDnsName); mockAdminApi.addGenerateEphemeralCertResponse( instanceConnectionName, Duration.ofHours(1), baseUrl); return mockAdminApi; diff --git a/core/src/test/java/com/google/cloud/sql/core/MockAdminApi.java b/core/src/test/java/com/google/cloud/sql/core/MockAdminApi.java index eea4cc678..6cae2273a 100644 --- a/core/src/test/java/com/google/cloud/sql/core/MockAdminApi.java +++ b/core/src/test/java/com/google/cloud/sql/core/MockAdminApi.java @@ -26,6 +26,7 @@ import com.google.api.client.testing.http.MockLowLevelHttpRequest; import com.google.api.client.testing.http.MockLowLevelHttpResponse; import com.google.api.services.sqladmin.model.ConnectSettings; +import com.google.api.services.sqladmin.model.DnsNameMapping; import com.google.api.services.sqladmin.model.GenerateEphemeralCertResponse; import com.google.api.services.sqladmin.model.IpMapping; import com.google.api.services.sqladmin.model.SslCert; @@ -38,6 +39,7 @@ import java.security.spec.InvalidKeySpecException; import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.Date; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -83,7 +85,8 @@ public void addConnectSettingsResponse( String privateIp, String databaseVersion, String pscHostname, - String baseUrl) { + String baseUrl, + boolean legacyPscDnsName) { CloudSqlInstanceName cloudSqlInstanceName = new CloudSqlInstanceName(instanceConnectionName); ArrayList ipMappings = new ArrayList<>(); @@ -103,10 +106,19 @@ public void addConnectSettingsResponse( .setServerCaCert(new SslCert().setCert(TestKeys.getServerCertPem())) .setDatabaseVersion(databaseVersion) .setPscEnabled(pscHostname != null) - .setDnsName(pscHostname) .setPscEnabled(pscHostname != null) .setRegion(cloudSqlInstanceName.getRegionId()); settings.setFactory(GsonFactory.getDefaultInstance()); + if (legacyPscDnsName) { + settings.setDnsName(pscHostname); + } else { + settings.setDnsNames( + Collections.singletonList( + new DnsNameMapping() + .setDnsScope("INSTANCE") + .setConnectionType("PRIVATE_SERVICE_CONNECT") + .setName(pscHostname))); + } connectSettingsRequests.add( new ConnectSettingsRequest(cloudSqlInstanceName, settings, baseUrl)); diff --git a/core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java b/core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java index c55fa81dc..00c7c0d52 100644 --- a/core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java +++ b/core/src/test/java/com/google/cloud/sql/core/TestCertificateGenerator.java @@ -130,7 +130,7 @@ static KeyPair generateKeyPair() { SERVER_CA_SUBJECT, serverCaKeyPair.getPrivate(), ONE_YEAR_FROM_NOW, - null); + Collections.singletonList(new GeneralName(GeneralName.dNSName, "db.example.com"))); this.serverCertificate2 = buildSignedCertificate( @@ -139,7 +139,7 @@ static KeyPair generateKeyPair() { SERVER_CA_SUBJECT, serverCaKeyPair.getPrivate(), ONE_YEAR_FROM_NOW, - null); + Collections.singletonList(new GeneralName(GeneralName.dNSName, "db.example.com"))); this.serverIntemediateCaCert = buildSignedCertificate( diff --git a/pom.xml b/pom.xml index 72a28c7ee..0abd6a244 100644 --- a/pom.xml +++ b/pom.xml @@ -148,7 +148,7 @@ com.google.apis google-api-services-sqladmin - v1beta4-rev20250205-2.0.0 + v1beta4-rev20250226-2.0.0 com.google.http-client