Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions pkg/action/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,7 @@ func findFilesRecursively(ctx context.Context, rootPath string) ([]string, error

// Ignore symlinked directories like regular directories
if info.Type()&fs.ModeSymlink == fs.ModeSymlink {
logger.Debugf("attempting to resolve symlink: %s", path)
eval, err := filepath.EvalSymlinks(path)
if err != nil {
logger.Debugf("eval: %s: %s", path, err)
return nil
}
fi, err := os.Stat(eval)
if err != nil {
logger.Debugf("stat: %s: %s", path, err)
return nil
}
if fi.IsDir() {
logger.Debugf("ignoring symlinked directory: %s", path)
return nil
}
path = eval
return nil
}

files = append(files, path)
Expand Down
108 changes: 71 additions & 37 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
Expand All @@ -17,6 +16,7 @@ import (
"strings"
"sync"
"sync/atomic"
"syscall"

"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/malcontent/pkg/archive"
Expand All @@ -43,13 +43,59 @@ var (
ErrMatchedCondition = errors.New("matched exit criteria")
// initializeOnce ensures that the file and scanner pools are only initialized once.
initializeOnce sync.Once
filePool *pool.BufferPool
scannerPool *pool.ScannerPool
maxMmapSize int64 = 1 << 31
)

// scanFD scans a file descriptor using memory mapping for efficient large file handling.
// This avoids loading the entire file into memory while still using yara-x's byte slice scanning.
func scanFD(scanner *yarax.Scanner, fd uintptr, logger *clog.Logger) ([]byte, *yarax.ScanResults, error) {
var stat syscall.Stat_t
if err := syscall.Fstat(int(fd), &stat); err != nil {
return nil, nil, fmt.Errorf("fstat failed: %w", err)
}

size := stat.Size
if size == 0 {
mrs, err := scanner.Scan([]byte{})
return nil, mrs, err
}

if size < 0 {
return nil, nil, fmt.Errorf("invalid file size: %d", size)
}

if size > maxMmapSize {
logger.Warn("file exceeds mmap limit, scanning first portion only",
"size", size, "limit", maxMmapSize)
size = maxMmapSize
}

data, err := syscall.Mmap(int(fd), 0, int(size), syscall.PROT_READ, syscall.MAP_PRIVATE)
if err != nil {
return nil, nil, fmt.Errorf("mmap failed: %w", err)
}
defer func() {
if unmapErr := syscall.Munmap(data); unmapErr != nil {
logger.Error("failed to unmap memory", "error", unmapErr)
}
}()

mrs, err := scanner.Scan(data)
if err != nil {
return nil, nil, err
}

// Create a copy of the data to return since the mmap will be unmapped
// This is necessary because report generation needs access to file content
// for checksum calculation and match string extraction
fc := make([]byte, len(data))
copy(fc, data)

return fc, mrs, err
}

// scanSinglePath YARA scans a single path and converts it to a fileReport.
//
//nolint:cyclop // ignore complexity of 38
func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleFS []fs.FS, absPath string, archiveRoot string) (*malcontent.FileReport, error) {
if ctx.Err() != nil {
return &malcontent.FileReport{}, ctx.Err()
Expand All @@ -60,7 +106,14 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF

isArchive := archiveRoot != ""

fi, err := os.Stat(path)
f, err := os.Open(path)
if err != nil {
return nil, err
}
fd := f.Fd()
defer f.Close()

fi, err := f.Stat()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -105,43 +158,13 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
}

initializeOnce.Do(func() {
filePool = pool.NewBufferPool(c.Concurrency + 1)
scannerPool = pool.NewScannerPool(yrs, c.Concurrency+1)
scannerPool = pool.NewScannerPool(yrs, c.Concurrency)
})

scanner := scannerPool.Get()
if scanner == nil {
scanner = yarax.NewScanner(yrs)
}
defer scannerPool.Put(scanner)

f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()

fc := filePool.Get(size)
defer filePool.Put(fc)

var bytesRead int
var totalRead int64
for totalRead < size {
bytesRead, err = f.Read(fc[totalRead:])
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
totalRead += int64(bytesRead)
}

if totalRead < size && err != nil {
return nil, fmt.Errorf("incomplete read: got %d bytes, expected %d: %w", totalRead, size, err)
}

mrs, err := scanner.Scan(fc)
fc, mrs, err := scanFD(scanner, fd, logger)
if err != nil {
logger.Debug("skipping", slog.Any("error", err))
return nil, err
Expand All @@ -164,6 +187,11 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
return nil, NewFileReportError(err, path, TypeGenerateError)
}

defer func() {
fc = nil
mrs = nil
}()

// Clean up the path if scanning an archive
var clean string
if isArchive || c.OCI {
Expand Down Expand Up @@ -427,6 +455,12 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c
}
}()

// Zero-out the path strings and empty the slice once read into the path channel
defer func() {
clear(paths)
paths = paths[:0]
}()

for path := range pc {
g.Go(func() error {
if gCtx.Err() != nil {
Expand Down
6 changes: 0 additions & 6 deletions pkg/archive/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,6 @@ func ExtractArchiveToTempDir(ctx context.Context, path string) (string, error) {
return "", fmt.Errorf("failed to create temp dir: %w", err)
}

go func() {
<-ctx.Done()
logger.Debug("context cancelled, cleaning up temp dir")
os.RemoveAll(tmpDir)
}()

initializeOnce.Do(func() {
archivePool = pool.NewBufferPool(runtime.GOMAXPROCS(0))
})
Expand Down
55 changes: 36 additions & 19 deletions pkg/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pool

import (
"math"
"runtime"
"sync"

yarax "github.com/VirusTotal/yara-x/go"
Expand All @@ -23,7 +24,8 @@ func NewBufferPool(count int) *BufferPool {

bp.pool = sync.Pool{
New: func() any {
return make([]byte, defaultBuffer)
buffer := make([]byte, defaultBuffer)
return &buffer
},
}

Expand All @@ -43,18 +45,17 @@ func (bp *BufferPool) Get(size int64) []byte {

bufInterface := bp.pool.Get()

buf, ok := bufInterface.([]byte)
if !ok || buf == nil {
bufPtr, ok := bufInterface.(*[]byte)
if !ok || bufPtr == nil {
return make([]byte, size)
}

bufPtr := &buf
if cap(*bufPtr) < int(size) {
bp.pool.Put(bufPtr)
return make([]byte, size)
}

return buf[:size]
return (*bufPtr)[:size]
}

// Put returns a byte buffer to the pool for future reuse.
Expand All @@ -72,34 +73,50 @@ func (bp *BufferPool) Put(buf []byte) {

// ScannerPool provides a pool of yara-x scanners.
type ScannerPool struct {
pool sync.Pool
scanners chan *yarax.Scanner
pinner *runtime.Pinner
closeOnce sync.Once
}

// NewScannerPool creates a pool containing the specified number of yara-x scanners.
func NewScannerPool(yrs *yarax.Rules, count int) *ScannerPool {
sp := &ScannerPool{}

sp.pool = sync.Pool{
New: func() any {
return yarax.NewScanner(yrs)
},
sp := &ScannerPool{
scanners: make(chan *yarax.Scanner, count),
pinner: &runtime.Pinner{},
}

for range count {
sp.pool.Put(yarax.NewScanner(yrs))
scanner := yarax.NewScanner(yrs)
// Pin the scanner in memory to prevent GC movement
sp.pinner.Pin(scanner)
sp.scanners <- scanner
}
return sp
}

// Get retrieves a scanner from the scanner pool.
// Get retrieves a scanner from the scanner pool, blocking if none are available.
func (sp *ScannerPool) Get() *yarax.Scanner {
if scanner, ok := sp.pool.Get().(*yarax.Scanner); ok {
return scanner
}
return nil
return <-sp.scanners
}

// Put returns a scanner to the scanner pool.
func (sp *ScannerPool) Put(scanner *yarax.Scanner) {
sp.pool.Put(scanner)
if scanner != nil {
select {
case sp.scanners <- scanner:
default:
}
}
}

// Close destroys all active scanners.
// Currently unused.
func (sp *ScannerPool) Close() {
sp.closeOnce.Do(func() {
sp.pinner.Unpin()
close(sp.scanners)
for scanner := range sp.scanners {
scanner.Destroy()
}
})
}
1 change: 1 addition & 0 deletions pkg/report/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ func Generate(ctx context.Context, path string, mrs *yarax.ScanResults, c malcon

processor := newMatchProcessor(fc, matches, m.Patterns())
matchedStrings = processor.process()
processor.clearFileContent()
}

b := &malcontent.Behavior{
Expand Down
23 changes: 20 additions & 3 deletions pkg/report/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ type StringPool struct {
strings map[string]string
}

// clear removes all strings from the pool to free memory.
func (sp *StringPool) clear() {
sp.Lock()
defer sp.Unlock()
clear(sp.strings)
}

// NewStringPool creates a new string pool.
func NewStringPool(length int) *StringPool {
return &StringPool{
Expand Down Expand Up @@ -70,6 +77,12 @@ var matchResultPool = sync.Pool{
},
}

// clearFileContent releases the file content to free memory after processing.
func (mp *matchProcessor) clearFileContent() {
mp.fc = nil
mp.pool.clear()
}

// process performantly handles the conversion of matched data to strings.
// yara-x does not expose the rendered string via the API due to performance overhead.
func (mp *matchProcessor) process() []string {
Expand Down Expand Up @@ -115,20 +128,24 @@ func (mp *matchProcessor) process() []string {
if l <= cap(buffer) {
buffer = buffer[:l]
copy(buffer, matchBytes)
*result = append(*result, mp.pool.Intern(string(buffer)))
matchStr := string(buffer)
*result = append(*result, mp.pool.Intern(string([]byte(matchStr))))
} else {
*result = append(*result, mp.pool.Intern(string(matchBytes)))
matchStr := string(matchBytes)
*result = append(*result, mp.pool.Intern(string([]byte(matchStr))))
}
} else {
if patterns == nil || cap(patterns) < patternsCap {
patterns = make([]string, 0, patternsCap)
} else {
clear(patterns)
patterns = patterns[:0]
}
for _, p := range mp.patterns {
patterns = append(patterns, p.Identifier())
}
*result = append(*result, slices.Compact(patterns)...)
compacted := slices.Compact(patterns)
*result = append(*result, compacted...)
}
}

Expand Down
Loading