Skip to content

Commit d7f492e

Browse files
authored
[AWS] fix(logs-forwarder): properly set AWS partition for log groups tags cache (#940)
* [AWS] fix(logs-forwarder): properly set AWS partition for log groups tags cache * fix template
1 parent 69a47d0 commit d7f492e

3 files changed

Lines changed: 95 additions & 23 deletions

File tree

aws/logs_monitoring/steps/handlers/aws_attributes.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,26 @@
22

33

44
class AwsAttributes:
5-
def __init__(self, log_group=None, log_stream=None, log_events=None, owner=None):
5+
def __init__(
6+
self, context, log_group=None, log_stream=None, log_events=None, owner=None
7+
):
68
self.log_group = log_group
79
self.log_stream = log_stream
810
self.log_events = log_events
911
self.owner = owner
12+
self.partition = self._get_aws_partition(context)
1013
self.lambda_arn = None
1114
self.account = None
1215
self.region = None
1316

17+
def _get_aws_partition(self, context):
18+
if context.invoked_function_arn.startswith("arn:aws-cn:"):
19+
return "aws-cn"
20+
elif context.invoked_function_arn.startswith("arn:aws-us-gov:"):
21+
return "aws-us-gov"
22+
else:
23+
return "aws"
24+
1425
def to_dict(self):
1526
awslogs = {
1627
"aws": {
@@ -30,7 +41,7 @@ def get_log_group(self):
3041
return self.log_group
3142

3243
def get_log_group_arn(self):
33-
return f"arn:aws:logs:{self.region}:{self.account}:log-group:{self.log_group}"
44+
return f"arn:{self.partition}:logs:{self.region}:{self.account}:log-group:{self.log_group}"
3445

3546
def get_log_stream(self):
3647
return self.log_stream

aws/logs_monitoring/steps/handlers/awslogs_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def handle(self, event):
3737
logs = self.extract_logs(event)
3838
# Build aws attributes
3939
aws_attributes = AwsAttributes(
40+
self.context,
4041
logs.get("logGroup"),
4142
logs.get("logStream"),
4243
logs.get("logEvents"),

aws/logs_monitoring/tests/test_awslogs_handler.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333

3434

3535
class Context:
36+
def __init__(self, invoked_function_arn="invoked_function_arn"):
37+
self.invoked_function_arn = invoked_function_arn
38+
3639
function_version = "$LATEST"
3740
invoked_function_arn = "invoked_function_arn"
3841
function_name = "function_name"
@@ -45,6 +48,7 @@ def setUp(self):
4548
r"forwarder_version:\d+\.\d+\.\d+",
4649
"forwarder_version:<redacted>",
4750
)
51+
self.context = Context()
4852

4953
@patch("caching.cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.__init__")
5054
def test_handle_with_overridden_source(self, mock_cache_init):
@@ -79,13 +83,12 @@ def test_handle_with_overridden_source(self, mock_cache_init):
7983
}
8084

8185
# Create required args
82-
context = Context()
8386
mock_cache_init.return_value = None
8487
cache_layer = CacheLayer("")
8588
cache_layer._cloudwatch_log_group_cache.get = MagicMock(return_value=[])
8689

8790
# Process the event
88-
awslogs_handler = AwsLogsHandler(context, cache_layer)
91+
awslogs_handler = AwsLogsHandler(self.context, cache_layer)
8992

9093
# Verify
9194
verify_as_json(
@@ -127,14 +130,13 @@ def test_awslogs_handler_rds_postgresql(self, mock_cache_init):
127130
)
128131
}
129132
}
130-
context = Context()
131133
mock_cache_init.return_value = None
132134
cache_layer = CacheLayer("")
133135
cache_layer._cloudwatch_log_group_cache.get = MagicMock(
134136
return_value=["test_tag_key:test_tag_value"]
135137
)
136138

137-
awslogs_handler = AwsLogsHandler(context, cache_layer)
139+
awslogs_handler = AwsLogsHandler(self.context, cache_layer)
138140
verify_as_json(
139141
list(awslogs_handler.handle(event)),
140142
options=Options().with_scrubber(self.scrubber),
@@ -189,7 +191,6 @@ def test_awslogs_handler_step_functions_tags_added_properly(
189191
)
190192
}
191193
}
192-
context = Context()
193194
mock_forward_metrics.side_effect = MagicMock()
194195
mock_cache_init.return_value = None
195196
cache_layer = CacheLayer("")
@@ -198,7 +199,7 @@ def test_awslogs_handler_step_functions_tags_added_properly(
198199
)
199200
cache_layer._cloudwatch_log_group_cache.get = MagicMock()
200201

201-
awslogs_handler = AwsLogsHandler(context, cache_layer)
202+
awslogs_handler = AwsLogsHandler(self.context, cache_layer)
202203
verify_as_json(
203204
list(awslogs_handler.handle(event)),
204205
options=Options().with_scrubber(self.scrubber),
@@ -252,7 +253,6 @@ def test_awslogs_handler_step_functions_customized_log_group(
252253
)
253254
}
254255
}
255-
context = Context()
256256
mock_forward_metrics.side_effect = MagicMock()
257257
mock_cache_init.return_value = None
258258
cache_layer = CacheLayer("")
@@ -261,7 +261,7 @@ def test_awslogs_handler_step_functions_customized_log_group(
261261
)
262262
cache_layer._cloudwatch_log_group_cache.get = MagicMock()
263263

264-
awslogs_handler = AwsLogsHandler(context, cache_layer)
264+
awslogs_handler = AwsLogsHandler(self.context, cache_layer)
265265
# for some reasons, the below two are needed to update the context of the handler
266266
verify_as_json(
267267
list(awslogs_handler.handle(eventFromCustomizedLogGroup)),
@@ -304,14 +304,13 @@ def test_awslogs_handler_lambda_log(self):
304304
)
305305
}
306306
}
307-
context = Context()
308307
cache_layer = CacheLayer("")
309308
cache_layer._cloudwatch_log_group_cache.get = MagicMock()
310309
cache_layer._lambda_cache.get = MagicMock(
311310
return_value=["service:customtags_service"]
312311
)
313312

314-
awslogs_handler = AwsLogsHandler(context, cache_layer)
313+
awslogs_handler = AwsLogsHandler(self.context, cache_layer)
315314
verify_as_json(
316315
list(awslogs_handler.handle(event)),
317316
options=Options().with_scrubber(self.scrubber),
@@ -327,12 +326,12 @@ def test_process_lambda_logs(self):
327326
}
328327
metadata = {"ddsource": "postgresql", "ddtags": ""}
329328
aws_attributes = AwsAttributes(
329+
self.context,
330330
stepfunction_loggroup.get("logGroup"),
331331
stepfunction_loggroup.get("logStream"),
332332
stepfunction_loggroup.get("owner"),
333333
)
334-
context = Context()
335-
aws_handler = AwsLogsHandler(context, CacheLayer(""))
334+
aws_handler = AwsLogsHandler(self.context, CacheLayer(""))
336335

337336
aws_handler.process_lambda_logs(metadata, aws_attributes)
338337
self.assertEqual(metadata, {"ddsource": "postgresql", "ddtags": ""})
@@ -346,13 +345,13 @@ def test_process_lambda_logs(self):
346345
}
347346
metadata = {"ddsource": "postgresql", "ddtags": "env:dev"}
348347
aws_attributes = AwsAttributes(
348+
self.context,
349349
lambda_default_loggroup.get("logGroup"),
350350
lambda_default_loggroup.get("logStream"),
351351
lambda_default_loggroup.get("owner"),
352352
)
353-
context = Context()
354353

355-
aws_handler = AwsLogsHandler(context, CacheLayer(""))
354+
aws_handler = AwsLogsHandler(self.context, CacheLayer(""))
356355
aws_handler.process_lambda_logs(metadata, aws_attributes)
357356
self.assertEqual(
358357
metadata,
@@ -380,12 +379,13 @@ def test_process_lambda_logs(self):
380379

381380
class TestLambdaCustomizedLogGroup(unittest.TestCase):
382381
def setUp(self):
383-
self.aws_handler = AwsLogsHandler(None, None)
382+
self.context = Context()
383+
self.aws_handler = AwsLogsHandler(self.context, None)
384384

385385
def test_get_lower_cased_lambda_function_name(self):
386-
self.assertEqual(True, True)
387386
# Non Lambda log
388387
aws_attributes = AwsAttributes(
388+
self.context,
389389
"/aws/vendedlogs/states/logs-to-traces-sequential-Logs",
390390
"states/logs-to-traces-sequential/2022-11-10-15-50/7851b2d9",
391391
[],
@@ -396,6 +396,7 @@ def test_get_lower_cased_lambda_function_name(self):
396396
)
397397

398398
aws_attributes = AwsAttributes(
399+
self.context,
399400
"/aws/lambda/test-lambda-default-log-group",
400401
"2023/11/06/[$LATEST]b25b1f977b3e416faa45a00f427e7acb",
401402
[],
@@ -406,6 +407,7 @@ def test_get_lower_cased_lambda_function_name(self):
406407
)
407408

408409
aws_attributes = AwsAttributes(
410+
self.context,
409411
"customizeLambdaGrop",
410412
"2023/11/06/test-customized-log-group1[$LATEST]13e304cba4b9446eb7ef082a00038990",
411413
[],
@@ -418,20 +420,23 @@ def test_get_lower_cased_lambda_function_name(self):
418420

419421
class TestParsingStepFunctionLogs(unittest.TestCase):
420422
def setUp(self):
421-
self.aws_handler = AwsLogsHandler(None, None)
423+
self.context = Context()
424+
self.aws_handler = AwsLogsHandler(self.context, None)
422425

423426
def test_get_state_machine_arn(self):
424427
aws_attributes = AwsAttributes(
428+
context=self.context,
425429
log_events=[
426430
{
427431
"message": json.dumps({"no_execution_arn": "xxxx/yyy"}),
428432
}
429-
]
433+
],
430434
)
431435

432436
self.assertEqual(self.aws_handler.get_state_machine_arn(aws_attributes), "")
433437

434438
aws_attributes = AwsAttributes(
439+
context=self.context,
435440
log_events=[
436441
{
437442
"message": json.dumps(
@@ -442,14 +447,15 @@ def test_get_state_machine_arn(self):
442447
}
443448
),
444449
}
445-
]
450+
],
446451
)
447452
self.assertEqual(
448453
self.aws_handler.get_state_machine_arn(aws_attributes),
449454
"arn:aws:states:sa-east-1:425362996713:stateMachine:my-Various-States",
450455
)
451456

452457
aws_attributes = AwsAttributes(
458+
context=self.context,
453459
log_events=[
454460
{
455461
"message": json.dumps(
@@ -460,7 +466,7 @@ def test_get_state_machine_arn(self):
460466
}
461467
)
462468
}
463-
]
469+
],
464470
)
465471

466472
self.assertEqual(
@@ -469,6 +475,7 @@ def test_get_state_machine_arn(self):
469475
)
470476

471477
aws_attributes = AwsAttributes(
478+
context=self.context,
472479
log_events=[
473480
{
474481
"message": json.dumps(
@@ -479,13 +486,66 @@ def test_get_state_machine_arn(self):
479486
}
480487
)
481488
}
482-
]
489+
],
483490
)
484491
self.assertEqual(
485492
self.aws_handler.get_state_machine_arn(aws_attributes),
486493
"arn:aws:states:sa-east-1:425362996713:stateMachine:my-Various-States",
487494
)
488495

489496

497+
class TestAwsPartitionExtraction(unittest.TestCase):
498+
def test_get_log_group_aws_partition(self):
499+
# default partition
500+
context = Context(
501+
invoked_function_arn="arn:aws:lambda:us-east-1:12345678910:function:test-lambda"
502+
)
503+
aws_attributes = AwsAttributes(
504+
context=context,
505+
log_group="my-log-group",
506+
)
507+
508+
aws_attributes.set_account_region(
509+
"arn:aws:lambda:us-east-1:12345678910:function:test-lambda"
510+
)
511+
512+
self.assertEqual(
513+
aws_attributes.get_log_group_arn(),
514+
"arn:aws:logs:us-east-1:12345678910:log-group:my-log-group",
515+
)
516+
517+
# aws-cn partition
518+
context = Context(
519+
invoked_function_arn="arn:aws-cn:lambda:cn-north-1:12345678910:function:test-lambda"
520+
)
521+
aws_attributes = AwsAttributes(
522+
context=context,
523+
log_group="my-log-group",
524+
)
525+
aws_attributes.set_account_region(
526+
"arn:aws-cn:lambda:cn-north-1:12345678910:function:test-lambda"
527+
)
528+
self.assertEqual(
529+
aws_attributes.get_log_group_arn(),
530+
"arn:aws-cn:logs:cn-north-1:12345678910:log-group:my-log-group",
531+
)
532+
533+
# aws-us-gov partition
534+
context = Context(
535+
invoked_function_arn="arn:aws-us-gov:lambda:us-gov-west-1:12345678910:function:test-lambda"
536+
)
537+
aws_attributes = AwsAttributes(
538+
context=context,
539+
log_group="my-log-group",
540+
)
541+
aws_attributes.set_account_region(
542+
"arn:aws-us-gov:lambda:us-gov-west-1:12345678910:function:test-lambda"
543+
)
544+
self.assertEqual(
545+
aws_attributes.get_log_group_arn(),
546+
"arn:aws-us-gov:logs:us-gov-west-1:12345678910:log-group:my-log-group",
547+
)
548+
549+
490550
if __name__ == "__main__":
491551
unittest.main()

0 commit comments

Comments
 (0)