Skip to content

Commit cd396d0

Browse files
k-krawczykjbertram
authored andcommitted
ARTEMIS-6085 Fix NPE from race in MQTTStateManager.getSessionState
getSessionState() used a non-atomic containsKey()/get() pair on a ConcurrentHashMap. During link stealing (two connections sharing a client ID) removeSessionState() can run between the two calls, so get() returns null and handleLinkStealing() dereferences it via .getSession(), throwing a NullPointerException. Use computeIfAbsent() for an atomic get-or-create that never returns null. Adds a concurrency regression test that reproduces the NPE.
1 parent 419caaf commit cd396d0

3 files changed

Lines changed: 116 additions & 8 deletions

File tree

artemis-protocols/artemis-mqtt-protocol/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@
107107
<artifactId>junit-jupiter-engine</artifactId>
108108
<scope>test</scope>
109109
</dependency>
110+
<dependency>
111+
<groupId>org.mockito</groupId>
112+
<artifactId>mockito-core</artifactId>
113+
<scope>test</scope>
114+
</dependency>
110115
<dependency>
111116
<groupId>org.apache.logging.log4j</groupId>
112117
<artifactId>log4j-slf4j2-impl</artifactId>

artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,10 @@ public void scanSessions() {
137137

138138
public MQTTSessionState getSessionState(String clientId) throws Exception {
139139
// [MQTT-3.1.2-4] Attach an existing session if one exists otherwise create a new one.
140-
if (sessionStates.containsKey(clientId)) {
141-
return sessionStates.get(clientId);
142-
} else {
143-
MQTTSessionState sessionState = new MQTTSessionState(clientId);
144-
logger.debug("Adding MQTT session state for: {}", clientId);
145-
sessionStates.put(clientId, sessionState);
146-
return sessionState;
147-
}
140+
return sessionStates.computeIfAbsent(clientId, key -> {
141+
logger.debug("Adding MQTT session state for: {}", key);
142+
return new MQTTSessionState(key);
143+
});
148144
}
149145

150146
public MQTTSessionState removeSessionState(String clientId) throws Exception {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.activemq.artemis.core.protocol.mqtt;
20+
21+
import java.util.concurrent.CountDownLatch;
22+
import java.util.concurrent.ExecutorService;
23+
import java.util.concurrent.Executors;
24+
import java.util.concurrent.TimeUnit;
25+
import java.util.concurrent.atomic.AtomicBoolean;
26+
import java.util.concurrent.atomic.AtomicReference;
27+
28+
import org.apache.activemq.artemis.core.config.Configuration;
29+
import org.apache.activemq.artemis.core.server.ActiveMQServer;
30+
import org.junit.jupiter.api.Test;
31+
import org.junit.jupiter.api.Timeout;
32+
33+
import static org.junit.jupiter.api.Assertions.assertNotNull;
34+
import static org.junit.jupiter.api.Assertions.assertTrue;
35+
import static org.junit.jupiter.api.Assertions.fail;
36+
import static org.mockito.Mockito.mock;
37+
import static org.mockito.Mockito.when;
38+
39+
public class MQTTStateManagerTest {
40+
41+
@Test
42+
@Timeout(60)
43+
public void testGetSessionStateNeverReturnsNullUnderConcurrentRemoval() throws Exception {
44+
final ActiveMQServer server = mock(ActiveMQServer.class);
45+
final Configuration configuration = mock(Configuration.class);
46+
when(server.getConfiguration()).thenReturn(configuration);
47+
when(configuration.isMqttSubscriptionPersistenceEnabled()).thenReturn(false);
48+
49+
final MQTTStateManager manager = MQTTStateManager.getInstance(server);
50+
try {
51+
final String clientId = "link-stealing-client";
52+
final int iterations = 2_000_000;
53+
final int removerThreads = 3;
54+
55+
final AtomicReference<Throwable> failure = new AtomicReference<>();
56+
final AtomicBoolean stop = new AtomicBoolean(false);
57+
final CountDownLatch start = new CountDownLatch(1);
58+
final ExecutorService pool = Executors.newFixedThreadPool(removerThreads + 1);
59+
60+
for (int i = 0; i < removerThreads; i++) {
61+
pool.submit(() -> {
62+
awaitQuietly(start);
63+
while (!stop.get() && failure.get() == null) {
64+
try {
65+
manager.removeSessionState(clientId);
66+
} catch (Throwable t) {
67+
failure.compareAndSet(null, t);
68+
return;
69+
}
70+
}
71+
});
72+
}
73+
74+
pool.submit(() -> {
75+
awaitQuietly(start);
76+
try {
77+
for (int i = 0; i < iterations && failure.get() == null; i++) {
78+
MQTTSessionState state = manager.getSessionState(clientId);
79+
assertNotNull(state, "getSessionState(String) must never return null (ARTEMIS-6085)");
80+
}
81+
} catch (Throwable t) {
82+
failure.compareAndSet(null, t);
83+
} finally {
84+
stop.set(true);
85+
}
86+
});
87+
88+
start.countDown();
89+
pool.shutdown();
90+
assertTrue(pool.awaitTermination(50, TimeUnit.SECONDS), "concurrency test did not finish in time");
91+
92+
if (failure.get() != null) {
93+
fail("getSessionState() returned null / threw under concurrent removal (ARTEMIS-6085)", failure.get());
94+
}
95+
} finally {
96+
MQTTStateManager.removeInstance(server);
97+
}
98+
}
99+
100+
private static void awaitQuietly(CountDownLatch latch) {
101+
try {
102+
latch.await();
103+
} catch (InterruptedException e) {
104+
Thread.currentThread().interrupt();
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)