2121import static org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .base .Preconditions .checkNotNull ;
2222
2323import com .google .auto .value .AutoValue ;
24+ import com .hivemq .client .mqtt .MqttGlobalPublishFilter ;
25+ import com .hivemq .client .mqtt .datatypes .MqttQos ;
26+ import com .hivemq .client .mqtt .mqtt3 .Mqtt3BlockingClient ;
27+ import com .hivemq .client .mqtt .mqtt3 .Mqtt3Client ;
28+ import com .hivemq .client .mqtt .mqtt3 .Mqtt3ClientBuilder ;
29+ import com .hivemq .client .mqtt .mqtt3 .message .publish .Mqtt3Publish ;
2430import java .io .IOException ;
2531import java .io .Serializable ;
32+ import java .net .URI ;
2633import java .nio .charset .StandardCharsets ;
2734import java .util .ArrayList ;
2835import java .util .Collections ;
2936import java .util .List ;
3037import java .util .NoSuchElementException ;
3138import java .util .Objects ;
39+ import java .util .Optional ;
3240import java .util .UUID ;
3341import java .util .concurrent .TimeUnit ;
34- import java .util .concurrent .TimeoutException ;
3542import org .apache .beam .sdk .coders .ByteArrayCoder ;
3643import org .apache .beam .sdk .coders .Coder ;
3744import org .apache .beam .sdk .coders .SerializableCoder ;
5259import org .apache .beam .sdk .values .PDone ;
5360import org .apache .beam .vendor .guava .v32_1_2_jre .com .google .common .annotations .VisibleForTesting ;
5461import org .checkerframework .checker .nullness .qual .Nullable ;
55- import org .fusesource .mqtt .client .BlockingConnection ;
56- import org .fusesource .mqtt .client .FutureConnection ;
57- import org .fusesource .mqtt .client .MQTT ;
58- import org .fusesource .mqtt .client .Message ;
59- import org .fusesource .mqtt .client .QoS ;
60- import org .fusesource .mqtt .client .Topic ;
6162import org .joda .time .Duration ;
6263import org .joda .time .Instant ;
6364import org .slf4j .Logger ;
@@ -299,29 +300,42 @@ private void populateDisplayData(DisplayData.Builder builder) {
299300 builder .addIfNotNull (DisplayData .item ("username" , getUsername ()));
300301 }
301302
302- private MQTT createClient () throws Exception {
303+ private Mqtt3BlockingClient createClient () throws Exception {
303304 LOG .debug ("Creating MQTT client to {}" , getServerUri ());
304- MQTT client = new MQTT ();
305- client .setHost (getServerUri ());
306- if (getUsername () != null ) {
307- LOG .debug ("MQTT client uses username {}" , getUsername ());
308- client .setUserName (getUsername ());
309- client .setPassword (getPassword ());
305+ URI uri = new URI (getServerUri ());
306+ String host = uri .getHost ();
307+ int port = uri .getPort ();
308+ if (port == -1 ) {
309+ port = "ssl" .equals (uri .getScheme ()) || "tls" .equals (uri .getScheme ()) ? 8883 : 1883 ;
310+ }
311+
312+ Mqtt3ClientBuilder builder = Mqtt3Client .builder ().serverHost (host ).serverPort (port );
313+
314+ if ("ssl" .equals (uri .getScheme ()) || "tls" .equals (uri .getScheme ())) {
315+ builder = builder .sslWithDefaultConfig ();
310316 }
311- if (getClientId () != null ) {
312- String clientId = getClientId () + "-" + UUID .randomUUID ().toString ();
313- clientId =
314- clientId .substring (0 , Math .min (clientId .length (), MQTT_3_1_MAX_CLIENT_ID_LENGTH ));
315- LOG .debug ("MQTT client id set to {}" , clientId );
316- client .setClientId (clientId );
317+
318+ String clientId = getClientId ();
319+ if (clientId == null ) {
320+ clientId = UUID .randomUUID ().toString ();
317321 } else {
318- String clientId = UUID .randomUUID ().toString ();
319- clientId =
320- clientId .substring (0 , Math .min (clientId .length (), MQTT_3_1_MAX_CLIENT_ID_LENGTH ));
321- LOG .debug ("MQTT client id set to random value {}" , clientId );
322- client .setClientId (clientId );
322+ clientId = clientId + "-" + UUID .randomUUID ().toString ();
323323 }
324- return client ;
324+ clientId = clientId .substring (0 , Math .min (clientId .length (), MQTT_3_1_MAX_CLIENT_ID_LENGTH ));
325+ LOG .debug ("MQTT client id set to {}" , clientId );
326+ builder = builder .identifier (clientId );
327+
328+ if (getUsername () != null ) {
329+ LOG .debug ("MQTT client uses username {}" , getUsername ());
330+ builder =
331+ builder
332+ .simpleAuth ()
333+ .username (getUsername ())
334+ .password (getPassword ().getBytes (StandardCharsets .UTF_8 ))
335+ .applySimpleAuth ();
336+ }
337+
338+ return builder .buildBlocking ();
325339 }
326340 }
327341
@@ -429,9 +443,9 @@ public void populateDisplayData(DisplayData.Builder builder) {
429443 static class MqttCheckpointMark implements UnboundedSource .CheckpointMark , Serializable {
430444
431445 @ VisibleForTesting String clientId ;
432- @ VisibleForTesting transient List <Message > messages = new ArrayList <>();
446+ @ VisibleForTesting transient List <Mqtt3Publish > messages = new ArrayList <>();
433447
434- public MqttCheckpointMark (String id , List <Message > messages ) {
448+ public MqttCheckpointMark (String id , List <Mqtt3Publish > messages ) {
435449 this .clientId = id ;
436450 this .messages = messages ;
437451 }
@@ -444,9 +458,9 @@ public MqttCheckpointMark(String id, List<Message> messages) {
444458 @ Override
445459 public void finalizeCheckpoint () {
446460 LOG .debug ("Finalizing checkpoint acknowledging pending messages for client ID {}" , clientId );
447- for (Message message : messages ) {
461+ for (Mqtt3Publish message : messages ) {
448462 try {
449- message .ack ();
463+ message .acknowledge ();
450464 } catch (Exception e ) {
451465 LOG .warn ("Can't ack message for client ID {}" , clientId , e );
452466 }
@@ -480,7 +494,7 @@ public int hashCode() {
480494 static class Preparer {
481495 @ VisibleForTesting String clientId ;
482496 @ VisibleForTesting Instant oldestMessageTimestamp = Instant .now ();
483- @ VisibleForTesting transient List <Message > messages = new ArrayList <>();
497+ @ VisibleForTesting transient List <Mqtt3Publish > messages = new ArrayList <>();
484498
485499 public Preparer (MqttCheckpointMark checkpointMark ) {
486500 clientId = checkpointMark .clientId ;
@@ -493,15 +507,15 @@ public Preparer(String id) {
493507
494508 public Preparer () {}
495509
496- public void add (Message message , Instant timestamp ) {
510+ public void add (Mqtt3Publish message , Instant timestamp ) {
497511 if (timestamp .isBefore (oldestMessageTimestamp )) {
498512 oldestMessageTimestamp = timestamp ;
499513 }
500514 messages .add (message );
501515 }
502516
503517 MqttCheckpointMark newCheckpoint () {
504- List <Message > currentMessages = messages ;
518+ List <Mqtt3Publish > currentMessages = messages ;
505519 messages = new ArrayList <>();
506520 oldestMessageTimestamp = Instant .now ();
507521 return new MqttCheckpointMark (clientId , currentMessages );
@@ -532,7 +546,8 @@ public UnboundedReader<T> createReader(
532546 new UnboundedMqttReader <>(
533547 this ,
534548 preparer ,
535- message -> (T ) MqttRecord .of (message .getTopic (), message .getPayload ()));
549+ message ->
550+ (T ) MqttRecord .of (message .getTopic ().toString (), message .getPayloadAsBytes ()));
536551 } else {
537552 unboundedMqttReader = new UnboundedMqttReader <>(this , preparer );
538553 }
@@ -570,12 +585,13 @@ static class UnboundedMqttReader<T> extends UnboundedSource.UnboundedReader<T> {
570585
571586 private final UnboundedMqttSource <T > source ;
572587
573- private MQTT client ;
574- private BlockingConnection connection ;
588+ private Mqtt3BlockingClient client ;
589+ private Mqtt3BlockingClient .Mqtt3Publishes publishes ;
590+ private String clientId = "" ;
575591 private T current ;
576592 private Instant currentTimestamp ;
577593 private final MqttCheckpointMark .Preparer checkpointPreparer ;
578- private SerializableFunction <Message , T > extractFn ;
594+ private SerializableFunction <Mqtt3Publish , T > extractFn ;
579595
580596 public UnboundedMqttReader (
581597 UnboundedMqttSource <T > source , MqttCheckpointMark .Preparer checkpointPreparer ) {
@@ -586,13 +602,13 @@ public UnboundedMqttReader(
586602 } else {
587603 this .checkpointPreparer = new MqttCheckpointMark .Preparer ();
588604 }
589- this .extractFn = message -> (T ) message .getPayload ();
605+ this .extractFn = message -> (T ) message .getPayloadAsBytes ();
590606 }
591607
592608 public UnboundedMqttReader (
593609 UnboundedMqttSource <T > source ,
594610 MqttCheckpointMark .Preparer checkpointPreparer ,
595- SerializableFunction <Message , T > extractFn ) {
611+ SerializableFunction <Mqtt3Publish , T > extractFn ) {
596612 this (source , checkpointPreparer );
597613 this .extractFn = extractFn ;
598614 }
@@ -603,11 +619,20 @@ public boolean start() throws IOException {
603619 Read <T > spec = source .spec ;
604620 try {
605621 client = spec .connectionConfiguration ().createClient ();
606- LOG .debug ("Reader client ID is {}" , client .getClientId ());
607- checkpointPreparer .clientId = client .getClientId ().toString ();
608- connection = createConnection (client );
609- connection .subscribe (
610- new Topic [] {new Topic (spec .connectionConfiguration ().getTopic (), QoS .AT_LEAST_ONCE )});
622+ this .clientId = client .getConfig ().getClientIdentifier ().map (Object ::toString ).orElse ("" );
623+ LOG .debug ("Reader client ID is {}" , clientId );
624+ checkpointPreparer .clientId = clientId ;
625+ client .connect ();
626+
627+ // Subscribe and get the publishes stream with manual acks enabled
628+ publishes = client .publishes (MqttGlobalPublishFilter .ALL , true );
629+
630+ client
631+ .subscribeWith ()
632+ .topicFilter (spec .connectionConfiguration ().getTopic ())
633+ .qos (MqttQos .AT_LEAST_ONCE )
634+ .send ();
635+
611636 return advance ();
612637 } catch (Exception e ) {
613638 throw new IOException (e );
@@ -617,11 +642,12 @@ public boolean start() throws IOException {
617642 @ Override
618643 public boolean advance () throws IOException {
619644 try {
620- LOG .trace ("MQTT reader (client ID {}) waiting message ..." , client . getClientId () );
621- Message message = connection .receive (1 , TimeUnit .SECONDS );
622- if (message == null ) {
645+ LOG .trace ("MQTT reader (client ID {}) waiting message ..." , clientId );
646+ Optional < Mqtt3Publish > messageOpt = publishes .receive (1 , TimeUnit .SECONDS );
647+ if (! messageOpt . isPresent () ) {
623648 return false ;
624649 }
650+ Mqtt3Publish message = messageOpt .get ();
625651 current = this .extractFn .apply (message );
626652 currentTimestamp = Instant .now ();
627653 checkpointPreparer .add (message , currentTimestamp );
@@ -633,10 +659,13 @@ public boolean advance() throws IOException {
633659
634660 @ Override
635661 public void close () throws IOException {
636- LOG .debug ("Closing MQTT reader (client ID {})" , client . getClientId () );
662+ LOG .debug ("Closing MQTT reader (client ID {})" , clientId );
637663 try {
638- if (connection != null ) {
639- connection .disconnect ();
664+ if (publishes != null ) {
665+ publishes .close ();
666+ }
667+ if (client != null ) {
668+ client .disconnect ();
640669 }
641670 } catch (Exception e ) {
642671 throw new IOException (e );
@@ -764,8 +793,7 @@ private static class WriteFn<InputT> extends DoFn<InputT, Void> {
764793 private final SerializableFunction <InputT , byte []> payloadFn ;
765794 private final boolean retained ;
766795
767- private transient MQTT client ;
768- private transient BlockingConnection connection ;
796+ private transient Mqtt3BlockingClient client ;
769797
770798 public WriteFn (Write <InputT > spec ) {
771799 this .spec = spec ;
@@ -783,8 +811,9 @@ public WriteFn(Write<InputT> spec) {
783811 public void createMqttClient () throws Exception {
784812 LOG .debug ("Starting MQTT writer" );
785813 this .client = this .spec .connectionConfiguration ().createClient ();
786- LOG .debug ("MQTT writer client ID is {}" , client .getClientId ());
787- this .connection = createConnection (client );
814+ String clientId = client .getConfig ().getClientIdentifier ().map (Object ::toString ).orElse ("" );
815+ LOG .debug ("MQTT writer client ID is {}" , clientId );
816+ this .client .connect ();
788817 }
789818
790819 @ ProcessElement
@@ -793,32 +822,25 @@ public void processElement(ProcessContext context) throws Exception {
793822 byte [] payload = this .payloadFn .apply (element );
794823 String topic = this .topicFn .apply (element );
795824 LOG .debug ("Sending message {}" , new String (payload , StandardCharsets .UTF_8 ));
796- this .connection .publish (topic , payload , QoS .AT_LEAST_ONCE , this .retained );
825+
826+ client
827+ .publishWith ()
828+ .topic (topic )
829+ .payload (payload )
830+ .qos (MqttQos .AT_LEAST_ONCE )
831+ .retain (this .retained )
832+ .send ();
797833 }
798834
799835 @ Teardown
800836 public void closeMqttClient () throws Exception {
801- if (this .connection != null ) {
802- LOG .debug ("Disconnecting MQTT connection (client ID {})" , client .getClientId ());
803- this .connection .disconnect ();
837+ if (this .client != null ) {
838+ String clientId =
839+ client .getConfig ().getClientIdentifier ().map (Object ::toString ).orElse ("" );
840+ LOG .debug ("Disconnecting MQTT connection (client ID {})" , clientId );
841+ this .client .disconnect ();
804842 }
805843 }
806844 }
807845 }
808-
809- /** Create a connected MQTT BlockingConnection from given client, aware of connection timeout. */
810- static BlockingConnection createConnection (MQTT client ) throws Exception {
811- FutureConnection futureConnection = client .futureConnection ();
812- org .fusesource .mqtt .client .Future <Void > connecting = futureConnection .connect ();
813- while (true ) {
814- try {
815- connecting .await (1 , TimeUnit .MINUTES );
816- } catch (TimeoutException e ) {
817- LOG .warn ("Connection to {} pending after waiting for 1 minute" , client .getHost ());
818- continue ;
819- }
820- break ;
821- }
822- return new BlockingConnection (futureConnection );
823- }
824846}
0 commit comments