diff --git a/pkg/action/scan.go b/pkg/action/scan.go index 7d634dcb3..eb0a4d11b 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -4,6 +4,7 @@ package action import ( + "bytes" "context" "errors" "fmt" @@ -53,32 +54,17 @@ var ( // This avoids loading the entire file into memory while still using yara-x's byte slice scanning. // scanFD also returns the file's contents for match string extraction, // as well as the file's size and its checksum which were originally calculated separately as part of report generation. -func scanFD(scanner *yarax.Scanner, fd uintptr, logger *clog.Logger) ([]byte, *yarax.ScanResults, int64, string, error) { - var stat syscall.Stat_t - if err := syscall.Fstat(int(fd), &stat); err != nil { - return nil, nil, 0, "", fmt.Errorf("fstat failed: %w", err) - } - - size := stat.Size - if size == 0 { - mrs, err := scanner.Scan([]byte{}) - return nil, mrs, 0, "", err - } - - if size < 0 { - return nil, nil, 0, "", fmt.Errorf("invalid file size: %d", size) - } - - if size > maxMmapSize { - logger.Warn("file exceeds mmap limit, scanning first portion only", - "size", size, "limit", maxMmapSize) - size = maxMmapSize +func scanFD(scanner *yarax.Scanner, fd uintptr, size int64, logger *clog.Logger) ([]byte, *yarax.ScanResults, string, error) { + stat := &syscall.Stat_t{} + if err := syscall.Fstat(int(fd), stat); err != nil { + return nil, nil, "", fmt.Errorf("fstat failed: %w", err) } data, err := syscall.Mmap(int(fd), 0, int(size), syscall.PROT_READ, syscall.MAP_PRIVATE) if err != nil { - return nil, nil, 0, "", fmt.Errorf("mmap failed: %w", err) + return nil, nil, "", fmt.Errorf("mmap failed: %w", err) } + defer func() { if unmapErr := syscall.Munmap(data); unmapErr != nil { logger.Error("failed to unmap memory", "error", unmapErr) @@ -89,17 +75,17 @@ func scanFD(scanner *yarax.Scanner, fd uintptr, logger *clog.Logger) ([]byte, *y h.Write(data) checksum := fmt.Sprintf("%x", h.Sum(nil)) - mrs, err := scanner.Scan(data) - if err != nil { - return nil, nil, 0, "", err - } - // Create a copy of the data to return since the mmap will be unmapped // This is necessary because report generation needs access to file content // for match string extraction - fc := append([]byte(nil), data...) + fc := bytes.Clone(data) - return fc, mrs, size, checksum, err + mrs, err := scanner.Scan(data) + if err != nil { + return nil, nil, "", err + } + + return fc, mrs, checksum, err } // scanSinglePath YARA scans a single path and converts it to a fileReport. @@ -118,7 +104,6 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF return nil, err } fd := f.Fd() - defer f.Close() fi, err := f.Stat() if err != nil { @@ -134,6 +119,12 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF return fr, nil } + if size > maxMmapSize { + logger.Warn("file exceeds mmap limit, scanning first portion only", + "size", size, "limit", maxMmapSize) + size = maxMmapSize + } + mime := "" kind, err := programkind.File(path) if err != nil && !interactive(c) { @@ -155,21 +146,23 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF logger = logger.With("mime", mime) var yrs *yarax.Rules - if c.Rules == nil { + if c.Rules != nil { + yrs = c.Rules + } else { yrs, err = CachedRules(ctx, ruleFS) if err != nil { return nil, fmt.Errorf("rules: %w", err) } - } else { - yrs = c.Rules } initializeOnce.Do(func() { - scannerPool = pool.NewScannerPool(yrs, getMaxConcurrency(c.Concurrency)) + // always create one scanner per available CPU core since the pool is used for the duration of + // a scan which may involve concurrent scans of individual files + scannerPool = pool.NewScannerPool(yrs, getMaxConcurrency(runtime.GOMAXPROCS(0))) }) scanner := scannerPool.Get(yrs) - fc, mrs, size, checksum, err := scanFD(scanner, fd, logger) + fc, mrs, checksum, err := scanFD(scanner, fd, size, logger) if err != nil { logger.Debug("skipping", slog.Any("error", err)) return nil, err @@ -193,6 +186,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF } defer func() { + f.Close() scannerPool.Put(scanner) fc = nil mrs = nil @@ -433,7 +427,10 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c return ctx.Err() } - maxConcurrency := getMaxConcurrency(c.Concurrency) + // adjust concurrency if the number of paths to scan + // is lower than the configured value + numPaths := len(paths) + maxConcurrency := getMaxConcurrency(min(c.Concurrency, numPaths)) scanCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -449,7 +446,7 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c setupMatchHandler(gCtx, matchChan, c, cancel, logger) - pc := make(chan string, len(paths)) + pc := make(chan string, numPaths) go func() { defer close(pc) for _, path := range paths { @@ -495,10 +492,7 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c } func getMaxConcurrency(configured int) int { - if configured < 1 { - return 1 - } - return configured + return max(1, configured) } func setupMatchHandler(ctx context.Context, matchChan chan matchResult, c malcontent.Config, cancel context.CancelFunc, logger *clog.Logger) { @@ -695,7 +689,9 @@ func processArchive(ctx context.Context, c malcontent.Config, rfs []fs.FS, archi return nil, fmt.Errorf("find: %w", err) } - ep := make(chan string, len(extractedPaths)) + numPaths := len(extractedPaths) + + ep := make(chan string, numPaths) go func() { defer close(ep) for _, path := range extractedPaths { @@ -707,7 +703,9 @@ func processArchive(ctx context.Context, c malcontent.Config, rfs []fs.FS, archi } }() - maxConcurrency := getMaxConcurrency(c.Concurrency) + // adjust concurrency if the number of paths to scan + // is lower than the configured value + maxConcurrency := getMaxConcurrency(min(c.Concurrency, numPaths)) scanCtx, cancel := context.WithCancel(ctx) defer cancel()