diff --git a/.gitignore b/.gitignore index 3530a8e295..98832795b3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ *.pyc *.profile +ui/opensnitch/proto/ +daemon/ui/protocol/ + .vscode/ .idea/ .DS_Store diff --git a/daemon/rule/loader.go b/daemon/rule/loader.go index 155d2d1cd5..390453a6a5 100644 --- a/daemon/rule/loader.go +++ b/daemon/rule/loader.go @@ -10,6 +10,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "time" "github.com/evilsocket/opensnitch/daemon/conman" @@ -26,15 +27,20 @@ type Loader struct { watcher *fsnotify.Watcher rules map[string]*Rule activeRules []string + activeSnapshot atomic.Pointer[activeRulesSnapshot] Path string liveReload bool liveReloadRunning bool - checkSums bool + checkSums atomic.Bool stopLiveReload chan struct{} sync.RWMutex } +type activeRulesSnapshot struct { + rules []*Rule +} + // NewLoader loads rules from disk, and watches for changes made to the rules files // on disk. func NewLoader(liveReload bool) (*Loader, error) { @@ -69,7 +75,7 @@ func (l *Loader) GetAll() map[string]*Rule { // EnableChecksums enables checksums field for rules globally. func (l *Loader) EnableChecksums(enable bool) { log.Debug("[rules loader] EnableChecksums: %v", enable) - l.checkSums = enable + l.checkSums.Store(enable) procmon.EventsCache.SetComputeChecksums(enable) procmon.EventsCache.AddChecksumHash(string(OpProcessHashMD5)) } @@ -113,6 +119,7 @@ func (l *Loader) Reload(path string) error { l.Lock() l.activeRules = make([]string, 0) l.rules = make(map[string]*Rule) + l.activeSnapshot.Store(nil) l.Unlock() return l.Load(path) } @@ -367,6 +374,7 @@ func (l *Loader) unmarshalOperatorList(op *Operator) error { func (l *Loader) sortRules() { l.activeRules = make([]string, 0, len(l.rules)) + orderedRules := make([]*Rule, 0, len(l.rules)) for k, r := range l.rules { // exclude not enabled rules from the list of active rules if !r.Enabled { @@ -375,6 +383,10 @@ func (l *Loader) sortRules() { l.activeRules = append(l.activeRules, k) } sort.Strings(l.activeRules) + for _, name := range l.activeRules { + orderedRules = append(orderedRules, l.rules[name]) + } + l.activeSnapshot.Store(&activeRulesSnapshot{rules: orderedRules}) } func (l *Loader) addUserRule(rule *Rule) { @@ -495,12 +507,14 @@ Exit: // FindFirstMatch will try match the connection against the existing rule set. func (l *Loader) FindFirstMatch(con *conman.Connection) (match *Rule) { - l.RLock() - defer l.RUnlock() + snapshot := l.activeSnapshot.Load() + if snapshot == nil { + return nil + } + hasChecksums := l.checkSums.Load() - for _, idx := range l.activeRules { - rule, _ := l.rules[idx] - if rule.Match(con, l.checkSums) { + for _, rule := range snapshot.rules { + if rule.Match(con, hasChecksums) { // We have a match. // Save the rule in order to don't ask the user to take action, // and keep iterating until a Deny or a Priority rule appears. diff --git a/daemon/rule/operator.go b/daemon/rule/operator.go index 6bcb495714..4bce88bf86 100644 --- a/daemon/rule/operator.go +++ b/daemon/rule/operator.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "github.com/evilsocket/opensnitch/daemon/conman" "github.com/evilsocket/opensnitch/daemon/core" @@ -75,6 +76,20 @@ const ( type opCallback func(value string) bool type opGenericCallback func(value interface{}) bool +type listRegexEntry struct { + file string + re *regexp.Regexp +} + +type listCacheSnapshot struct { + lists map[string]interface{} + domainWildcards domainWildcardTrie + domainGlobs []string + listExact map[string]struct{} + listNets []*net.IPNet + regexEntries []listRegexEntry +} + // Operator represents what we want to filter of a connection, and how. type Operator struct { cb opCallback @@ -82,6 +97,11 @@ type Operator struct { re *regexp.Regexp netMask *net.IPNet lists map[string]interface{} + domainWildcards domainWildcardTrie + domainGlobs []string + listExact map[string]struct{} + listNets []*net.IPNet + listSnapshot atomic.Pointer[listCacheSnapshot] exitMonitorChan chan (struct{}) rangeMin uint64 rangeMax uint64 @@ -178,10 +198,10 @@ func (o *Operator) Compile() error { o.cb = o.reListCmp } else if o.Operand == OpIPLists { o.loadLists() - o.cb = o.simpleListsCmp + o.cbGeneric = o.ipListsCmp } else if o.Operand == OpNetLists { o.loadLists() - o.cbGeneric = o.ipNetCmp + o.cbGeneric = o.netListsCmp } else if o.Operand == OpHashMD5Lists { o.loadLists() o.cb = o.simpleListsCmp @@ -290,9 +310,12 @@ func (o *Operator) cmpNetwork(destIP interface{}) bool { } func (o *Operator) matchListsCmp(msg, what string) bool { - o.RLock() - item, found := o.lists[what] - o.RUnlock() + snapshot := o.listSnapshot.Load() + if snapshot == nil { + return false + } + + item, found := snapshot.lists[what] if found { log.Debug("%s: %s, %s", log.Red(msg), what, item) @@ -309,7 +332,29 @@ func (o *Operator) domainsListsCmp(data string) bool { data = strings.ToLower(data) } - return o.matchListsCmp("domains list match", data) + snapshot := o.listSnapshot.Load() + if snapshot == nil { + return false + } + + _, exactFound := snapshot.lists[data] + + if exactFound { + log.Debug("%s: %s", log.Red("domains list match"), data) + return true + } + if snapshot.domainWildcards.matchesHost(data) { + log.Debug("%s: %s", log.Red("domains wildcard match"), data) + return true + } + for _, g := range snapshot.domainGlobs { + if matchDomainGlob(g, data) { + log.Debug("%s: %s", log.Red("domains glob match"), data) + return true + } + } + + return false } func (o *Operator) simpleListsCmp(what string) bool { @@ -320,17 +365,52 @@ func (o *Operator) simpleListsCmp(what string) bool { return o.matchListsCmp("simple list match", what) } -func (o *Operator) ipNetCmp(dstIP interface{}) bool { - o.RLock() - defer o.RUnlock() +func (o *Operator) netListsCmp(dstIP interface{}) bool { + ip := dstIP.(net.IP) + ipText := ip.String() + snapshot := o.listSnapshot.Load() + if snapshot == nil { + return false + } + + _, exactFound := snapshot.listExact[ipText] + + if exactFound { + log.Debug("%s: %s", log.Red("Net exact list match"), ipText) + return true + } + + for _, netMask := range snapshot.listNets { + if netMask.Contains(ip) { + log.Debug("%s: %s, %s", log.Red("Net list match"), ipText, netMask.String()) + return true + } + } + return false +} + +func (o *Operator) ipListsCmp(dstIP interface{}) bool { + ip := dstIP.(net.IP) + ipText := ip.String() + snapshot := o.listSnapshot.Load() + if snapshot == nil { + return false + } + + _, exactFound := snapshot.listExact[ipText] + + if exactFound { + log.Debug("%s: %s", log.Red("IP list exact match"), ipText) + return true + } - for host, netMask := range o.lists { - n := netMask.(*net.IPNet) - if n.Contains(dstIP.(net.IP)) { - log.Debug("%s: %s, %s", log.Red("Net list match"), dstIP, host) + for _, netMask := range snapshot.listNets { + if netMask.Contains(ip) { + log.Debug("%s: %s, %s", log.Red("IP list cidr match"), ipText, netMask.String()) return true } } + return false } @@ -341,13 +421,14 @@ func (o *Operator) reListCmp(data string) bool { if o.Sensitive == false { data = strings.ToLower(data) } - o.RLock() - defer o.RUnlock() + snapshot := o.listSnapshot.Load() + if snapshot == nil { + return false + } - for file, re := range o.lists { - r := re.(*regexp.Regexp) - if r.MatchString(data) { - log.Debug("%s: %s, %s", log.Red("Regexp list match"), data, file) + for _, entry := range snapshot.regexEntries { + if entry.re.MatchString(data) { + log.Debug("%s: %s, %s", log.Red("Regexp list match"), data, entry.file) return true } } @@ -389,7 +470,7 @@ func (o *Operator) Match(con *conman.Connection, hasChecksums bool) bool { } else if o.Operand == OpDomainsLists { return o.cb(con.DstHost) } else if o.Operand == OpIPLists { - return o.cb(con.DstIP.String()) + return o.cbGeneric(con.DstIP) } else if o.Operand == OpHashMD5Lists { return o.cb(con.Process.Checksums[procmon.HashMD5]) } else if o.Operand == OpUserID || o.Operand == OpUserName { diff --git a/daemon/rule/operator_lists.go b/daemon/rule/operator_lists.go index f8572a50b9..1f94c89b8b 100644 --- a/daemon/rule/operator_lists.go +++ b/daemon/rule/operator_lists.go @@ -2,8 +2,8 @@ package rule import ( "fmt" - "io/ioutil" "net" + "os" "path/filepath" "regexp" "runtime/debug" @@ -14,6 +14,61 @@ import ( "github.com/evilsocket/opensnitch/daemon/log" ) +type domainWildcardTrieNode struct { + terminal bool + children map[string]*domainWildcardTrieNode +} + +type domainWildcardTrie struct { + root *domainWildcardTrieNode +} + +func newDomainWildcardTrie() domainWildcardTrie { + return domainWildcardTrie{root: &domainWildcardTrieNode{children: make(map[string]*domainWildcardTrieNode)}} +} + +func (t *domainWildcardTrie) insertSuffix(suffix string) { + if t.root == nil { + t.root = &domainWildcardTrieNode{children: make(map[string]*domainWildcardTrieNode)} + } + parts := strings.Split(suffix, ".") + node := t.root + for i := len(parts) - 1; i >= 0; i-- { + label := strings.TrimSpace(parts[i]) + if label == "" { + return + } + next, found := node.children[label] + if !found { + next = &domainWildcardTrieNode{children: make(map[string]*domainWildcardTrieNode)} + node.children[label] = next + } + node = next + } + node.terminal = true +} + +func (t *domainWildcardTrie) matchesHost(host string) bool { + if t.root == nil { + return false + } + parts := strings.Split(host, ".") + node := t.root + for i := len(parts) - 1; i >= 0; i-- { + label := strings.TrimSpace(parts[i]) + next, found := node.children[label] + if !found { + return false + } + node = next + // wildcard suffixes should only match subdomains, not the suffix root itself + if node.terminal && i > 0 { + return true + } + } + return false +} + func (o *Operator) monitorLists() { log.Info("monitor lists started: %s", o.Data) @@ -92,6 +147,11 @@ func (o *Operator) ClearLists() { for k := range o.lists { delete(o.lists, k) } + o.domainWildcards = newDomainWildcardTrie() + o.domainGlobs = nil + o.listExact = nil + o.listNets = nil + o.listSnapshot.Store(nil) debug.FreeOSMemory() } @@ -139,6 +199,18 @@ func (o *Operator) readTupleList(raw, fileName string, filter func(line, defValu continue } key = core.Trim(key) + if suffix := wildcardSuffix(key); suffix != "" { + o.domainWildcards.insertSuffix(suffix) + continue + } + if isDomainGlobPattern(key) { + if err := validateDomainGlobPattern(key); err != nil { + log.Warning("Error validating domain glob from list: %s, (%s)", err, fileName) + continue + } + o.domainGlobs = append(o.domainGlobs, key) + continue + } if _, found := o.lists[key]; found { dups++ continue @@ -163,12 +235,18 @@ func (o *Operator) readNetList(raw, fileName string) (dups uint64) { dups++ continue } + if ip := net.ParseIP(host); ip != nil { + o.lists[host] = fileName + o.listExact[host] = struct{}{} + continue + } _, netMask, err := net.ParseCIDR(host) if err != nil { log.Warning("Error parsing net from list: %s, (%s)", err, fileName) continue } - o.lists[host] = netMask + o.lists[host] = fileName + o.listNets = append(o.listNets, netMask) } lines = nil log.Info("%d nets loaded, %s", len(o.lists), fileName) @@ -217,6 +295,13 @@ func (o *Operator) readSimpleList(raw, fileName string) (dups uint64) { continue } o.lists[what] = fileName + if ip := net.ParseIP(what); ip != nil { + o.listExact[what] = struct{}{} + continue + } + if _, netMask, err := net.ParseCIDR(what); err == nil { + o.listNets = append(o.listNets, netMask) + } } lines = nil log.Info("%d entries loaded, %s", len(o.lists), fileName) @@ -232,6 +317,10 @@ func (o *Operator) readLists() error { o.Lock() defer o.Unlock() o.lists = make(map[string]interface{}) + o.domainWildcards = newDomainWildcardTrie() + o.domainGlobs = make([]string, 0) + o.listExact = make(map[string]struct{}) + o.listNets = make([]*net.IPNet, 0) expr := filepath.Join(o.Data, "*.*") fileList, err := filepath.Glob(expr) @@ -247,7 +336,7 @@ func (o *Operator) readLists() error { continue } - raw, err := ioutil.ReadFile(fileName) + raw, err := os.ReadFile(fileName) if err != nil { log.Warning("Error reading list of IPs (%s): %s", fileName, err) continue @@ -267,10 +356,89 @@ func (o *Operator) readLists() error { log.Warning("Unknown lists operand type: %s", o.Operand) } } + o.listSnapshot.Store(o.buildListSnapshot()) log.Info("%d lists loaded, %d domains, %d duplicated", len(fileList), len(o.lists), dups) return nil } +func (o *Operator) buildListSnapshot() *listCacheSnapshot { + snapshot := &listCacheSnapshot{ + lists: o.lists, + domainWildcards: o.domainWildcards, + domainGlobs: o.domainGlobs, + listExact: o.listExact, + listNets: o.listNets, + } + + if o.Operand == OpDomainsRegexpLists { + snapshot.regexEntries = make([]listRegexEntry, 0, len(o.lists)) + for file, re := range o.lists { + snapshot.regexEntries = append(snapshot.regexEntries, listRegexEntry{ + file: file, + re: re.(*regexp.Regexp), + }) + } + } + + return snapshot +} + +func wildcardSuffix(host string) string { + if strings.HasPrefix(host, "*.") { + return strings.Trim(host[2:], ".") + } + if strings.HasPrefix(host, ".") { + return strings.Trim(host[1:], ".") + } + return "" +} + +// isDomainGlobPattern reports whether host is a glob pattern that requires +// matchDomainGlob evaluation (i.e. it contains *, ?, or [...] but is NOT a +// plain wildcard suffix like *.example.org, which is handled by the trie). +// +// Known limitation: '{www,api}.example.org' alternation syntax is NOT +// supported. path.Match treats '{' as a literal. Such patterns are not +// detected here and fall through to the exact-map lookup where they will +// never match – a silent false negative. Use separate list entries instead. +func isDomainGlobPattern(host string) bool { + if wildcardSuffix(host) != "" { + return false + } + return strings.ContainsAny(host, "*?[]") +} + +// validateDomainGlobPattern checks that every DNS label in pattern is a valid +// filepath.Match expression (i.e. no unclosed '['). Returns non-nil on bad syntax. +func validateDomainGlobPattern(pattern string) error { + for _, label := range strings.Split(pattern, ".") { + if _, err := filepath.Match(label, ""); err != nil { + return err + } + } + return nil +} + +// matchDomainGlob reports whether host matches the DNS-aware glob pattern. +// The pattern is split on '.' and each label is matched independently with +// filepath.Match, so '*' and '?' are confined to a single DNS label and cannot +// cross dot boundaries. This preserves standard blocklist glob semantics. +// The pattern must have been validated by validateDomainGlobPattern at load +// time; invalid patterns silently fail to match. +func matchDomainGlob(pattern, host string) bool { + patLabels := strings.Split(pattern, ".") + hostLabels := strings.Split(host, ".") + if len(patLabels) != len(hostLabels) { + return false + } + for i, p := range patLabels { + if ok, _ := filepath.Match(p, hostLabels[i]); !ok { + return false + } + } + return true +} + func (o *Operator) loadLists() { log.Info("loading domains lists: %s, %s, %s", o.Type, o.Operand, o.Data) diff --git a/daemon/rule/operator_test.go b/daemon/rule/operator_test.go index 7234e4ae9e..62509d4a3f 100644 --- a/daemon/rule/operator_test.go +++ b/daemon/rule/operator_test.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" "net" + "regexp" + "sync" "testing" "time" @@ -52,6 +54,431 @@ func compileListOperators(list *[]Operator, t *testing.T) { } } +func BenchmarkOperatorDomainsSnapshotMatchParallel(b *testing.B) { + op := &Operator{ + Sensitive: false, + lists: make(map[string]interface{}), + domainWildcards: newDomainWildcardTrie(), + } + op.domainWildcards.insertSuffix("example.org") + const globPat = "api-??.example.org" + if err := validateDomainGlobPattern(globPat); err != nil { + b.Fatalf("invalid benchmark glob: %v", err) + } + op.domainGlobs = append(op.domainGlobs, globPat) + op.listSnapshot.Store(&listCacheSnapshot{ + lists: op.lists, + domainWildcards: op.domainWildcards, + domainGlobs: op.domainGlobs, + }) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if !op.domainsListsCmp("svc.example.org") { + b.Fatal("expected wildcard snapshot match") + } + } + }) +} + +func BenchmarkOperatorDomainsSnapshotMixedParallel(b *testing.B) { + op := &Operator{ + Sensitive: false, + lists: make(map[string]interface{}), + domainWildcards: newDomainWildcardTrie(), + } + op.lists["exact.example.org"] = "bench" + op.domainWildcards.insertSuffix("example.org") + const globPat = "api-??.example.org" + if err := validateDomainGlobPattern(globPat); err != nil { + b.Fatalf("invalid benchmark glob: %v", err) + } + op.domainGlobs = append(op.domainGlobs, globPat) + op.listSnapshot.Store(&listCacheSnapshot{ + lists: op.lists, + domainWildcards: op.domainWildcards, + domainGlobs: op.domainGlobs, + }) + + inputs := []string{ + "exact.example.org", // exact hit + "svc.example.org", // wildcard hit + "api-12.example.org", // glob hit + "no-match.invalid.local", // miss + "exact.example.org", + "svc.example.org", + "api-99.example.org", + "nope.nowhere", + "exact.example.org", + "svc.example.org", + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _ = op.domainsListsCmp(inputs[i%len(inputs)]) + i++ + } + }) +} + +type rlockDomainMatcher struct { + sync.RWMutex + lists map[string]interface{} + wildcards domainWildcardTrie + domainGlobs []string +} + +func (m *rlockDomainMatcher) match(host string) bool { + m.RLock() + defer m.RUnlock() + if _, found := m.lists[host]; found { + return true + } + if m.wildcards.matchesHost(host) { + return true + } + for _, g := range m.domainGlobs { + if matchDomainGlob(g, host) { + return true + } + } + return false +} + +func BenchmarkOperatorDomainsRLockMixedParallel(b *testing.B) { + m := &rlockDomainMatcher{ + lists: make(map[string]interface{}), + wildcards: newDomainWildcardTrie(), + domainGlobs: make([]string, 0, 1), + } + m.lists["exact.example.org"] = "bench" + m.wildcards.insertSuffix("example.org") + const globPat = "api-??.example.org" + if err := validateDomainGlobPattern(globPat); err != nil { + b.Fatalf("invalid benchmark glob: %v", err) + } + m.domainGlobs = append(m.domainGlobs, globPat) + + inputs := []string{ + "exact.example.org", + "svc.example.org", + "api-12.example.org", + "no-match.invalid.local", + "exact.example.org", + "svc.example.org", + "api-99.example.org", + "nope.nowhere", + "exact.example.org", + "svc.example.org", + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _ = m.match(inputs[i%len(inputs)]) + i++ + } + }) +} + +type rlockIPNetMatcher struct { + sync.RWMutex + exact map[string]struct{} + nets []*net.IPNet +} + +func (m *rlockIPNetMatcher) match(ip net.IP) bool { + m.RLock() + defer m.RUnlock() + if _, found := m.exact[ip.String()]; found { + return true + } + for _, n := range m.nets { + if n.Contains(ip) { + return true + } + } + return false +} + +func BenchmarkOperatorIPSnapshotMixedParallel(b *testing.B) { + _, cidrA, err := net.ParseCIDR("10.0.0.0/24") + if err != nil { + b.Fatalf("failed to parse benchmark CIDR A: %v", err) + } + _, cidrB, err := net.ParseCIDR("2002:dead:beef::/48") + if err != nil { + b.Fatalf("failed to parse benchmark CIDR B: %v", err) + } + + op := &Operator{} + exact := map[string]struct{}{ + "10.0.0.4": {}, + "2002:dead:beef::": {}, + } + nets := []*net.IPNet{cidrA, cidrB} + op.listSnapshot.Store(&listCacheSnapshot{ + listExact: exact, + listNets: nets, + }) + + inputs := []net.IP{ + net.ParseIP("10.0.0.4"), // exact + net.ParseIP("10.0.0.99"), // cidr + net.ParseIP("2002:dead:beef::"), // exact + net.ParseIP("2002:dead:beef::1234"), // cidr + net.ParseIP("172.16.0.1"), // miss + net.ParseIP("8.8.8.8"), // miss + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _ = op.ipListsCmp(inputs[i%len(inputs)]) + i++ + } + }) +} + +func BenchmarkOperatorIPRLockMixedParallel(b *testing.B) { + _, cidrA, err := net.ParseCIDR("10.0.0.0/24") + if err != nil { + b.Fatalf("failed to parse benchmark CIDR A: %v", err) + } + _, cidrB, err := net.ParseCIDR("2002:dead:beef::/48") + if err != nil { + b.Fatalf("failed to parse benchmark CIDR B: %v", err) + } + + m := &rlockIPNetMatcher{ + exact: map[string]struct{}{ + "10.0.0.4": {}, + "2002:dead:beef::": {}, + }, + nets: []*net.IPNet{cidrA, cidrB}, + } + + inputs := []net.IP{ + net.ParseIP("10.0.0.4"), + net.ParseIP("10.0.0.99"), + net.ParseIP("2002:dead:beef::"), + net.ParseIP("2002:dead:beef::1234"), + net.ParseIP("172.16.0.1"), + net.ParseIP("8.8.8.8"), + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _ = m.match(inputs[i%len(inputs)]) + i++ + } + }) +} + +func BenchmarkOperatorNetSnapshotMixedParallel(b *testing.B) { + _, cidrA, err := net.ParseCIDR("172.16.0.0/16") + if err != nil { + b.Fatalf("failed to parse benchmark CIDR A: %v", err) + } + _, cidrB, err := net.ParseCIDR("10.200.0.0/16") + if err != nil { + b.Fatalf("failed to parse benchmark CIDR B: %v", err) + } + + op := &Operator{} + exact := map[string]struct{}{ + "172.16.1.2": {}, + "10.200.8.9": {}, + } + nets := []*net.IPNet{cidrA, cidrB} + op.listSnapshot.Store(&listCacheSnapshot{ + listExact: exact, + listNets: nets, + }) + + inputs := []net.IP{ + net.ParseIP("172.16.1.2"), // exact + net.ParseIP("172.16.44.10"), // cidr + net.ParseIP("10.200.8.9"), // exact + net.ParseIP("10.200.77.1"), // cidr + net.ParseIP("192.168.1.10"), // miss + net.ParseIP("1.1.1.1"), // miss + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _ = op.netListsCmp(inputs[i%len(inputs)]) + i++ + } + }) +} + +func BenchmarkOperatorNetRLockMixedParallel(b *testing.B) { + _, cidrA, err := net.ParseCIDR("172.16.0.0/16") + if err != nil { + b.Fatalf("failed to parse benchmark CIDR A: %v", err) + } + _, cidrB, err := net.ParseCIDR("10.200.0.0/16") + if err != nil { + b.Fatalf("failed to parse benchmark CIDR B: %v", err) + } + + m := &rlockIPNetMatcher{ + exact: map[string]struct{}{ + "172.16.1.2": {}, + "10.200.8.9": {}, + }, + nets: []*net.IPNet{cidrA, cidrB}, + } + + inputs := []net.IP{ + net.ParseIP("172.16.1.2"), + net.ParseIP("172.16.44.10"), + net.ParseIP("10.200.8.9"), + net.ParseIP("10.200.77.1"), + net.ParseIP("192.168.1.10"), + net.ParseIP("1.1.1.1"), + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _ = m.match(inputs[i%len(inputs)]) + i++ + } + }) +} + +type rlockRegexpMatcher struct { + sync.RWMutex + entries []listRegexEntry +} + +func (m *rlockRegexpMatcher) match(host string) bool { + m.RLock() + defer m.RUnlock() + for _, entry := range m.entries { + if entry.re.MatchString(host) { + return true + } + } + return false +} + +func BenchmarkOperatorDomainsRegexpSnapshotMixedParallel(b *testing.B) { + op := &Operator{} + op.listSnapshot.Store(&listCacheSnapshot{ + regexEntries: []listRegexEntry{ + {file: "bench-a", re: mustCompileRegexpBench(b, `(^|\\.)example\\.org$`)}, + {file: "bench-b", re: mustCompileRegexpBench(b, `^api-[0-9]{2}\\.example\\.org$`)}, + {file: "bench-c", re: mustCompileRegexpBench(b, `^[a-z0-9-]+\\.service\\.internal$`)}, + }, + }) + + inputs := []string{ + "www.example.org", // hit + "api-12.example.org", // hit + "node-1.service.internal", // hit + "no-match.local", // miss + "api-aa.example.org", // miss + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _ = op.reListCmp(inputs[i%len(inputs)]) + i++ + } + }) +} + +func BenchmarkOperatorDomainsRegexpRLockMixedParallel(b *testing.B) { + m := &rlockRegexpMatcher{ + entries: []listRegexEntry{ + {file: "bench-a", re: mustCompileRegexpBench(b, `(^|\\.)example\\.org$`)}, + {file: "bench-b", re: mustCompileRegexpBench(b, `^api-[0-9]{2}\\.example\\.org$`)}, + {file: "bench-c", re: mustCompileRegexpBench(b, `^[a-z0-9-]+\\.service\\.internal$`)}, + }, + } + + inputs := []string{ + "www.example.org", + "api-12.example.org", + "node-1.service.internal", + "no-match.local", + "api-aa.example.org", + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _ = m.match(inputs[i%len(inputs)]) + i++ + } + }) +} + +func mustCompileRegexpBench(b *testing.B, pattern string) *regexp.Regexp { + b.Helper() + re, err := regexp.Compile(pattern) + if err != nil { + b.Fatalf("failed to compile benchmark regexp %q: %v", pattern, err) + } + return re +} + +func BenchmarkLoaderFindFirstMatchSnapshotParallel(b *testing.B) { + loader := &Loader{rules: make(map[string]*Rule)} + + dummyList := make([]Operator, 0) + nonMatchOp, err := NewOperator(Simple, false, OpDstHost, "does-not-match.example", dummyList) + if err != nil { + b.Fatalf("failed creating non-match operator: %v", err) + } + if err := nonMatchOp.Compile(); err != nil { + b.Fatalf("failed compiling non-match operator: %v", err) + } + + matchOp, err := NewOperator(Simple, false, OpDstHost, "opensnitch.io", dummyList) + if err != nil { + b.Fatalf("failed creating match operator: %v", err) + } + if err := matchOp.Compile(); err != nil { + b.Fatalf("failed compiling match operator: %v", err) + } + + for i := 0; i < 63; i++ { + r := Create(fmt.Sprintf("%03d-non-match", i), "", true, false, false, Allow, Always, nonMatchOp) + loader.rules[r.Name] = r + } + matchRule := Create("999-match", "", true, false, false, Allow, Always, matchOp) + loader.rules[matchRule.Name] = matchRule + loader.sortRules() + + conn := &conman.Connection{DstHost: "opensnitch.io"} + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if loader.FindFirstMatch(conn) == nil { + b.Fatal("expected non-nil matching rule") + } + } + }) +} + func unmarshalListData(data string, t *testing.T) (op *[]Operator) { if err := json.Unmarshal([]byte(data), &op); err != nil { t.Error("Error unmarshalling list data:", err, data) @@ -668,6 +1095,114 @@ func TestNewOperatorListsDomainsRegexp(t *testing.T) { restoreConnection() } +func TestDomainsListsWildcardAndGlobFallback(t *testing.T) { + op := &Operator{ + Sensitive: false, + lists: make(map[string]interface{}), + domainWildcards: newDomainWildcardTrie(), + } + op.domainWildcards.insertSuffix("example.org") + const globPat = "api-??.example.org" + if err := validateDomainGlobPattern(globPat); err != nil { + t.Fatalf("invalid test glob: %v", err) + } + op.domainGlobs = append(op.domainGlobs, globPat) + op.listSnapshot.Store(&listCacheSnapshot{ + lists: op.lists, + domainWildcards: op.domainWildcards, + domainGlobs: op.domainGlobs, + }) + + if !op.domainsListsCmp("svc.example.org") { + t.Fatal("expected wildcard trie fallback match") + } + if op.domainsListsCmp("example.org") { + t.Fatal("wildcard fallback must not match suffix root") + } + if !op.domainsListsCmp("api-12.example.org") { + t.Fatal("expected glob fallback match") + } +} + +func TestMatchDomainGlobLabelBoundary(t *testing.T) { + tests := []struct { + pattern string + host string + want bool + desc string + }{ + {"api-??.example.org", "api-12.example.org", true, "? matches single char in label"}, + {"api-??.example.org", "api-123.example.org", false, "? must not match more than one char"}, + {"api*.example.org", "apidev.example.org", true, "* matches within label"}, + {"api*.example.org", "api.v2.example.org", false, "* must not cross label boundary"}, + {"tracker-[0-9].example.org", "tracker-3.example.org", true, "character class in label"}, + {"tracker-[0-9].example.org", "tracker-x.example.org", false, "character class mismatch"}, + {"api-??.example.org", "api-12.sub.example.org", false, "different label count"}, + } + for _, tc := range tests { + got := matchDomainGlob(tc.pattern, tc.host) + if got != tc.want { + t.Errorf("%s: matchDomainGlob(%q, %q) = %v, want %v", tc.desc, tc.pattern, tc.host, got, tc.want) + } + } +} + +func TestIPListsCmpSupportsExactAndCIDRFallback(t *testing.T) { + _, cidr, err := net.ParseCIDR("10.0.0.0/24") + if err != nil { + t.Fatalf("failed to parse cidr: %v", err) + } + + op := &Operator{ + listExact: map[string]struct{}{ + "10.0.0.4": {}, + }, + listNets: []*net.IPNet{cidr}, + } + op.listSnapshot.Store(&listCacheSnapshot{ + listExact: op.listExact, + listNets: op.listNets, + }) + + if !op.ipListsCmp(net.ParseIP("10.0.0.4")) { + t.Fatal("expected exact ip list match") + } + if !op.ipListsCmp(net.ParseIP("10.0.0.99")) { + t.Fatal("expected cidr fallback match for ip list") + } + if op.ipListsCmp(net.ParseIP("192.168.1.10")) { + t.Fatal("unexpected ip list match") + } +} + +func TestNetListsCmpSupportsExactAndCIDRFallback(t *testing.T) { + _, cidr, err := net.ParseCIDR("10.1.0.0/16") + if err != nil { + t.Fatalf("failed to parse cidr: %v", err) + } + + op := &Operator{ + listExact: map[string]struct{}{ + "10.1.2.3": {}, + }, + listNets: []*net.IPNet{cidr}, + } + op.listSnapshot.Store(&listCacheSnapshot{ + listExact: op.listExact, + listNets: op.listNets, + }) + + if !op.netListsCmp(net.ParseIP("10.1.2.3")) { + t.Fatal("expected exact net list match") + } + if !op.netListsCmp(net.ParseIP("10.1.44.5")) { + t.Fatal("expected cidr fallback match for net list") + } + if op.netListsCmp(net.ParseIP("172.16.0.1")) { + t.Fatal("unexpected net list match") + } +} + // Must be launched with -race to test that we don't cause leaks // Race occured on operator.go:241 reListCmp().MathString() // fixed here: 53419fe