Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 3
"modification": 4
}
3 changes: 1 addition & 2 deletions sdks/java/io/mqtt/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ dependencies {
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation library.java.slf4j_api
implementation library.java.joda_time
implementation "org.fusesource.mqtt-client:mqtt-client:1.15"
implementation "org.fusesource.hawtbuf:hawtbuf:1.11"
implementation "com.hivemq:hivemq-mqtt-client:1.3.15"
testImplementation project(path: ":sdks:java:io:common")
testImplementation library.java.activemq_broker
testImplementation library.java.activemq_mqtt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,24 @@
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;

import com.google.auto.value.AutoValue;
import com.hivemq.client.mqtt.MqttGlobalPublishFilter;
import com.hivemq.client.mqtt.datatypes.MqttQos;
import com.hivemq.client.mqtt.mqtt3.Mqtt3BlockingClient;
import com.hivemq.client.mqtt.mqtt3.Mqtt3Client;
import com.hivemq.client.mqtt.mqtt3.Mqtt3ClientBuilder;
import com.hivemq.client.mqtt.mqtt3.message.publish.Mqtt3Publish;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
Expand All @@ -52,12 +59,6 @@
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.fusesource.mqtt.client.BlockingConnection;
import org.fusesource.mqtt.client.FutureConnection;
import org.fusesource.mqtt.client.MQTT;
import org.fusesource.mqtt.client.Message;
import org.fusesource.mqtt.client.QoS;
import org.fusesource.mqtt.client.Topic;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
Expand Down Expand Up @@ -299,29 +300,42 @@ private void populateDisplayData(DisplayData.Builder builder) {
builder.addIfNotNull(DisplayData.item("username", getUsername()));
}

private MQTT createClient() throws Exception {
private Mqtt3BlockingClient createClient() throws Exception {
LOG.debug("Creating MQTT client to {}", getServerUri());
MQTT client = new MQTT();
client.setHost(getServerUri());
if (getUsername() != null) {
LOG.debug("MQTT client uses username {}", getUsername());
client.setUserName(getUsername());
client.setPassword(getPassword());
URI uri = new URI(getServerUri());
String host = uri.getHost();
int port = uri.getPort();
if (port == -1) {
port = "ssl".equals(uri.getScheme()) || "tls".equals(uri.getScheme()) ? 8883 : 1883;
}

Mqtt3ClientBuilder builder = Mqtt3Client.builder().serverHost(host).serverPort(port);

if ("ssl".equals(uri.getScheme()) || "tls".equals(uri.getScheme())) {
builder = builder.sslWithDefaultConfig();
}
if (getClientId() != null) {
String clientId = getClientId() + "-" + UUID.randomUUID().toString();
clientId =
clientId.substring(0, Math.min(clientId.length(), MQTT_3_1_MAX_CLIENT_ID_LENGTH));
LOG.debug("MQTT client id set to {}", clientId);
client.setClientId(clientId);

String clientId = getClientId();
if (clientId == null) {
clientId = UUID.randomUUID().toString();
} else {
String clientId = UUID.randomUUID().toString();
clientId =
clientId.substring(0, Math.min(clientId.length(), MQTT_3_1_MAX_CLIENT_ID_LENGTH));
LOG.debug("MQTT client id set to random value {}", clientId);
client.setClientId(clientId);
clientId = clientId + "-" + UUID.randomUUID().toString();
}
return client;
clientId = clientId.substring(0, Math.min(clientId.length(), MQTT_3_1_MAX_CLIENT_ID_LENGTH));
LOG.debug("MQTT client id set to {}", clientId);
builder = builder.identifier(clientId);

if (getUsername() != null) {
LOG.debug("MQTT client uses username {}", getUsername());
builder =
builder
.simpleAuth()
.username(getUsername())
.password(getPassword().getBytes(StandardCharsets.UTF_8))
.applySimpleAuth();
}
Comment on lines +328 to +336

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If getPassword() is null, calling getPassword().getBytes(StandardCharsets.UTF_8) will throw a NullPointerException. We should check if the password is non-null before attempting to set it on the client builder.

      if (getUsername() != null) {
        LOG.debug("MQTT client uses username {}", getUsername());
        if (getPassword() != null) {
          builder =
              builder
                  .simpleAuth()
                  .username(getUsername())
                  .password(getPassword().getBytes(StandardCharsets.UTF_8))
                  .applySimpleAuth();
        } else {
          builder = builder.simpleAuth().username(getUsername()).applySimpleAuth();
        }
      }


return builder.buildBlocking();
}
}

Expand Down Expand Up @@ -429,9 +443,9 @@ public void populateDisplayData(DisplayData.Builder builder) {
static class MqttCheckpointMark implements UnboundedSource.CheckpointMark, Serializable {

@VisibleForTesting String clientId;
@VisibleForTesting transient List<Message> messages = new ArrayList<>();
@VisibleForTesting transient List<Mqtt3Publish> messages = new ArrayList<>();

public MqttCheckpointMark(String id, List<Message> messages) {
public MqttCheckpointMark(String id, List<Mqtt3Publish> messages) {
this.clientId = id;
this.messages = messages;
}
Expand All @@ -444,9 +458,9 @@ public MqttCheckpointMark(String id, List<Message> messages) {
@Override
public void finalizeCheckpoint() {
LOG.debug("Finalizing checkpoint acknowledging pending messages for client ID {}", clientId);
for (Message message : messages) {
for (Mqtt3Publish message : messages) {
try {
message.ack();
message.acknowledge();
} catch (Exception e) {
LOG.warn("Can't ack message for client ID {}", clientId, e);
}
Expand Down Expand Up @@ -480,7 +494,7 @@ public int hashCode() {
static class Preparer {
@VisibleForTesting String clientId;
@VisibleForTesting Instant oldestMessageTimestamp = Instant.now();
@VisibleForTesting transient List<Message> messages = new ArrayList<>();
@VisibleForTesting transient List<Mqtt3Publish> messages = new ArrayList<>();

public Preparer(MqttCheckpointMark checkpointMark) {
clientId = checkpointMark.clientId;
Expand All @@ -493,15 +507,15 @@ public Preparer(String id) {

public Preparer() {}

public void add(Message message, Instant timestamp) {
public void add(Mqtt3Publish message, Instant timestamp) {
if (timestamp.isBefore(oldestMessageTimestamp)) {
oldestMessageTimestamp = timestamp;
}
messages.add(message);
}

MqttCheckpointMark newCheckpoint() {
List<Message> currentMessages = messages;
List<Mqtt3Publish> currentMessages = messages;
messages = new ArrayList<>();
oldestMessageTimestamp = Instant.now();
return new MqttCheckpointMark(clientId, currentMessages);
Expand Down Expand Up @@ -532,7 +546,8 @@ public UnboundedReader<T> createReader(
new UnboundedMqttReader<>(
this,
preparer,
message -> (T) MqttRecord.of(message.getTopic(), message.getPayload()));
message ->
(T) MqttRecord.of(message.getTopic().toString(), message.getPayloadAsBytes()));
} else {
unboundedMqttReader = new UnboundedMqttReader<>(this, preparer);
}
Expand Down Expand Up @@ -570,12 +585,13 @@ static class UnboundedMqttReader<T> extends UnboundedSource.UnboundedReader<T> {

private final UnboundedMqttSource<T> source;

private MQTT client;
private BlockingConnection connection;
private Mqtt3BlockingClient client;
private Mqtt3BlockingClient.Mqtt3Publishes publishes;
private String clientId = "";
private T current;
private Instant currentTimestamp;
private final MqttCheckpointMark.Preparer checkpointPreparer;
private SerializableFunction<Message, T> extractFn;
private SerializableFunction<Mqtt3Publish, T> extractFn;

public UnboundedMqttReader(
UnboundedMqttSource<T> source, MqttCheckpointMark.Preparer checkpointPreparer) {
Expand All @@ -586,13 +602,13 @@ public UnboundedMqttReader(
} else {
this.checkpointPreparer = new MqttCheckpointMark.Preparer();
}
this.extractFn = message -> (T) message.getPayload();
this.extractFn = message -> (T) message.getPayloadAsBytes();
}

public UnboundedMqttReader(
UnboundedMqttSource<T> source,
MqttCheckpointMark.Preparer checkpointPreparer,
SerializableFunction<Message, T> extractFn) {
SerializableFunction<Mqtt3Publish, T> extractFn) {
this(source, checkpointPreparer);
this.extractFn = extractFn;
}
Expand All @@ -603,11 +619,20 @@ public boolean start() throws IOException {
Read<T> spec = source.spec;
try {
client = spec.connectionConfiguration().createClient();
LOG.debug("Reader client ID is {}", client.getClientId());
checkpointPreparer.clientId = client.getClientId().toString();
connection = createConnection(client);
connection.subscribe(
new Topic[] {new Topic(spec.connectionConfiguration().getTopic(), QoS.AT_LEAST_ONCE)});
this.clientId = client.getConfig().getClientIdentifier().map(Object::toString).orElse("");
LOG.debug("Reader client ID is {}", clientId);
checkpointPreparer.clientId = clientId;
client.connect();

// Subscribe and get the publishes stream with manual acks enabled
publishes = client.publishes(MqttGlobalPublishFilter.ALL, true);

client
.subscribeWith()
.topicFilter(spec.connectionConfiguration().getTopic())
.qos(MqttQos.AT_LEAST_ONCE)
.send();
Comment on lines +625 to +634

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If an exception occurs after client.connect() succeeds (e.g., during subscription or getting the publishes stream), the client connection is leaked because it is never disconnected. We should catch exceptions after connection and disconnect the client before rethrowing.

        client.connect();
        try {
          // Subscribe and get the publishes stream with manual acks enabled
          publishes = client.publishes(MqttGlobalPublishFilter.ALL, true);

          client
              .subscribeWith()
              .topicFilter(spec.connectionConfiguration().getTopic())
              .qos(MqttQos.AT_LEAST_ONCE)
              .send();
        } catch (Exception e) {
          try {
            client.disconnect();
          } catch (Exception disconnectEx) {
            e.addSuppressed(disconnectEx);
          }
          throw e;
        }


return advance();
} catch (Exception e) {
throw new IOException(e);
Expand All @@ -617,11 +642,12 @@ public boolean start() throws IOException {
@Override
public boolean advance() throws IOException {
try {
LOG.trace("MQTT reader (client ID {}) waiting message ...", client.getClientId());
Message message = connection.receive(1, TimeUnit.SECONDS);
if (message == null) {
LOG.trace("MQTT reader (client ID {}) waiting message ...", clientId);
Optional<Mqtt3Publish> messageOpt = publishes.receive(1, TimeUnit.SECONDS);
if (!messageOpt.isPresent()) {
return false;
}
Mqtt3Publish message = messageOpt.get();
current = this.extractFn.apply(message);
currentTimestamp = Instant.now();
checkpointPreparer.add(message, currentTimestamp);
Expand All @@ -633,10 +659,13 @@ public boolean advance() throws IOException {

@Override
public void close() throws IOException {
LOG.debug("Closing MQTT reader (client ID {})", client.getClientId());
LOG.debug("Closing MQTT reader (client ID {})", clientId);
try {
if (connection != null) {
connection.disconnect();
if (publishes != null) {
publishes.close();
}
if (client != null) {
client.disconnect();
}
Comment on lines +664 to 669

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If publishes.close() throws an exception, client.disconnect() will not be called, which can lead to connection leaks. We should ensure both resources are closed/disconnected even if one of them fails.

        if (publishes != null) {
          try {
            publishes.close();
          } catch (Exception e) {
            LOG.warn("Error closing publishes stream", e);
          }
        }
        if (client != null) {
          client.disconnect();
        }

} catch (Exception e) {
throw new IOException(e);
Expand Down Expand Up @@ -764,8 +793,7 @@ private static class WriteFn<InputT> extends DoFn<InputT, Void> {
private final SerializableFunction<InputT, byte[]> payloadFn;
private final boolean retained;

private transient MQTT client;
private transient BlockingConnection connection;
private transient Mqtt3BlockingClient client;

public WriteFn(Write<InputT> spec) {
this.spec = spec;
Expand All @@ -783,8 +811,9 @@ public WriteFn(Write<InputT> spec) {
public void createMqttClient() throws Exception {
LOG.debug("Starting MQTT writer");
this.client = this.spec.connectionConfiguration().createClient();
LOG.debug("MQTT writer client ID is {}", client.getClientId());
this.connection = createConnection(client);
String clientId = client.getConfig().getClientIdentifier().map(Object::toString).orElse("");
LOG.debug("MQTT writer client ID is {}", clientId);
this.client.connect();
}

@ProcessElement
Expand All @@ -793,32 +822,25 @@ public void processElement(ProcessContext context) throws Exception {
byte[] payload = this.payloadFn.apply(element);
String topic = this.topicFn.apply(element);
LOG.debug("Sending message {}", new String(payload, StandardCharsets.UTF_8));
this.connection.publish(topic, payload, QoS.AT_LEAST_ONCE, this.retained);

client
.publishWith()
.topic(topic)
.payload(payload)
.qos(MqttQos.AT_LEAST_ONCE)
.retain(this.retained)
.send();
}

@Teardown
public void closeMqttClient() throws Exception {
if (this.connection != null) {
LOG.debug("Disconnecting MQTT connection (client ID {})", client.getClientId());
this.connection.disconnect();
if (this.client != null) {
String clientId =
client.getConfig().getClientIdentifier().map(Object::toString).orElse("");
LOG.debug("Disconnecting MQTT connection (client ID {})", clientId);
this.client.disconnect();
}
}
}
}

/** Create a connected MQTT BlockingConnection from given client, aware of connection timeout. */
static BlockingConnection createConnection(MQTT client) throws Exception {
FutureConnection futureConnection = client.futureConnection();
org.fusesource.mqtt.client.Future<Void> connecting = futureConnection.connect();
while (true) {
try {
connecting.await(1, TimeUnit.MINUTES);
} catch (TimeoutException e) {
LOG.warn("Connection to {} pending after waiting for 1 minute", client.getHost());
continue;
}
break;
}
return new BlockingConnection(futureConnection);
}
}
Loading
Loading