diff --git a/artemis-protocols/artemis-mqtt-protocol/pom.xml b/artemis-protocols/artemis-mqtt-protocol/pom.xml
index 62c575526c9..ccb22c31955 100644
--- a/artemis-protocols/artemis-mqtt-protocol/pom.xml
+++ b/artemis-protocols/artemis-mqtt-protocol/pom.xml
@@ -107,6 +107,11 @@
junit-jupiter-engine
test
+
+ org.mockito
+ mockito-core
+ test
+
org.apache.logging.log4j
log4j-slf4j2-impl
diff --git a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java
index 69c21015727..99e3d2265f3 100644
--- a/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java
+++ b/artemis-protocols/artemis-mqtt-protocol/src/main/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManager.java
@@ -137,14 +137,10 @@ public void scanSessions() {
public MQTTSessionState getSessionState(String clientId) throws Exception {
// [MQTT-3.1.2-4] Attach an existing session if one exists otherwise create a new one.
- if (sessionStates.containsKey(clientId)) {
- return sessionStates.get(clientId);
- } else {
- MQTTSessionState sessionState = new MQTTSessionState(clientId);
- logger.debug("Adding MQTT session state for: {}", clientId);
- sessionStates.put(clientId, sessionState);
- return sessionState;
- }
+ return sessionStates.computeIfAbsent(clientId, key -> {
+ logger.debug("Adding MQTT session state for: {}", key);
+ return new MQTTSessionState(key);
+ });
}
public MQTTSessionState removeSessionState(String clientId) throws Exception {
diff --git a/artemis-protocols/artemis-mqtt-protocol/src/test/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManagerTest.java b/artemis-protocols/artemis-mqtt-protocol/src/test/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManagerTest.java
new file mode 100644
index 00000000000..0b4b461ce71
--- /dev/null
+++ b/artemis-protocols/artemis-mqtt-protocol/src/test/java/org/apache/activemq/artemis/core/protocol/mqtt/MQTTStateManagerTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.activemq.artemis.core.protocol.mqtt;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.activemq.artemis.core.config.Configuration;
+import org.apache.activemq.artemis.core.server.ActiveMQServer;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class MQTTStateManagerTest {
+
+ @Test
+ @Timeout(60)
+ public void testGetSessionStateNeverReturnsNullUnderConcurrentRemoval() throws Exception {
+ final ActiveMQServer server = mock(ActiveMQServer.class);
+ final Configuration configuration = mock(Configuration.class);
+ when(server.getConfiguration()).thenReturn(configuration);
+ when(configuration.isMqttSubscriptionPersistenceEnabled()).thenReturn(false);
+
+ final MQTTStateManager manager = MQTTStateManager.getInstance(server);
+ try {
+ final String clientId = "link-stealing-client";
+ final int iterations = 2_000_000;
+ final int removerThreads = 3;
+
+ final AtomicReference failure = new AtomicReference<>();
+ final AtomicBoolean stop = new AtomicBoolean(false);
+ final CountDownLatch start = new CountDownLatch(1);
+ final ExecutorService pool = Executors.newFixedThreadPool(removerThreads + 1);
+
+ for (int i = 0; i < removerThreads; i++) {
+ pool.submit(() -> {
+ awaitQuietly(start);
+ while (!stop.get() && failure.get() == null) {
+ try {
+ manager.removeSessionState(clientId);
+ } catch (Throwable t) {
+ failure.compareAndSet(null, t);
+ return;
+ }
+ }
+ });
+ }
+
+ pool.submit(() -> {
+ awaitQuietly(start);
+ try {
+ for (int i = 0; i < iterations && failure.get() == null; i++) {
+ MQTTSessionState state = manager.getSessionState(clientId);
+ assertNotNull(state, "getSessionState(String) must never return null (ARTEMIS-6085)");
+ }
+ } catch (Throwable t) {
+ failure.compareAndSet(null, t);
+ } finally {
+ stop.set(true);
+ }
+ });
+
+ start.countDown();
+ pool.shutdown();
+ assertTrue(pool.awaitTermination(50, TimeUnit.SECONDS), "concurrency test did not finish in time");
+
+ if (failure.get() != null) {
+ fail("getSessionState() returned null / threw under concurrent removal (ARTEMIS-6085)", failure.get());
+ }
+ } finally {
+ MQTTStateManager.removeInstance(server);
+ }
+ }
+
+ private static void awaitQuietly(CountDownLatch latch) {
+ try {
+ latch.await();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }
+}