From f84f5f1e8765735630455e34e71a94b2013c4287 Mon Sep 17 00:00:00 2001 From: somiljain2006 Date: Sun, 24 May 2026 17:59:38 +0530 Subject: [PATCH 1/2] Fix ACL hook routing for internal proxy system resources --- .../proxy/common/InternalContextHolder.java | 34 +++ .../common/SystemResourceAwareRpcHook.java | 133 +++++++++ .../proxy/service/ClusterServiceManager.java | 31 ++- .../service/sysmessage/HeartbeatSyncer.java | 23 +- .../SystemResourceAwareRpcHookTest.java | 259 ++++++++++++++++++ .../sysmessage/HeartbeatSyncerTest.java | 46 +++- 6 files changed, 508 insertions(+), 18 deletions(-) create mode 100644 proxy/src/main/java/org/apache/rocketmq/proxy/common/InternalContextHolder.java create mode 100644 proxy/src/main/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHook.java create mode 100644 proxy/src/test/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHookTest.java diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/common/InternalContextHolder.java b/proxy/src/main/java/org/apache/rocketmq/proxy/common/InternalContextHolder.java new file mode 100644 index 00000000000..c4307e546e9 --- /dev/null +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/common/InternalContextHolder.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.proxy.common; + +public class InternalContextHolder { + private static final ThreadLocal IS_INTERNAL = ThreadLocal.withInitial(() -> false); + + public static void beginInternalScope() { + IS_INTERNAL.set(true); + } + + public static void clear() { + IS_INTERNAL.remove(); + } + + public static boolean isInternalScope() { + return IS_INTERNAL.get(); + } +} diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHook.java b/proxy/src/main/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHook.java new file mode 100644 index 00000000000..f3b2f1e8cbf --- /dev/null +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHook.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.proxy.common; + +import java.util.Map; +import org.apache.rocketmq.common.MixAll; +import org.apache.rocketmq.common.topic.TopicValidator; +import org.apache.rocketmq.remoting.RPCHook; +import org.apache.rocketmq.remoting.protocol.RemotingCommand; +import org.apache.rocketmq.remoting.protocol.RequestCode; +import org.apache.rocketmq.remoting.protocol.header.SendMessageRequestHeader; +import org.apache.rocketmq.remoting.protocol.header.SendMessageRequestHeaderV2; +import org.apache.rocketmq.remoting.protocol.header.UnregisterClientRequestHeader; +import org.apache.rocketmq.remoting.protocol.header.namesrv.GetRouteInfoRequestHeader; + +public class SystemResourceAwareRpcHook implements RPCHook { + private final RPCHook userHook; + private final RPCHook systemHook; + + public SystemResourceAwareRpcHook(RPCHook userHook, RPCHook systemHook) { + this.userHook = userHook; + this.systemHook = systemHook; + } + + @Override + public void doBeforeRequest(String remoteAddr, RemotingCommand request) { + if (isTargetingSystemResource(request)) { + if (systemHook != null) { + systemHook.doBeforeRequest(remoteAddr, request); + } + return; + } + if (userHook != null) { + userHook.doBeforeRequest(remoteAddr, request); + } + } + + @Override + public void doAfterResponse(String remoteAddr, RemotingCommand request, RemotingCommand response) { + if (isTargetingSystemResource(request)) { + if (systemHook != null) { + systemHook.doAfterResponse(remoteAddr, request, response); + } + return; + } + if (userHook != null) { + userHook.doAfterResponse(remoteAddr, request, response); + } + } + + private boolean isTargetingSystemResource(RemotingCommand request) { + if (!InternalContextHolder.isInternalScope()) { + return false; + } + + if (request == null) { + return false; + } + + int code = request.getCode(); + try { + switch (code) { + case RequestCode.GET_ROUTEINFO_BY_TOPIC: + GetRouteInfoRequestHeader routeHeader = request + .decodeCommandCustomHeader(GetRouteInfoRequestHeader.class); + return TopicValidator.AUTO_CREATE_TOPIC_KEY_TOPIC.equals(routeHeader.getTopic()); + + case RequestCode.SEND_MESSAGE: + SendMessageRequestHeader sendHeader = request + .decodeCommandCustomHeader(SendMessageRequestHeader.class); + return TopicValidator.AUTO_CREATE_TOPIC_KEY_TOPIC.equals(sendHeader.getTopic()) + || MixAll.CLIENT_INNER_PRODUCER_GROUP.equals(sendHeader.getProducerGroup()); + + case RequestCode.SEND_MESSAGE_V2: + SendMessageRequestHeaderV2 sendHeaderV2 = request + .decodeCommandCustomHeader(SendMessageRequestHeaderV2.class); + SendMessageRequestHeader v1Header = + SendMessageRequestHeaderV2.createSendMessageRequestHeaderV1(sendHeaderV2); + return TopicValidator.AUTO_CREATE_TOPIC_KEY_TOPIC.equals(v1Header.getTopic()) + || MixAll.CLIENT_INNER_PRODUCER_GROUP.equals(v1Header.getProducerGroup()); + + case RequestCode.UNREGISTER_CLIENT: + UnregisterClientRequestHeader unregisterHeader = request + .decodeCommandCustomHeader(UnregisterClientRequestHeader.class); + return MixAll.CLIENT_INNER_PRODUCER_GROUP.equals(unregisterHeader.getProducerGroup()) + || MixAll.TOOLS_CONSUMER_GROUP.equals(unregisterHeader.getConsumerGroup()); + + default: + return checkFallbackExtFields(request.getExtFields()); + } + } catch (Exception e) { + return false; + } + } + + private boolean checkFallbackExtFields(Map extFields) { + if (extFields == null || extFields.isEmpty()) { + return false; + } + String topic = extFields.get("topic"); + if (TopicValidator.AUTO_CREATE_TOPIC_KEY_TOPIC.equals(topic)) { + return true; + } + + String producerGroup = extFields.get("producerGroup"); + if (MixAll.CLIENT_INNER_PRODUCER_GROUP.equals(producerGroup)) { + return true; + } + + String consumerGroup = extFields.get("consumerGroup"); + if (MixAll.TOOLS_CONSUMER_GROUP.equals(consumerGroup)) { + return true; + } + + String generalGroup = extFields.get("group"); + return MixAll.CLIENT_INNER_PRODUCER_GROUP.equals(generalGroup); + } +} diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/service/ClusterServiceManager.java b/proxy/src/main/java/org/apache/rocketmq/proxy/service/ClusterServiceManager.java index 8b1c20c0bdb..dea03c23ed7 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/service/ClusterServiceManager.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/service/ClusterServiceManager.java @@ -35,6 +35,7 @@ import org.apache.rocketmq.logging.org.slf4j.Logger; import org.apache.rocketmq.logging.org.slf4j.LoggerFactory; import org.apache.rocketmq.proxy.common.ProxyContext; +import org.apache.rocketmq.proxy.common.SystemResourceAwareRpcHook; import org.apache.rocketmq.proxy.config.ConfigurationManager; import org.apache.rocketmq.proxy.config.ProxyConfig; import org.apache.rocketmq.proxy.service.admin.AdminService; @@ -54,6 +55,9 @@ import org.apache.rocketmq.proxy.service.transaction.TransactionService; import org.apache.rocketmq.remoting.RPCHook; import org.apache.rocketmq.remoting.RemotingClient; +import org.apache.rocketmq.acl.common.AclClientRPCHook; +import org.apache.rocketmq.acl.common.SessionCredentials; +import org.apache.commons.lang3.StringUtils; public class ClusterServiceManager extends AbstractStartAndShutdown implements ServiceManager { private static final Logger log = LoggerFactory.getLogger(LoggerName.PROXY_LOGGER_NAME); @@ -74,22 +78,31 @@ public class ClusterServiceManager extends AbstractStartAndShutdown implements S protected MQClientAPIFactory transactionClientAPIFactory; protected MQClientAPIFactory liteSubscriptionAPIFactory; - public ClusterServiceManager(RPCHook rpcHook) { - this(rpcHook, null); - } - public ClusterServiceManager(RPCHook rpcHook, ObjectCreator remotingClientCreator) { ProxyConfig proxyConfig = ConfigurationManager.getProxyConfig(); NameserverAccessConfig nameserverAccessConfig = new NameserverAccessConfig(proxyConfig.getNamesrvAddr(), proxyConfig.getNamesrvDomain(), proxyConfig.getNamesrvDomainSubgroup()); this.scheduledExecutorService = ThreadUtils.newScheduledThreadPool(3); + String proxyAccessKey = System.getProperty("rocketmq.proxy.accessKey", System.getenv("ROCKETMQ_PROXY_ACCESS_KEY")); + String proxySecretKey = System.getProperty("rocketmq.proxy.secretKey", System.getenv("ROCKETMQ_PROXY_SECRET_KEY")); + + RPCHook systemHook = null; + if (StringUtils.isNotBlank(proxyAccessKey) && StringUtils.isNotBlank(proxySecretKey)) { + systemHook = new AclClientRPCHook(new SessionCredentials(proxyAccessKey, proxySecretKey)); + log.info("SystemResourceAwareRpcHook initialized with provided Proxy Admin Credentials."); + } else { + log.warn("No Proxy Admin Credentials found (rocketmq.proxy.accessKey/secretKey). System requests will be anonymous."); + } + + RPCHook systemAwareHook = new SystemResourceAwareRpcHook(rpcHook, systemHook); + this.messagingClientAPIFactory = new MQClientAPIFactory( nameserverAccessConfig, "ClusterMQClient_", proxyConfig.getRocketmqMQClientNum(), new DoNothingClientRemotingProcessor(null), - rpcHook, + systemAwareHook, scheduledExecutorService, remotingClientCreator ); @@ -99,7 +112,7 @@ public ClusterServiceManager(RPCHook rpcHook, ObjectCreator remo "OperationClient_", 1, new DoNothingClientRemotingProcessor(null), - rpcHook, + systemAwareHook, this.scheduledExecutorService, remotingClientCreator ); @@ -110,14 +123,14 @@ public ClusterServiceManager(RPCHook rpcHook, ObjectCreator remo this.adminService = new DefaultAdminService(this.operationClientAPIFactory); this.producerManager = new ProducerManager(); - this.consumerManager = new ClusterConsumerManager(this.topicRouteService, this.adminService, this.operationClientAPIFactory, new ConsumerIdsChangeListenerImpl(), proxyConfig.getChannelExpiredTimeout(), rpcHook); + this.consumerManager = new ClusterConsumerManager(this.topicRouteService, this.adminService, this.operationClientAPIFactory, new ConsumerIdsChangeListenerImpl(), proxyConfig.getChannelExpiredTimeout(), systemAwareHook); this.transactionClientAPIFactory = new MQClientAPIFactory( nameserverAccessConfig, "ClusterTransaction_", 1, new ProxyClientRemotingProcessor(producerManager, consumerManager), - rpcHook, + systemAwareHook, scheduledExecutorService, remotingClientCreator ); @@ -132,7 +145,7 @@ public ClusterServiceManager(RPCHook rpcHook, ObjectCreator remo "LiteSubscription_", 1, new ProxyClientRemotingProcessor(producerManager, consumerManager), - rpcHook, + systemAwareHook, scheduledExecutorService); this.liteSubscriptionService = new LiteSubscriptionService(this.topicRouteService, this.liteSubscriptionAPIFactory); diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncer.java b/proxy/src/main/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncer.java index e063d79707b..5083170263c 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncer.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncer.java @@ -48,6 +48,9 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import static org.apache.rocketmq.proxy.common.InternalContextHolder.beginInternalScope; +import static org.apache.rocketmq.proxy.common.InternalContextHolder.clear; + public class HeartbeatSyncer extends AbstractSystemMessageSyncer { protected ThreadPoolExecutor threadPoolExecutor; @@ -113,6 +116,8 @@ public void onConsumerRegister(String consumerGroup, ClientChannelInfo clientCha try { this.threadPoolExecutor.submit(() -> { try { + beginInternalScope(); + RemoteChannel remoteChannel = RemoteChannel.create(clientChannelInfo.getChannel()); if (remoteChannel == null) { return; @@ -134,13 +139,13 @@ public void onConsumerRegister(String consumerGroup, ClientChannelInfo clientCha log.debug("sync register heart beat. topic:{}, data:{}", this.getBroadcastTopicName(), data); this.sendSystemMessage(data); } catch (Throwable t) { - log.error("heartbeat register broadcast failed. group:{}, clientChannelInfo:{}, consumeType:{}, messageModel:{}, consumeFromWhere:{}, subList:{}", - consumerGroup, clientChannelInfo, consumeType, messageModel, consumeFromWhere, subList, t); + log.error("heartbeat register broadcast failed...", t); + } finally { + clear(); } }); } catch (Throwable t) { - log.error("heartbeat submit register broadcast failed. group:{}, clientChannelInfo:{}, consumeType:{}, messageModel:{}, consumeFromWhere:{}, subList:{}", - consumerGroup, clientChannelInfo, consumeType, messageModel, consumeFromWhere, subList, t); + log.error("heartbeat submit register broadcast failed...", t); } } @@ -151,6 +156,8 @@ public void onConsumerUnRegister(String consumerGroup, ClientChannelInfo clientC try { this.threadPoolExecutor.submit(() -> { try { + beginInternalScope(); + RemoteChannel remoteChannel = RemoteChannel.create(clientChannelInfo.getChannel()); if (remoteChannel == null) { return; @@ -171,13 +178,13 @@ public void onConsumerUnRegister(String consumerGroup, ClientChannelInfo clientC log.debug("sync unregister heart beat. topic:{}, data:{}", this.getBroadcastTopicName(), data); this.sendSystemMessage(data); } catch (Throwable t) { - log.error("heartbeat unregister broadcast failed. group:{}, clientChannelInfo:{}, consumeType:{}", - consumerGroup, clientChannelInfo, t); + log.error("heartbeat unregister broadcast failed...", t); + } finally { + clear(); } }); } catch (Throwable t) { - log.error("heartbeat submit unregister broadcast failed. group:{}, clientChannelInfo:{}, consumeType:{}", - consumerGroup, clientChannelInfo, t); + log.error("heartbeat submit unregister broadcast failed...", t); } } diff --git a/proxy/src/test/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHookTest.java b/proxy/src/test/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHookTest.java new file mode 100644 index 00000000000..d1d27fe90af --- /dev/null +++ b/proxy/src/test/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHookTest.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.proxy.common; + +import org.apache.rocketmq.common.MixAll; +import org.apache.rocketmq.common.topic.TopicValidator; +import org.apache.rocketmq.remoting.RPCHook; +import org.apache.rocketmq.remoting.protocol.RemotingCommand; +import org.apache.rocketmq.remoting.protocol.RequestCode; +import org.apache.rocketmq.remoting.protocol.header.SendMessageRequestHeader; +import org.apache.rocketmq.remoting.protocol.header.UnregisterClientRequestHeader; +import org.apache.rocketmq.remoting.protocol.header.namesrv.GetRouteInfoRequestHeader; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +public class SystemResourceAwareRpcHookTest { + private RecordingHook userHook; + private RecordingHook systemHook; + + private SystemResourceAwareRpcHook rpcHook; + private final String remoteAddr = "127.0.0.1:9876"; + + @Before + public void setUp() { + userHook = new RecordingHook(); + systemHook = new RecordingHook(); + rpcHook = new SystemResourceAwareRpcHook(userHook, systemHook); + InternalContextHolder.clear(); + } + + @After + public void tearDown() { + InternalContextHolder.clear(); + } + + @Test + public void testExternalClientAttackPrevented() { + SendMessageRequestHeader header = new SendMessageRequestHeader(); + header.setProducerGroup(MixAll.CLIENT_INNER_PRODUCER_GROUP); + header.setTopic("user_topic"); + RemotingCommand request = RemotingCommand.createRequestCommand(RequestCode.SEND_MESSAGE, header); + + request.makeCustomHeaderToNet(); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, userHook.beforeRequestCount); + assertSame(request, userHook.lastBeforeRequest); + assertEquals(remoteAddr, userHook.lastBeforeRemoteAddr); + assertEquals(0, systemHook.beforeRequestCount); + } + + @Test + public void testSystemHookUsedForInternalSystemRequest() { + InternalContextHolder.beginInternalScope(); + + RemotingCommand request = RemotingCommand.createRequestCommand(RequestCode.SEND_MESSAGE, null); + request.addExtField("producerGroup", MixAll.CLIENT_INNER_PRODUCER_GROUP); + request.addExtField("topic", TopicValidator.AUTO_CREATE_TOPIC_KEY_TOPIC); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, systemHook.beforeRequestCount); + assertSame(request, systemHook.lastBeforeRequest); + assertEquals(remoteAddr, systemHook.lastBeforeRemoteAddr); + assertEquals(0, userHook.beforeRequestCount); + } + + @Test + public void testInternalScopeWithNonSystemResourceRoute() { + InternalContextHolder.beginInternalScope(); + + SendMessageRequestHeader header = new SendMessageRequestHeader(); + header.setProducerGroup("STANDARD_USER_GROUP"); + header.setTopic("STANDARD_USER_TOPIC"); + RemotingCommand request = RemotingCommand.createRequestCommand(RequestCode.SEND_MESSAGE, header); + + request.makeCustomHeaderToNet(); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, userHook.beforeRequestCount); + assertSame(request, userHook.lastBeforeRequest); + assertEquals(remoteAddr, userHook.lastBeforeRemoteAddr); + assertEquals(0, systemHook.beforeRequestCount); + } + + @Test + public void testGetRouteInfoSystemTopicRoute() { + InternalContextHolder.beginInternalScope(); + + GetRouteInfoRequestHeader header = new GetRouteInfoRequestHeader(); + header.setTopic(TopicValidator.AUTO_CREATE_TOPIC_KEY_TOPIC); + RemotingCommand request = RemotingCommand.createRequestCommand(RequestCode.GET_ROUTEINFO_BY_TOPIC, header); + + request.makeCustomHeaderToNet(); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, systemHook.beforeRequestCount); + assertSame(request, systemHook.lastBeforeRequest); + assertEquals(remoteAddr, systemHook.lastBeforeRemoteAddr); + assertEquals(0, userHook.beforeRequestCount); + } + + @Test + public void testUnregisterClientStrictVerification() { + InternalContextHolder.beginInternalScope(); + + UnregisterClientRequestHeader headerA = new UnregisterClientRequestHeader(); + headerA.setClientID("clientA"); + headerA.setConsumerGroup(MixAll.TOOLS_CONSUMER_GROUP); + + RemotingCommand reqA = RemotingCommand.createRequestCommand(RequestCode.UNREGISTER_CLIENT, headerA); + reqA.makeCustomHeaderToNet(); + + rpcHook.doBeforeRequest(remoteAddr, reqA); + + assertEquals(1, systemHook.beforeRequestCount); + assertSame(reqA, systemHook.lastBeforeRequest); + assertEquals(remoteAddr, systemHook.lastBeforeRemoteAddr); + + InternalContextHolder.clear(); + InternalContextHolder.beginInternalScope(); + + UnregisterClientRequestHeader headerB = new UnregisterClientRequestHeader(); + headerB.setClientID("clientB"); + headerB.setConsumerGroup("NORMAL_CONSUMER_GROUP"); + + RemotingCommand reqB = RemotingCommand.createRequestCommand(RequestCode.UNREGISTER_CLIENT, headerB); + reqB.makeCustomHeaderToNet(); + + rpcHook.doBeforeRequest(remoteAddr, reqB); + + assertEquals(1, userHook.beforeRequestCount); + assertSame(reqB, userHook.lastBeforeRequest); + assertEquals(remoteAddr, userHook.lastBeforeRemoteAddr); + } + + @Test + public void testDoAfterResponseRouting() { + InternalContextHolder.beginInternalScope(); + + SendMessageRequestHeader header = new SendMessageRequestHeader(); + header.setProducerGroup(MixAll.CLIENT_INNER_PRODUCER_GROUP); + RemotingCommand request = RemotingCommand.createRequestCommand(RequestCode.SEND_MESSAGE, header); + + request.makeCustomHeaderToNet(); + + RemotingCommand response = RemotingCommand.createResponseCommand(0, "SUCCESS"); + + rpcHook.doAfterResponse(remoteAddr, request, response); + + assertEquals(1, systemHook.afterResponseCount); + assertSame(request, systemHook.lastAfterRequest); + assertSame(response, systemHook.lastAfterResponse); + assertEquals(remoteAddr, systemHook.lastAfterRemoteAddr); + assertEquals(0, userHook.afterResponseCount); + + InternalContextHolder.clear(); + rpcHook.doAfterResponse(remoteAddr, request, response); + + assertEquals(1, userHook.afterResponseCount); + assertSame(request, userHook.lastAfterRequest); + assertSame(response, userHook.lastAfterResponse); + assertEquals(remoteAddr, userHook.lastAfterRemoteAddr); + } + + @Test + public void testFallbackExtFieldsValidSystemMatch() { + InternalContextHolder.beginInternalScope(); + + RemotingCommand request = RemotingCommand.createRequestCommand(9999, null); + request.addExtField("topic", TopicValidator.AUTO_CREATE_TOPIC_KEY_TOPIC); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, systemHook.beforeRequestCount); + assertSame(request, systemHook.lastBeforeRequest); + assertEquals(remoteAddr, systemHook.lastBeforeRemoteAddr); + assertEquals(0, userHook.beforeRequestCount); + } + + @Test + public void testFallbackExtFieldsRejection() { + InternalContextHolder.beginInternalScope(); + + RemotingCommand request = RemotingCommand.createRequestCommand(9999, null); + request.addExtField("topic", "standard_user_topic"); + request.addExtField("producerGroup", "standard_user_group"); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, userHook.beforeRequestCount); + assertSame(request, userHook.lastBeforeRequest); + assertEquals(remoteAddr, userHook.lastBeforeRemoteAddr); + assertEquals(0, systemHook.beforeRequestCount); + } + + @Test + public void testFallbackExtFieldsNullMap() { + InternalContextHolder.beginInternalScope(); + + RemotingCommand request = RemotingCommand.createRequestCommand(9999, null); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, userHook.beforeRequestCount); + assertSame(request, userHook.lastBeforeRequest); + assertEquals(remoteAddr, userHook.lastBeforeRemoteAddr); + assertEquals(0, systemHook.beforeRequestCount); + assertNull(systemHook.lastBeforeRequest); + } + + private static class RecordingHook implements RPCHook { + private int beforeRequestCount; + private int afterResponseCount; + private String lastBeforeRemoteAddr; + private String lastAfterRemoteAddr; + private RemotingCommand lastBeforeRequest; + private RemotingCommand lastAfterRequest; + private RemotingCommand lastAfterResponse; + + @Override + public void doBeforeRequest(String remoteAddr, RemotingCommand request) { + beforeRequestCount++; + lastBeforeRemoteAddr = remoteAddr; + lastBeforeRequest = request; + } + + @Override + public void doAfterResponse(String remoteAddr, RemotingCommand request, RemotingCommand response) { + afterResponseCount++; + lastAfterRemoteAddr = remoteAddr; + lastAfterRequest = request; + lastAfterResponse = response; + } + } +} diff --git a/proxy/src/test/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncerTest.java b/proxy/src/test/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncerTest.java index 9a2c5e3437d..a04060404ee 100644 --- a/proxy/src/test/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncerTest.java +++ b/proxy/src/test/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncerTest.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import org.apache.commons.lang3.RandomStringUtils; import org.apache.rocketmq.broker.client.ClientChannelInfo; @@ -75,8 +76,11 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import static org.apache.rocketmq.proxy.common.InternalContextHolder.clear; +import static org.apache.rocketmq.proxy.common.InternalContextHolder.isInternalScope; import static org.awaitility.Awaitility.await; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; @@ -433,4 +437,44 @@ public int compareTo(@NotNull ChannelId o) { return this.channelId.compareTo(o.asLongText()); } } -} \ No newline at end of file + + @Test + public void testAsyncThreadContextPropagationAndCleanup() throws Exception { + String consumerGroup = "consumerGroup"; + Channel channel = createMockChannel(); + RemotingProxyOutClient remotingProxyOutClient = mock(RemotingProxyOutClient.class); + RemotingChannel remotingChannel = new RemotingChannel(remotingProxyOutClient, proxyRelayService, channel, clientId, Collections.emptySet()); + ClientChannelInfo clientChannelInfo = new ClientChannelInfo(remotingChannel, clientId, LanguageCode.JAVA, 4); + + AtomicBoolean contextValidatedDuringExecution = new AtomicBoolean(false); + + clear(); + assertFalse("Test must start with clean context", isInternalScope()); + + when(this.mqClientAPIExt.sendMessageAsync(anyString(), anyString(), any(Message.class), any(), anyLong())) + .thenAnswer(invocation -> { + boolean isInternal = isInternalScope(); + contextValidatedDuringExecution.set(isInternal); + + SendResult result = new SendResult(); + result.setSendStatus(SendStatus.SEND_OK); + return CompletableFuture.completedFuture(result); + }); + + HeartbeatSyncer heartbeatSyncer = new HeartbeatSyncer(topicRouteService, adminService, consumerManager, mqClientAPIFactory, null); + + heartbeatSyncer.onConsumerRegister( + consumerGroup, + clientChannelInfo, + ConsumeType.CONSUME_PASSIVELY, + MessageModel.CLUSTERING, + ConsumeFromWhere.CONSUME_FROM_LAST_OFFSET, + Collections.emptySet() + ); + + await().atMost(Duration.ofSeconds(3)).until(contextValidatedDuringExecution::get); + assertTrue("Context must explicitly exist during the execution phase", contextValidatedDuringExecution.get()); + + assertFalse("CRITICAL: Context MUST be cleared after execution to prevent ThreadLocal privilege leakage in thread pools", isInternalScope()); + } +} From 2ae552c485b748cfdb370b47f57ba97aec4f3461 Mon Sep 17 00:00:00 2001 From: somiljain2006 Date: Tue, 26 May 2026 00:28:21 +0530 Subject: [PATCH 2/2] Added more tests --- .../SystemResourceAwareRpcHookTest.java | 99 ++++++++ .../cluster/ClusterServiceManagerTest.java | 71 ++++++ .../sysmessage/HeartbeatSyncerTest.java | 234 +++++++++++++++++- 3 files changed, 403 insertions(+), 1 deletion(-) create mode 100644 proxy/src/test/java/org/apache/rocketmq/proxy/service/cluster/ClusterServiceManagerTest.java diff --git a/proxy/src/test/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHookTest.java b/proxy/src/test/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHookTest.java index d1d27fe90af..2eae67c2f69 100644 --- a/proxy/src/test/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHookTest.java +++ b/proxy/src/test/java/org/apache/rocketmq/proxy/common/SystemResourceAwareRpcHookTest.java @@ -23,6 +23,7 @@ import org.apache.rocketmq.remoting.protocol.RemotingCommand; import org.apache.rocketmq.remoting.protocol.RequestCode; import org.apache.rocketmq.remoting.protocol.header.SendMessageRequestHeader; +import org.apache.rocketmq.remoting.protocol.header.SendMessageRequestHeaderV2; import org.apache.rocketmq.remoting.protocol.header.UnregisterClientRequestHeader; import org.apache.rocketmq.remoting.protocol.header.namesrv.GetRouteInfoRequestHeader; import org.junit.After; @@ -232,6 +233,104 @@ public void testFallbackExtFieldsNullMap() { assertNull(systemHook.lastBeforeRequest); } + @Test + public void testNullRequestFallsBackToUserHook() { + InternalContextHolder.beginInternalScope(); + + rpcHook.doBeforeRequest(remoteAddr, null); + + assertEquals(1, userHook.beforeRequestCount); + assertNull(userHook.lastBeforeRequest); + assertEquals(remoteAddr, userHook.lastBeforeRemoteAddr); + + assertEquals(0, systemHook.beforeRequestCount); + } + + @Test + public void testSendMessageV2SystemRouting() { + InternalContextHolder.beginInternalScope(); + + SendMessageRequestHeaderV2 headerV2 = new SendMessageRequestHeaderV2(); + + headerV2.setA(MixAll.CLIENT_INNER_PRODUCER_GROUP); + headerV2.setB(TopicValidator.AUTO_CREATE_TOPIC_KEY_TOPIC); + + RemotingCommand request = RemotingCommand.createRequestCommand( + RequestCode.SEND_MESSAGE_V2, + headerV2 + ); + + request.makeCustomHeaderToNet(); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, systemHook.beforeRequestCount); + assertSame(request, systemHook.lastBeforeRequest); + assertEquals(remoteAddr, systemHook.lastBeforeRemoteAddr); + + assertEquals(0, userHook.beforeRequestCount); + } + + @Test + public void testExceptionDuringHeaderDecodeFallsBackToUserHook() { + InternalContextHolder.beginInternalScope(); + + RemotingCommand request = RemotingCommand.createRequestCommand( + RequestCode.SEND_MESSAGE, + null + ); + + request.addExtField("producerGroup", "groupOnly"); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, userHook.beforeRequestCount); + assertSame(request, userHook.lastBeforeRequest); + assertEquals(remoteAddr, userHook.lastBeforeRemoteAddr); + + assertEquals(0, systemHook.beforeRequestCount); + } + + @Test + public void testFallbackExtFieldsProducerGroupSystemMatch() { + InternalContextHolder.beginInternalScope(); + + RemotingCommand request = RemotingCommand.createRequestCommand(9999, null); + + request.addExtField( + "producerGroup", + MixAll.CLIENT_INNER_PRODUCER_GROUP + ); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, systemHook.beforeRequestCount); + assertSame(request, systemHook.lastBeforeRequest); + assertEquals(remoteAddr, systemHook.lastBeforeRemoteAddr); + + assertEquals(0, userHook.beforeRequestCount); + } + + @Test + public void testFallbackExtFieldsConsumerGroupSystemMatch() { + InternalContextHolder.beginInternalScope(); + + RemotingCommand request = RemotingCommand.createRequestCommand(9999, null); + + request.addExtField( + "consumerGroup", + MixAll.TOOLS_CONSUMER_GROUP + ); + + rpcHook.doBeforeRequest(remoteAddr, request); + + assertEquals(1, systemHook.beforeRequestCount); + assertSame(request, systemHook.lastBeforeRequest); + assertEquals(remoteAddr, systemHook.lastBeforeRemoteAddr); + + assertEquals(0, userHook.beforeRequestCount); + } + private static class RecordingHook implements RPCHook { private int beforeRequestCount; private int afterResponseCount; diff --git a/proxy/src/test/java/org/apache/rocketmq/proxy/service/cluster/ClusterServiceManagerTest.java b/proxy/src/test/java/org/apache/rocketmq/proxy/service/cluster/ClusterServiceManagerTest.java new file mode 100644 index 00000000000..16808971568 --- /dev/null +++ b/proxy/src/test/java/org/apache/rocketmq/proxy/service/cluster/ClusterServiceManagerTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.proxy.service.cluster; + +import org.apache.rocketmq.proxy.config.ConfigurationManager; +import org.apache.rocketmq.proxy.service.ClusterServiceManager; +import org.apache.rocketmq.remoting.RPCHook; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import static org.junit.Assert.assertNotNull; + +@RunWith(MockitoJUnitRunner.class) +public class ClusterServiceManagerTest { + + @Mock + private RPCHook rpcHook; + + @Before + public void setUp() throws Exception { + System.setProperty("rocketmq.namesrv.addr", "127.0.0.1:9876"); + ConfigurationManager.initEnv(); + ConfigurationManager.initConfig(); + } + + @After + public void tearDown() { + System.clearProperty("rocketmq.proxy.accessKey"); + System.clearProperty("rocketmq.proxy.secretKey"); + System.clearProperty("rocketmq.namesrv.addr"); + } + + @Test + public void testConstructorWithAdminCredentials() { + System.setProperty("rocketmq.proxy.accessKey", "admin"); + System.setProperty("rocketmq.proxy.secretKey", "admin123"); + + ClusterServiceManager manager = new ClusterServiceManager(rpcHook, null); + + assertNotNull(manager); + } + + @Test + public void testConstructorWithoutAdminCredentials() { + System.clearProperty("rocketmq.proxy.accessKey"); + System.clearProperty("rocketmq.proxy.secretKey"); + + ClusterServiceManager manager = new ClusterServiceManager(rpcHook, null); + + assertNotNull(manager); + } +} diff --git a/proxy/src/test/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncerTest.java b/proxy/src/test/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncerTest.java index a04060404ee..3e490a64223 100644 --- a/proxy/src/test/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncerTest.java +++ b/proxy/src/test/java/org/apache/rocketmq/proxy/service/sysmessage/HeartbeatSyncerTest.java @@ -439,7 +439,7 @@ public int compareTo(@NotNull ChannelId o) { } @Test - public void testAsyncThreadContextPropagationAndCleanup() throws Exception { + public void testAsyncThreadContextPropagationAndCleanup() { String consumerGroup = "consumerGroup"; Channel channel = createMockChannel(); RemotingProxyOutClient remotingProxyOutClient = mock(RemotingProxyOutClient.class); @@ -477,4 +477,236 @@ public void testAsyncThreadContextPropagationAndCleanup() throws Exception { assertFalse("CRITICAL: Context MUST be cleared after execution to prevent ThreadLocal privilege leakage in thread pools", isInternalScope()); } + + @Test + public void testOnConsumerRegisterWithNullRemoteChannel() { + String consumerGroup = "consumerGroup"; + + Channel channel = mock(Channel.class); + + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + channel, + clientId, + LanguageCode.JAVA, + 4 + ); + + HeartbeatSyncer heartbeatSyncer = new HeartbeatSyncer( + topicRouteService, + adminService, + consumerManager, + mqClientAPIFactory, + null + ); + + heartbeatSyncer.onConsumerRegister( + consumerGroup, + clientChannelInfo, + ConsumeType.CONSUME_PASSIVELY, + MessageModel.CLUSTERING, + ConsumeFromWhere.CONSUME_FROM_LAST_OFFSET, + Collections.emptySet() + ); + + await().atMost(Duration.ofSeconds(1)).untilAsserted(() -> + verify(mqClientAPIExt, never()).sendMessageAsync( + anyString(), + anyString(), + any(Message.class), + any(), + anyLong() + )); + } + + @Test + public void testRegisterBroadcastFailureHandled() { + String consumerGroup = "consumerGroup"; + + Channel channel = createMockChannel(); + RemotingProxyOutClient remotingProxyOutClient = mock(RemotingProxyOutClient.class); + + RemotingChannel remotingChannel = new RemotingChannel( + remotingProxyOutClient, + proxyRelayService, + channel, + clientId, + Collections.emptySet() + ); + + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + remotingChannel, + clientId, + LanguageCode.JAVA, + 4 + ); + + when(this.mqClientAPIExt.sendMessageAsync( + anyString(), + anyString(), + any(Message.class), + any(), + anyLong() + )).thenThrow(new RuntimeException("mock failure")); + + HeartbeatSyncer heartbeatSyncer = new HeartbeatSyncer( + topicRouteService, + adminService, + consumerManager, + mqClientAPIFactory, + null + ); + + heartbeatSyncer.onConsumerRegister( + consumerGroup, + clientChannelInfo, + ConsumeType.CONSUME_PASSIVELY, + MessageModel.CLUSTERING, + ConsumeFromWhere.CONSUME_FROM_LAST_OFFSET, + Collections.emptySet() + ); + + await().atMost(Duration.ofSeconds(2)).untilAsserted(() -> + verify(mqClientAPIExt).sendMessageAsync( + anyString(), + anyString(), + any(Message.class), + any(), + anyLong() + ) + ); + } + + @Test + public void testRegisterSubmitFailureHandled() { + String consumerGroup = "consumerGroup"; + + Channel channel = createMockChannel(); + RemotingProxyOutClient remotingProxyOutClient = mock(RemotingProxyOutClient.class); + + RemotingChannel remotingChannel = new RemotingChannel( + remotingProxyOutClient, + proxyRelayService, + channel, + clientId, + Collections.emptySet() + ); + + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + remotingChannel, + clientId, + LanguageCode.JAVA, + 4 + ); + + HeartbeatSyncer heartbeatSyncer = new HeartbeatSyncer( + topicRouteService, + adminService, + consumerManager, + mqClientAPIFactory, + null + ); + + heartbeatSyncer.threadPoolExecutor.shutdownNow(); + + heartbeatSyncer.onConsumerRegister( + consumerGroup, + clientChannelInfo, + ConsumeType.CONSUME_PASSIVELY, + MessageModel.CLUSTERING, + ConsumeFromWhere.CONSUME_FROM_LAST_OFFSET, + Collections.emptySet() + ); + } + + @Test + public void testUnregisterBroadcastFailureHandled() throws Exception { + String consumerGroup = "consumerGroup"; + + Channel channel = createMockChannel(); + RemotingProxyOutClient remotingProxyOutClient = mock(RemotingProxyOutClient.class); + + RemotingChannel remotingChannel = new RemotingChannel( + remotingProxyOutClient, + proxyRelayService, + channel, + clientId, + Collections.emptySet() + ); + + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + remotingChannel, + clientId, + LanguageCode.JAVA, + 4 + ); + + when(this.mqClientAPIExt.sendMessageAsync( + anyString(), + anyString(), + any(Message.class), + any(), + anyLong() + )).thenThrow(new RuntimeException("mock failure")); + + HeartbeatSyncer heartbeatSyncer = new HeartbeatSyncer( + topicRouteService, + adminService, + consumerManager, + mqClientAPIFactory, + null + ); + + heartbeatSyncer.onConsumerUnRegister( + consumerGroup, + clientChannelInfo + ); + + await().atMost(Duration.ofSeconds(2)).untilAsserted(() -> + verify(mqClientAPIExt).sendMessageAsync( + anyString(), + anyString(), + any(Message.class), + any(), + anyLong() + ) + ); + } + + @Test + public void testUnregisterSubmitFailureHandled() { + String consumerGroup = "consumerGroup"; + + Channel channel = createMockChannel(); + RemotingProxyOutClient remotingProxyOutClient = mock(RemotingProxyOutClient.class); + + RemotingChannel remotingChannel = new RemotingChannel( + remotingProxyOutClient, + proxyRelayService, + channel, + clientId, + Collections.emptySet() + ); + + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + remotingChannel, + clientId, + LanguageCode.JAVA, + 4 + ); + + HeartbeatSyncer heartbeatSyncer = new HeartbeatSyncer( + topicRouteService, + adminService, + consumerManager, + mqClientAPIFactory, + null + ); + + heartbeatSyncer.threadPoolExecutor.shutdownNow(); + + heartbeatSyncer.onConsumerUnRegister( + consumerGroup, + clientChannelInfo + ); + } }