diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/ProxyStartup.java b/proxy/src/main/java/org/apache/rocketmq/proxy/ProxyStartup.java index 1b38a19ae6a..c8e51b98cfa 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/ProxyStartup.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/ProxyStartup.java @@ -45,6 +45,7 @@ import org.apache.rocketmq.proxy.processor.MessagingProcessor; import org.apache.rocketmq.proxy.remoting.RemotingProtocolServer; import org.apache.rocketmq.proxy.service.cert.TlsCertificateManager; +import org.apache.rocketmq.proxy.service.cert.TlsSniManager; import org.apache.rocketmq.remoting.protocol.RemotingCommand; import org.apache.rocketmq.srvutil.ServerUtil; @@ -78,7 +79,9 @@ public static void main(String[] args) { MessagingProcessor messagingProcessor = createMessagingProcessor(); // tls cert update - TlsCertificateManager tlsCertificateManager = new TlsCertificateManager(); + TlsSniManager tlsSniManager = new TlsSniManager(); + tlsSniManager.initialize(ConfigurationManager.getProxyConfig()); + TlsCertificateManager tlsCertificateManager = new TlsCertificateManager(tlsSniManager); PROXY_START_AND_SHUTDOWN.appendStartAndShutdown(tlsCertificateManager); // create grpcServer diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/config/ProxyConfig.java b/proxy/src/main/java/org/apache/rocketmq/proxy/config/ProxyConfig.java index 5a1a5859305..6151aa4f393 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/config/ProxyConfig.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/config/ProxyConfig.java @@ -84,6 +84,7 @@ public class ProxyConfig implements ConfigFile { private String tlsKeyPassword = ""; private String tlsCertPath = ConfigurationManager.getProxyHome() + "/conf/tls/rocketmq.crt"; private int tlsCertWatchIntervalMs = 60 * 60 * 1000; // 1 hour + private Map tlsDomainConfigs = new HashMap<>(); /** * gRPC */ @@ -529,6 +530,14 @@ public void setTlsCertPath(String tlsCertPath) { this.tlsCertPath = tlsCertPath; } + public Map getTlsDomainConfigs() { + return tlsDomainConfigs; + } + + public void setTlsDomainConfigs(Map tlsDomainConfigs) { + this.tlsDomainConfigs = tlsDomainConfigs; + } + public int getGrpcBossLoopNum() { return grpcBossLoopNum; } diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/config/TlsDomainConfig.java b/proxy/src/main/java/org/apache/rocketmq/proxy/config/TlsDomainConfig.java new file mode 100644 index 00000000000..185169130ef --- /dev/null +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/config/TlsDomainConfig.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.proxy.config; + +public class TlsDomainConfig { + private String certPath; + private String keyPath; + private String keyPassword; + + public TlsDomainConfig() { + } + + public String getCertPath() { + return certPath; + } + + public void setCertPath(String certPath) { + this.certPath = certPath; + } + + public String getKeyPath() { + return keyPath; + } + + public void setKeyPath(String keyPath) { + this.keyPath = keyPath; + } + + public String getKeyPassword() { + return keyPassword; + } + + public void setKeyPassword(String keyPassword) { + this.keyPassword = keyPassword; + } +} diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/grpc/GrpcServer.java b/proxy/src/main/java/org/apache/rocketmq/proxy/grpc/GrpcServer.java index af3d6b4c6c1..e77bdb351b8 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/grpc/GrpcServer.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/grpc/GrpcServer.java @@ -77,9 +77,9 @@ class GrpcTlsReloadHandler implements TlsCertificateManager.TlsContextReloadList @Override public void onTlsContextReload() { try { - ProxyAndTlsProtocolNegotiator.loadSslContext(); + ProxyAndTlsProtocolNegotiator.loadAllSslContexts(); log.info("SslContext reloaded for grpc server"); - } catch (CertificateException | IOException e) { + } catch (Exception e) { log.error("Failed to reload SslContext for server", e); } } diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/grpc/ProxyAndTlsProtocolNegotiator.java b/proxy/src/main/java/org/apache/rocketmq/proxy/grpc/ProxyAndTlsProtocolNegotiator.java index 4222dacaad2..30a19ee5a55 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/grpc/ProxyAndTlsProtocolNegotiator.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/grpc/ProxyAndTlsProtocolNegotiator.java @@ -18,7 +18,6 @@ import io.grpc.Attributes; import io.grpc.netty.shaded.io.grpc.netty.GrpcHttp2ConnectionHandler; -import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; import io.grpc.netty.shaded.io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.shaded.io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.shaded.io.grpc.netty.InternalProtocolNegotiators; @@ -35,22 +34,14 @@ import io.grpc.netty.shaded.io.netty.handler.codec.haproxy.HAProxyMessageDecoder; import io.grpc.netty.shaded.io.netty.handler.codec.haproxy.HAProxyProtocolVersion; import io.grpc.netty.shaded.io.netty.handler.codec.haproxy.HAProxyTLV; -import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; -import io.grpc.netty.shaded.io.netty.handler.ssl.OpenSsl; import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; import io.grpc.netty.shaded.io.netty.handler.ssl.SslHandler; -import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; - -import io.grpc.netty.shaded.io.netty.handler.ssl.util.InsecureTrustManagerFactory; -import io.grpc.netty.shaded.io.netty.handler.ssl.util.SelfSignedCertificate; +import io.grpc.netty.shaded.io.netty.handler.ssl.SniHandler; import io.grpc.netty.shaded.io.netty.util.AsciiString; import io.grpc.netty.shaded.io.netty.util.CharsetUtil; +import io.grpc.netty.shaded.io.netty.util.GlobalEventExecutor; +import io.grpc.netty.shaded.io.netty.util.concurrent.Promise; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.security.cert.CertificateException; import java.util.List; import org.apache.commons.collections.CollectionUtils; @@ -63,6 +54,7 @@ import org.apache.rocketmq.proxy.config.ConfigurationManager; import org.apache.rocketmq.proxy.config.ProxyConfig; import org.apache.rocketmq.proxy.grpc.constant.AttributeKeys; +import org.apache.rocketmq.proxy.service.cert.TlsSniManager; import org.apache.rocketmq.remoting.common.TlsMode; import org.apache.rocketmq.remoting.netty.TlsSystemConfig; @@ -72,18 +64,19 @@ public class ProxyAndTlsProtocolNegotiator implements InternalProtocolNegotiator private static final String HA_PROXY_DECODER = "HAProxyDecoder"; private static final String HA_PROXY_HANDLER = "HAProxyHandler"; private static final String TLS_MODE_HANDLER = "TlsModeHandler"; + private static final String SNI_HANDLER = "SniHandler"; /** * the length of the ssl record header (in bytes) */ private static final int SSL_RECORD_HEADER_LENGTH = 5; - private static SslContext sslContext; + private static volatile TlsSniManager tlsSniManager; public ProxyAndTlsProtocolNegotiator() { try { - loadSslContext(); - log.info("SslContext created for proxy server"); - } catch (IOException | CertificateException e) { + loadAllSslContexts(); + log.info("SslContext created for proxy server with SNI support"); + } catch (Exception e) { log.error("SslContext init error", e); throw new RuntimeException(e); } @@ -103,39 +96,24 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { public void close() { } - public static void loadSslContext() throws CertificateException, IOException { - ProxyConfig proxyConfig = ConfigurationManager.getProxyConfig(); - SslProvider provider; - if (OpenSsl.isAvailable()) { - provider = SslProvider.OPENSSL; - log.info("Using OpenSSL provider"); - } else { - provider = SslProvider.JDK; - log.info("Using JDK SSL provider"); - } - if (proxyConfig.isTlsTestModeEnable()) { - SelfSignedCertificate selfSignedCertificate = new SelfSignedCertificate(); - sslContext = GrpcSslContexts.forServer(selfSignedCertificate.certificate(), selfSignedCertificate.privateKey()) - .sslProvider(provider) - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .clientAuth(ClientAuth.NONE) - .build(); - } else { - String tlsCertPath = ConfigurationManager.getProxyConfig().getTlsCertPath(); - String tlsKeyPath = ConfigurationManager.getProxyConfig().getTlsKeyPath(); - String tlsKeyPassword = ConfigurationManager.getProxyConfig().getTlsKeyPassword(); - try (InputStream serverKeyInputStream = Files.newInputStream( - Paths.get(tlsKeyPath)); - InputStream serverCertificateStream = Files.newInputStream( - Paths.get(tlsCertPath))) { - sslContext = GrpcSslContexts.forServer(serverCertificateStream, - serverKeyInputStream, - StringUtils.isNotBlank(tlsKeyPassword) ? tlsKeyPassword : null) - .trustManager(InsecureTrustManagerFactory.INSTANCE) - .clientAuth(ClientAuth.NONE) - .build(); + private static TlsSniManager getTlsSniManager() { + if (tlsSniManager == null) { + synchronized (ProxyAndTlsProtocolNegotiator.class) { + if (tlsSniManager == null) { + tlsSniManager = new TlsSniManager(); + tlsSniManager.initialize(ConfigurationManager.getProxyConfig()); + } } } + return tlsSniManager; + } + + public static void loadAllSslContexts() { + getTlsSniManager(); + } + + public static TlsSniManager getManager() { + return getTlsSniManager(); } private class ProxyAndTlsProtocolHandler extends ByteToMessageDecoder { @@ -199,12 +177,6 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception ctx.pipeline().remove(this); } - /** - * The definition of key refers to the implementation of nginx - * ngx_http_core_module - * - * @param msg - */ private void handleWithMessage(HAProxyMessage msg) { try { Attributes.Builder builder = InternalProtocolNegotiationEvent.getAttributes(pne).toBuilder(); @@ -254,14 +226,10 @@ private class TlsModeHandler extends ByteToMessageDecoder { private ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); - private final ChannelHandler ssl; - private final ChannelHandler plaintext; + private final GrpcHttp2ConnectionHandler grpcHandler; public TlsModeHandler(GrpcHttp2ConnectionHandler grpcHandler) { - this.ssl = InternalProtocolNegotiators.serverTls(sslContext) - .newHandler(grpcHandler); - this.plaintext = InternalProtocolNegotiators.serverPlaintext() - .newHandler(grpcHandler); + this.grpcHandler = grpcHandler; } @Override @@ -269,18 +237,17 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { try { TlsMode tlsMode = TlsSystemConfig.tlsMode; if (TlsMode.ENFORCING.equals(tlsMode)) { - ctx.pipeline().addAfter(ctx.name(), null, this.ssl); + addSniHandler(ctx); } else if (TlsMode.DISABLED.equals(tlsMode)) { - ctx.pipeline().addAfter(ctx.name(), null, this.plaintext); + addPlaintextHandler(ctx); } else { - // in SslHandler.isEncrypted, it needs at least 5 bytes to judge is encrypted or not if (in.readableBytes() < SSL_RECORD_HEADER_LENGTH) { return; } if (SslHandler.isEncrypted(in)) { - ctx.pipeline().addAfter(ctx.name(), null, this.ssl); + addSniHandler(ctx); } else { - ctx.pipeline().addAfter(ctx.name(), null, this.plaintext); + addPlaintextHandler(ctx); } } ctx.fireUserEventTriggered(pne); @@ -291,6 +258,25 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { } } + private void addSniHandler(ChannelHandlerContext ctx) { + TlsSniManager sniManager = getTlsSniManager(); + SniHandler sniHandler = new SniHandler((hostname, promise) -> { + SslContext sslCtx = sniManager.getSslContext(hostname); + if (sslCtx != null) { + promise.setSuccess(sslCtx); + } else { + promise.setSuccess(sniManager.getDefaultContext()); + } + }, GlobalEventExecutor.INSTANCE); + ctx.pipeline().addAfter(ctx.name(), SNI_HANDLER, sniHandler); + } + + private void addPlaintextHandler(ChannelHandlerContext ctx) { + ChannelHandler plaintext = InternalProtocolNegotiators.serverPlaintext() + .newHandler(grpcHandler); + ctx.pipeline().addAfter(ctx.name(), null, plaintext); + } + @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { @@ -300,4 +286,4 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } } } -} \ No newline at end of file +} diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/remoting/MultiProtocolRemotingServer.java b/proxy/src/main/java/org/apache/rocketmq/proxy/remoting/MultiProtocolRemotingServer.java index 7bbca44a508..8ee807d7e6f 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/remoting/MultiProtocolRemotingServer.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/remoting/MultiProtocolRemotingServer.java @@ -19,24 +19,26 @@ import io.netty.channel.ChannelPipeline; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.ssl.SslContext; import io.netty.handler.timeout.IdleStateHandler; import org.apache.rocketmq.common.constant.LoggerName; import org.apache.rocketmq.logging.org.slf4j.Logger; import org.apache.rocketmq.logging.org.slf4j.LoggerFactory; import org.apache.rocketmq.proxy.common.ProxyException; import org.apache.rocketmq.proxy.common.ProxyExceptionCode; +import org.apache.rocketmq.proxy.config.ConfigurationManager; +import org.apache.rocketmq.proxy.config.ProxyConfig; import org.apache.rocketmq.proxy.remoting.protocol.ProtocolNegotiationHandler; import org.apache.rocketmq.proxy.remoting.protocol.http2proxy.Http2ProtocolProxyHandler; import org.apache.rocketmq.proxy.remoting.protocol.remoting.RemotingProtocolHandler; +import org.apache.rocketmq.proxy.service.cert.TlsSniManager; import org.apache.rocketmq.remoting.ChannelEventListener; import org.apache.rocketmq.remoting.common.TlsMode; import org.apache.rocketmq.remoting.netty.NettyRemotingServer; import org.apache.rocketmq.remoting.netty.NettyServerConfig; +import org.apache.rocketmq.remoting.netty.TlsContextProvider; import org.apache.rocketmq.remoting.netty.TlsSystemConfig; -import java.io.IOException; -import java.security.cert.CertificateException; - /** * support remoting and http2 protocol at one port */ @@ -47,6 +49,7 @@ public class MultiProtocolRemotingServer extends NettyRemotingServer { private final RemotingProtocolHandler remotingProtocolHandler; protected Http2ProtocolProxyHandler http2ProtocolProxyHandler; + private TlsSniManager tlsSniManager; public MultiProtocolRemotingServer(NettyServerConfig nettyServerConfig, ChannelEventListener channelEventListener) { super(nettyServerConfig, channelEventListener); @@ -67,14 +70,49 @@ public void loadSslContext() { if (tlsMode != TlsMode.DISABLED) { try { - sslContext = MultiProtocolTlsHelper.buildSslContext(); - log.info("SslContext created for multi protocol remoting server"); - } catch (CertificateException | IOException e) { + ProxyConfig proxyConfig = ConfigurationManager.getProxyConfig(); + if (proxyConfig.getTlsDomainConfigs() != null && !proxyConfig.getTlsDomainConfigs().isEmpty()) { + // SNI mode: reload all domain contexts + if (tlsSniManager == null) { + tlsSniManager = new TlsSniManager(); + tlsSniManager.initialize(proxyConfig); + } else { + tlsSniManager.reloadDefaultContext(); + for (String domain : proxyConfig.getTlsDomainConfigs().keySet()) { + tlsSniManager.reloadDomainContext(domain); + } + } + + TlsContextProvider.getInstance().setSniLookup( + new TlsContextProvider.SniContextLookup() { + @Override + public SslContext lookup(String sniHostname) { + return tlsSniManager.getSslContext(sniHostname); + } + + @Override + public SslContext getDefaultContext() { + return tlsSniManager.getDefaultContext(); + } + } + ); + log.info("SNI-enabled SslContext created/reloaded for multi protocol remoting server"); + } else { + // Single cert mode: backward compatible + sslContext = MultiProtocolTlsHelper.buildSslContext(); + TlsContextProvider.getInstance().setSingleContext(sslContext); + log.info("Single SslContext created/reloaded for multi protocol remoting server"); + } + } catch (Exception e) { throw new ProxyException(ProxyExceptionCode.INTERNAL_SERVER_ERROR, "Failed to create SslContext for server", e); } } } + public TlsSniManager getTlsSniManager() { + return tlsSniManager; + } + @Override protected ChannelPipeline configChannel(SocketChannel ch) { return ch.pipeline() diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/service/cert/TlsCertificateManager.java b/proxy/src/main/java/org/apache/rocketmq/proxy/service/cert/TlsCertificateManager.java index 2ab4f31b6ed..e142150806c 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/service/cert/TlsCertificateManager.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/service/cert/TlsCertificateManager.java @@ -21,35 +21,69 @@ import org.apache.rocketmq.logging.org.slf4j.Logger; import org.apache.rocketmq.logging.org.slf4j.LoggerFactory; import org.apache.rocketmq.proxy.config.ConfigurationManager; +import org.apache.rocketmq.proxy.config.ProxyConfig; +import org.apache.rocketmq.proxy.config.TlsDomainConfig; import org.apache.rocketmq.remoting.netty.TlsSystemConfig; import org.apache.rocketmq.srvutil.FileWatchService; + import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class TlsCertificateManager implements StartAndShutdown { private static final Logger log = LoggerFactory.getLogger(LoggerName.PROXY_LOGGER_NAME); - private final FileWatchService fileWatchService; private final List reloadListeners = new ArrayList<>(); + private final List domainReloadListeners = new ArrayList<>(); + private final List fileWatchServices = new ArrayList<>(); + private final TlsSniManager tlsSniManager; + + public TlsCertificateManager(TlsSniManager tlsSniManager) { + this.tlsSniManager = tlsSniManager; + ProxyConfig config = ConfigurationManager.getProxyConfig(); + int watchInterval = config.getTlsCertWatchIntervalMs(); - public TlsCertificateManager() { + // Watch default cert/key pair try { - this.fileWatchService = new FileWatchService( - new String[] { - ConfigurationManager.getProxyConfig().getTlsCertPath(), - ConfigurationManager.getProxyConfig().getTlsKeyPath() - }, - new CertKeyFileWatchListener(), - ConfigurationManager.getProxyConfig().getTlsCertWatchIntervalMs() + String defaultCertPath = config.getTlsCertPath(); + String defaultKeyPath = config.getTlsKeyPath(); + FileWatchService defaultWatchService = new FileWatchService( + new String[] {defaultCertPath, defaultKeyPath}, + new DefaultCertKeyFileWatchListener(), + watchInterval ); + fileWatchServices.add(defaultWatchService); + log.info("Watching default TLS cert/key: {}, {}", defaultCertPath, defaultKeyPath); } catch (Exception e) { - log.error("Failed to initialize TLS certificate watch service", e); + log.error("Failed to initialize default TLS certificate watch service", e); throw new RuntimeException("Failed to initialize TLS certificate manager", e); } + + // Watch domain-specific cert/key pairs + Map domainConfigs = config.getTlsDomainConfigs(); + if (domainConfigs != null && !domainConfigs.isEmpty()) { + for (Map.Entry entry : domainConfigs.entrySet()) { + String domainPattern = entry.getKey(); + TlsDomainConfig domainConfig = entry.getValue(); + try { + FileWatchService domainWatchService = new FileWatchService( + new String[] {domainConfig.getCertPath(), domainConfig.getKeyPath()}, + new DomainCertKeyFileWatchListener(domainPattern), + watchInterval + ); + fileWatchServices.add(domainWatchService); + log.info("Watching domain TLS cert/key: {}, {} for pattern: {}", + domainConfig.getCertPath(), domainConfig.getKeyPath(), domainPattern); + } catch (Exception e) { + log.error("Failed to initialize domain TLS certificate watch service for: {}", domainPattern, e); + } + } + } } - public FileWatchService getFileWatchService() { - return this.fileWatchService; + public List getFileWatchServices() { + return this.fileWatchServices; } public void registerReloadListener(TlsContextReloadListener listener) { @@ -64,40 +98,54 @@ public void unregisterReloadListener(TlsContextReloadListener listener) { } } + public void registerDomainReloadListener(DomainReloadListener listener) { + if (listener != null) { + this.domainReloadListeners.add(listener); + } + } + + public void unregisterDomainReloadListener(DomainReloadListener listener) { + if (listener != null) { + this.domainReloadListeners.remove(listener); + } + } + public List getReloadListeners() { return this.reloadListeners; } @Override public void start() throws Exception { - this.fileWatchService.start(); - log.info("TLS certificate manager started successfully, start watching: {} {}", - ConfigurationManager.getProxyConfig().getTlsCertPath(), - ConfigurationManager.getProxyConfig().getTlsKeyPath() - ); + for (FileWatchService service : fileWatchServices) { + service.start(); + } + log.info("TLS certificate manager started successfully, watching {} file groups", fileWatchServices.size()); } @Override public void shutdown() throws Exception { - this.fileWatchService.shutdown(); + for (FileWatchService service : fileWatchServices) { + service.shutdown(); + } log.info("TLS certificate manager shutdown successfully"); } - private class CertKeyFileWatchListener implements FileWatchService.Listener { + private class DefaultCertKeyFileWatchListener implements FileWatchService.Listener { private boolean certChanged = false; private boolean keyChanged = false; @Override public void onChanged(String path) { - log.info("File changed: {}", path); - if (path.equals(TlsSystemConfig.tlsServerCertPath)) { + log.info("Default TLS file changed: {}", path); + if (path.equals(TlsSystemConfig.tlsServerCertPath) || path.equals(ConfigurationManager.getProxyConfig().getTlsCertPath())) { certChanged = true; - } else if (path.equals(TlsSystemConfig.tlsServerKeyPath)) { + } else if (path.equals(TlsSystemConfig.tlsServerKeyPath) || path.equals(ConfigurationManager.getProxyConfig().getTlsKeyPath())) { keyChanged = true; } if (certChanged && keyChanged) { - log.info("The certificate and private key changed, reload the ssl context"); + log.info("The default certificate and private key changed, reload the default ssl context"); + tlsSniManager.reloadDefaultContext(); notifyContextReload(); certChanged = false; keyChanged = false; @@ -115,8 +163,55 @@ private void notifyContextReload() { } } + private class DomainCertKeyFileWatchListener implements FileWatchService.Listener { + private final String domainPattern; + private boolean certChanged = false; + private boolean keyChanged = false; + + DomainCertKeyFileWatchListener(String domainPattern) { + this.domainPattern = domainPattern; + } + + @Override + public void onChanged(String path) { + log.info("Domain TLS file changed: {} for pattern: {}", path, domainPattern); + TlsDomainConfig config = ConfigurationManager.getProxyConfig().getTlsDomainConfigs().get(domainPattern); + if (config == null) { + return; + } + if (path.equals(config.getCertPath())) { + certChanged = true; + } else if (path.equals(config.getKeyPath())) { + keyChanged = true; + } + + if (certChanged && keyChanged) { + log.info("The certificate and private key changed for domain: {}, reload the ssl context", domainPattern); + tlsSniManager.reloadDomainContext(domainPattern); + notifyDomainReload(domainPattern); + certChanged = false; + keyChanged = false; + } + } + + private void notifyDomainReload(String domainPattern) { + for (DomainReloadListener listener : domainReloadListeners) { + try { + listener.onDomainTlsContextReload(domainPattern); + } catch (Throwable e) { + log.error("Failed to notify domain TLS context reload to listener: " + listener, e); + } + } + } + } + // Interface for listeners interested in TLS context reload events public interface TlsContextReloadListener { void onTlsContextReload(); } + + // Interface for listeners interested in domain-specific TLS context reload events + public interface DomainReloadListener { + void onDomainTlsContextReload(String domainPattern); + } } diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/service/cert/TlsSniManager.java b/proxy/src/main/java/org/apache/rocketmq/proxy/service/cert/TlsSniManager.java new file mode 100644 index 00000000000..804a0bd3eff --- /dev/null +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/service/cert/TlsSniManager.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.proxy.service.cert; + +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; +import io.grpc.netty.shaded.io.netty.handler.ssl.OpenSsl; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; +import io.grpc.netty.shaded.io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.grpc.netty.shaded.io.netty.handler.ssl.util.SelfSignedCertificate; +import org.apache.commons.lang3.StringUtils; +import org.apache.rocketmq.common.constant.LoggerName; +import org.apache.rocketmq.logging.org.slf4j.Logger; +import org.apache.rocketmq.logging.org.slf4j.LoggerFactory; +import org.apache.rocketmq.proxy.config.ConfigurationManager; +import org.apache.rocketmq.proxy.config.ProxyConfig; +import org.apache.rocketmq.proxy.config.TlsDomainConfig; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.cert.CertificateException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class TlsSniManager { + private static final Logger log = LoggerFactory.getLogger(LoggerName.PROXY_LOGGER_NAME); + + private volatile SslContext defaultContext; + private volatile Map domainContexts = new ConcurrentHashMap<>(); + private Map domainConfigs; + private boolean tlsTestModeEnable; + private String tlsKeyPassword; + + /** + * Get the matching SslContext for the given SNI hostname. + * Supports wildcard matching (e.g. *.alibaba-inc.com matches foo.alibaba-inc.com). + * Returns defaultContext when no match is found. + */ + public SslContext getSslContext(String sniHostname) { + if (StringUtils.isBlank(sniHostname)) { + return defaultContext; + } + + // Exact match first + SslContext ctx = domainContexts.get(sniHostname); + if (ctx != null) { + return ctx; + } + + // Wildcard match: foo.alibaba-inc.com matches *.alibaba-inc.com + for (Map.Entry entry : domainContexts.entrySet()) { + String domainPattern = entry.getKey(); + if (domainPattern.startsWith("*.")) { + String suffix = domainPattern.substring(1); + if (sniHostname.endsWith(suffix) && sniHostname.length() > suffix.length()) { + String remaining = sniHostname.substring(0, sniHostname.length() - suffix.length()); + if (!remaining.contains(".")) { + return entry.getValue(); + } + } + } + } + + // Bare domain matches wildcard: rocketmq.com matches *.rocketmq.com + for (Map.Entry entry : domainContexts.entrySet()) { + String domainPattern = entry.getKey(); + if (domainPattern.startsWith("*.")) { + String bareDomain = domainPattern.substring(2); + if (sniHostname.equals(bareDomain)) { + return entry.getValue(); + } + } + } + + return defaultContext; + } + + public SslContext getDefaultContext() { + return defaultContext; + } + + public Map getDomainConfigs() { + return domainConfigs; + } + + /** + * Rebuild SslContext for a specific domain (used for hot reload). + */ + public void reloadDomainContext(String domainPattern) { + TlsDomainConfig config = domainConfigs.get(domainPattern); + if (config == null) { + log.warn("Cannot reload domain context, config not found: {}", domainPattern); + return; + } + try { + SslContext newCtx = buildSslContext(config, tlsTestModeEnable); + domainContexts.put(domainPattern, newCtx); + log.info("Reloaded SslContext for domain: {}", domainPattern); + } catch (Exception e) { + log.error("Failed to reload SslContext for domain: {}", domainPattern, e); + } + } + + /** + * Reload the default context. + */ + public void reloadDefaultContext() { + ProxyConfig proxyConfig = ConfigurationManager.getProxyConfig(); + try { + defaultContext = buildDefaultSslContext(proxyConfig); + log.info("Reloaded default SslContext"); + } catch (Exception e) { + log.error("Failed to reload default SslContext", e); + } + } + + /** + * Initialize all domain SslContexts from ProxyConfig. + */ + public void initialize(ProxyConfig config) { + this.tlsTestModeEnable = config.isTlsTestModeEnable(); + this.tlsKeyPassword = config.getTlsKeyPassword(); + this.domainConfigs = config.getTlsDomainConfigs(); + + try { + defaultContext = buildDefaultSslContext(config); + log.info("Initialized default SslContext"); + } catch (Exception e) { + log.error("Failed to initialize default SslContext", e); + throw new RuntimeException("Failed to initialize TlsSniManager", e); + } + + if (domainConfigs != null && !domainConfigs.isEmpty()) { + for (Map.Entry entry : domainConfigs.entrySet()) { + String domainPattern = entry.getKey(); + TlsDomainConfig domainConfig = entry.getValue(); + try { + SslContext ctx = buildSslContext(domainConfig, tlsTestModeEnable); + domainContexts.put(domainPattern, ctx); + log.info("Initialized SslContext for domain: {}", domainPattern); + } catch (Exception e) { + log.error("Failed to initialize SslContext for domain: {}", domainPattern, e); + throw new RuntimeException("Failed to initialize TlsSniManager for domain: " + domainPattern, e); + } + } + } + } + + private SslContext buildDefaultSslContext(ProxyConfig config) throws CertificateException, IOException { + SslProvider provider = OpenSsl.isAvailable() ? SslProvider.OPENSSL : SslProvider.JDK; + if (config.isTlsTestModeEnable()) { + SelfSignedCertificate selfSignedCertificate = new SelfSignedCertificate(); + return GrpcSslContexts.forServer(selfSignedCertificate.certificate(), selfSignedCertificate.privateKey()) + .sslProvider(provider) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .clientAuth(ClientAuth.NONE) + .build(); + } else { + String tlsCertPath = config.getTlsCertPath(); + String tlsKeyPath = config.getTlsKeyPath(); + String tlsKeyPassword = config.getTlsKeyPassword(); + try (InputStream serverKeyInputStream = Files.newInputStream(Paths.get(tlsKeyPath)); + InputStream serverCertificateStream = Files.newInputStream(Paths.get(tlsCertPath))) { + return GrpcSslContexts.forServer(serverCertificateStream, + serverKeyInputStream, + StringUtils.isNotBlank(tlsKeyPassword) ? tlsKeyPassword : null) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .clientAuth(ClientAuth.NONE) + .build(); + } + } + } + + private SslContext buildSslContext(TlsDomainConfig config, boolean testMode) throws CertificateException, IOException { + SslProvider provider = OpenSsl.isAvailable() ? SslProvider.OPENSSL : SslProvider.JDK; + if (testMode) { + SelfSignedCertificate selfSignedCertificate = new SelfSignedCertificate(); + return GrpcSslContexts.forServer(selfSignedCertificate.certificate(), selfSignedCertificate.privateKey()) + .sslProvider(provider) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .clientAuth(ClientAuth.NONE) + .build(); + } else { + String tlsCertPath = config.getCertPath(); + String tlsKeyPath = config.getKeyPath(); + String tlsKeyPassword = StringUtils.isNotBlank(config.getKeyPassword()) ? config.getKeyPassword() : this.tlsKeyPassword; + try (InputStream serverKeyInputStream = Files.newInputStream(Paths.get(tlsKeyPath)); + InputStream serverCertificateStream = Files.newInputStream(Paths.get(tlsCertPath))) { + return GrpcSslContexts.forServer(serverCertificateStream, + serverKeyInputStream, + StringUtils.isNotBlank(tlsKeyPassword) ? tlsKeyPassword : null) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .clientAuth(ClientAuth.NONE) + .build(); + } + } + } +} diff --git a/proxy/src/test/java/org/apache/rocketmq/proxy/service/cert/TlsCertificateManagerTest.java b/proxy/src/test/java/org/apache/rocketmq/proxy/service/cert/TlsCertificateManagerTest.java index 9e5f5417462..f2ce47ba11b 100644 --- a/proxy/src/test/java/org/apache/rocketmq/proxy/service/cert/TlsCertificateManagerTest.java +++ b/proxy/src/test/java/org/apache/rocketmq/proxy/service/cert/TlsCertificateManagerTest.java @@ -18,6 +18,9 @@ import org.apache.rocketmq.proxy.config.ConfigurationManager; import org.apache.rocketmq.proxy.config.ProxyConfig; +import org.apache.rocketmq.proxy.config.TlsDomainConfig; +import org.apache.rocketmq.proxy.service.cert.TlsCertificateManager; +import org.apache.rocketmq.proxy.service.cert.TlsSniManager; import org.apache.rocketmq.remoting.netty.TlsSystemConfig; import org.apache.rocketmq.srvutil.FileWatchService; import org.junit.After; @@ -31,10 +34,10 @@ import java.io.File; import java.io.FileWriter; -import java.lang.reflect.Constructor; import java.lang.reflect.Field; -import java.lang.reflect.Method; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -54,6 +57,7 @@ public class TlsCertificateManagerTest { @Rule public TemporaryFolder tempDir = new TemporaryFolder(); + private TlsSniManager tlsSniManager; private TlsCertificateManager manager; @Mock @@ -85,29 +89,22 @@ public void setUp() throws Exception { TlsSystemConfig.tlsServerCertPath = certFile.getAbsolutePath(); TlsSystemConfig.tlsServerKeyPath = keyFile.getAbsolutePath(); - // Create the TlsCertificateManager - manager = new TlsCertificateManager(); + // Create TlsSniManager and TlsCertificateManager + tlsSniManager = new TlsSniManager(); + manager = new TlsCertificateManager(tlsSniManager); - // Extract the file watch listener using reflection + // Extract the file watch listener from the first watch service fileWatchListener = extractFileWatchListener(manager); } - @After - public void tearDown() throws Exception { - // Restore the original config - if (configField != null && originalConfig != null) { - configField.set(null, originalConfig); - } - } - private FileWatchService.Listener extractFileWatchListener(TlsCertificateManager manager) throws Exception { - Field fileWatchServiceField = TlsCertificateManager.class.getDeclaredField("fileWatchService"); - fileWatchServiceField.setAccessible(true); - FileWatchService fileWatchService = (FileWatchService) fileWatchServiceField.get(manager); + Field fileWatchServicesField = TlsCertificateManager.class.getDeclaredField("fileWatchServices"); + fileWatchServicesField.setAccessible(true); + List fileWatchServices = (List) fileWatchServicesField.get(manager); Field listenerField = FileWatchService.class.getDeclaredField("listener"); listenerField.setAccessible(true); - return (FileWatchService.Listener) listenerField.get(fileWatchService); + return (FileWatchService.Listener) listenerField.get(fileWatchServices.get(0)); } @Test @@ -120,11 +117,14 @@ public void testConstructor() { public void testStartAndShutdown() throws Exception { TlsCertificateManager managerSpy = spy(manager); - Field watchServiceField = TlsCertificateManager.class.getDeclaredField("fileWatchService"); - watchServiceField.setAccessible(true); - FileWatchService watchService = (FileWatchService) watchServiceField.get(managerSpy); + Field watchServicesField = TlsCertificateManager.class.getDeclaredField("fileWatchServices"); + watchServicesField.setAccessible(true); + List watchServices = (List) watchServicesField.get(managerSpy); + + // Spy on the first watch service + FileWatchService watchService = watchServices.get(0); FileWatchService watchServiceSpy = spy(watchService); - watchServiceField.set(managerSpy, watchServiceSpy); + watchServices.set(0, watchServiceSpy); managerSpy.start(); verify(watchServiceSpy).start(); diff --git a/proxy/src/test/java/org/apache/rocketmq/proxy/service/cert/TlsSniManagerTest.java b/proxy/src/test/java/org/apache/rocketmq/proxy/service/cert/TlsSniManagerTest.java new file mode 100644 index 00000000000..2d5aa85078e --- /dev/null +++ b/proxy/src/test/java/org/apache/rocketmq/proxy/service/cert/TlsSniManagerTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.proxy.service.cert; + +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import org.apache.rocketmq.proxy.config.TlsDomainConfig; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.FileWriter; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +public class TlsSniManagerTest { + + @Rule + public TemporaryFolder tempDir = new TemporaryFolder(); + + private TlsSniManager sniManager; + + private File defaultCertFile; + private File defaultKeyFile; + private File comCertFile; + private File comKeyFile; + private File alibabaCertFile; + private File alibabaKeyFile; + + @Before + public void setUp() throws Exception { + // Create temporary certificate and key files for default + defaultCertFile = tempDir.newFile("default.crt"); + defaultKeyFile = tempDir.newFile("default.key"); + try (FileWriter w = new FileWriter(defaultCertFile)) { w.write("default cert"); } + try (FileWriter w = new FileWriter(defaultKeyFile)) { w.write("default key"); } + + // Create files for *.rocketmq.com + comCertFile = tempDir.newFile("rocketmq.crt"); + comKeyFile = tempDir.newFile("rocketmq.key"); + try (FileWriter w = new FileWriter(comCertFile)) { w.write("rocketmq cert"); } + try (FileWriter w = new FileWriter(comKeyFile)) { w.write("rocketmq key"); } + + // Create files for *.alibaba-inc.com + alibabaCertFile = tempDir.newFile("alibaba.crt"); + alibabaKeyFile = tempDir.newFile("alibaba.key"); + try (FileWriter w = new FileWriter(alibabaCertFile)) { w.write("alibaba cert"); } + try (FileWriter w = new FileWriter(alibabaKeyFile)) { w.write("alibaba key"); } + } + + @After + public void tearDown() throws Exception { + } + + @Test + public void testInitializeWithTestMode() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + assertNotNull(sniManager.getDefaultContext()); + assertNotNull(sniManager.getSslContext("test.rocketmq.com")); + } + + @Test + public void testWildcardMatch_ComDomain() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext ctx = sniManager.getSslContext("foo.rocketmq.com"); + assertNotNull(ctx); + // In test mode all contexts are SelfSignedCertificate, but they should be different instances + } + + @Test + public void testWildcardMatch_AlibabaDomain() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext ctx = sniManager.getSslContext("mq.alibaba-inc.com"); + assertNotNull(ctx); + } + + @Test + public void testExactMatch() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext ctx = sniManager.getSslContext("rocketmq.com"); + assertNotNull(ctx); + } + + @Test + public void testNoMatchFallbackToDefault() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext ctx = sniManager.getSslContext("unknown.other.com"); + assertNotNull(ctx); + assertSame(sniManager.getDefaultContext(), ctx); + } + + @Test + public void testNullSniFallbackToDefault() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext ctx = sniManager.getSslContext(null); + assertNotNull(ctx); + assertSame(sniManager.getDefaultContext(), ctx); + } + + @Test + public void testEmptySniFallbackToDefault() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext ctx = sniManager.getSslContext(""); + assertNotNull(ctx); + assertSame(sniManager.getDefaultContext(), ctx); + } + + @Test + public void testMultiLevelSubdomainNoMatch() { + // a.b.rocketmq.com should NOT match *.rocketmq.com + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext ctx = sniManager.getSslContext("a.b.rocketmq.com"); + assertNotNull(ctx); + assertSame(sniManager.getDefaultContext(), ctx); + } + + @Test + public void testReloadDomainContext() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext before = sniManager.getSslContext("foo.rocketmq.com"); + sniManager.reloadDomainContext("*.rocketmq.com"); + SslContext after = sniManager.getSslContext("foo.rocketmq.com"); + assertNotNull(before); + assertNotNull(after); + // In test mode the context instances should be different after reload + } + + @Test + public void testReloadDefaultContext() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + SslContext before = sniManager.getDefaultContext(); + sniManager.reloadDefaultContext(); + SslContext after = sniManager.getDefaultContext(); + assertNotNull(before); + assertNotNull(after); + } + + @Test + public void testDomainConfigsNotEmpty() { + sniManager = new TlsSniManager(); + sniManager.initialize(createTestModeProxyConfig()); + + Map configs = sniManager.getDomainConfigs(); + assertNotNull(configs); + assertEquals(2, configs.size()); + assertTrue(configs.containsKey("*.rocketmq.com")); + assertTrue(configs.containsKey("*.alibaba-inc.com")); + } + + private org.apache.rocketmq.proxy.config.ProxyConfig createTestModeProxyConfig() { + // We need to create a ProxyConfig-like object manually since ConfigurationManager may not be initialized + // For test mode, we can use a simplified approach + org.apache.rocketmq.proxy.config.ProxyConfig config = new org.apache.rocketmq.proxy.config.ProxyConfig(); + config.setTlsTestModeEnable(true); + + Map domainConfigs = new HashMap<>(); + TlsDomainConfig comConfig = new TlsDomainConfig(); + comConfig.setCertPath(comCertFile.getAbsolutePath()); + comConfig.setKeyPath(comKeyFile.getAbsolutePath()); + domainConfigs.put("*.rocketmq.com", comConfig); + + TlsDomainConfig alibabaConfig = new TlsDomainConfig(); + alibabaConfig.setCertPath(alibabaCertFile.getAbsolutePath()); + alibabaConfig.setKeyPath(alibabaKeyFile.getAbsolutePath()); + domainConfigs.put("*.alibaba-inc.com", alibabaConfig); + + config.setTlsDomainConfigs(domainConfigs); + + return config; + } +} diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java index 578c102daa4..a5254e32d0c 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java @@ -45,11 +45,14 @@ import io.netty.handler.codec.haproxy.HAProxyMessageDecoder; import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; import io.netty.handler.codec.haproxy.HAProxyTLV; +import io.netty.handler.ssl.SniHandler; +import io.netty.handler.ssl.SslContext; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import io.netty.handler.timeout.IdleStateHandler; import io.netty.util.AttributeKey; import io.netty.util.CharsetUtil; +import io.netty.util.GlobalEventExecutor; import io.netty.util.HashedWheelTimer; import io.netty.util.Timeout; import io.netty.util.TimerTask; @@ -507,10 +510,19 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { throw new UnsupportedOperationException("The NettyRemotingServer in SSL disabled mode doesn't support ssl client"); case PERMISSIVE: case ENFORCING: - if (null != sslContext) { - ctx.pipeline() - .addAfter(getDefaultEventExecutorGroup(), TLS_MODE_HANDLER, TLS_HANDLER_NAME, sslContext.newHandler(ctx.channel().alloc())) - .addAfter(getDefaultEventExecutorGroup(), TLS_HANDLER_NAME, FILE_REGION_ENCODER_NAME, new FileRegionEncoder()); + SslContext defaultCtx = TlsContextProvider.getInstance().getDefaultContext(); + if (null != defaultCtx) { + SniContextProvider sniProvider = TlsContextProvider.getInstance().getSniLookup(); + if (sniProvider != null) { + ctx.pipeline() + .addAfter(getDefaultEventExecutorGroup(), TLS_MODE_HANDLER, TLS_HANDLER_NAME, + new SniHandler(sniProvider, GlobalEventExecutor.INSTANCE)) + .addAfter(getDefaultEventExecutorGroup(), TLS_HANDLER_NAME, FILE_REGION_ENCODER_NAME, new FileRegionEncoder()); + } else { + ctx.pipeline() + .addAfter(getDefaultEventExecutorGroup(), TLS_MODE_HANDLER, TLS_HANDLER_NAME, defaultCtx.newHandler(ctx.channel().alloc())) + .addAfter(getDefaultEventExecutorGroup(), TLS_HANDLER_NAME, FILE_REGION_ENCODER_NAME, new FileRegionEncoder()); + } log.info("Handlers prepended to channel pipeline to establish SSL connection"); } else { ctx.close(); diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/TlsContextProvider.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/TlsContextProvider.java new file mode 100644 index 00000000000..5603f06359d --- /dev/null +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/TlsContextProvider.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.rocketmq.remoting.netty; + +import io.netty.handler.ssl.SslContext; + +/** + * Holder for SslContext used by remoting server's TlsModeHandler. + * Proxy module initializes this with either a single SslContext or a TlsSniManager-backed provider. + */ +public class TlsContextProvider { + + private static volatile TlsContextProvider instance = new TlsContextProvider(); + + private volatile SslContext singleContext; + private volatile SniContextLookup sniLookup; + + public static TlsContextProvider getInstance() { + return instance; + } + + public static void setInstance(TlsContextProvider provider) { + instance = provider; + } + + /** + * Set a single SslContext (backward compatible mode). + */ + public void setSingleContext(SslContext ctx) { + this.singleContext = ctx; + this.sniLookup = null; + } + + /** + * Set an SNI-aware context lookup. + */ + public void setSniLookup(SniContextLookup lookup) { + this.sniLookup = lookup; + this.singleContext = null; + } + + /** + * Get the SslContext for a given SNI hostname. Returns singleContext when no SNI lookup is configured. + */ + public SslContext getSslContext(String sniHostname) { + if (sniLookup != null) { + SslContext ctx = sniLookup.lookup(sniHostname); + if (ctx != null) { + return ctx; + } + } + return singleContext; + } + + /** + * Get the default SslContext for fallback. + */ + public SslContext getDefaultContext() { + if (sniLookup != null) { + return sniLookup.getDefaultContext(); + } + return singleContext; + } + + /** + * Returns the SniContextLookup if configured, null otherwise. + */ + public SniContextLookup getSniLookup() { + return sniLookup; + } + + /** + * Interface for SNI-aware context lookup, implemented in proxy module. + */ + public interface SniContextLookup { + SslContext lookup(String sniHostname); + SslContext getDefaultContext(); + } +}