diff --git a/aws/logs_monitoring_go/internal/handling/cloudtrail.go b/aws/logs_monitoring_go/internal/handling/cloudtrail.go index f167cf96..96f1274b 100644 --- a/aws/logs_monitoring_go/internal/handling/cloudtrail.go +++ b/aws/logs_monitoring_go/internal/handling/cloudtrail.go @@ -6,7 +6,6 @@ package handling import ( - "compress/gzip" "encoding/json" "fmt" "io" @@ -25,14 +24,7 @@ var ( func decodeCloudTrail(r io.Reader) iter.Seq2[string, error] { return func(yield func(string, error) bool) { - gz, err := gzip.NewReader(r) - if err != nil { - yield("", fmt.Errorf("gzip: %w", err)) - return - } - defer gz.Close() //nolint:errcheck - - dec := json.NewDecoder(gz) + dec := json.NewDecoder(r) if err := parsing.SkipToRecords(dec); err != nil { yield("", fmt.Errorf("cloudtrail: %w", err)) return diff --git a/aws/logs_monitoring_go/internal/handling/cloudtrail_test.go b/aws/logs_monitoring_go/internal/handling/cloudtrail_test.go index 16478f94..02cbd429 100644 --- a/aws/logs_monitoring_go/internal/handling/cloudtrail_test.go +++ b/aws/logs_monitoring_go/internal/handling/cloudtrail_test.go @@ -9,7 +9,6 @@ import ( "bytes" "testing" - "github.com/DataDog/datadog-serverless-functions/aws/logs_monitoring_go/internal/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -117,44 +116,24 @@ func TestDecodeCloudTrail(t *testing.T) { wantErr bool }{ "single record": { - input: testutil.MustGzipJSON(t, map[string]any{ - "Records": []any{ - map[string]any{ - "eventName": "DescribeTable", - "userIdentity": map[string]any{ - "arn": "arn:aws:sts::601427279990:assumed-role/MyRole/i-08014e4f62ccf762d", - }, - }, - }, - }), + input: []byte(`{"Records":[{"eventName":"DescribeTable","userIdentity":{"arn":"arn:aws:sts::601427279990:assumed-role/MyRole/i-08014e4f62ccf762d"}}]}`), want: []string{ `{"eventName":"DescribeTable","userIdentity":{"arn":"arn:aws:sts::601427279990:assumed-role/MyRole/i-08014e4f62ccf762d"}}`, }, }, "multiple records": { - input: testutil.MustGzipJSON(t, map[string]any{ - "Records": []any{ - map[string]any{"eventName": "event1"}, - map[string]any{"eventName": "event2"}, - }, - }), + input: []byte(`{"Records":[{"eventName":"event1"},{"eventName":"event2"}]}`), want: []string{ `{"eventName":"event1"}`, `{"eventName":"event2"}`, }, }, "empty records array": { - input: testutil.MustGzipJSON(t, map[string]any{ - "Records": []any{}, - }), - want: nil, - }, - "invalid gzip": { - input: []byte("not gzip"), - wantErr: true, + input: []byte(`{"Records":[]}`), + want: nil, }, "invalid json": { - input: testutil.MustGzipJSON(t, "not an object"), + input: []byte("not json"), wantErr: true, }, } diff --git a/aws/logs_monitoring_go/internal/handling/eventbridge.go b/aws/logs_monitoring_go/internal/handling/eventbridge.go index 06ab7619..3de8d4cb 100644 --- a/aws/logs_monitoring_go/internal/handling/eventbridge.go +++ b/aws/logs_monitoring_go/internal/handling/eventbridge.go @@ -10,13 +10,13 @@ import ( "cmp" "context" "encoding/json" - "errors" "fmt" "strings" "github.com/DataDog/datadog-serverless-functions/aws/logs_monitoring_go/internal/concurrent" "github.com/DataDog/datadog-serverless-functions/aws/logs_monitoring_go/internal/config" "github.com/DataDog/datadog-serverless-functions/aws/logs_monitoring_go/internal/model" + "github.com/DataDog/datadog-serverless-functions/aws/logs_monitoring_go/internal/parsing" ) type EventBridgeHandler struct { @@ -35,60 +35,80 @@ func (h *EventBridgeHandler) Handle(ctx context.Context, event json.RawMessage, return fmt.Errorf("get lambda origin: %w", err) } - ebSource, err := decodeEventBridgeSource(event) + source, err := eventBridgeSource(event) if err != nil { - return err + return fmt.Errorf("source: %w", err) } - source := cmp.Or(h.cfg.Source, ebSource) - service := cmp.Or(h.cfg.Service, source) - entry := model.NewLogEntry() - entry.Message = string(event) - entry.Source = source - entry.Service = service - entry.Tags = h.cfg.Tags - entry.Metadata = lambdaOrigin + switch source { + case sourceSecurityHub: + return h.securityHub(ctx, event, source, out, lambdaOrigin) + default: + return h.eventBridge(ctx, event, source, out, lambdaOrigin) + } +} - if h.cfg.Filter.ShouldExclude(entry.Message) { +func (h *EventBridgeHandler) eventBridge(ctx context.Context, event json.RawMessage, source string, out chan<- model.LogEntry, lambdaOrigin model.LambdaOrigin) error { + message := string(event) + if h.cfg.Filter.ShouldExclude(message) { return nil } - entry.Message = h.cfg.Scrubber.Scrub(entry.Message) + entry := h.newEntry(source, lambdaOrigin) + entry.Message = h.cfg.Scrubber.Scrub(message) + return concurrent.SafeSender(ctx, out, entry) } -func decodeEventBridgeSource(event json.RawMessage) (string, error) { - dec := json.NewDecoder(bytes.NewReader(event)) - - if t, err := dec.Token(); err != nil || t != json.Delim('{') { - return "", errors.New("decode eventbridge source: expected '{'") +func (h *EventBridgeHandler) securityHub(ctx context.Context, event json.RawMessage, source string, out chan<- model.LogEntry, lambdaOrigin model.LambdaOrigin) error { + messages := separateFindings(event) + if len(messages) == 0 { + return h.eventBridge(ctx, event, source, out, lambdaOrigin) } - for dec.More() { - key, err := dec.Token() - if err != nil { - return "", fmt.Errorf("decode eventbridge source: read key: %w", err) - } - if key == "source" { - var source string - if err := dec.Decode(&source); err != nil { - return "", fmt.Errorf("decode eventbridge source: %w", err) - } - return eventBridgeSource(source), nil + base := h.newEntry(source, lambdaOrigin) + for _, message := range messages { + if h.cfg.Filter.ShouldExclude(message) { + continue } - var skip json.RawMessage - if err := dec.Decode(&skip); err != nil { - return "", fmt.Errorf("decode eventbridge source: skip field: %w", err) + + entry := base + entry.Message = h.cfg.Scrubber.Scrub(message) + + if err := concurrent.SafeSender(ctx, out, entry); err != nil { + return err } } + return nil +} - return "", nil +func (h *EventBridgeHandler) newEntry(source string, lambdaOrigin model.LambdaOrigin) model.LogEntry { + entry := model.NewLogEntry() + entry.Source = cmp.Or(h.cfg.Source, source) + entry.Service = cmp.Or(h.cfg.Service, entry.Source) + entry.Tags = h.cfg.Tags + entry.Metadata = lambdaOrigin + return entry } -func eventBridgeSource(source string) string { - _, after, found := strings.Cut(source, ".") +func eventBridgeSource(event json.RawMessage) (string, error) { + dec := json.NewDecoder(bytes.NewReader(event)) + if err := parsing.SkipBrace(dec); err != nil { + return "", err + } + + if err := parsing.SkipToKey(dec, "source"); err != nil { + return "", err + } + + var rawSource string + if err := dec.Decode(&rawSource); err != nil { + return "", fmt.Errorf("decode: %w", err) + } + + _, source, found := strings.Cut(rawSource, ".") if found { - return after + return source, nil } - return sourceCloudwatch + return sourceCloudwatch, nil } diff --git a/aws/logs_monitoring_go/internal/handling/eventbridge_test.go b/aws/logs_monitoring_go/internal/handling/eventbridge_test.go index 2455ee8a..8dee161f 100644 --- a/aws/logs_monitoring_go/internal/handling/eventbridge_test.go +++ b/aws/logs_monitoring_go/internal/handling/eventbridge_test.go @@ -71,6 +71,19 @@ func TestEventBridgeHandler_Handle(t *testing.T) { cfg: testutil.EmptyConfig(), wantErr: true, }, + "securityhub no findings falls back": { + event: json.RawMessage(`{"source":"aws.securityhub","detail":{}}`), + cfg: testutil.EmptyConfig(), + want: []model.LogEntry{ + { + Message: `{"source":"aws.securityhub","detail":{}}`, + Source: sourceSecurityHub, + SourceCategory: "aws", + Service: sourceSecurityHub, + Metadata: testutil.LambdaOrigin(), + }, + }, + }, } for name, tc := range tests { @@ -100,26 +113,57 @@ func TestEventBridgeHandler_Handle(t *testing.T) { } } -func TestEventBridgeSource(t *testing.T) { +func TestEventBridgeHandler_SecurityHub(t *testing.T) { t.Parallel() + ctx := testutil.LambdaContext(t) + tests := map[string]struct { - source string - want string + event json.RawMessage + cfg *config.Config + want []string }{ - "aws.events": {source: "aws.events", want: "events"}, - "aws.ec2": {source: "aws.ec2", want: "ec2"}, - "aws.s3": {source: "aws.s3", want: "s3"}, - "custom.app": {source: "custom.app", want: "app"}, - "no dot": {source: "nodot", want: "cloudwatch"}, - "empty string": {source: "", want: "cloudwatch"}, + "one finding": { + event: json.RawMessage(`{"source":"aws.securityhub","detail-type":"Security Hub Findings - Imported","detail":{"findings":[{"myattribute":"somevalue","Resources":[{"Region":"us-east-1","Type":"AwsEc2SecurityGroup"}]}]}}`), + cfg: testutil.EmptyConfig(), + want: []string{`{"source":"aws.securityhub","detail-type":"Security Hub Findings - Imported","detail":{"finding":{"myattribute":"somevalue","resources":{"AwsEc2SecurityGroup":{"Region":"us-east-1"}}}}}`}, + }, + "multiple findings": { + event: json.RawMessage(`{"source":"aws.securityhub","detail":{"findings":[{"id":"f1","Resources":[{"Type":"AwsEc2SecurityGroup","Region":"us-east-1"}]},{"id":"f2","Resources":[{"Type":"AwsIamRole","Region":"us-west-2"}]}]}}`), + cfg: testutil.EmptyConfig(), + want: []string{ + `{"source":"aws.securityhub","detail":{"finding":{"id":"f1","resources":{"AwsEc2SecurityGroup":{"Region":"us-east-1"}}}}}`, + `{"source":"aws.securityhub","detail":{"finding":{"id":"f2","resources":{"AwsIamRole":{"Region":"us-west-2"}}}}}`, + }, + }, + "with filtering": { + event: json.RawMessage(`{"source":"aws.securityhub","detail":{"findings":[{"id":"keep","Resources":[]},{"id":"drop","Resources":[]}]}}`), + cfg: testutil.Config(t, testutil.WithExcludeFilter(`"id":"drop"`)), + want: []string{`{"source":"aws.securityhub","detail":{"finding":{"id":"keep","resources":{}}}}`}, + }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { t.Parallel() - got := eventBridgeSource(tc.source) - assert.Equal(t, tc.want, got) + + handler := NewEventBridge(tc.cfg) + out := make(chan model.LogEntry, len(tc.want)) + + err := handler.Handle(ctx, tc.event, out) + close(out) + + require.NoError(t, err) + + var got []model.LogEntry + for entry := range out { + got = append(got, entry) + } + + require.Len(t, got, len(tc.want)) + for i := range tc.want { + assert.JSONEq(t, tc.want[i], got[i].Message) + } }) } } diff --git a/aws/logs_monitoring_go/internal/handling/handling.go b/aws/logs_monitoring_go/internal/handling/handling.go new file mode 100644 index 00000000..809d0dfe --- /dev/null +++ b/aws/logs_monitoring_go/internal/handling/handling.go @@ -0,0 +1,59 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2026-Present Datadog, Inc. + +package handling + +import ( + "bufio" + "bytes" + "compress/gzip" + "fmt" + "io" +) + +var gzipMagic = []byte{0x1f, 0x8b} + +func gunzip(r io.Reader) (io.Reader, func() error, error) { + buf := bufio.NewReaderSize(r, len(gzipMagic)) + header, err := buf.Peek(len(gzipMagic)) + if err != nil || !bytes.Equal(header, gzipMagic) { + return buf, func() error { return nil }, nil + } + + gz, err := gzip.NewReader(buf) + if err != nil { + return nil, nil, fmt.Errorf("gzip: %w", err) + } + return gz, func() error { return gz.Close() }, nil +} + +func flattenByKey(src map[string]any, field, keyField, outputField string, alwaysWrite bool) { + arr, ok := src[field].([]any) + if !ok && !alwaysWrite { + return + } + + result := make(map[string]any, len(arr)) + for _, item := range arr { + obj, ok := item.(map[string]any) + if !ok { + continue + } + + key, _ := obj[keyField].(string) + if key == "" { + continue + } + delete(obj, keyField) + result[key] = obj + } + + out := field + if outputField != "" { + delete(src, field) + out = outputField + } + src[out] = result +} diff --git a/aws/logs_monitoring_go/internal/handling/s3.go b/aws/logs_monitoring_go/internal/handling/s3.go index 89a32e44..66b6e1d7 100644 --- a/aws/logs_monitoring_go/internal/handling/s3.go +++ b/aws/logs_monitoring_go/internal/handling/s3.go @@ -7,7 +7,6 @@ package handling import ( "cmp" - "compress/gzip" "context" "encoding/json" "fmt" @@ -73,14 +72,24 @@ func (h S3Handler) processRecord(ctx context.Context, client S3APIClient, out ch } }() + reader, close, err := gunzip(body) + if err != nil { + return err + } + defer func() { + if err := close(); err != nil { + slog.Warn("close gunzip", slog.Any("error", err)) + } + }() + source := S3Source(eventRecord.S3.Object.URLDecodedKey) switch source { case sourceCloudtrail: - err = h.CloudTrail(ctx, out, body, eventRecord, lambdaOrigin) + err = h.CloudTrail(ctx, out, reader, eventRecord, lambdaOrigin) case sourceWAF: - err = h.WAF(ctx, out, body, eventRecord, lambdaOrigin) + err = h.WAF(ctx, out, reader, eventRecord, lambdaOrigin) default: - err = h.S3(ctx, out, body, eventRecord, lambdaOrigin) + err = h.S3(ctx, out, reader, eventRecord, lambdaOrigin) } if err != nil { @@ -102,9 +111,9 @@ func S3Source(key string) string { return sourceS3 } -func (h S3Handler) S3(ctx context.Context, out chan<- model.LogEntry, body io.ReadCloser, eventRecord events.S3EventRecord, lambdaOrigin model.LambdaOrigin) error { +func (h S3Handler) S3(ctx context.Context, out chan<- model.LogEntry, r io.Reader, eventRecord events.S3EventRecord, lambdaOrigin model.LambdaOrigin) error { base := h.newBaseEntry(eventRecord, lambdaOrigin) - for message, err := range scan(body, h.cfg.S3MultilineLogRegex) { + for message, err := range scan(r, h.cfg.S3MultilineLogRegex) { if err != nil { return err } @@ -126,19 +135,9 @@ func (h S3Handler) S3(ctx context.Context, out chan<- model.LogEntry, body io.Re return nil } -func (h S3Handler) WAF(ctx context.Context, out chan<- model.LogEntry, body io.ReadCloser, eventRecord events.S3EventRecord, lambdaOrigin model.LambdaOrigin) error { - gz, err := gzip.NewReader(body) - if err != nil { - return err - } - defer func() { - if err := gz.Close(); err != nil { - slog.Warn("close gzip reader", slog.Any("error", err)) - } - }() - +func (h S3Handler) WAF(ctx context.Context, out chan<- model.LogEntry, r io.Reader, eventRecord events.S3EventRecord, lambdaOrigin model.LambdaOrigin) error { base := h.newBaseEntry(eventRecord, lambdaOrigin) - for message, err := range scan(gz, nil) { + for message, err := range scan(r, nil) { if err != nil { return err } @@ -158,9 +157,9 @@ func (h S3Handler) WAF(ctx context.Context, out chan<- model.LogEntry, body io.R return nil } -func (h S3Handler) CloudTrail(ctx context.Context, out chan<- model.LogEntry, body io.ReadCloser, eventRecord events.S3EventRecord, lambdaOrigin model.LambdaOrigin) error { +func (h S3Handler) CloudTrail(ctx context.Context, out chan<- model.LogEntry, r io.Reader, eventRecord events.S3EventRecord, lambdaOrigin model.LambdaOrigin) error { base := h.newBaseEntry(eventRecord, lambdaOrigin) - for message, err := range decodeCloudTrail(body) { + for message, err := range decodeCloudTrail(r) { if err != nil { return err } diff --git a/aws/logs_monitoring_go/internal/handling/s3_test.go b/aws/logs_monitoring_go/internal/handling/s3_test.go index 19933457..67bdef75 100644 --- a/aws/logs_monitoring_go/internal/handling/s3_test.go +++ b/aws/logs_monitoring_go/internal/handling/s3_test.go @@ -258,6 +258,17 @@ func TestProcessS3Record(t *testing.T) { wantWAFEntry(`{"action":"BLOCK"}`), }, }, + "waf non-gzipped": { + mockSetup: func(m *MockS3APIClient) { + m.EXPECT().GetObject(gomock.Any(), gomock.Any()). + Return(&s3.GetObjectOutput{ + Body: io.NopCloser(strings.NewReader(`{"httpRequest":{"headers":[{"name":"Host","value":"example.com"}]}}`)), + }, nil) + }, + cfg: testutil.EmptyConfig(), + eventRecord: testWAFEventRecord, + want: []model.LogEntry{wantWAFEntry(`{"httpRequest":{"headers":{"Host":"example.com"}}}`)}, + }, "waf exclude at match": { mockSetup: func(m *MockS3APIClient) { lines := `{"action":"ALLOW","httpRequest":{}}` + "\n" + `{"action":"BLOCK","httpRequest":{}}` diff --git a/aws/logs_monitoring_go/internal/handling/securityhub.go b/aws/logs_monitoring_go/internal/handling/securityhub.go new file mode 100644 index 00000000..de35f376 --- /dev/null +++ b/aws/logs_monitoring_go/internal/handling/securityhub.go @@ -0,0 +1,49 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2026-Present Datadog, Inc. + +package handling + +import ( + "encoding/json" + "log/slog" +) + +const ( + findingKey = "finding" + findingsKey = "findings" +) + +func separateFindings(event json.RawMessage) []string { + var raw map[string]any + if err := json.Unmarshal(event, &raw); err != nil { + return nil + } + + detail, _ := raw["detail"].(map[string]any) + findings, _ := detail[findingsKey].([]any) + if len(findings) == 0 { + return nil + } + + delete(detail, findingsKey) + + messages := make([]string, 0, len(findings)) + for _, f := range findings { + finding, ok := f.(map[string]any) + if !ok { + continue + } + flattenByKey(finding, "Resources", "Type", "resources", true) + detail[findingKey] = finding + + out, err := json.Marshal(raw) + if err != nil { + slog.Warn("marshal securityhub finding, skipped", slog.Any("error", err)) + continue + } + messages = append(messages, string(out)) + } + return messages +} diff --git a/aws/logs_monitoring_go/internal/handling/securityhub_test.go b/aws/logs_monitoring_go/internal/handling/securityhub_test.go new file mode 100644 index 00000000..6956870d --- /dev/null +++ b/aws/logs_monitoring_go/internal/handling/securityhub_test.go @@ -0,0 +1,65 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2026-Present Datadog, Inc. + +package handling + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSeparateFindings(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + input string + want []string + }{ + "no findings field": { + input: `{"source":"aws.securityhub"}`, + }, + "empty findings": { + input: `{"detail":{"findings":[]}}`, + }, + "invalid json": { + input: `not json`, + }, + "one finding no resources": { + input: `{"ddsource":"securityhub","detail":{"findings":[{"myattribute":"somevalue"}]}}`, + want: []string{ + `{"ddsource":"securityhub","detail":{"finding":{"myattribute":"somevalue","resources":{}}}}`, + }, + }, + "two findings one resource each": { + input: `{"ddsource":"securityhub","detail":{"findings":[{"myattribute":"somevalue","Resources":[{"Region":"us-east-1","Type":"AwsEc2SecurityGroup"}]},{"myattribute":"somevalue","Resources":[{"Region":"us-east-1","Type":"AwsEc2SecurityGroup"}]}]}}`, + want: []string{ + `{"ddsource":"securityhub","detail":{"finding":{"myattribute":"somevalue","resources":{"AwsEc2SecurityGroup":{"Region":"us-east-1"}}}}}`, + `{"ddsource":"securityhub","detail":{"finding":{"myattribute":"somevalue","resources":{"AwsEc2SecurityGroup":{"Region":"us-east-1"}}}}}`, + }, + }, + "multiple findings multiple resources": { + input: `{"ddsource":"securityhub","detail":{"findings":[{"myattribute":"somevalue","Resources":[{"Region":"us-east-1","Type":"AwsEc2SecurityGroup"}]},{"myattribute":"somevalue","Resources":[{"Region":"us-east-1","Type":"AwsEc2SecurityGroup"},{"Region":"us-east-1","Type":"AwsOtherSecurityGroup"}]},{"myattribute":"somevalue","Resources":[{"Region":"us-east-1","Type":"AwsEc2SecurityGroup"},{"Region":"us-east-1","Type":"AwsOtherSecurityGroup"},{"Region":"us-east-1","Type":"AwsAnotherSecurityGroup"}]}]}}`, + want: []string{ + `{"ddsource":"securityhub","detail":{"finding":{"myattribute":"somevalue","resources":{"AwsEc2SecurityGroup":{"Region":"us-east-1"}}}}}`, + `{"ddsource":"securityhub","detail":{"finding":{"myattribute":"somevalue","resources":{"AwsEc2SecurityGroup":{"Region":"us-east-1"},"AwsOtherSecurityGroup":{"Region":"us-east-1"}}}}}`, + `{"ddsource":"securityhub","detail":{"finding":{"myattribute":"somevalue","resources":{"AwsAnotherSecurityGroup":{"Region":"us-east-1"},"AwsEc2SecurityGroup":{"Region":"us-east-1"},"AwsOtherSecurityGroup":{"Region":"us-east-1"}}}}}`, + }, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + got := separateFindings(json.RawMessage(tc.input)) + require.Len(t, got, len(tc.want)) + for i, want := range tc.want { + assert.JSONEq(t, want, got[i]) + } + }) + } +} diff --git a/aws/logs_monitoring_go/internal/handling/source.go b/aws/logs_monitoring_go/internal/handling/source.go index 13071e15..41e05335 100644 --- a/aws/logs_monitoring_go/internal/handling/source.go +++ b/aws/logs_monitoring_go/internal/handling/source.go @@ -7,11 +7,12 @@ package handling const ( sourceCloudtrail = "cloudtrail" + sourceCloudwatch = "cloudwatch" sourceKinesis = "kinesis" sourceLambda = "lambda" - sourceWAF = "waf" sourceS3 = "s3" - sourceCloudwatch = "cloudwatch" + sourceSecurityHub = "securityhub" sourceSNS = "sns" sourceStepFunction = "stepfunction" + sourceWAF = "waf" ) diff --git a/aws/logs_monitoring_go/internal/handling/waf.go b/aws/logs_monitoring_go/internal/handling/waf.go index 9c1aa8d2..1149748a 100644 --- a/aws/logs_monitoring_go/internal/handling/waf.go +++ b/aws/logs_monitoring_go/internal/handling/waf.go @@ -7,7 +7,13 @@ package handling import "encoding/json" -const ruleIDKey = "ruleId" +const ( + headersKey = "headers" + nonTerminatingMatchingRulesKey = "nonTerminatingMatchingRules" + ruleGroupListKey = "ruleGroupList" + ruleGroupIdKey = "ruleGroupId" + ruleIDKey = "ruleId" +) func flattenWAFMessage(message string) string { var msg map[string]any @@ -17,8 +23,8 @@ func flattenWAFMessage(message string) string { flattenHeaders(msg) flattenRuleGroupList(msg) - flattenByKey(msg, "rateBasedRuleList", "rateBasedRuleName") - flattenByKey(msg, "nonTerminatingMatchingRules", ruleIDKey) + flattenByKey(msg, "rateBasedRuleList", "rateBasedRuleName", "", false) + flattenByKey(msg, nonTerminatingMatchingRulesKey, ruleIDKey, "", false) out, err := json.Marshal(msg) if err != nil { @@ -32,7 +38,7 @@ func flattenHeaders(msg map[string]any) { if !ok { return } - headers, ok := httpReq["headers"].([]any) + headers, ok := httpReq[headersKey].([]any) if !ok { return } @@ -43,39 +49,18 @@ func flattenHeaders(msg map[string]any) { if !ok { continue } + name, _ := header["name"].(string) if name == "" { continue } result[name] = header["value"] } - httpReq["headers"] = result -} - -func flattenByKey(msg map[string]any, field, keyField string) { - arr, ok := msg[field].([]any) - if !ok { - return - } - - result := make(map[string]any, len(arr)) - for _, item := range arr { - entry, ok := item.(map[string]any) - if !ok { - continue - } - key, _ := entry[keyField].(string) - if key == "" { - continue - } - delete(entry, keyField) - result[key] = entry - } - msg[field] = result + httpReq[headersKey] = result } func flattenRuleGroupList(msg map[string]any) { - arr, ok := msg["ruleGroupList"].([]any) + arr, ok := msg[ruleGroupListKey].([]any) if !ok { return } @@ -86,9 +71,9 @@ func flattenRuleGroupList(msg map[string]any) { if !ok { continue } - groupID, _ := group["ruleGroupId"].(string) - delete(group, "ruleGroupId") + groupID, _ := group[ruleGroupIdKey].(string) + delete(group, ruleGroupIdKey) existing, ok := result[groupID].(map[string]any) if !ok { existing = make(map[string]any) @@ -96,10 +81,10 @@ func flattenRuleGroupList(msg map[string]any) { } flattenRuleGroupField(group, existing, "terminatingRule") - flattenRuleGroupField(group, existing, "nonTerminatingMatchingRules") + flattenRuleGroupField(group, existing, nonTerminatingMatchingRulesKey) flattenRuleGroupField(group, existing, "excludedRules") } - msg["ruleGroupList"] = result + msg[ruleGroupListKey] = result } func flattenRuleGroupField(group, dest map[string]any, field string) {