Skip to content

Commit 7d6ac6d

Browse files
authored
Set server id for mysql binlog client (#5725)
Signed-off-by: Hai Yan <oeyh@amazon.com>
1 parent e14722f commit 7d6ac6d

4 files changed

Lines changed: 282 additions & 53 deletions

File tree

data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorker.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourcePartition;
1212
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
1313
import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate;
14+
import org.opensearch.dataprepper.plugins.source.rds.utils.ServerIdGenerator;
1415
import org.postgresql.replication.LogSequenceNumber;
1516
import org.slf4j.Logger;
1617
import org.slf4j.LoggerFactory;
@@ -55,6 +56,7 @@ public void processStream(final StreamPartition streamPartition) {
5556

5657
if (replicationLogClient instanceof BinlogClientWrapper) {
5758
setStartBinlogPosition(streamPartition);
59+
setServerId();
5860
} else {
5961
setStartLsn(streamPartition);
6062
}
@@ -110,6 +112,12 @@ private void setStartBinlogPosition(final StreamPartition streamPartition) {
110112
}
111113
}
112114

115+
private void setServerId() {
116+
final int serverId = ServerIdGenerator.generateServerId();
117+
LOG.info("Binary log client server id is {}", serverId);
118+
((BinlogClientWrapper) replicationLogClient).getBinlogClient().setServerId(serverId);
119+
}
120+
113121
private void setStartLsn(final StreamPartition streamPartition) {
114122
final String startLsn = streamPartition.getProgressState().get().getPostgresStreamState().getCurrentLsn();
115123

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* The OpenSearch Contributors require contributions made to
6+
* this file be licensed under the Apache-2.0 license or a
7+
* compatible open source license.
8+
*/
9+
10+
package org.opensearch.dataprepper.plugins.source.rds.utils;
11+
12+
import java.net.InetAddress;
13+
import java.util.Random;
14+
15+
public class ServerIdGenerator {
16+
static final int MIN_SERVER_ID = 100_000;
17+
static final int MAX_SERVER_ID = 999_999;
18+
19+
public static int generateServerId() {
20+
try {
21+
// Get local IP address
22+
String hostAddress = InetAddress.getLocalHost().getHostAddress();
23+
24+
// Get process-specific info
25+
long processId = ProcessHandle.current().pid();
26+
long currentTime = System.currentTimeMillis() % MIN_SERVER_ID;
27+
28+
int hash = Math.abs((hostAddress + processId + currentTime).hashCode());
29+
30+
return MIN_SERVER_ID + (hash % (MAX_SERVER_ID - MIN_SERVER_ID + 1));
31+
32+
} catch (Exception e) {
33+
// Fallback to random number
34+
return MIN_SERVER_ID + new Random().nextInt(MAX_SERVER_ID - MIN_SERVER_ID + 1);
35+
}
36+
}
37+
}

data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTest.java

Lines changed: 115 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import com.github.shyiko.mysql.binlog.BinaryLogClient;
99
import org.junit.jupiter.api.BeforeEach;
10+
import org.junit.jupiter.api.Nested;
1011
import org.junit.jupiter.api.Test;
1112
import org.junit.jupiter.api.extension.ExtendWith;
1213
import org.mockito.Mock;
@@ -15,13 +16,16 @@
1516
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
1617
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
1718
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.MySqlStreamState;
19+
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.PostgresStreamState;
1820
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState;
1921
import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate;
22+
import org.postgresql.replication.LogSequenceNumber;
2023

2124
import java.io.IOException;
2225
import java.util.Optional;
2326
import java.util.UUID;
2427

28+
import static org.mockito.ArgumentMatchers.any;
2529
import static org.mockito.Mockito.mock;
2630
import static org.mockito.Mockito.never;
2731
import static org.mockito.Mockito.verify;
@@ -33,12 +37,6 @@ class StreamWorkerTest {
3337
@Mock
3438
private EnhancedSourceCoordinator sourceCoordinator;
3539

36-
@Mock
37-
private BinlogClientWrapper binlogClientWrapper;
38-
39-
@Mock
40-
private BinaryLogClient binaryLogClient;
41-
4240
@Mock
4341
private PluginMetrics pluginMetrics;
4442

@@ -47,55 +45,119 @@ class StreamWorkerTest {
4745

4846
private StreamWorker streamWorker;
4947

50-
@BeforeEach
51-
void setUp() {
52-
streamWorker = createObjectUnderTest();
53-
}
54-
55-
@Test
56-
void test_processStream_with_given_binlog_coordinates() throws IOException {
57-
final StreamProgressState streamProgressState = mock(StreamProgressState.class);
58-
final MySqlStreamState mySqlStreamState = mock(MySqlStreamState.class);
59-
final String binlogFilename = UUID.randomUUID().toString();
60-
final long binlogPosition = 100L;
61-
when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState));
62-
when(streamProgressState.getMySqlStreamState()).thenReturn(mySqlStreamState);
63-
when(mySqlStreamState.getCurrentPosition()).thenReturn(new BinlogCoordinate(binlogFilename, binlogPosition));
64-
when(streamProgressState.shouldWaitForExport()).thenReturn(false);
65-
when(binlogClientWrapper.getBinlogClient()).thenReturn(binaryLogClient);
66-
67-
streamWorker.processStream(streamPartition);
68-
69-
verify(binaryLogClient).setBinlogFilename(binlogFilename);
70-
verify(binaryLogClient).setBinlogPosition(binlogPosition);
71-
verify(binlogClientWrapper).connect();
72-
}
48+
@Nested
49+
class TestForMySql {
50+
@Mock
51+
private BinlogClientWrapper binlogClientWrapper;
52+
53+
@Mock
54+
private BinaryLogClient binaryLogClient;
55+
56+
@BeforeEach
57+
void setUp() {
58+
streamWorker = createObjectUnderTest();
59+
}
60+
61+
@Test
62+
void test_processStream_with_given_binlog_coordinates() throws IOException {
63+
final StreamProgressState streamProgressState = mock(StreamProgressState.class);
64+
final MySqlStreamState mySqlStreamState = mock(MySqlStreamState.class);
65+
final String binlogFilename = UUID.randomUUID().toString();
66+
final long binlogPosition = 100L;
67+
when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState));
68+
when(streamProgressState.getMySqlStreamState()).thenReturn(mySqlStreamState);
69+
when(mySqlStreamState.getCurrentPosition()).thenReturn(new BinlogCoordinate(binlogFilename, binlogPosition));
70+
when(streamProgressState.shouldWaitForExport()).thenReturn(false);
71+
when(binlogClientWrapper.getBinlogClient()).thenReturn(binaryLogClient);
72+
73+
streamWorker.processStream(streamPartition);
74+
75+
verify(binaryLogClient).setBinlogFilename(binlogFilename);
76+
verify(binaryLogClient).setBinlogPosition(binlogPosition);
77+
verify(binlogClientWrapper).connect();
78+
}
79+
80+
@Test
81+
void test_processStream_without_current_binlog_coordinates() throws IOException {
82+
final StreamProgressState streamProgressState = mock(StreamProgressState.class);
83+
final MySqlStreamState mySqlStreamState = mock(MySqlStreamState.class);
84+
when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState));
85+
final String binlogFilename = "binlog-001";
86+
final long binlogPosition = 100L;
87+
when(streamProgressState.getMySqlStreamState()).thenReturn(mySqlStreamState);
88+
when(mySqlStreamState.getCurrentPosition()).thenReturn(null);
89+
when(streamProgressState.shouldWaitForExport()).thenReturn(false);
90+
when(binlogClientWrapper.getBinlogClient()).thenReturn(binaryLogClient);
91+
92+
streamWorker.processStream(streamPartition);
93+
94+
verify(binaryLogClient, never()).setBinlogFilename(binlogFilename);
95+
verify(binaryLogClient, never()).setBinlogPosition(binlogPosition);
96+
verify(binlogClientWrapper).connect();
97+
}
98+
99+
@Test
100+
void test_shutdown() throws IOException {
101+
streamWorker.shutdown();
102+
verify(binlogClientWrapper).disconnect();
103+
}
104+
105+
private StreamWorker createObjectUnderTest() {
106+
return new StreamWorker(sourceCoordinator, binlogClientWrapper, pluginMetrics);
107+
}
73108

74-
@Test
75-
void test_processStream_without_current_binlog_coordinates() throws IOException {
76-
final StreamProgressState streamProgressState = mock(StreamProgressState.class);
77-
final MySqlStreamState mySqlStreamState = mock(MySqlStreamState.class);
78-
when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState));
79-
final String binlogFilename = "binlog-001";
80-
final long binlogPosition = 100L;
81-
when(streamProgressState.getMySqlStreamState()).thenReturn(mySqlStreamState);
82-
when(mySqlStreamState.getCurrentPosition()).thenReturn(null);
83-
when(streamProgressState.shouldWaitForExport()).thenReturn(false);
84-
85-
streamWorker.processStream(streamPartition);
86-
87-
verify(binaryLogClient, never()).setBinlogFilename(binlogFilename);
88-
verify(binaryLogClient, never()).setBinlogPosition(binlogPosition);
89-
verify(binlogClientWrapper).connect();
90109
}
91110

92-
@Test
93-
void test_shutdown() throws IOException {
94-
streamWorker.shutdown();
95-
verify(binlogClientWrapper).disconnect();
96-
}
111+
@Nested
112+
class TestForPostgres {
113+
@Mock
114+
private LogicalReplicationClient logicalReplicationClient;
115+
116+
@BeforeEach
117+
void setUp() {
118+
streamWorker = createObjectUnderTest();
119+
}
120+
121+
@Test
122+
void test_processStream_with_given_currentLsn() {
123+
final StreamProgressState streamProgressState = mock(StreamProgressState.class);
124+
final PostgresStreamState postgresStreamState = mock(PostgresStreamState.class);
125+
final String currentLsn = UUID.randomUUID().toString();
126+
when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState));
127+
when(streamProgressState.getPostgresStreamState()).thenReturn(postgresStreamState);
128+
when(postgresStreamState.getCurrentLsn()).thenReturn(currentLsn);
129+
when(streamProgressState.shouldWaitForExport()).thenReturn(false);
130+
131+
streamWorker.processStream(streamPartition);
132+
133+
verify(logicalReplicationClient).setStartLsn(LogSequenceNumber.valueOf(currentLsn));
134+
verify(logicalReplicationClient).connect();
135+
}
136+
137+
@Test
138+
void test_processStream_without_currentLsn() {
139+
final StreamProgressState streamProgressState = mock(StreamProgressState.class);
140+
final PostgresStreamState postgresStreamState = mock(PostgresStreamState.class);
141+
when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState));
142+
when(streamProgressState.getPostgresStreamState()).thenReturn(postgresStreamState);
143+
when(postgresStreamState.getCurrentLsn()).thenReturn(null);
144+
when(streamProgressState.shouldWaitForExport()).thenReturn(false);
145+
146+
streamWorker.processStream(streamPartition);
147+
148+
verify(logicalReplicationClient, never()).setStartLsn(any());
149+
verify(logicalReplicationClient).connect();
150+
}
151+
152+
@Test
153+
void test_shutdown() throws IOException {
154+
streamWorker.shutdown();
155+
verify(logicalReplicationClient).disconnect();
156+
}
157+
158+
private StreamWorker createObjectUnderTest() {
159+
return new StreamWorker(sourceCoordinator, logicalReplicationClient, pluginMetrics);
160+
}
97161

98-
private StreamWorker createObjectUnderTest() {
99-
return new StreamWorker(sourceCoordinator, binlogClientWrapper, pluginMetrics);
100162
}
101-
}
163+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* The OpenSearch Contributors require contributions made to
6+
* this file be licensed under the Apache-2.0 license or a
7+
* compatible open source license.
8+
*/
9+
10+
package org.opensearch.dataprepper.plugins.source.rds.utils;
11+
12+
import org.junit.jupiter.api.Test;
13+
import org.mockito.MockedStatic;
14+
15+
import java.net.InetAddress;
16+
import java.net.UnknownHostException;
17+
18+
import static org.apache.parquet.filter.ColumnPredicates.equalTo;
19+
import static org.hamcrest.MatcherAssert.assertThat;
20+
import static org.hamcrest.Matchers.allOf;
21+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
22+
import static org.hamcrest.Matchers.is;
23+
import static org.hamcrest.Matchers.lessThanOrEqualTo;
24+
import static org.hamcrest.Matchers.not;
25+
import static org.mockito.Mockito.mock;
26+
import static org.mockito.Mockito.mockStatic;
27+
import static org.mockito.Mockito.when;
28+
import static org.opensearch.dataprepper.plugins.source.rds.utils.ServerIdGenerator.MAX_SERVER_ID;
29+
import static org.opensearch.dataprepper.plugins.source.rds.utils.ServerIdGenerator.MIN_SERVER_ID;
30+
31+
class ServerIdGeneratorTest {
32+
@Test
33+
public void generateServerId_shouldReturnValueWithinValidRange() {
34+
// When
35+
int serverId = ServerIdGenerator.generateServerId();
36+
37+
// Then
38+
assertThat("Server ID should be within valid range",
39+
serverId, allOf(
40+
greaterThanOrEqualTo(MIN_SERVER_ID),
41+
lessThanOrEqualTo(MAX_SERVER_ID)
42+
));
43+
}
44+
45+
@Test
46+
public void generateServerId_shouldFallbackToRandomWhenInetAddressFails() {
47+
// Given - Mock InetAddress to throw exception
48+
try (MockedStatic<InetAddress> mockedInetAddress = mockStatic(InetAddress.class)) {
49+
mockedInetAddress.when(InetAddress::getLocalHost)
50+
.thenThrow(new UnknownHostException("Mocked exception"));
51+
52+
// When
53+
int serverId = ServerIdGenerator.generateServerId();
54+
55+
// Then
56+
assertThat("Fallback should generate valid server ID",
57+
serverId, allOf(
58+
greaterThanOrEqualTo(MIN_SERVER_ID),
59+
lessThanOrEqualTo(MAX_SERVER_ID)
60+
));
61+
62+
// Verify the exception path was taken
63+
mockedInetAddress.verify(InetAddress::getLocalHost);
64+
}
65+
}
66+
67+
@Test
68+
public void generateServerId_shouldHandleProcessHandleFailure() {
69+
// Given - Mock ProcessHandle to throw exception
70+
try (MockedStatic<ProcessHandle> mockedProcessHandle = mockStatic(ProcessHandle.class)) {
71+
mockedProcessHandle.when(ProcessHandle::current)
72+
.thenThrow(new RuntimeException("Mocked process failure"));
73+
74+
// When
75+
int serverId = ServerIdGenerator.generateServerId();
76+
77+
// Then
78+
assertThat("Should fallback gracefully on ProcessHandle failure",
79+
serverId, allOf(
80+
greaterThanOrEqualTo(MIN_SERVER_ID),
81+
lessThanOrEqualTo(MAX_SERVER_ID)
82+
));
83+
}
84+
}
85+
86+
@Test
87+
public void generateServerId_shouldProduceDifferentValuesForDifferentHosts() {
88+
// Given - Mock different IP addresses
89+
try (MockedStatic<InetAddress> mockedInetAddress = mockStatic(InetAddress.class)) {
90+
91+
// Mock first host
92+
InetAddress mockAddress1 = mock(InetAddress.class);
93+
when(mockAddress1.getHostAddress()).thenReturn("192.168.1.100");
94+
mockedInetAddress.when(InetAddress::getLocalHost).thenReturn(mockAddress1);
95+
96+
int serverId1 = ServerIdGenerator.generateServerId();
97+
98+
// Mock second host
99+
InetAddress mockAddress2 = mock(InetAddress.class);
100+
when(mockAddress2.getHostAddress()).thenReturn("192.168.1.200");
101+
mockedInetAddress.when(InetAddress::getLocalHost).thenReturn(mockAddress2);
102+
103+
int serverId2 = ServerIdGenerator.generateServerId();
104+
105+
// Then
106+
assertThat("Both IDs should be in valid range",
107+
serverId1, allOf(
108+
greaterThanOrEqualTo(MIN_SERVER_ID),
109+
lessThanOrEqualTo(MAX_SERVER_ID)
110+
));
111+
112+
assertThat("Both IDs should be in valid range",
113+
serverId2, allOf(
114+
greaterThanOrEqualTo(MIN_SERVER_ID),
115+
lessThanOrEqualTo(MAX_SERVER_ID)
116+
));
117+
118+
assertThat("Different hosts should likely generate different server IDs",
119+
serverId1, is(not(equalTo(serverId2))));
120+
}
121+
}
122+
}

0 commit comments

Comments
 (0)