3333
3434
3535class Context :
36+ def __init__ (self , invoked_function_arn = "invoked_function_arn" ):
37+ self .invoked_function_arn = invoked_function_arn
38+
3639 function_version = "$LATEST"
3740 invoked_function_arn = "invoked_function_arn"
3841 function_name = "function_name"
@@ -45,6 +48,7 @@ def setUp(self):
4548 r"forwarder_version:\d+\.\d+\.\d+" ,
4649 "forwarder_version:<redacted>" ,
4750 )
51+ self .context = Context ()
4852
4953 @patch ("caching.cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.__init__" )
5054 def test_handle_with_overridden_source (self , mock_cache_init ):
@@ -79,13 +83,12 @@ def test_handle_with_overridden_source(self, mock_cache_init):
7983 }
8084
8185 # Create required args
82- context = Context ()
8386 mock_cache_init .return_value = None
8487 cache_layer = CacheLayer ("" )
8588 cache_layer ._cloudwatch_log_group_cache .get = MagicMock (return_value = [])
8689
8790 # Process the event
88- awslogs_handler = AwsLogsHandler (context , cache_layer )
91+ awslogs_handler = AwsLogsHandler (self . context , cache_layer )
8992
9093 # Verify
9194 verify_as_json (
@@ -127,14 +130,13 @@ def test_awslogs_handler_rds_postgresql(self, mock_cache_init):
127130 )
128131 }
129132 }
130- context = Context ()
131133 mock_cache_init .return_value = None
132134 cache_layer = CacheLayer ("" )
133135 cache_layer ._cloudwatch_log_group_cache .get = MagicMock (
134136 return_value = ["test_tag_key:test_tag_value" ]
135137 )
136138
137- awslogs_handler = AwsLogsHandler (context , cache_layer )
139+ awslogs_handler = AwsLogsHandler (self . context , cache_layer )
138140 verify_as_json (
139141 list (awslogs_handler .handle (event )),
140142 options = Options ().with_scrubber (self .scrubber ),
@@ -189,7 +191,6 @@ def test_awslogs_handler_step_functions_tags_added_properly(
189191 )
190192 }
191193 }
192- context = Context ()
193194 mock_forward_metrics .side_effect = MagicMock ()
194195 mock_cache_init .return_value = None
195196 cache_layer = CacheLayer ("" )
@@ -198,7 +199,7 @@ def test_awslogs_handler_step_functions_tags_added_properly(
198199 )
199200 cache_layer ._cloudwatch_log_group_cache .get = MagicMock ()
200201
201- awslogs_handler = AwsLogsHandler (context , cache_layer )
202+ awslogs_handler = AwsLogsHandler (self . context , cache_layer )
202203 verify_as_json (
203204 list (awslogs_handler .handle (event )),
204205 options = Options ().with_scrubber (self .scrubber ),
@@ -252,7 +253,6 @@ def test_awslogs_handler_step_functions_customized_log_group(
252253 )
253254 }
254255 }
255- context = Context ()
256256 mock_forward_metrics .side_effect = MagicMock ()
257257 mock_cache_init .return_value = None
258258 cache_layer = CacheLayer ("" )
@@ -261,7 +261,7 @@ def test_awslogs_handler_step_functions_customized_log_group(
261261 )
262262 cache_layer ._cloudwatch_log_group_cache .get = MagicMock ()
263263
264- awslogs_handler = AwsLogsHandler (context , cache_layer )
264+ awslogs_handler = AwsLogsHandler (self . context , cache_layer )
265265 # for some reasons, the below two are needed to update the context of the handler
266266 verify_as_json (
267267 list (awslogs_handler .handle (eventFromCustomizedLogGroup )),
@@ -304,14 +304,13 @@ def test_awslogs_handler_lambda_log(self):
304304 )
305305 }
306306 }
307- context = Context ()
308307 cache_layer = CacheLayer ("" )
309308 cache_layer ._cloudwatch_log_group_cache .get = MagicMock ()
310309 cache_layer ._lambda_cache .get = MagicMock (
311310 return_value = ["service:customtags_service" ]
312311 )
313312
314- awslogs_handler = AwsLogsHandler (context , cache_layer )
313+ awslogs_handler = AwsLogsHandler (self . context , cache_layer )
315314 verify_as_json (
316315 list (awslogs_handler .handle (event )),
317316 options = Options ().with_scrubber (self .scrubber ),
@@ -327,12 +326,12 @@ def test_process_lambda_logs(self):
327326 }
328327 metadata = {"ddsource" : "postgresql" , "ddtags" : "" }
329328 aws_attributes = AwsAttributes (
329+ self .context ,
330330 stepfunction_loggroup .get ("logGroup" ),
331331 stepfunction_loggroup .get ("logStream" ),
332332 stepfunction_loggroup .get ("owner" ),
333333 )
334- context = Context ()
335- aws_handler = AwsLogsHandler (context , CacheLayer ("" ))
334+ aws_handler = AwsLogsHandler (self .context , CacheLayer ("" ))
336335
337336 aws_handler .process_lambda_logs (metadata , aws_attributes )
338337 self .assertEqual (metadata , {"ddsource" : "postgresql" , "ddtags" : "" })
@@ -346,13 +345,13 @@ def test_process_lambda_logs(self):
346345 }
347346 metadata = {"ddsource" : "postgresql" , "ddtags" : "env:dev" }
348347 aws_attributes = AwsAttributes (
348+ self .context ,
349349 lambda_default_loggroup .get ("logGroup" ),
350350 lambda_default_loggroup .get ("logStream" ),
351351 lambda_default_loggroup .get ("owner" ),
352352 )
353- context = Context ()
354353
355- aws_handler = AwsLogsHandler (context , CacheLayer ("" ))
354+ aws_handler = AwsLogsHandler (self . context , CacheLayer ("" ))
356355 aws_handler .process_lambda_logs (metadata , aws_attributes )
357356 self .assertEqual (
358357 metadata ,
@@ -380,12 +379,13 @@ def test_process_lambda_logs(self):
380379
381380class TestLambdaCustomizedLogGroup (unittest .TestCase ):
382381 def setUp (self ):
383- self .aws_handler = AwsLogsHandler (None , None )
382+ self .context = Context ()
383+ self .aws_handler = AwsLogsHandler (self .context , None )
384384
385385 def test_get_lower_cased_lambda_function_name (self ):
386- self .assertEqual (True , True )
387386 # Non Lambda log
388387 aws_attributes = AwsAttributes (
388+ self .context ,
389389 "/aws/vendedlogs/states/logs-to-traces-sequential-Logs" ,
390390 "states/logs-to-traces-sequential/2022-11-10-15-50/7851b2d9" ,
391391 [],
@@ -396,6 +396,7 @@ def test_get_lower_cased_lambda_function_name(self):
396396 )
397397
398398 aws_attributes = AwsAttributes (
399+ self .context ,
399400 "/aws/lambda/test-lambda-default-log-group" ,
400401 "2023/11/06/[$LATEST]b25b1f977b3e416faa45a00f427e7acb" ,
401402 [],
@@ -406,6 +407,7 @@ def test_get_lower_cased_lambda_function_name(self):
406407 )
407408
408409 aws_attributes = AwsAttributes (
410+ self .context ,
409411 "customizeLambdaGrop" ,
410412 "2023/11/06/test-customized-log-group1[$LATEST]13e304cba4b9446eb7ef082a00038990" ,
411413 [],
@@ -418,20 +420,23 @@ def test_get_lower_cased_lambda_function_name(self):
418420
419421class TestParsingStepFunctionLogs (unittest .TestCase ):
420422 def setUp (self ):
421- self .aws_handler = AwsLogsHandler (None , None )
423+ self .context = Context ()
424+ self .aws_handler = AwsLogsHandler (self .context , None )
422425
423426 def test_get_state_machine_arn (self ):
424427 aws_attributes = AwsAttributes (
428+ context = self .context ,
425429 log_events = [
426430 {
427431 "message" : json .dumps ({"no_execution_arn" : "xxxx/yyy" }),
428432 }
429- ]
433+ ],
430434 )
431435
432436 self .assertEqual (self .aws_handler .get_state_machine_arn (aws_attributes ), "" )
433437
434438 aws_attributes = AwsAttributes (
439+ context = self .context ,
435440 log_events = [
436441 {
437442 "message" : json .dumps (
@@ -442,14 +447,15 @@ def test_get_state_machine_arn(self):
442447 }
443448 ),
444449 }
445- ]
450+ ],
446451 )
447452 self .assertEqual (
448453 self .aws_handler .get_state_machine_arn (aws_attributes ),
449454 "arn:aws:states:sa-east-1:425362996713:stateMachine:my-Various-States" ,
450455 )
451456
452457 aws_attributes = AwsAttributes (
458+ context = self .context ,
453459 log_events = [
454460 {
455461 "message" : json .dumps (
@@ -460,7 +466,7 @@ def test_get_state_machine_arn(self):
460466 }
461467 )
462468 }
463- ]
469+ ],
464470 )
465471
466472 self .assertEqual (
@@ -469,6 +475,7 @@ def test_get_state_machine_arn(self):
469475 )
470476
471477 aws_attributes = AwsAttributes (
478+ context = self .context ,
472479 log_events = [
473480 {
474481 "message" : json .dumps (
@@ -479,13 +486,66 @@ def test_get_state_machine_arn(self):
479486 }
480487 )
481488 }
482- ]
489+ ],
483490 )
484491 self .assertEqual (
485492 self .aws_handler .get_state_machine_arn (aws_attributes ),
486493 "arn:aws:states:sa-east-1:425362996713:stateMachine:my-Various-States" ,
487494 )
488495
489496
497+ class TestAwsPartitionExtraction (unittest .TestCase ):
498+ def test_get_log_group_aws_partition (self ):
499+ # default partition
500+ context = Context (
501+ invoked_function_arn = "arn:aws:lambda:us-east-1:12345678910:function:test-lambda"
502+ )
503+ aws_attributes = AwsAttributes (
504+ context = context ,
505+ log_group = "my-log-group" ,
506+ )
507+
508+ aws_attributes .set_account_region (
509+ "arn:aws:lambda:us-east-1:12345678910:function:test-lambda"
510+ )
511+
512+ self .assertEqual (
513+ aws_attributes .get_log_group_arn (),
514+ "arn:aws:logs:us-east-1:12345678910:log-group:my-log-group" ,
515+ )
516+
517+ # aws-cn partition
518+ context = Context (
519+ invoked_function_arn = "arn:aws-cn:lambda:cn-north-1:12345678910:function:test-lambda"
520+ )
521+ aws_attributes = AwsAttributes (
522+ context = context ,
523+ log_group = "my-log-group" ,
524+ )
525+ aws_attributes .set_account_region (
526+ "arn:aws-cn:lambda:cn-north-1:12345678910:function:test-lambda"
527+ )
528+ self .assertEqual (
529+ aws_attributes .get_log_group_arn (),
530+ "arn:aws-cn:logs:cn-north-1:12345678910:log-group:my-log-group" ,
531+ )
532+
533+ # aws-us-gov partition
534+ context = Context (
535+ invoked_function_arn = "arn:aws-us-gov:lambda:us-gov-west-1:12345678910:function:test-lambda"
536+ )
537+ aws_attributes = AwsAttributes (
538+ context = context ,
539+ log_group = "my-log-group" ,
540+ )
541+ aws_attributes .set_account_region (
542+ "arn:aws-us-gov:lambda:us-gov-west-1:12345678910:function:test-lambda"
543+ )
544+ self .assertEqual (
545+ aws_attributes .get_log_group_arn (),
546+ "arn:aws-us-gov:logs:us-gov-west-1:12345678910:log-group:my-log-group" ,
547+ )
548+
549+
490550if __name__ == "__main__" :
491551 unittest .main ()
0 commit comments