Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions aws/logs_monitoring/steps/handlers/s3_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@
)


def create_s3_client():
"""Create a boto3 S3 client with VPC-aware configuration when applicable."""
if DD_USE_VPC:
return boto3.client(
"s3",
os.environ["AWS_REGION"],
config=botocore.config.Config(s3={"addressing_style": "path"}),
)
return boto3.client("s3")


class S3EventDataStore:
def __init__(self):
self.bucket = None
Expand All @@ -48,14 +59,15 @@ def __init__(self):


class S3EventHandler:
def __init__(self, context, metadata, cache_layer):
def __init__(self, context, metadata, cache_layer, s3_client=None):
self.logger = logging.getLogger()
self.logger.setLevel(
logging.getLevelName(os.environ.get("DD_LOG_LEVEL", "INFO").upper())
)
self.context = context
self.metadata = metadata
self.cache_layer = cache_layer
self._s3_client = s3_client or create_s3_client()
self.multiline_regex_start_pattern = _MULTILINE_REGEX_START_PATTERN
self.multiline_regex_pattern = _MULTILINE_REGEX_PATTERN
self.data_store = S3EventDataStore()
Expand Down Expand Up @@ -122,26 +134,12 @@ def _add_s3_tags_from_cache(self):
)

def _extract_data(self):
s3_client = self._get_s3_client()
response = s3_client.get_object(
response = self._s3_client.get_object(
Bucket=self.data_store.bucket, Key=self.data_store.key
)
body = response.get("Body")
self.data_store.data = body.read()

def _get_s3_client(self):
# Need to use path style to access s3 via VPC Endpoints
# https://github.com/gford1000-aws/lambda_s3_access_using_vpc_endpoint#boto3-specific-notes
if DD_USE_VPC:
s3 = boto3.client(
"s3",
os.environ["AWS_REGION"],
config=botocore.config.Config(s3={"addressing_style": "path"}),
)
else:
s3 = boto3.client("s3")
return s3

def _get_structured_lines_for_s3_handler(self):
self._decompress_data()

Expand Down
5 changes: 3 additions & 2 deletions aws/logs_monitoring/steps/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from steps.enums import AwsEventSource, AwsEventType, AwsEventTypeKeyword
from steps.handlers.awslogs_handler import AwsLogsHandler
from steps.handlers.s3_handler import S3EventHandler
from steps.handlers.s3_handler import S3EventHandler, create_s3_client
from telemetry import send_event_metric, set_forwarder_telemetry_tags

logger = logging.getLogger()
Expand Down Expand Up @@ -97,14 +97,15 @@ def parse_event_type(event):

# Handle S3 events delivered via SQS (S3 -> SQS or S3 -> SNS -> SQS)
def sqs_handler(event, context, cache_layer):
s3_client = create_s3_client()
for record in event["Records"]:
inner_event = _extract_inner_event_from_sqs(record)
if inner_event is None:
continue
# Fresh metadata per SQS record: S3EventHandler mutates metadata
# (DD_SOURCE, tags, service), so each record needs its own copy.
metadata = generate_metadata(context)
s3_handler = S3EventHandler(context, metadata, cache_layer)
s3_handler = S3EventHandler(context, metadata, cache_layer, s3_client=s3_client)
for log_event in s3_handler.handle(inner_event):
if isinstance(log_event, dict):
yield merge_dicts(log_event, metadata)
Expand Down
5 changes: 2 additions & 3 deletions aws/logs_monitoring/tests/test_s3_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,16 @@ def test_s3_handler_with_sns(self):
self.assertEqual(self.s3_handler.metadata["ddsource"], "s3")

@patch("caching.cloudwatch_log_group_cache.CloudwatchLogGroupTagsCache.__init__")
@patch("steps.handlers.s3_handler.S3EventHandler._get_s3_client")
def test_s3_tags_added_to_metadata(
self,
mock_get_s3_client,
mock_cache_init,
):
mock_get_s3_client.side_effect = MagicMock()
mock_cache_init.return_value = None
cache_layer = CacheLayer("")
cache_layer._s3_tags_cache.get = MagicMock(return_value=["s3_tag:tag_value"])
self.s3_handler.cache_layer = cache_layer
self.s3_handler._extract_data = MagicMock()
self.s3_handler.data_store.data = b""
event = {
"Records": [
{
Expand Down
Loading