diff --git a/pkg/analysis/topology/topology.go b/pkg/analysis/topology/topology.go index e3d2454..923ba5b 100644 --- a/pkg/analysis/topology/topology.go +++ b/pkg/analysis/topology/topology.go @@ -205,20 +205,7 @@ func ExtractTopology(fn *ssa.Function) *FunctionTopology { val = val[:maxStrLen] } - // Security Fix: Prevent DoS via O(N^2) ValidString checks on binary data. - // Perform a linear scan O(N) to find the longest valid UTF-8 prefix. - if !utf8.ValidString(val) { - validLen := 0 - for i := 0; i < len(val); { - r, size := utf8.DecodeRuneInString(val[i:]) - if r == utf8.RuneError && size == 1 { - break // Stop at first invalid byte - } - validLen += size - i += size - } - val = val[:validLen] - } + val = truncateToValidUTF8(val) if currentStringBytes+len(val) <= maxTotalBytes { t.StringLiterals = append(t.StringLiterals, val) @@ -576,3 +563,23 @@ func flattenStringLiterals(literals []string) []byte { } return dataAccumulator } + +// truncateToValidUTF8 returns the longest valid UTF-8 prefix of s. +// It avoids O(N^2) behavior by performing a single linear scan if the string is invalid. +func truncateToValidUTF8(s string) string { + if utf8.ValidString(s) { + return s + } + + // Perform a linear scan O(N) to find the longest valid UTF-8 prefix. + validLen := 0 + for i := 0; i < len(s); { + r, size := utf8.DecodeRuneInString(s[i:]) + if r == utf8.RuneError && size == 1 { + break // Stop at first invalid byte + } + validLen += size + i += size + } + return s[:validLen] +} diff --git a/pkg/analysis/topology/topology_internal_test.go b/pkg/analysis/topology/topology_internal_test.go index 26ce001..ef9c0c9 100644 --- a/pkg/analysis/topology/topology_internal_test.go +++ b/pkg/analysis/topology/topology_internal_test.go @@ -41,3 +41,56 @@ func BenchmarkFlattenStringLiterals(b *testing.B) { _ = flattenStringLiterals(literals) } } + +func TestTruncateToValidUTF8(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "Empty string", + input: "", + expected: "", + }, + { + name: "Valid ASCII", + input: "hello world", + expected: "hello world", + }, + { + name: "Valid UTF-8", + input: "Hello, 世界", + expected: "Hello, 世界", + }, + { + name: "Invalid at end", + input: "Hello, \xFF", + expected: "Hello, ", + }, + { + name: "Invalid in middle", + input: "Hello, \xFF World", + expected: "Hello, ", + }, + { + name: "Binary start", + input: "\xFF\xFE", + expected: "", + }, + { + name: "Partial multi-byte rune at end", + input: string([]byte{0xe4, 0xb8, 0x96, 0xe7, 0x95}), // "世" + first byte of "界" + expected: "世", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncateToValidUTF8(tt.input) + if result != tt.expected { + t.Errorf("truncateToValidUTF8(%q) = %q; want %q", tt.input, result, tt.expected) + } + }) + } +}