diff --git a/pkg/action/path.go b/pkg/action/path.go index ec1dcafbd..b5a210441 100644 --- a/pkg/action/path.go +++ b/pkg/action/path.go @@ -49,22 +49,7 @@ func findFilesRecursively(ctx context.Context, rootPath string) ([]string, error // Ignore symlinked directories like regular directories if info.Type()&fs.ModeSymlink == fs.ModeSymlink { - logger.Debugf("attempting to resolve symlink: %s", path) - eval, err := filepath.EvalSymlinks(path) - if err != nil { - logger.Debugf("eval: %s: %s", path, err) - return nil - } - fi, err := os.Stat(eval) - if err != nil { - logger.Debugf("stat: %s: %s", path, err) - return nil - } - if fi.IsDir() { - logger.Debugf("ignoring symlinked directory: %s", path) - return nil - } - path = eval + return nil } files = append(files, path) diff --git a/pkg/action/scan.go b/pkg/action/scan.go index b605aa813..e94eac9ca 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -7,7 +7,6 @@ import ( "context" "errors" "fmt" - "io" "io/fs" "log/slog" "os" @@ -17,6 +16,7 @@ import ( "strings" "sync" "sync/atomic" + "syscall" "github.com/chainguard-dev/clog" "github.com/chainguard-dev/malcontent/pkg/archive" @@ -43,13 +43,59 @@ var ( ErrMatchedCondition = errors.New("matched exit criteria") // initializeOnce ensures that the file and scanner pools are only initialized once. initializeOnce sync.Once - filePool *pool.BufferPool scannerPool *pool.ScannerPool + maxMmapSize int64 = 1 << 31 ) +// scanFD scans a file descriptor using memory mapping for efficient large file handling. +// This avoids loading the entire file into memory while still using yara-x's byte slice scanning. +func scanFD(scanner *yarax.Scanner, fd uintptr, logger *clog.Logger) ([]byte, *yarax.ScanResults, error) { + var stat syscall.Stat_t + if err := syscall.Fstat(int(fd), &stat); err != nil { + return nil, nil, fmt.Errorf("fstat failed: %w", err) + } + + size := stat.Size + if size == 0 { + mrs, err := scanner.Scan([]byte{}) + return nil, mrs, err + } + + if size < 0 { + return nil, nil, fmt.Errorf("invalid file size: %d", size) + } + + if size > maxMmapSize { + logger.Warn("file exceeds mmap limit, scanning first portion only", + "size", size, "limit", maxMmapSize) + size = maxMmapSize + } + + data, err := syscall.Mmap(int(fd), 0, int(size), syscall.PROT_READ, syscall.MAP_PRIVATE) + if err != nil { + return nil, nil, fmt.Errorf("mmap failed: %w", err) + } + defer func() { + if unmapErr := syscall.Munmap(data); unmapErr != nil { + logger.Error("failed to unmap memory", "error", unmapErr) + } + }() + + mrs, err := scanner.Scan(data) + if err != nil { + return nil, nil, err + } + + // Create a copy of the data to return since the mmap will be unmapped + // This is necessary because report generation needs access to file content + // for checksum calculation and match string extraction + fc := make([]byte, len(data)) + copy(fc, data) + + return fc, mrs, err +} + // scanSinglePath YARA scans a single path and converts it to a fileReport. -// -//nolint:cyclop // ignore complexity of 38 func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleFS []fs.FS, absPath string, archiveRoot string) (*malcontent.FileReport, error) { if ctx.Err() != nil { return &malcontent.FileReport{}, ctx.Err() @@ -60,7 +106,14 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF isArchive := archiveRoot != "" - fi, err := os.Stat(path) + f, err := os.Open(path) + if err != nil { + return nil, err + } + fd := f.Fd() + defer f.Close() + + fi, err := f.Stat() if err != nil { return nil, err } @@ -105,43 +158,13 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF } initializeOnce.Do(func() { - filePool = pool.NewBufferPool(c.Concurrency + 1) - scannerPool = pool.NewScannerPool(yrs, c.Concurrency+1) + scannerPool = pool.NewScannerPool(yrs, c.Concurrency) }) scanner := scannerPool.Get() - if scanner == nil { - scanner = yarax.NewScanner(yrs) - } defer scannerPool.Put(scanner) - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - - fc := filePool.Get(size) - defer filePool.Put(fc) - - var bytesRead int - var totalRead int64 - for totalRead < size { - bytesRead, err = f.Read(fc[totalRead:]) - if errors.Is(err, io.EOF) { - break - } - if err != nil { - return nil, err - } - totalRead += int64(bytesRead) - } - - if totalRead < size && err != nil { - return nil, fmt.Errorf("incomplete read: got %d bytes, expected %d: %w", totalRead, size, err) - } - - mrs, err := scanner.Scan(fc) + fc, mrs, err := scanFD(scanner, fd, logger) if err != nil { logger.Debug("skipping", slog.Any("error", err)) return nil, err @@ -164,6 +187,11 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF return nil, NewFileReportError(err, path, TypeGenerateError) } + defer func() { + fc = nil + mrs = nil + }() + // Clean up the path if scanning an archive var clean string if isArchive || c.OCI { @@ -427,6 +455,12 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c } }() + // Zero-out the path strings and empty the slice once read into the path channel + defer func() { + clear(paths) + paths = paths[:0] + }() + for path := range pc { g.Go(func() error { if gCtx.Err() != nil { diff --git a/pkg/archive/archive.go b/pkg/archive/archive.go index 683c7fe3f..ac086a470 100644 --- a/pkg/archive/archive.go +++ b/pkg/archive/archive.go @@ -146,12 +146,6 @@ func ExtractArchiveToTempDir(ctx context.Context, path string) (string, error) { return "", fmt.Errorf("failed to create temp dir: %w", err) } - go func() { - <-ctx.Done() - logger.Debug("context cancelled, cleaning up temp dir") - os.RemoveAll(tmpDir) - }() - initializeOnce.Do(func() { archivePool = pool.NewBufferPool(runtime.GOMAXPROCS(0)) }) diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index e1caa458a..0f98e36e9 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -2,6 +2,7 @@ package pool import ( "math" + "runtime" "sync" yarax "github.com/VirusTotal/yara-x/go" @@ -23,7 +24,8 @@ func NewBufferPool(count int) *BufferPool { bp.pool = sync.Pool{ New: func() any { - return make([]byte, defaultBuffer) + buffer := make([]byte, defaultBuffer) + return &buffer }, } @@ -43,18 +45,17 @@ func (bp *BufferPool) Get(size int64) []byte { bufInterface := bp.pool.Get() - buf, ok := bufInterface.([]byte) - if !ok || buf == nil { + bufPtr, ok := bufInterface.(*[]byte) + if !ok || bufPtr == nil { return make([]byte, size) } - bufPtr := &buf if cap(*bufPtr) < int(size) { bp.pool.Put(bufPtr) return make([]byte, size) } - return buf[:size] + return (*bufPtr)[:size] } // Put returns a byte buffer to the pool for future reuse. @@ -72,34 +73,50 @@ func (bp *BufferPool) Put(buf []byte) { // ScannerPool provides a pool of yara-x scanners. type ScannerPool struct { - pool sync.Pool + scanners chan *yarax.Scanner + pinner *runtime.Pinner + closeOnce sync.Once } // NewScannerPool creates a pool containing the specified number of yara-x scanners. func NewScannerPool(yrs *yarax.Rules, count int) *ScannerPool { - sp := &ScannerPool{} - - sp.pool = sync.Pool{ - New: func() any { - return yarax.NewScanner(yrs) - }, + sp := &ScannerPool{ + scanners: make(chan *yarax.Scanner, count), + pinner: &runtime.Pinner{}, } for range count { - sp.pool.Put(yarax.NewScanner(yrs)) + scanner := yarax.NewScanner(yrs) + // Pin the scanner in memory to prevent GC movement + sp.pinner.Pin(scanner) + sp.scanners <- scanner } return sp } -// Get retrieves a scanner from the scanner pool. +// Get retrieves a scanner from the scanner pool, blocking if none are available. func (sp *ScannerPool) Get() *yarax.Scanner { - if scanner, ok := sp.pool.Get().(*yarax.Scanner); ok { - return scanner - } - return nil + return <-sp.scanners } // Put returns a scanner to the scanner pool. func (sp *ScannerPool) Put(scanner *yarax.Scanner) { - sp.pool.Put(scanner) + if scanner != nil { + select { + case sp.scanners <- scanner: + default: + } + } +} + +// Close destroys all active scanners. +// Currently unused. +func (sp *ScannerPool) Close() { + sp.closeOnce.Do(func() { + sp.pinner.Unpin() + close(sp.scanners) + for scanner := range sp.scanners { + scanner.Destroy() + } + }) } diff --git a/pkg/report/report.go b/pkg/report/report.go index 3290df6eb..d933edc7f 100644 --- a/pkg/report/report.go +++ b/pkg/report/report.go @@ -479,6 +479,7 @@ func Generate(ctx context.Context, path string, mrs *yarax.ScanResults, c malcon processor := newMatchProcessor(fc, matches, m.Patterns()) matchedStrings = processor.process() + processor.clearFileContent() } b := &malcontent.Behavior{ diff --git a/pkg/report/strings.go b/pkg/report/strings.go index a214f9e1e..6f6bd559e 100644 --- a/pkg/report/strings.go +++ b/pkg/report/strings.go @@ -19,6 +19,13 @@ type StringPool struct { strings map[string]string } +// clear removes all strings from the pool to free memory. +func (sp *StringPool) clear() { + sp.Lock() + defer sp.Unlock() + clear(sp.strings) +} + // NewStringPool creates a new string pool. func NewStringPool(length int) *StringPool { return &StringPool{ @@ -70,6 +77,12 @@ var matchResultPool = sync.Pool{ }, } +// clearFileContent releases the file content to free memory after processing. +func (mp *matchProcessor) clearFileContent() { + mp.fc = nil + mp.pool.clear() +} + // process performantly handles the conversion of matched data to strings. // yara-x does not expose the rendered string via the API due to performance overhead. func (mp *matchProcessor) process() []string { @@ -115,20 +128,24 @@ func (mp *matchProcessor) process() []string { if l <= cap(buffer) { buffer = buffer[:l] copy(buffer, matchBytes) - *result = append(*result, mp.pool.Intern(string(buffer))) + matchStr := string(buffer) + *result = append(*result, mp.pool.Intern(string([]byte(matchStr)))) } else { - *result = append(*result, mp.pool.Intern(string(matchBytes))) + matchStr := string(matchBytes) + *result = append(*result, mp.pool.Intern(string([]byte(matchStr)))) } } else { if patterns == nil || cap(patterns) < patternsCap { patterns = make([]string, 0, patternsCap) } else { + clear(patterns) patterns = patterns[:0] } for _, p := range mp.patterns { patterns = append(patterns, p.Identifier()) } - *result = append(*result, slices.Compact(patterns)...) + compacted := slices.Compact(patterns) + *result = append(*result, compacted...) } }