5454import javax .net .ssl .SSLContext ;
5555import javax .net .ssl .SSLEngine ;
5656import javax .net .ssl .SSLPeerUnverifiedException ;
57+ import javax .net .ssl .SSLSocket ;
5758import javax .net .ssl .SSLSocketFactory ;
5859import javax .net .ssl .TrustManager ;
5960import javax .net .ssl .TrustManagerFactory ;
@@ -263,11 +264,10 @@ public void perRpcAuthorityOverride_checkServerTrustedIsCalled() throws Exceptio
263264 }
264265
265266 /**
266- * Uses a fake Trust Manager to fail authority verification for rpc identified by the call count .
267+ * Uses a fake Trust Manager to fail authority verification.
267268 */
268269 @ Test
269- public void perRpcAuthorityOverride_peerVerificationFails_rpcFails ()
270- throws Exception {
270+ public void perRpcAuthorityOverride_peerVerificationFails_rpcFails () throws Exception {
271271 OkHttpClientTransport .enablePerRpcAuthorityCheck = true ;
272272 try {
273273 ServerCredentials serverCreds ;
@@ -278,21 +278,21 @@ public void perRpcAuthorityOverride_peerVerificationFails_rpcFails()
278278 .build ();
279279 }
280280 ChannelCredentials channelCreds ;
281- FakeX509ExtendedTrustManager fakeTrustManager ;
282281 try (InputStream caCert = TlsTesting .loadCert ("ca.pem" )) {
283- X509ExtendedTrustManager x509ExtendedTrustManager =
282+ X509ExtendedTrustManager regularTrustManager =
284283 (X509ExtendedTrustManager ) getX509ExtendedTrustManager (caCert ).get ();
285- fakeTrustManager = new FakeX509ExtendedTrustManager (x509ExtendedTrustManager );
286- fakeTrustManager .setFailCheckServerTrusted ();
287284 channelCreds = TlsChannelCredentials .newBuilder ()
288- .trustManager (fakeTrustManager )
285+ .trustManager (new HostnameCheckingX509ExtendedTrustManager ( regularTrustManager ) )
289286 .build ();
290287 }
291288 Server server = grpcCleanupRule .register (server (serverCreds ));
292289 ManagedChannel channel = grpcCleanupRule .register (clientChannel (server , channelCreds ));
293290
291+ // Regular RPC succeeds
292+ ClientCalls .blockingUnaryCall (channel , SimpleServiceGrpc .getUnaryRpcMethod (),
293+ CallOptions .DEFAULT , SimpleRequest .getDefaultInstance ());
294294 try {
295- fakeTrustManager . setFailCheckServerTrusted ();
295+ // But with an authority it fails
296296 ClientCalls .blockingUnaryCall (channel , SimpleServiceGrpc .getUnaryRpcMethod (),
297297 CallOptions .DEFAULT .withAuthority ("moo.test.google.fr" ),
298298 SimpleRequest .getDefaultInstance ());
@@ -317,19 +317,16 @@ public void perRpcAuthorityOverride_peerVerificationFails_featureDisabled_rpcSuc
317317 .build ();
318318 }
319319 ChannelCredentials channelCreds ;
320- FakeX509ExtendedTrustManager fakeTrustManager ;
321320 try (InputStream caCert = TlsTesting .loadCert ("ca.pem" )) {
322- X509ExtendedTrustManager x509ExtendedTrustManager =
321+ X509ExtendedTrustManager regularTrustManager =
323322 (X509ExtendedTrustManager ) getX509ExtendedTrustManager (caCert ).get ();
324- fakeTrustManager = new FakeX509ExtendedTrustManager (x509ExtendedTrustManager );
325323 channelCreds = TlsChannelCredentials .newBuilder ()
326- .trustManager (fakeTrustManager )
324+ .trustManager (new HostnameCheckingX509ExtendedTrustManager ( regularTrustManager ) )
327325 .build ();
328326 }
329327 Server server = grpcCleanupRule .register (server (serverCreds ));
330328 ManagedChannel channel = grpcCleanupRule .register (clientChannel (server , channelCreds ));
331329
332- fakeTrustManager .setFailCheckServerTrusted ();
333330 ClientCalls .blockingUnaryCall (channel , SimpleServiceGrpc .getUnaryRpcMethod (),
334331 CallOptions .DEFAULT .withAuthority ("foo.test.google.fr" ),
335332 SimpleRequest .getDefaultInstance ());
@@ -623,28 +620,49 @@ private FakeX509ExtendedTrustManager getFakeX509ExtendedTrustManager()
623620 }
624621 }
625622
626- @ IgnoreJRERequirement
627- private static class FakeX509ExtendedTrustManager extends X509ExtendedTrustManager {
628- private final X509ExtendedTrustManager delegate ;
623+ private static class HostnameCheckingX509ExtendedTrustManager
624+ extends ForwardingX509ExtendedTrustManager {
625+ public HostnameCheckingX509ExtendedTrustManager (X509ExtendedTrustManager tm ) {
626+ super (tm );
627+ }
628+
629+ @ Override
630+ public void checkServerTrusted (X509Certificate [] chain , String authType , Socket socket )
631+ throws CertificateException {
632+ String peer = ((SSLSocket ) socket ).getHandshakeSession ().getPeerHost ();
633+ if (!TestUtils .TEST_SERVER_HOST .equals (peer )) {
634+ throw new CertificateException ("Peer verification failed." );
635+ }
636+ super .checkServerTrusted (chain , authType , socket );
637+ }
638+ }
639+
640+ private static class FakeX509ExtendedTrustManager extends ForwardingX509ExtendedTrustManager {
629641 private boolean checkServerTrustedCalled ;
630- private boolean shouldFailCheckServerTrustedForRpc ;
631- private int numCalls ;
632642
633643 private FakeX509ExtendedTrustManager (X509ExtendedTrustManager delegate ) {
634- this . delegate = delegate ;
644+ super ( delegate ) ;
635645 }
636646
637- private void setFailCheckServerTrusted () {
638- shouldFailCheckServerTrustedForRpc = true ;
647+ @ Override
648+ public void checkServerTrusted (X509Certificate [] chain , String authType , Socket socket )
649+ throws CertificateException {
650+ this .checkServerTrustedCalled = true ;
651+ super .checkServerTrusted (chain , authType , socket );
652+ }
653+ }
654+
655+ @ IgnoreJRERequirement
656+ private static class ForwardingX509ExtendedTrustManager extends X509ExtendedTrustManager {
657+ private final X509ExtendedTrustManager delegate ;
658+
659+ private ForwardingX509ExtendedTrustManager (X509ExtendedTrustManager delegate ) {
660+ this .delegate = delegate ;
639661 }
640662
641663 @ Override
642664 public void checkServerTrusted (X509Certificate [] chain , String authType , Socket socket )
643665 throws CertificateException {
644- this .checkServerTrustedCalled = true ;
645- if (shouldFailCheckServerTrustedForRpc && ++numCalls > 1 ) {
646- throw new CertificateException ("Peer verification failed." );
647- }
648666 delegate .checkServerTrusted (chain , authType , socket );
649667 }
650668
@@ -661,20 +679,26 @@ public void checkServerTrusted(X509Certificate[] chain, String authType)
661679 }
662680
663681 @ Override
664- public void checkClientTrusted (X509Certificate [] chain , String authType , SSLEngine engine ) {
682+ public void checkClientTrusted (X509Certificate [] chain , String authType , SSLEngine engine )
683+ throws CertificateException {
684+ delegate .checkClientTrusted (chain , authType , engine );
665685 }
666686
667687 @ Override
668- public void checkClientTrusted (X509Certificate [] chain , String authType ) {
688+ public void checkClientTrusted (X509Certificate [] chain , String authType )
689+ throws CertificateException {
690+ delegate .checkClientTrusted (chain , authType );
669691 }
670692
671693 @ Override
672- public void checkClientTrusted (X509Certificate [] chain , String authType , Socket socket ) {
694+ public void checkClientTrusted (X509Certificate [] chain , String authType , Socket socket )
695+ throws CertificateException {
696+ delegate .checkClientTrusted (chain , authType , socket );
673697 }
674698
675699 @ Override
676700 public X509Certificate [] getAcceptedIssuers () {
677- return new X509Certificate [ 0 ] ;
701+ return delegate . getAcceptedIssuers () ;
678702 }
679703 }
680704
0 commit comments