Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 39 additions & 42 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,17 @@ var (
// This avoids loading the entire file into memory while still using yara-x's byte slice scanning.
// scanFD also returns the file's contents for match string extraction,
// as well as the file's size and its checksum which were originally calculated separately as part of report generation.
func scanFD(scanner *yarax.Scanner, fd uintptr, logger *clog.Logger) ([]byte, *yarax.ScanResults, int64, string, error) {
var stat syscall.Stat_t
if err := syscall.Fstat(int(fd), &stat); err != nil {
return nil, nil, 0, "", fmt.Errorf("fstat failed: %w", err)
}

size := stat.Size
if size == 0 {
mrs, err := scanner.Scan([]byte{})
return nil, mrs, 0, "", err
}

if size < 0 {
return nil, nil, 0, "", fmt.Errorf("invalid file size: %d", size)
}

if size > maxMmapSize {
logger.Warn("file exceeds mmap limit, scanning first portion only",
"size", size, "limit", maxMmapSize)
size = maxMmapSize
func scanFD(scanner *yarax.Scanner, fd uintptr, size int64, logger *clog.Logger) ([]byte, *yarax.ScanResults, string, error) {
stat := &syscall.Stat_t{}
if err := syscall.Fstat(int(fd), stat); err != nil {
return nil, nil, "", fmt.Errorf("fstat failed: %w", err)
}

data, err := syscall.Mmap(int(fd), 0, int(size), syscall.PROT_READ, syscall.MAP_PRIVATE)
if err != nil {
return nil, nil, 0, "", fmt.Errorf("mmap failed: %w", err)
return nil, nil, "", fmt.Errorf("mmap failed: %w", err)
}

defer func() {
if unmapErr := syscall.Munmap(data); unmapErr != nil {
logger.Error("failed to unmap memory", "error", unmapErr)
Expand All @@ -89,17 +74,17 @@ func scanFD(scanner *yarax.Scanner, fd uintptr, logger *clog.Logger) ([]byte, *y
h.Write(data)
checksum := fmt.Sprintf("%x", h.Sum(nil))

mrs, err := scanner.Scan(data)
if err != nil {
return nil, nil, 0, "", err
}

// Create a copy of the data to return since the mmap will be unmapped
// This is necessary because report generation needs access to file content
// for match string extraction
fc := append([]byte(nil), data...)

return fc, mrs, size, checksum, err
mrs, err := scanner.Scan(data)
if err != nil {
return nil, nil, "", err
}

return fc, mrs, checksum, err
}

// scanSinglePath YARA scans a single path and converts it to a fileReport.
Expand All @@ -118,22 +103,27 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
return nil, err
}
fd := f.Fd()
defer f.Close()

fi, err := f.Stat()
if err != nil {
return nil, err
}

size := fi.Size()
if size == 0 {
if size <= 0 {
Comment thread
egibs marked this conversation as resolved.
Outdated
fr := &malcontent.FileReport{Skipped: "zero-sized file", Path: path}
if isArchive {
defer os.RemoveAll(path)
}
return fr, nil
}

if size > maxMmapSize {
logger.Warn("file exceeds mmap limit, scanning first portion only",
"size", size, "limit", maxMmapSize)
size = maxMmapSize
}

mime := "<unknown>"
kind, err := programkind.File(path)
if err != nil && !interactive(c) {
Expand All @@ -155,21 +145,23 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
logger = logger.With("mime", mime)

var yrs *yarax.Rules
if c.Rules == nil {
if c.Rules != nil {
yrs = c.Rules
} else {
yrs, err = CachedRules(ctx, ruleFS)
if err != nil {
return nil, fmt.Errorf("rules: %w", err)
}
} else {
yrs = c.Rules
}

initializeOnce.Do(func() {
scannerPool = pool.NewScannerPool(yrs, getMaxConcurrency(c.Concurrency))
// always create one scanner per available CPU core since the pool is used for the duration of
// a scan which may involve concurrent scans of individual files
scannerPool = pool.NewScannerPool(yrs, getMaxConcurrency(runtime.GOMAXPROCS(0)))
})
scanner := scannerPool.Get(yrs)

fc, mrs, size, checksum, err := scanFD(scanner, fd, logger)
fc, mrs, checksum, err := scanFD(scanner, fd, size, logger)
if err != nil {
logger.Debug("skipping", slog.Any("error", err))
return nil, err
Expand All @@ -193,6 +185,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
}

defer func() {
f.Close()
scannerPool.Put(scanner)
fc = nil
mrs = nil
Expand Down Expand Up @@ -433,7 +426,10 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c
return ctx.Err()
}

maxConcurrency := getMaxConcurrency(c.Concurrency)
// adjust concurrency if the number of paths to scan
// is lower than the configured value
numPaths := len(paths)
maxConcurrency := getMaxConcurrency(min(c.Concurrency, numPaths))

scanCtx, cancel := context.WithCancel(ctx)
defer cancel()
Expand All @@ -449,7 +445,7 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c

setupMatchHandler(gCtx, matchChan, c, cancel, logger)

pc := make(chan string, len(paths))
pc := make(chan string, numPaths)
go func() {
defer close(pc)
for _, path := range paths {
Expand Down Expand Up @@ -495,10 +491,7 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c
}

func getMaxConcurrency(configured int) int {
if configured < 1 {
return 1
}
return configured
return max(1, configured)
}

func setupMatchHandler(ctx context.Context, matchChan chan matchResult, c malcontent.Config, cancel context.CancelFunc, logger *clog.Logger) {
Expand Down Expand Up @@ -695,7 +688,9 @@ func processArchive(ctx context.Context, c malcontent.Config, rfs []fs.FS, archi
return nil, fmt.Errorf("find: %w", err)
}

ep := make(chan string, len(extractedPaths))
numPaths := len(extractedPaths)

ep := make(chan string, numPaths)
go func() {
defer close(ep)
for _, path := range extractedPaths {
Expand All @@ -707,7 +702,9 @@ func processArchive(ctx context.Context, c malcontent.Config, rfs []fs.FS, archi
}
}()

maxConcurrency := getMaxConcurrency(c.Concurrency)
// adjust concurrency if the number of paths to scan
// is lower than the configured value
maxConcurrency := getMaxConcurrency(min(c.Concurrency, numPaths))
scanCtx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down
Loading