Skip to content

Commit 451c873

Browse files
committed
Merge branch '4.0.x'
Closes gh-50645
2 parents f9c9820 + 89597d4 commit 451c873

2 files changed

Lines changed: 37 additions & 7 deletions

File tree

module/spring-boot-rsocket/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,13 @@ private TcpServer apply(TcpServer server) {
269269

270270
}
271271

272-
private static final class HttpServerSslCustomizer extends SslCustomizer {
272+
static final class HttpServerSslCustomizer extends SslCustomizer {
273273

274274
private final SslProvider sslProvider;
275275

276276
private final Map<String, SslProvider> serverNameSslProviders;
277277

278-
private HttpServerSslCustomizer(Ssl.@Nullable ClientAuth clientAuth, SslBundle sslBundle,
278+
HttpServerSslCustomizer(Ssl.@Nullable ClientAuth clientAuth, SslBundle sslBundle,
279279
Map<String, SslBundle> serverNameSslBundles) {
280280
super(Ssl.ClientAuth.map(clientAuth, ClientAuth.NONE, ClientAuth.OPTIONAL, ClientAuth.REQUIRE));
281281
this.sslProvider = createSslProvider(sslBundle);
@@ -287,11 +287,13 @@ private HttpServer apply(HttpServer server) {
287287
}
288288

289289
private void applySecurity(SslContextSpec spec) {
290-
spec.sslContext(this.sslProvider.getSslContext()).setSniAsyncMappings((serverName, promise) -> {
291-
SslProvider provider = (serverName != null) ? this.serverNameSslProviders.get(serverName)
292-
: this.sslProvider;
293-
return promise.setSuccess(provider);
294-
});
290+
spec.sslContext(this.sslProvider.getSslContext())
291+
.setSniAsyncMappings((serverName, promise) -> promise.setSuccess(getSslProvider(serverName)));
292+
}
293+
294+
SslProvider getSslProvider(@Nullable String serverName) {
295+
return (serverName != null) ? this.serverNameSslProviders.getOrDefault(serverName, this.sslProvider)
296+
: this.sslProvider;
295297
}
296298

297299
private Map<String, SslProvider> createServerNameSslProviders(Map<String, SslBundle> serverNameSslBundles) {

module/spring-boot-rsocket/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import java.nio.channels.ClosedChannelException;
2121
import java.time.Duration;
2222
import java.util.Arrays;
23+
import java.util.Collections;
24+
import java.util.Map;
2325
import java.util.concurrent.Callable;
2426

2527
import io.netty.buffer.PooledByteBufAllocator;
@@ -267,6 +269,26 @@ void websocketTransportBasicSslCertificateFromFileSystemWithBundle(@ResourcePath
267269
testBasicSslWithPemCertificateFromBundle(testCert, testKey, testCert, Transport.WEBSOCKET);
268270
}
269271

272+
@Test
273+
@WithPackageResources({ "test-cert.pem", "test-key.pem" })
274+
void websocketTransportSslProviderFallsBackToDefaultWhenServerNameIsUnmapped() {
275+
SslBundle defaultBundle = createBundle("test-cert.pem", "test-key.pem");
276+
SslBundle mappedBundle = createBundle("test-cert.pem", "test-key.pem");
277+
NettyRSocketServerFactory.HttpServerSslCustomizer customizer = new NettyRSocketServerFactory.HttpServerSslCustomizer(
278+
Ssl.ClientAuth.NONE, defaultBundle, Map.of("mapped.example", mappedBundle));
279+
assertThat(customizer.getSslProvider("unmapped.example")).isSameAs(customizer.getSslProvider(null));
280+
}
281+
282+
@Test
283+
@WithPackageResources({ "test-cert.pem", "test-key.pem" })
284+
@SuppressWarnings("NullAway") // Test null check
285+
void websocketTransportSslProviderReturnsDefaultWhenServerNameIsNull() {
286+
SslBundle defaultBundle = createBundle("test-cert.pem", "test-key.pem");
287+
NettyRSocketServerFactory.HttpServerSslCustomizer customizer = new NettyRSocketServerFactory.HttpServerSslCustomizer(
288+
Ssl.ClientAuth.NONE, defaultBundle, Collections.emptyMap());
289+
assertThat(customizer.getSslProvider(null)).isNotNull();
290+
}
291+
270292
private void checkEchoRequest() {
271293
String payload = "test payload";
272294
assertThat(this.requester).isNotNull();
@@ -338,6 +360,12 @@ private void testBasicSslWithPemCertificateFromBundle(String certificate, String
338360
checkEchoRequest();
339361
}
340362

363+
private static SslBundle createBundle(String certificate, String certificatePrivateKey) {
364+
PemSslStoreDetails keyStoreDetails = PemSslStoreDetails.forCertificate("classpath:" + certificate)
365+
.withPrivateKey("classpath:" + certificatePrivateKey);
366+
return SslBundle.of(new PemSslStoreBundle(keyStoreDetails, null));
367+
}
368+
341369
@Test
342370
void tcpTransportSslRejectsInsecureClient() {
343371
NettyRSocketServerFactory factory = getFactory();

0 commit comments

Comments
 (0)