88import io .micrometer .core .instrument .DistributionSummary ;
99import org .junit .jupiter .api .BeforeEach ;
1010import org .junit .jupiter .api .Test ;
11+ import org .mockito .MockedConstruction ;
1112import org .mockito .MockedStatic ;
1213import org .opensearch .dataprepper .aws .api .AwsCredentialsSupplier ;
1314import org .opensearch .dataprepper .metrics .PluginMetrics ;
15+ import org .opensearch .dataprepper .model .configuration .PluginModel ;
1416import org .opensearch .dataprepper .model .plugin .PluginFactory ;
1517import org .opensearch .dataprepper .model .configuration .PluginSetting ;
1618import org .opensearch .dataprepper .model .event .Event ;
1719import org .opensearch .dataprepper .model .event .JacksonEvent ;
1820import org .opensearch .dataprepper .model .record .Record ;
21+ import org .opensearch .dataprepper .plugins .dlq .DlqPushHandler ;
1922import org .opensearch .dataprepper .plugins .sink .cloudwatch_logs .client .CloudWatchLogsClientFactory ;
23+ import org .opensearch .dataprepper .plugins .sink .cloudwatch_logs .client .CloudWatchLogsDispatcher ;
2024import org .opensearch .dataprepper .plugins .sink .cloudwatch_logs .client .CloudWatchLogsMetrics ;
2125import org .opensearch .dataprepper .plugins .sink .cloudwatch_logs .config .AwsConfig ;
2226import org .opensearch .dataprepper .plugins .sink .cloudwatch_logs .config .CloudWatchLogsSinkConfig ;
2327import org .opensearch .dataprepper .plugins .sink .cloudwatch_logs .config .ThresholdConfig ;
2428import software .amazon .awssdk .services .cloudwatchlogs .CloudWatchLogsClient ;
29+ import software .amazon .awssdk .regions .Region ;
30+
2531
2632import java .util .ArrayList ;
2733import java .util .Collection ;
2834import java .util .HashMap ;
2935import java .util .Map ;
3036
37+ import static org .hamcrest .MatcherAssert .assertThat ;
38+ import static org .hamcrest .Matchers .equalTo ;
3139import static org .junit .jupiter .api .Assertions .assertTrue ;
3240import static org .junit .jupiter .api .Assertions .assertThrows ;
3341import static org .mockito .ArgumentMatchers .any ;
3644import static org .mockito .Mockito .atLeast ;
3745import static org .mockito .Mockito .mock ;
3846import static org .mockito .Mockito .mockStatic ;
47+ import static org .mockito .Mockito .mockConstruction ;
3948import static org .mockito .Mockito .spy ;
4049import static org .mockito .Mockito .times ;
4150import static org .mockito .Mockito .verify ;
4251import static org .mockito .Mockito .when ;
4352
4453class CloudWatchLogsSinkTest {
54+ private static int TEST_MAX_RETRIES = 3 ;
4555 private PluginSetting mockPluginSetting ;
4656 private PluginMetrics mockPluginMetrics ;
4757 private PluginFactory mockPluginFactory ;
@@ -57,6 +67,7 @@ class CloudWatchLogsSinkTest {
5767 private static final String TEST_PLUGIN_NAME = "testPluginName" ;
5868 private static final String TEST_PIPELINE_NAME = "testPipelineName" ;
5969 private static final String TEST_BUFFER_TYPE = "in_memory" ;
70+ private int numRetries ;
6071 @ BeforeEach
6172 void setUp () {
6273 mockPluginSetting = mock (PluginSetting .class );
@@ -73,12 +84,13 @@ void setUp() {
7384 DistributionSummary summary = mock (DistributionSummary .class );
7485 when (mockPluginMetrics .summary (anyString ())).thenReturn (summary );
7586
87+ when (mockCloudWatchLogsSinkConfig .getDlq ()).thenReturn (null );
7688 when (mockCloudWatchLogsSinkConfig .getAwsConfig ()).thenReturn (mockAwsConfig );
7789 when (mockCloudWatchLogsSinkConfig .getThresholdConfig ()).thenReturn (thresholdConfig );
7890 when (mockCloudWatchLogsSinkConfig .getHeaderOverrides ()).thenReturn (new HashMap <>());
7991 when (mockCloudWatchLogsSinkConfig .getLogGroup ()).thenReturn (TEST_LOG_GROUP );
8092 when (mockCloudWatchLogsSinkConfig .getLogStream ()).thenReturn (TEST_LOG_STREAM );
81- when (mockCloudWatchLogsSinkConfig .getMaxRetries ()).thenReturn (3 );
93+ when (mockCloudWatchLogsSinkConfig .getMaxRetries ()).thenReturn (TEST_MAX_RETRIES );
8294 when (mockCloudWatchLogsSinkConfig .getWorkers ()).thenReturn (10 );
8395
8496 when (mockPluginSetting .getName ()).thenReturn (TEST_PLUGIN_NAME );
@@ -167,17 +179,17 @@ void WHEN_given_sample_empty_records_THEN_records_are_not_processed() {
167179 void WHEN_header_overrides_is_empty_THEN_empty_map_is_passed_to_client_factory () {
168180 Map <String , String > emptyHeaders = new HashMap <>();
169181 when (mockCloudWatchLogsSinkConfig .getHeaderOverrides ()).thenReturn (emptyHeaders );
170-
182+
171183 try (MockedStatic <CloudWatchLogsClientFactory > mockedStatic = mockStatic (CloudWatchLogsClientFactory .class )) {
172184 mockedStatic .when (() -> CloudWatchLogsClientFactory .createCwlClient (any (AwsConfig .class ),
173185 any (AwsCredentialsSupplier .class ), any (), any ()))
174186 .thenReturn (mockClient );
175187
176188 CloudWatchLogsSink testCloudWatchSink = getTestCloudWatchSink ();
177-
189+
178190 mockedStatic .verify (() -> CloudWatchLogsClientFactory .createCwlClient (
179- eq (mockAwsConfig ),
180- eq (mockCredentialSupplier ),
191+ eq (mockAwsConfig ),
192+ eq (mockCredentialSupplier ),
181193 eq (emptyHeaders ),
182194 any ()));
183195 }
@@ -186,17 +198,17 @@ void WHEN_header_overrides_is_empty_THEN_empty_map_is_passed_to_client_factory()
186198 @ Test
187199 void WHEN_header_overrides_is_provided_THEN_headers_are_passed_to_client_factory () {
188200 when (mockCloudWatchLogsSinkConfig .getHeaderOverrides ()).thenReturn (mockHeaderOverrides );
189-
201+
190202 try (MockedStatic <CloudWatchLogsClientFactory > mockedStatic = mockStatic (CloudWatchLogsClientFactory .class )) {
191203 mockedStatic .when (() -> CloudWatchLogsClientFactory .createCwlClient (any (AwsConfig .class ),
192204 any (AwsCredentialsSupplier .class ), any (), any ()))
193205 .thenReturn (mockClient );
194206
195207 CloudWatchLogsSink testCloudWatchSink = getTestCloudWatchSink ();
196-
208+
197209 mockedStatic .verify (() -> CloudWatchLogsClientFactory .createCwlClient (
198- eq (mockAwsConfig ),
199- eq (mockCredentialSupplier ),
210+ eq (mockAwsConfig ),
211+ eq (mockCredentialSupplier ),
200212 eq (mockHeaderOverrides ),
201213 any ()));
202214 }
@@ -205,16 +217,67 @@ void WHEN_header_overrides_is_provided_THEN_headers_are_passed_to_client_factory
205217 @ Test
206218 void WHEN_sink_initialization_with_header_overrides_THEN_sink_is_ready () {
207219 when (mockCloudWatchLogsSinkConfig .getHeaderOverrides ()).thenReturn (mockHeaderOverrides );
208-
220+
209221 try (MockedStatic <CloudWatchLogsClientFactory > mockedStatic = mockStatic (CloudWatchLogsClientFactory .class )) {
210222 mockedStatic .when (() -> CloudWatchLogsClientFactory .createCwlClient (any (AwsConfig .class ),
211223 any (AwsCredentialsSupplier .class ), any (), any ()))
212224 .thenReturn (mockClient );
213225
214226 CloudWatchLogsSink testCloudWatchSink = getTestCloudWatchSink ();
215227 testCloudWatchSink .doInitialize ();
216-
228+
217229 assertTrue (testCloudWatchSink .isReady ());
218230 }
219231 }
232+
233+ @ Test
234+ void WHEN_sink_has_no_dlq_config_THEN_retries_set_to_maxint () {
235+ when (mockCloudWatchLogsSinkConfig .getHeaderOverrides ()).thenReturn (mockHeaderOverrides );
236+
237+ try (MockedStatic <CloudWatchLogsClientFactory > mockedStatic = mockStatic (CloudWatchLogsClientFactory .class )) {
238+ final MockedConstruction <CloudWatchLogsDispatcher > dispatcherMock =
239+ mockConstruction (CloudWatchLogsDispatcher .class , (mock , context ) -> {
240+ numRetries = (int )context .arguments ().get (7 );
241+ });
242+
243+ mockedStatic .when (() -> CloudWatchLogsClientFactory .createCwlClient (any (AwsConfig .class ),
244+ any (AwsCredentialsSupplier .class ), any (), any ()))
245+ .thenReturn (mockClient );
246+
247+ CloudWatchLogsSink testCloudWatchSink = getTestCloudWatchSink ();
248+ testCloudWatchSink .doInitialize ();
249+ dispatcherMock .close ();
250+
251+ }
252+ assertThat (numRetries , equalTo (Integer .MAX_VALUE ));
253+ }
254+
255+ @ Test
256+ void WHEN_sink_has_dlq_config_THEN_retries_set_to_user_configured_value () {
257+ PluginModel dlqConfig = mock (PluginModel .class );
258+ when (mockCloudWatchLogsSinkConfig .getDlq ()).thenReturn (dlqConfig );
259+ when (mockCloudWatchLogsSinkConfig .getHeaderOverrides ()).thenReturn (mockHeaderOverrides );
260+ when (mockAwsConfig .getAwsRegion ()).thenReturn (Region .of ("us-west-2" ));
261+ when (mockAwsConfig .getAwsStsRoleArn ()).thenReturn ("role" );
262+
263+ try (MockedStatic <CloudWatchLogsClientFactory > mockedStatic = mockStatic (CloudWatchLogsClientFactory .class )) {
264+ final MockedConstruction <CloudWatchLogsDispatcher > dispatcherMock =
265+ mockConstruction (CloudWatchLogsDispatcher .class , (mock , context ) -> {
266+ numRetries = (int )context .arguments ().get (7 );
267+ });
268+ final MockedConstruction <DlqPushHandler > dlqMock =
269+ mockConstruction (DlqPushHandler .class , (mock , context ) -> {
270+ });
271+
272+ mockedStatic .when (() -> CloudWatchLogsClientFactory .createCwlClient (any (AwsConfig .class ),
273+ any (AwsCredentialsSupplier .class ), any (), any ()))
274+ .thenReturn (mockClient );
275+
276+ CloudWatchLogsSink testCloudWatchSink = getTestCloudWatchSink ();
277+ testCloudWatchSink .doInitialize ();
278+ dispatcherMock .close ();
279+ }
280+ assertThat (numRetries , equalTo (TEST_MAX_RETRIES ));
281+ }
282+
220283}
0 commit comments