Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.configuration.PipelineDescription;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventFactory;
import org.opensearch.dataprepper.model.plugin.PluginConfigObservable;
Expand Down Expand Up @@ -64,6 +65,7 @@ public class RdsService {
private final RdsSourceConfig sourceConfig;
private final AcknowledgementSetManager acknowledgementSetManager;
private final PluginConfigObservable pluginConfigObservable;
private final PipelineDescription pipelineDescription;
private ExecutorService executor;
private LeaderScheduler leaderScheduler;
private ExportScheduler exportScheduler;
Expand All @@ -77,13 +79,15 @@ public RdsService(final EnhancedSourceCoordinator sourceCoordinator,
final ClientFactory clientFactory,
final PluginMetrics pluginMetrics,
final AcknowledgementSetManager acknowledgementSetManager,
final PluginConfigObservable pluginConfigObservable) {
final PluginConfigObservable pluginConfigObservable,
final PipelineDescription pipelineDescription) {
this.sourceCoordinator = sourceCoordinator;
this.eventFactory = eventFactory;
this.pluginMetrics = pluginMetrics;
this.sourceConfig = sourceConfig;
this.acknowledgementSetManager = acknowledgementSetManager;
this.pluginConfigObservable = pluginConfigObservable;
this.pipelineDescription = pipelineDescription;

rdsClient = clientFactory.buildRdsClient();
s3Client = clientFactory.buildS3Client();
Expand All @@ -109,7 +113,7 @@ public void start(Buffer<Record<Event>> buffer) {
DbTableMetadata dbTableMetadata = getDbTableMetadata(dbMetadata, schemaManager);

leaderScheduler = new LeaderScheduler(
sourceCoordinator, sourceConfig, s3PathPrefix, schemaManager, dbTableMetadata);
sourceCoordinator, sourceConfig, s3PathPrefix, schemaManager, dbTableMetadata, pipelineDescription.getPipelineName());
runnableList.add(leaderScheduler);

if (sourceConfig.isExportEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.configuration.PipelineDescription;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventFactory;
import org.opensearch.dataprepper.model.plugin.PluginConfigObservable;
Expand Down Expand Up @@ -39,6 +40,7 @@ public class RdsSource implements Source<Record<Event>>, UsesEnhancedSourceCoord
private final EventFactory eventFactory;
private final AcknowledgementSetManager acknowledgementSetManager;
private final PluginConfigObservable pluginConfigObservable;
private final PipelineDescription pipelineDescription;
private EnhancedSourceCoordinator sourceCoordinator;
private RdsService rdsService;

Expand All @@ -48,12 +50,14 @@ public RdsSource(final PluginMetrics pluginMetrics,
final EventFactory eventFactory,
final AwsCredentialsSupplier awsCredentialsSupplier,
final AcknowledgementSetManager acknowledgementSetManager,
final PluginConfigObservable pluginConfigObservable) {
final PluginConfigObservable pluginConfigObservable,
final PipelineDescription pipelineDescription) {
this.pluginMetrics = pluginMetrics;
this.sourceConfig = sourceConfig;
this.eventFactory = eventFactory;
this.acknowledgementSetManager = acknowledgementSetManager;
this.pluginConfigObservable = pluginConfigObservable;
this.pipelineDescription = pipelineDescription;

clientFactory = new ClientFactory(awsCredentialsSupplier, sourceConfig);
}
Expand All @@ -64,7 +68,8 @@ public void start(Buffer<Record<Event>> buffer) {
Objects.requireNonNull(sourceCoordinator);
sourceCoordinator.createPartition(new LeaderPartition());

rdsService = new RdsService(sourceCoordinator, sourceConfig, eventFactory, clientFactory, pluginMetrics, acknowledgementSetManager, pluginConfigObservable);
rdsService = new RdsService(sourceCoordinator, sourceConfig, eventFactory, clientFactory, pluginMetrics,
acknowledgementSetManager, pluginConfigObservable, pipelineDescription);

LOG.info("Start RDS service");
rdsService.start(buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class LeaderScheduler implements Runnable {
private final String s3Prefix;
private final SchemaManager schemaManager;
private final DbTableMetadata dbTableMetadata;
private final String pipelineName;

private LeaderPartition leaderPartition;
private List<String> tableNames;
Expand All @@ -56,12 +57,14 @@ public LeaderScheduler(final EnhancedSourceCoordinator sourceCoordinator,
final RdsSourceConfig sourceConfig,
final String s3Prefix,
final SchemaManager schemaManager,
final DbTableMetadata dbTableMetadata) {
final DbTableMetadata dbTableMetadata,
final String pipelineName) {
this.sourceCoordinator = sourceCoordinator;
this.sourceConfig = sourceConfig;
this.s3Prefix = s3Prefix;
this.schemaManager = schemaManager;
this.dbTableMetadata = dbTableMetadata;
this.pipelineName = pipelineName;
tableNames = new ArrayList<>(dbTableMetadata.getTableColumnDataTypeMap().keySet());
}

Expand Down Expand Up @@ -180,8 +183,9 @@ private void createStreamPartition(RdsSourceConfig sourceConfig) {
} else {
// Postgres
// Create replication slot, which will mark the starting point for stream
final String publicationName = generatePublicationName();
final String slotName = generateReplicationSlotName();
final String suffix = UUID.randomUUID().toString().substring(0, 8);
final String publicationName = generatePublicationName(suffix);
final String slotName = generateReplicationSlotName(suffix);
((PostgresSchemaManager)schemaManager).createLogicalReplicationSlot(tableNames, publicationName, slotName);
final PostgresStreamState postgresStreamState = new PostgresStreamState();
postgresStreamState.setPublicationName(publicationName);
Expand All @@ -199,11 +203,17 @@ private Optional<BinlogCoordinate> getCurrentBinlogPosition() {
return binlogCoordinate;
}

private String generatePublicationName() {
return "data_prepper_publication_" + UUID.randomUUID().toString().substring(0, 8);
private String generatePublicationName(final String suffix) {
return "data_prepper_" + getPipelineName() + "_pub_" + suffix;
}

private String generateReplicationSlotName() {
return "data_prepper_slot_" + UUID.randomUUID().toString().substring(0, 8);
private String generateReplicationSlotName(final String suffix) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is there any restriction on the replication slot name length?

@oeyh oeyh May 16, 2025

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Length limit is 63 chars. The generated slot name is around 40-50 chars, so it's within range.

return "data_prepper_" + getPipelineName() + "_slot_" + suffix;
}

private String getPipelineName() {
// Shorten the name (if needed) and replace any invalid characters with underscores
final String shortenedPipelineName = pipelineName.length() <= 16 ? pipelineName : pipelineName.substring(0, 16);
return shortenedPipelineName.replaceAll("[^a-zA-Z0-9_]", "_");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public void createLogicalReplicationSlot(final List<String> tableNames, final St
try {
PreparedStatement statement = conn.prepareStatement(createPublicationStatement);
statement.executeUpdate();
LOG.info("Publication {} created successfully. ", publicationName);
} catch (Exception e) {
LOG.error("Failed to create publication: {}", e.getMessage());
throw e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.configuration.PipelineDescription;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventFactory;
import org.opensearch.dataprepper.model.plugin.PluginConfigObservable;
Expand Down Expand Up @@ -95,6 +96,9 @@ class RdsServiceTest {
@Mock
private PluginConfigObservable pluginConfigObservable;

@Mock
private PipelineDescription pipelineDescription;

@BeforeEach
void setUp() {
when(clientFactory.buildRdsClient()).thenReturn(rdsClient);
Expand Down Expand Up @@ -206,6 +210,7 @@ private void prepareMocks() {
}

private RdsService createObjectUnderTest() {
return new RdsService(sourceCoordinator, sourceConfig, eventFactory, clientFactory, pluginMetrics, acknowledgementSetManager, pluginConfigObservable);
return new RdsService(sourceCoordinator, sourceConfig, eventFactory, clientFactory, pluginMetrics,
acknowledgementSetManager, pluginConfigObservable, pipelineDescription);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.configuration.PipelineDescription;
import org.opensearch.dataprepper.model.event.EventFactory;
import org.opensearch.dataprepper.model.plugin.PluginConfigObservable;
import org.opensearch.dataprepper.plugins.source.rds.configuration.AwsAuthenticationConfig;
Expand Down Expand Up @@ -45,6 +46,9 @@ class RdsSourceTest {
@Mock
private PluginConfigObservable pluginConfigObservable;

@Mock
private PipelineDescription pipelineDescription;

@BeforeEach
void setUp() {
when(sourceConfig.getAwsAuthenticationConfig()).thenReturn(awsAuthenticationConfig);
Expand All @@ -58,6 +62,7 @@ void test_when_buffer_is_null_then_start_throws_exception() {

private RdsSource createObjectUnderTest() {
return new RdsSource(
pluginMetrics, sourceConfig, eventFactory, awsCredentialsSupplier, acknowledgementSetManager, pluginConfigObservable);
pluginMetrics, sourceConfig, eventFactory, awsCredentialsSupplier, acknowledgementSetManager,
pluginConfigObservable, pipelineDescription);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.Answers;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.model.configuration.PipelineDescription;
import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator;
import org.opensearch.dataprepper.plugins.source.rds.RdsSourceConfig;
import org.opensearch.dataprepper.plugins.source.rds.configuration.AwsAuthenticationConfig;
Expand All @@ -22,9 +24,12 @@
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.ExportPartition;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.GlobalState;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.LeaderPartition;
import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition;
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.LeaderProgressState;
import org.opensearch.dataprepper.plugins.source.rds.coordination.state.PostgresStreamState;
import org.opensearch.dataprepper.plugins.source.rds.model.DbTableMetadata;
import org.opensearch.dataprepper.plugins.source.rds.schema.MySqlSchemaManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.PostgresSchemaManager;
import org.opensearch.dataprepper.plugins.source.rds.schema.SchemaManager;

import java.time.Duration;
import java.util.Optional;
Expand All @@ -33,6 +38,10 @@
import java.util.concurrent.Executors;

import static org.awaitility.Awaitility.await;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.startsWith;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.lenient;
Expand All @@ -53,7 +62,7 @@ class LeaderSchedulerTest {
private RdsSourceConfig sourceConfig;

@Mock
private MySqlSchemaManager schemaManager;
private SchemaManager schemaManager;

@Mock
private DbTableMetadata dbTableMetadata;
Expand All @@ -64,20 +73,26 @@ class LeaderSchedulerTest {
@Mock
private LeaderProgressState leaderProgressState;

@Mock
private PipelineDescription pipelineDescription;

private String s3Prefix;
private String pipelineName;
private LeaderScheduler leaderScheduler;

@BeforeEach
void setUp() {
s3Prefix = UUID.randomUUID().toString();
leaderScheduler = createObjectUnderTest();
pipelineName = UUID.randomUUID().toString();

AwsAuthenticationConfig awsAuthenticationConfig = mock(AwsAuthenticationConfig.class);
lenient().when(awsAuthenticationConfig.getAwsStsRoleArn()).thenReturn(UUID.randomUUID().toString());
lenient().when(sourceConfig.getAwsAuthenticationConfig()).thenReturn(awsAuthenticationConfig);
ExportConfig exportConfig = mock(ExportConfig.class);
lenient().when(exportConfig.getKmsKeyId()).thenReturn(UUID.randomUUID().toString());
lenient().when(sourceConfig.getExport()).thenReturn(exportConfig);
lenient().when(pipelineDescription.getPipelineName()).thenReturn(pipelineName);
}

@Test
Expand Down Expand Up @@ -146,7 +161,43 @@ void test_shutDown() {
executorService.shutdownNow();
}

@Test
void leader_node_performs_init_creates_slot_with_expected_name() throws InterruptedException {
final PostgresSchemaManager postgresSchemaManager = mock(PostgresSchemaManager.class);
final String pipelineName = "simple-pipeline";

when(sourceCoordinator.acquireAvailablePartition(LeaderPartition.PARTITION_TYPE)).thenReturn(Optional.of(leaderPartition));
when(leaderPartition.getProgressState()).thenReturn(Optional.of(leaderProgressState));
when(leaderProgressState.isInitialized()).thenReturn(false);
when(sourceConfig.isStreamEnabled()).thenReturn(true);
when(sourceConfig.getEngine()).thenReturn(EngineType.POSTGRES);

final LeaderScheduler leaderScheduler = new LeaderScheduler(sourceCoordinator, sourceConfig, s3Prefix,
postgresSchemaManager, dbTableMetadata, pipelineName);
final ExecutorService executorService = Executors.newSingleThreadExecutor();
executorService.submit(leaderScheduler);
await().atMost(Duration.ofSeconds(1))
.untilAsserted(() -> verify(sourceCoordinator).acquireAvailablePartition(LeaderPartition.PARTITION_TYPE));
Thread.sleep(100);
executorService.shutdownNow();

ArgumentCaptor<StreamPartition> streamPartitionArgumentCaptor = ArgumentCaptor.forClass(StreamPartition.class);
verify(sourceCoordinator).createPartition(any(GlobalState.class));
verify(sourceCoordinator).createPartition(streamPartitionArgumentCaptor.capture());
verify(sourceCoordinator).saveProgressStateForPartition(eq(leaderPartition), any(Duration.class));

final StreamPartition streamPartition = streamPartitionArgumentCaptor.getValue();
assertThat(streamPartition.getProgressState().get().getPostgresStreamState(), notNullValue());

PostgresStreamState postgresStreamState = streamPartition.getProgressState().get().getPostgresStreamState();
final String publicationName = postgresStreamState.getPublicationName();
final String slotName = postgresStreamState.getReplicationSlotName();
assertThat(publicationName, startsWith("data_prepper_simple_pipeline"));
assertThat(slotName, startsWith("data_prepper_simple_pipeline"));
assertThat(publicationName.substring(publicationName.length() - 8), is(slotName.substring(slotName.length() - 8)));
}

private LeaderScheduler createObjectUnderTest() {
return new LeaderScheduler(sourceCoordinator, sourceConfig, s3Prefix, schemaManager, dbTableMetadata);
return new LeaderScheduler(sourceCoordinator, sourceConfig, s3Prefix, schemaManager, dbTableMetadata, pipelineName);
}
}
Loading