Skip to content

Commit 56dac75

Browse files
authored
Fix MqttIO read checkpoint logic (apache#36056)
* Fix MqttIO read checkpoint logic * add tests
1 parent ed383fa commit 56dac75

2 files changed

Lines changed: 115 additions & 28 deletions

File tree

sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -422,20 +422,16 @@ public void populateDisplayData(DisplayData.Builder builder) {
422422
static class MqttCheckpointMark implements UnboundedSource.CheckpointMark, Serializable {
423423

424424
@VisibleForTesting String clientId;
425-
@VisibleForTesting Instant oldestMessageTimestamp = Instant.now();
426425
@VisibleForTesting transient List<Message> messages = new ArrayList<>();
427426

428-
public MqttCheckpointMark() {}
429-
430-
public MqttCheckpointMark(String id) {
431-
clientId = id;
427+
public MqttCheckpointMark(String id, List<Message> messages) {
428+
this.clientId = id;
429+
this.messages = messages;
432430
}
433431

434-
public void add(Message message, Instant timestamp) {
435-
if (timestamp.isBefore(oldestMessageTimestamp)) {
436-
oldestMessageTimestamp = timestamp;
437-
}
438-
messages.add(message);
432+
@VisibleForTesting
433+
MqttCheckpointMark(String id) {
434+
this.clientId = id;
439435
}
440436

441437
@Override
@@ -448,7 +444,6 @@ public void finalizeCheckpoint() {
448444
LOG.warn("Can't ack message for client ID {}", clientId, e);
449445
}
450446
}
451-
oldestMessageTimestamp = Instant.now();
452447
messages.clear();
453448
}
454449

@@ -464,7 +459,6 @@ public boolean equals(@Nullable Object other) {
464459
if (other instanceof MqttCheckpointMark) {
465460
MqttCheckpointMark that = (MqttCheckpointMark) other;
466461
return Objects.equals(this.clientId, that.clientId)
467-
&& Objects.equals(this.oldestMessageTimestamp, that.oldestMessageTimestamp)
468462
&& Objects.deepEquals(this.messages, that.messages);
469463
} else {
470464
return false;
@@ -473,7 +467,38 @@ public boolean equals(@Nullable Object other) {
473467

474468
@Override
475469
public int hashCode() {
476-
return Objects.hash(clientId, oldestMessageTimestamp, messages);
470+
return Objects.hash(clientId, messages);
471+
}
472+
473+
static class Preparer {
474+
@VisibleForTesting String clientId;
475+
@VisibleForTesting Instant oldestMessageTimestamp = Instant.now();
476+
@VisibleForTesting transient List<Message> messages = new ArrayList<>();
477+
478+
public Preparer(MqttCheckpointMark checkpointMark) {
479+
clientId = checkpointMark.clientId;
480+
messages = checkpointMark.messages;
481+
}
482+
483+
public Preparer(String id) {
484+
clientId = id;
485+
}
486+
487+
public Preparer() {}
488+
489+
public void add(Message message, Instant timestamp) {
490+
if (timestamp.isBefore(oldestMessageTimestamp)) {
491+
oldestMessageTimestamp = timestamp;
492+
}
493+
messages.add(message);
494+
}
495+
496+
MqttCheckpointMark newCheckpoint() {
497+
List<Message> currentMessages = messages;
498+
messages = new ArrayList<>();
499+
oldestMessageTimestamp = Instant.now();
500+
return new MqttCheckpointMark(clientId, currentMessages);
501+
}
477502
}
478503
}
479504

@@ -489,16 +514,20 @@ public UnboundedMqttSource(Read<T> spec) {
489514
@Override
490515
@SuppressWarnings("unchecked")
491516
public UnboundedReader<T> createReader(
492-
PipelineOptions options, MqttCheckpointMark checkpointMark) {
517+
PipelineOptions options, @Nullable MqttCheckpointMark checkpointMark) {
493518
final UnboundedMqttReader<T> unboundedMqttReader;
519+
MqttCheckpointMark.Preparer preparer =
520+
checkpointMark == null
521+
? new MqttCheckpointMark.Preparer()
522+
: new MqttCheckpointMark.Preparer(checkpointMark);
494523
if (spec.withMetadata()) {
495524
unboundedMqttReader =
496525
new UnboundedMqttReader<>(
497526
this,
498-
checkpointMark,
527+
preparer,
499528
message -> (T) MqttRecord.of(message.getTopic(), message.getPayload()));
500529
} else {
501-
unboundedMqttReader = new UnboundedMqttReader<>(this, checkpointMark);
530+
unboundedMqttReader = new UnboundedMqttReader<>(this, preparer);
502531
}
503532

504533
return unboundedMqttReader;
@@ -538,25 +567,26 @@ static class UnboundedMqttReader<T> extends UnboundedSource.UnboundedReader<T> {
538567
private BlockingConnection connection;
539568
private T current;
540569
private Instant currentTimestamp;
541-
private MqttCheckpointMark checkpointMark;
570+
private final MqttCheckpointMark.Preparer checkpointPreparer;
542571
private SerializableFunction<Message, T> extractFn;
543572

544-
public UnboundedMqttReader(UnboundedMqttSource<T> source, MqttCheckpointMark checkpointMark) {
573+
public UnboundedMqttReader(
574+
UnboundedMqttSource<T> source, MqttCheckpointMark.Preparer checkpointPreparer) {
545575
this.source = source;
546576
this.current = null;
547-
if (checkpointMark != null) {
548-
this.checkpointMark = checkpointMark;
577+
if (checkpointPreparer != null) {
578+
this.checkpointPreparer = checkpointPreparer;
549579
} else {
550-
this.checkpointMark = new MqttCheckpointMark();
580+
this.checkpointPreparer = new MqttCheckpointMark.Preparer();
551581
}
552582
this.extractFn = message -> (T) message.getPayload();
553583
}
554584

555585
public UnboundedMqttReader(
556586
UnboundedMqttSource<T> source,
557-
MqttCheckpointMark checkpointMark,
587+
MqttCheckpointMark.Preparer checkpointPreparer,
558588
SerializableFunction<Message, T> extractFn) {
559-
this(source, checkpointMark);
589+
this(source, checkpointPreparer);
560590
this.extractFn = extractFn;
561591
}
562592

@@ -567,7 +597,7 @@ public boolean start() throws IOException {
567597
try {
568598
client = spec.connectionConfiguration().createClient();
569599
LOG.debug("Reader client ID is {}", client.getClientId());
570-
checkpointMark.clientId = client.getClientId().toString();
600+
checkpointPreparer.clientId = client.getClientId().toString();
571601
connection = createConnection(client);
572602
connection.subscribe(
573603
new Topic[] {new Topic(spec.connectionConfiguration().getTopic(), QoS.AT_LEAST_ONCE)});
@@ -587,7 +617,7 @@ public boolean advance() throws IOException {
587617
}
588618
current = this.extractFn.apply(message);
589619
currentTimestamp = Instant.now();
590-
checkpointMark.add(message, currentTimestamp);
620+
checkpointPreparer.add(message, currentTimestamp);
591621
} catch (Exception e) {
592622
throw new IOException(e);
593623
}
@@ -608,12 +638,12 @@ public void close() throws IOException {
608638

609639
@Override
610640
public Instant getWatermark() {
611-
return checkpointMark.oldestMessageTimestamp;
641+
return checkpointPreparer.oldestMessageTimestamp;
612642
}
613643

614644
@Override
615645
public UnboundedSource.CheckpointMark getCheckpointMark() {
616-
return checkpointMark;
646+
return checkpointPreparer.newCheckpoint();
617647
}
618648

619649
@Override

sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.io.ObjectOutputStream;
2828
import java.nio.charset.StandardCharsets;
2929
import java.util.ArrayList;
30+
import java.util.Arrays;
3031
import java.util.Collection;
3132
import java.util.List;
3233
import java.util.Map;
@@ -50,11 +51,13 @@
5051
import org.apache.beam.sdk.values.PCollection;
5152
import org.fusesource.hawtbuf.Buffer;
5253
import org.fusesource.mqtt.client.BlockingConnection;
54+
import org.fusesource.mqtt.client.Callback;
5355
import org.fusesource.mqtt.client.MQTT;
5456
import org.fusesource.mqtt.client.Message;
5557
import org.fusesource.mqtt.client.QoS;
5658
import org.fusesource.mqtt.client.Topic;
5759
import org.joda.time.Duration;
60+
import org.joda.time.Instant;
5861
import org.junit.After;
5962
import org.junit.Before;
6063
import org.junit.Ignore;
@@ -286,6 +289,61 @@ public void testReceiveWithTimeoutAndNoData() throws Exception {
286289
pipeline.run();
287290
}
288291

292+
private static class FakeMessage extends Message {
293+
294+
private int ackCount;
295+
296+
public FakeMessage() {
297+
super(null, null, null, null);
298+
this.ackCount = 0;
299+
}
300+
301+
@Override
302+
public void ack() {
303+
++ackCount;
304+
}
305+
306+
@Override
307+
public void ack(final Callback<Void> unused) {
308+
++ackCount;
309+
}
310+
311+
public int getAckCount() {
312+
return ackCount;
313+
}
314+
}
315+
316+
@Test
317+
public void testReadCheckpoint() {
318+
MqttIO.MqttCheckpointMark.Preparer preparer = new MqttIO.MqttCheckpointMark.Preparer("id");
319+
ArrayList<Message> messages = new ArrayList<>();
320+
for (int i = 0; i < 5; ++i) {
321+
messages.add(new FakeMessage());
322+
}
323+
preparer.add(messages.get(0), Instant.ofEpochMilli(20));
324+
preparer.add(messages.get(1), Instant.ofEpochMilli(10));
325+
preparer.add(messages.get(2), Instant.ofEpochMilli(30));
326+
assertEquals(Instant.ofEpochMilli(10), preparer.oldestMessageTimestamp);
327+
MqttIO.MqttCheckpointMark checkpointA = preparer.newCheckpoint();
328+
preparer.add(messages.get(3), Instant.ofEpochMilli(40));
329+
preparer.add(messages.get(4), Instant.ofEpochMilli(50));
330+
MqttIO.MqttCheckpointMark checkpointB = preparer.newCheckpoint();
331+
assertTrue(
332+
Arrays.stream(messages.toArray()).allMatch((m -> ((FakeMessage) m).getAckCount() == 0)));
333+
checkpointA.finalizeCheckpoint();
334+
// only messages in finalized checkpoint acked
335+
assertTrue(
336+
Arrays.stream(messages.subList(0, 3).toArray())
337+
.allMatch((m -> ((FakeMessage) m).getAckCount() == 1)));
338+
assertTrue(
339+
Arrays.stream(messages.subList(3, 5).toArray())
340+
.allMatch((m -> ((FakeMessage) m).getAckCount() == 0)));
341+
checkpointB.finalizeCheckpoint();
342+
// all messaged acked once
343+
assertTrue(
344+
Arrays.stream(messages.toArray()).allMatch((m -> ((FakeMessage) m).getAckCount() == 1)));
345+
}
346+
289347
@Test
290348
public void testWrite() throws Exception {
291349
final int numberOfTestMessages = 200;
@@ -560,7 +618,6 @@ public void testReadObject() throws Exception {
560618
// the number of messages of the decoded checkpoint should be zero
561619
assertEquals(0, cp2.messages.size());
562620
assertEquals(cp1.clientId, cp2.clientId);
563-
assertEquals(cp1.oldestMessageTimestamp, cp2.oldestMessageTimestamp);
564621
}
565622

566623
/**

0 commit comments

Comments
 (0)