Skip to content

Commit a711758

Browse files
committed
Migrate MqttIO to HiveMQ Mqtt client
1 parent 78fdc95 commit a711758

4 files changed

Lines changed: 212 additions & 167 deletions

File tree

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run",
3-
"modification": 3
3+
"modification": 4
44
}

sdks/java/io/mqtt/build.gradle

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ dependencies {
2727
implementation project(path: ":sdks:java:core", configuration: "shadow")
2828
implementation library.java.slf4j_api
2929
implementation library.java.joda_time
30-
implementation "org.fusesource.mqtt-client:mqtt-client:1.15"
31-
implementation "org.fusesource.hawtbuf:hawtbuf:1.11"
30+
implementation "com.hivemq:hivemq-mqtt-client:1.3.15"
3231
testImplementation project(path: ":sdks:java:io:common")
3332
testImplementation library.java.activemq_broker
3433
testImplementation library.java.activemq_mqtt

sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java

Lines changed: 96 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,24 @@
2121
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
2222

2323
import com.google.auto.value.AutoValue;
24+
import com.hivemq.client.mqtt.MqttGlobalPublishFilter;
25+
import com.hivemq.client.mqtt.datatypes.MqttQos;
26+
import com.hivemq.client.mqtt.mqtt3.Mqtt3BlockingClient;
27+
import com.hivemq.client.mqtt.mqtt3.Mqtt3Client;
28+
import com.hivemq.client.mqtt.mqtt3.Mqtt3ClientBuilder;
29+
import com.hivemq.client.mqtt.mqtt3.message.publish.Mqtt3Publish;
2430
import java.io.IOException;
2531
import java.io.Serializable;
32+
import java.net.URI;
2633
import java.nio.charset.StandardCharsets;
2734
import java.util.ArrayList;
2835
import java.util.Collections;
2936
import java.util.List;
3037
import java.util.NoSuchElementException;
3138
import java.util.Objects;
39+
import java.util.Optional;
3240
import java.util.UUID;
3341
import java.util.concurrent.TimeUnit;
34-
import java.util.concurrent.TimeoutException;
3542
import org.apache.beam.sdk.coders.ByteArrayCoder;
3643
import org.apache.beam.sdk.coders.Coder;
3744
import org.apache.beam.sdk.coders.SerializableCoder;
@@ -52,12 +59,6 @@
5259
import org.apache.beam.sdk.values.PDone;
5360
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
5461
import org.checkerframework.checker.nullness.qual.Nullable;
55-
import org.fusesource.mqtt.client.BlockingConnection;
56-
import org.fusesource.mqtt.client.FutureConnection;
57-
import org.fusesource.mqtt.client.MQTT;
58-
import org.fusesource.mqtt.client.Message;
59-
import org.fusesource.mqtt.client.QoS;
60-
import org.fusesource.mqtt.client.Topic;
6162
import org.joda.time.Duration;
6263
import org.joda.time.Instant;
6364
import org.slf4j.Logger;
@@ -299,29 +300,42 @@ private void populateDisplayData(DisplayData.Builder builder) {
299300
builder.addIfNotNull(DisplayData.item("username", getUsername()));
300301
}
301302

302-
private MQTT createClient() throws Exception {
303+
private Mqtt3BlockingClient createClient() throws Exception {
303304
LOG.debug("Creating MQTT client to {}", getServerUri());
304-
MQTT client = new MQTT();
305-
client.setHost(getServerUri());
306-
if (getUsername() != null) {
307-
LOG.debug("MQTT client uses username {}", getUsername());
308-
client.setUserName(getUsername());
309-
client.setPassword(getPassword());
305+
URI uri = new URI(getServerUri());
306+
String host = uri.getHost();
307+
int port = uri.getPort();
308+
if (port == -1) {
309+
port = "ssl".equals(uri.getScheme()) || "tls".equals(uri.getScheme()) ? 8883 : 1883;
310+
}
311+
312+
Mqtt3ClientBuilder builder = Mqtt3Client.builder().serverHost(host).serverPort(port);
313+
314+
if ("ssl".equals(uri.getScheme()) || "tls".equals(uri.getScheme())) {
315+
builder = builder.sslWithDefaultConfig();
310316
}
311-
if (getClientId() != null) {
312-
String clientId = getClientId() + "-" + UUID.randomUUID().toString();
313-
clientId =
314-
clientId.substring(0, Math.min(clientId.length(), MQTT_3_1_MAX_CLIENT_ID_LENGTH));
315-
LOG.debug("MQTT client id set to {}", clientId);
316-
client.setClientId(clientId);
317+
318+
String clientId = getClientId();
319+
if (clientId == null) {
320+
clientId = UUID.randomUUID().toString();
317321
} else {
318-
String clientId = UUID.randomUUID().toString();
319-
clientId =
320-
clientId.substring(0, Math.min(clientId.length(), MQTT_3_1_MAX_CLIENT_ID_LENGTH));
321-
LOG.debug("MQTT client id set to random value {}", clientId);
322-
client.setClientId(clientId);
322+
clientId = clientId + "-" + UUID.randomUUID().toString();
323323
}
324-
return client;
324+
clientId = clientId.substring(0, Math.min(clientId.length(), MQTT_3_1_MAX_CLIENT_ID_LENGTH));
325+
LOG.debug("MQTT client id set to {}", clientId);
326+
builder = builder.identifier(clientId);
327+
328+
if (getUsername() != null) {
329+
LOG.debug("MQTT client uses username {}", getUsername());
330+
builder =
331+
builder
332+
.simpleAuth()
333+
.username(getUsername())
334+
.password(getPassword().getBytes(StandardCharsets.UTF_8))
335+
.applySimpleAuth();
336+
}
337+
338+
return builder.buildBlocking();
325339
}
326340
}
327341

@@ -429,9 +443,9 @@ public void populateDisplayData(DisplayData.Builder builder) {
429443
static class MqttCheckpointMark implements UnboundedSource.CheckpointMark, Serializable {
430444

431445
@VisibleForTesting String clientId;
432-
@VisibleForTesting transient List<Message> messages = new ArrayList<>();
446+
@VisibleForTesting transient List<Mqtt3Publish> messages = new ArrayList<>();
433447

434-
public MqttCheckpointMark(String id, List<Message> messages) {
448+
public MqttCheckpointMark(String id, List<Mqtt3Publish> messages) {
435449
this.clientId = id;
436450
this.messages = messages;
437451
}
@@ -444,9 +458,9 @@ public MqttCheckpointMark(String id, List<Message> messages) {
444458
@Override
445459
public void finalizeCheckpoint() {
446460
LOG.debug("Finalizing checkpoint acknowledging pending messages for client ID {}", clientId);
447-
for (Message message : messages) {
461+
for (Mqtt3Publish message : messages) {
448462
try {
449-
message.ack();
463+
message.acknowledge();
450464
} catch (Exception e) {
451465
LOG.warn("Can't ack message for client ID {}", clientId, e);
452466
}
@@ -480,7 +494,7 @@ public int hashCode() {
480494
static class Preparer {
481495
@VisibleForTesting String clientId;
482496
@VisibleForTesting Instant oldestMessageTimestamp = Instant.now();
483-
@VisibleForTesting transient List<Message> messages = new ArrayList<>();
497+
@VisibleForTesting transient List<Mqtt3Publish> messages = new ArrayList<>();
484498

485499
public Preparer(MqttCheckpointMark checkpointMark) {
486500
clientId = checkpointMark.clientId;
@@ -493,15 +507,15 @@ public Preparer(String id) {
493507

494508
public Preparer() {}
495509

496-
public void add(Message message, Instant timestamp) {
510+
public void add(Mqtt3Publish message, Instant timestamp) {
497511
if (timestamp.isBefore(oldestMessageTimestamp)) {
498512
oldestMessageTimestamp = timestamp;
499513
}
500514
messages.add(message);
501515
}
502516

503517
MqttCheckpointMark newCheckpoint() {
504-
List<Message> currentMessages = messages;
518+
List<Mqtt3Publish> currentMessages = messages;
505519
messages = new ArrayList<>();
506520
oldestMessageTimestamp = Instant.now();
507521
return new MqttCheckpointMark(clientId, currentMessages);
@@ -532,7 +546,8 @@ public UnboundedReader<T> createReader(
532546
new UnboundedMqttReader<>(
533547
this,
534548
preparer,
535-
message -> (T) MqttRecord.of(message.getTopic(), message.getPayload()));
549+
message ->
550+
(T) MqttRecord.of(message.getTopic().toString(), message.getPayloadAsBytes()));
536551
} else {
537552
unboundedMqttReader = new UnboundedMqttReader<>(this, preparer);
538553
}
@@ -570,12 +585,13 @@ static class UnboundedMqttReader<T> extends UnboundedSource.UnboundedReader<T> {
570585

571586
private final UnboundedMqttSource<T> source;
572587

573-
private MQTT client;
574-
private BlockingConnection connection;
588+
private Mqtt3BlockingClient client;
589+
private Mqtt3BlockingClient.Mqtt3Publishes publishes;
590+
private String clientId = "";
575591
private T current;
576592
private Instant currentTimestamp;
577593
private final MqttCheckpointMark.Preparer checkpointPreparer;
578-
private SerializableFunction<Message, T> extractFn;
594+
private SerializableFunction<Mqtt3Publish, T> extractFn;
579595

580596
public UnboundedMqttReader(
581597
UnboundedMqttSource<T> source, MqttCheckpointMark.Preparer checkpointPreparer) {
@@ -586,13 +602,13 @@ public UnboundedMqttReader(
586602
} else {
587603
this.checkpointPreparer = new MqttCheckpointMark.Preparer();
588604
}
589-
this.extractFn = message -> (T) message.getPayload();
605+
this.extractFn = message -> (T) message.getPayloadAsBytes();
590606
}
591607

592608
public UnboundedMqttReader(
593609
UnboundedMqttSource<T> source,
594610
MqttCheckpointMark.Preparer checkpointPreparer,
595-
SerializableFunction<Message, T> extractFn) {
611+
SerializableFunction<Mqtt3Publish, T> extractFn) {
596612
this(source, checkpointPreparer);
597613
this.extractFn = extractFn;
598614
}
@@ -603,11 +619,20 @@ public boolean start() throws IOException {
603619
Read<T> spec = source.spec;
604620
try {
605621
client = spec.connectionConfiguration().createClient();
606-
LOG.debug("Reader client ID is {}", client.getClientId());
607-
checkpointPreparer.clientId = client.getClientId().toString();
608-
connection = createConnection(client);
609-
connection.subscribe(
610-
new Topic[] {new Topic(spec.connectionConfiguration().getTopic(), QoS.AT_LEAST_ONCE)});
622+
this.clientId = client.getConfig().getClientIdentifier().map(Object::toString).orElse("");
623+
LOG.debug("Reader client ID is {}", clientId);
624+
checkpointPreparer.clientId = clientId;
625+
client.connect();
626+
627+
// Subscribe and get the publishes stream with manual acks enabled
628+
publishes = client.publishes(MqttGlobalPublishFilter.ALL, true);
629+
630+
client
631+
.subscribeWith()
632+
.topicFilter(spec.connectionConfiguration().getTopic())
633+
.qos(MqttQos.AT_LEAST_ONCE)
634+
.send();
635+
611636
return advance();
612637
} catch (Exception e) {
613638
throw new IOException(e);
@@ -617,11 +642,12 @@ public boolean start() throws IOException {
617642
@Override
618643
public boolean advance() throws IOException {
619644
try {
620-
LOG.trace("MQTT reader (client ID {}) waiting message ...", client.getClientId());
621-
Message message = connection.receive(1, TimeUnit.SECONDS);
622-
if (message == null) {
645+
LOG.trace("MQTT reader (client ID {}) waiting message ...", clientId);
646+
Optional<Mqtt3Publish> messageOpt = publishes.receive(1, TimeUnit.SECONDS);
647+
if (!messageOpt.isPresent()) {
623648
return false;
624649
}
650+
Mqtt3Publish message = messageOpt.get();
625651
current = this.extractFn.apply(message);
626652
currentTimestamp = Instant.now();
627653
checkpointPreparer.add(message, currentTimestamp);
@@ -633,10 +659,13 @@ public boolean advance() throws IOException {
633659

634660
@Override
635661
public void close() throws IOException {
636-
LOG.debug("Closing MQTT reader (client ID {})", client.getClientId());
662+
LOG.debug("Closing MQTT reader (client ID {})", clientId);
637663
try {
638-
if (connection != null) {
639-
connection.disconnect();
664+
if (publishes != null) {
665+
publishes.close();
666+
}
667+
if (client != null) {
668+
client.disconnect();
640669
}
641670
} catch (Exception e) {
642671
throw new IOException(e);
@@ -764,8 +793,7 @@ private static class WriteFn<InputT> extends DoFn<InputT, Void> {
764793
private final SerializableFunction<InputT, byte[]> payloadFn;
765794
private final boolean retained;
766795

767-
private transient MQTT client;
768-
private transient BlockingConnection connection;
796+
private transient Mqtt3BlockingClient client;
769797

770798
public WriteFn(Write<InputT> spec) {
771799
this.spec = spec;
@@ -783,8 +811,9 @@ public WriteFn(Write<InputT> spec) {
783811
public void createMqttClient() throws Exception {
784812
LOG.debug("Starting MQTT writer");
785813
this.client = this.spec.connectionConfiguration().createClient();
786-
LOG.debug("MQTT writer client ID is {}", client.getClientId());
787-
this.connection = createConnection(client);
814+
String clientId = client.getConfig().getClientIdentifier().map(Object::toString).orElse("");
815+
LOG.debug("MQTT writer client ID is {}", clientId);
816+
this.client.connect();
788817
}
789818

790819
@ProcessElement
@@ -793,32 +822,25 @@ public void processElement(ProcessContext context) throws Exception {
793822
byte[] payload = this.payloadFn.apply(element);
794823
String topic = this.topicFn.apply(element);
795824
LOG.debug("Sending message {}", new String(payload, StandardCharsets.UTF_8));
796-
this.connection.publish(topic, payload, QoS.AT_LEAST_ONCE, this.retained);
825+
826+
client
827+
.publishWith()
828+
.topic(topic)
829+
.payload(payload)
830+
.qos(MqttQos.AT_LEAST_ONCE)
831+
.retain(this.retained)
832+
.send();
797833
}
798834

799835
@Teardown
800836
public void closeMqttClient() throws Exception {
801-
if (this.connection != null) {
802-
LOG.debug("Disconnecting MQTT connection (client ID {})", client.getClientId());
803-
this.connection.disconnect();
837+
if (this.client != null) {
838+
String clientId =
839+
client.getConfig().getClientIdentifier().map(Object::toString).orElse("");
840+
LOG.debug("Disconnecting MQTT connection (client ID {})", clientId);
841+
this.client.disconnect();
804842
}
805843
}
806844
}
807845
}
808-
809-
/** Create a connected MQTT BlockingConnection from given client, aware of connection timeout. */
810-
static BlockingConnection createConnection(MQTT client) throws Exception {
811-
FutureConnection futureConnection = client.futureConnection();
812-
org.fusesource.mqtt.client.Future<Void> connecting = futureConnection.connect();
813-
while (true) {
814-
try {
815-
connecting.await(1, TimeUnit.MINUTES);
816-
} catch (TimeoutException e) {
817-
LOG.warn("Connection to {} pending after waiting for 1 minute", client.getHost());
818-
continue;
819-
}
820-
break;
821-
}
822-
return new BlockingConnection(futureConnection);
823-
}
824846
}

0 commit comments

Comments
 (0)