Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func CachedRules(ctx context.Context, fss []fs.FS) (*yarax.Rules, error) {
var err error
compileOnce.Do(func() {
var yrs *yarax.Rules
yrs, err = compile.Recursive(ctx, fss)
yrs, err = compile.RecursiveCached(ctx, fss)
if err != nil {
err = fmt.Errorf("compile: %w", err)
return
Expand Down
130 changes: 130 additions & 0 deletions pkg/compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ import (
"context"
"fmt"
"io/fs"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"

"github.com/minio/sha256-simd"

"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/malcontent/rules"

Expand Down Expand Up @@ -221,3 +225,129 @@ func Recursive(ctx context.Context, fss []fs.FS) (*yarax.Rules, error) {

return yxc.Build(), nil
}

// getCacheDir returns the directory for storing compiled rules.
func getCacheDir() (string, error) {
var cacheDir string

if userCacheDir, err := os.UserCacheDir(); err == nil {
cacheDir = filepath.Join(userCacheDir, "malcontent")
} else {
cacheDir = filepath.Join(os.TempDir(), "malcontent-cache")
}

if err := os.MkdirAll(cacheDir, 0o755); err != nil {
return "", fmt.Errorf("create cache dir: %w", err)
}

return cacheDir, nil
}

// loadCachedRules attempts to load rules from the local, compiled rules.
func loadCachedRules(cacheFile string) (*yarax.Rules, error) {
file, err := os.Open(cacheFile)
if err != nil {
return nil, err
}
defer file.Close()

compiledRules, err := yarax.ReadFrom(file)
if err != nil {
return nil, fmt.Errorf("read cached rules: %w", err)
}

return compiledRules, nil
}

// saveCachedRules saves rules to a local file.
func saveCachedRules(compiledRules *yarax.Rules, cacheFile string) error {
tmpFile := cacheFile + ".tmp"
file, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644)
if err != nil {
return fmt.Errorf("create cache file: %w", err)
}
defer file.Close()

if _, err := compiledRules.WriteTo(file); err != nil {
os.Remove(tmpFile)
return fmt.Errorf("write rules to cache: %w", err)
}

if err := os.Rename(tmpFile, cacheFile); err != nil {
os.Remove(tmpFile)
return fmt.Errorf("rename cache file: %w", err)
}

return nil
}

// getRulesHash computes a hash of the rule sources for cache validation.
func getRulesHash(ctx context.Context, fss []fs.FS) (string, error) {
if ctx.Err() != nil {
return "", ctx.Err()
}

hasher := sha256.New()

for _, fsys := range fss {
err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if filepath.Ext(path) == ".yara" || filepath.Ext(path) == ".yar" {
hasher.Write([]byte(path))
content, err := fs.ReadFile(fsys, path)
if err != nil {
return err
}
hasher.Write(content)
}
return nil
})
if err != nil {
return "", err
}
}

return fmt.Sprintf("%x", hasher.Sum(nil)), nil
}

// RecursiveCached compiles rules with persistent disk caching to avoid penalizing successive executions with repeated rule compilations.
func RecursiveCached(ctx context.Context, fss []fs.FS) (*yarax.Rules, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}

cacheDir, cacheErr := getCacheDir()
if cacheErr != nil {
return Recursive(ctx, fss)
}

hash, hashErr := getRulesHash(ctx, fss)
if hashErr != nil {
return Recursive(ctx, fss)
}

cacheFile := filepath.Join(cacheDir, fmt.Sprintf("rules-%s.cache", hash))
if cachedRules, loadErr := loadCachedRules(cacheFile); loadErr == nil {
slog.Debug("Loaded rules from cache", "file", cacheFile)
return cachedRules, nil
}

slog.Debug("Cache miss, compiling rules", "file", cacheFile)
compiledRules, err := Recursive(ctx, fss)
if err != nil {
return nil, fmt.Errorf("compile: %w", err)
}

if saveErr := saveCachedRules(compiledRules, cacheFile); saveErr != nil {
slog.Warn("Failed to save rules to cache", "error", saveErr)
} else {
slog.Debug("Saved rules to cache", "file", cacheFile)
}

return compiledRules, nil
}
Loading