Skip to content

Commit 17a726c

Browse files
authored
Small memory optimizations (#1017)
Signed-off-by: egibs <20933572+egibs@users.noreply.github.com>
1 parent f545c64 commit 17a726c

6 files changed

Lines changed: 129 additions & 81 deletions

File tree

pkg/action/path.go

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,7 @@ func findFilesRecursively(ctx context.Context, rootPath string) ([]string, error
4949

5050
// Ignore symlinked directories like regular directories
5151
if info.Type()&fs.ModeSymlink == fs.ModeSymlink {
52-
logger.Debugf("attempting to resolve symlink: %s", path)
53-
eval, err := filepath.EvalSymlinks(path)
54-
if err != nil {
55-
logger.Debugf("eval: %s: %s", path, err)
56-
return nil
57-
}
58-
fi, err := os.Stat(eval)
59-
if err != nil {
60-
logger.Debugf("stat: %s: %s", path, err)
61-
return nil
62-
}
63-
if fi.IsDir() {
64-
logger.Debugf("ignoring symlinked directory: %s", path)
65-
return nil
66-
}
67-
path = eval
52+
return nil
6853
}
6954

7055
files = append(files, path)

pkg/action/scan.go

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"context"
88
"errors"
99
"fmt"
10-
"io"
1110
"io/fs"
1211
"log/slog"
1312
"os"
@@ -17,6 +16,7 @@ import (
1716
"strings"
1817
"sync"
1918
"sync/atomic"
19+
"syscall"
2020

2121
"github.com/chainguard-dev/clog"
2222
"github.com/chainguard-dev/malcontent/pkg/archive"
@@ -43,13 +43,59 @@ var (
4343
ErrMatchedCondition = errors.New("matched exit criteria")
4444
// initializeOnce ensures that the file and scanner pools are only initialized once.
4545
initializeOnce sync.Once
46-
filePool *pool.BufferPool
4746
scannerPool *pool.ScannerPool
47+
maxMmapSize int64 = 1 << 31
4848
)
4949

50+
// scanFD scans a file descriptor using memory mapping for efficient large file handling.
51+
// This avoids loading the entire file into memory while still using yara-x's byte slice scanning.
52+
func scanFD(scanner *yarax.Scanner, fd uintptr, logger *clog.Logger) ([]byte, *yarax.ScanResults, error) {
53+
var stat syscall.Stat_t
54+
if err := syscall.Fstat(int(fd), &stat); err != nil {
55+
return nil, nil, fmt.Errorf("fstat failed: %w", err)
56+
}
57+
58+
size := stat.Size
59+
if size == 0 {
60+
mrs, err := scanner.Scan([]byte{})
61+
return nil, mrs, err
62+
}
63+
64+
if size < 0 {
65+
return nil, nil, fmt.Errorf("invalid file size: %d", size)
66+
}
67+
68+
if size > maxMmapSize {
69+
logger.Warn("file exceeds mmap limit, scanning first portion only",
70+
"size", size, "limit", maxMmapSize)
71+
size = maxMmapSize
72+
}
73+
74+
data, err := syscall.Mmap(int(fd), 0, int(size), syscall.PROT_READ, syscall.MAP_PRIVATE)
75+
if err != nil {
76+
return nil, nil, fmt.Errorf("mmap failed: %w", err)
77+
}
78+
defer func() {
79+
if unmapErr := syscall.Munmap(data); unmapErr != nil {
80+
logger.Error("failed to unmap memory", "error", unmapErr)
81+
}
82+
}()
83+
84+
mrs, err := scanner.Scan(data)
85+
if err != nil {
86+
return nil, nil, err
87+
}
88+
89+
// Create a copy of the data to return since the mmap will be unmapped
90+
// This is necessary because report generation needs access to file content
91+
// for checksum calculation and match string extraction
92+
fc := make([]byte, len(data))
93+
copy(fc, data)
94+
95+
return fc, mrs, err
96+
}
97+
5098
// scanSinglePath YARA scans a single path and converts it to a fileReport.
51-
//
52-
//nolint:cyclop // ignore complexity of 38
5399
func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleFS []fs.FS, absPath string, archiveRoot string) (*malcontent.FileReport, error) {
54100
if ctx.Err() != nil {
55101
return &malcontent.FileReport{}, ctx.Err()
@@ -60,7 +106,14 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
60106

61107
isArchive := archiveRoot != ""
62108

63-
fi, err := os.Stat(path)
109+
f, err := os.Open(path)
110+
if err != nil {
111+
return nil, err
112+
}
113+
fd := f.Fd()
114+
defer f.Close()
115+
116+
fi, err := f.Stat()
64117
if err != nil {
65118
return nil, err
66119
}
@@ -105,43 +158,13 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
105158
}
106159

107160
initializeOnce.Do(func() {
108-
filePool = pool.NewBufferPool(c.Concurrency + 1)
109-
scannerPool = pool.NewScannerPool(yrs, c.Concurrency+1)
161+
scannerPool = pool.NewScannerPool(yrs, c.Concurrency)
110162
})
111163

112164
scanner := scannerPool.Get()
113-
if scanner == nil {
114-
scanner = yarax.NewScanner(yrs)
115-
}
116165
defer scannerPool.Put(scanner)
117166

118-
f, err := os.Open(path)
119-
if err != nil {
120-
return nil, err
121-
}
122-
defer f.Close()
123-
124-
fc := filePool.Get(size)
125-
defer filePool.Put(fc)
126-
127-
var bytesRead int
128-
var totalRead int64
129-
for totalRead < size {
130-
bytesRead, err = f.Read(fc[totalRead:])
131-
if errors.Is(err, io.EOF) {
132-
break
133-
}
134-
if err != nil {
135-
return nil, err
136-
}
137-
totalRead += int64(bytesRead)
138-
}
139-
140-
if totalRead < size && err != nil {
141-
return nil, fmt.Errorf("incomplete read: got %d bytes, expected %d: %w", totalRead, size, err)
142-
}
143-
144-
mrs, err := scanner.Scan(fc)
167+
fc, mrs, err := scanFD(scanner, fd, logger)
145168
if err != nil {
146169
logger.Debug("skipping", slog.Any("error", err))
147170
return nil, err
@@ -164,6 +187,11 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
164187
return nil, NewFileReportError(err, path, TypeGenerateError)
165188
}
166189

190+
defer func() {
191+
fc = nil
192+
mrs = nil
193+
}()
194+
167195
// Clean up the path if scanning an archive
168196
var clean string
169197
if isArchive || c.OCI {
@@ -427,6 +455,12 @@ func processPaths(ctx context.Context, paths []string, scanInfo scanPathInfo, c
427455
}
428456
}()
429457

458+
// Zero-out the path strings and empty the slice once read into the path channel
459+
defer func() {
460+
clear(paths)
461+
paths = paths[:0]
462+
}()
463+
430464
for path := range pc {
431465
g.Go(func() error {
432466
if gCtx.Err() != nil {

pkg/archive/archive.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,6 @@ func ExtractArchiveToTempDir(ctx context.Context, path string) (string, error) {
146146
return "", fmt.Errorf("failed to create temp dir: %w", err)
147147
}
148148

149-
go func() {
150-
<-ctx.Done()
151-
logger.Debug("context cancelled, cleaning up temp dir")
152-
os.RemoveAll(tmpDir)
153-
}()
154-
155149
initializeOnce.Do(func() {
156150
archivePool = pool.NewBufferPool(runtime.GOMAXPROCS(0))
157151
})

pkg/pool/pool.go

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package pool
22

33
import (
44
"math"
5+
"runtime"
56
"sync"
67

78
yarax "github.com/VirusTotal/yara-x/go"
@@ -23,7 +24,8 @@ func NewBufferPool(count int) *BufferPool {
2324

2425
bp.pool = sync.Pool{
2526
New: func() any {
26-
return make([]byte, defaultBuffer)
27+
buffer := make([]byte, defaultBuffer)
28+
return &buffer
2729
},
2830
}
2931

@@ -43,18 +45,17 @@ func (bp *BufferPool) Get(size int64) []byte {
4345

4446
bufInterface := bp.pool.Get()
4547

46-
buf, ok := bufInterface.([]byte)
47-
if !ok || buf == nil {
48+
bufPtr, ok := bufInterface.(*[]byte)
49+
if !ok || bufPtr == nil {
4850
return make([]byte, size)
4951
}
5052

51-
bufPtr := &buf
5253
if cap(*bufPtr) < int(size) {
5354
bp.pool.Put(bufPtr)
5455
return make([]byte, size)
5556
}
5657

57-
return buf[:size]
58+
return (*bufPtr)[:size]
5859
}
5960

6061
// Put returns a byte buffer to the pool for future reuse.
@@ -72,34 +73,50 @@ func (bp *BufferPool) Put(buf []byte) {
7273

7374
// ScannerPool provides a pool of yara-x scanners.
7475
type ScannerPool struct {
75-
pool sync.Pool
76+
scanners chan *yarax.Scanner
77+
pinner *runtime.Pinner
78+
closeOnce sync.Once
7679
}
7780

7881
// NewScannerPool creates a pool containing the specified number of yara-x scanners.
7982
func NewScannerPool(yrs *yarax.Rules, count int) *ScannerPool {
80-
sp := &ScannerPool{}
81-
82-
sp.pool = sync.Pool{
83-
New: func() any {
84-
return yarax.NewScanner(yrs)
85-
},
83+
sp := &ScannerPool{
84+
scanners: make(chan *yarax.Scanner, count),
85+
pinner: &runtime.Pinner{},
8686
}
8787

8888
for range count {
89-
sp.pool.Put(yarax.NewScanner(yrs))
89+
scanner := yarax.NewScanner(yrs)
90+
// Pin the scanner in memory to prevent GC movement
91+
sp.pinner.Pin(scanner)
92+
sp.scanners <- scanner
9093
}
9194
return sp
9295
}
9396

94-
// Get retrieves a scanner from the scanner pool.
97+
// Get retrieves a scanner from the scanner pool, blocking if none are available.
9598
func (sp *ScannerPool) Get() *yarax.Scanner {
96-
if scanner, ok := sp.pool.Get().(*yarax.Scanner); ok {
97-
return scanner
98-
}
99-
return nil
99+
return <-sp.scanners
100100
}
101101

102102
// Put returns a scanner to the scanner pool.
103103
func (sp *ScannerPool) Put(scanner *yarax.Scanner) {
104-
sp.pool.Put(scanner)
104+
if scanner != nil {
105+
select {
106+
case sp.scanners <- scanner:
107+
default:
108+
}
109+
}
110+
}
111+
112+
// Close destroys all active scanners.
113+
// Currently unused.
114+
func (sp *ScannerPool) Close() {
115+
sp.closeOnce.Do(func() {
116+
sp.pinner.Unpin()
117+
close(sp.scanners)
118+
for scanner := range sp.scanners {
119+
scanner.Destroy()
120+
}
121+
})
105122
}

pkg/report/report.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ func Generate(ctx context.Context, path string, mrs *yarax.ScanResults, c malcon
479479

480480
processor := newMatchProcessor(fc, matches, m.Patterns())
481481
matchedStrings = processor.process()
482+
processor.clearFileContent()
482483
}
483484

484485
b := &malcontent.Behavior{

pkg/report/strings.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ type StringPool struct {
1919
strings map[string]string
2020
}
2121

22+
// clear removes all strings from the pool to free memory.
23+
func (sp *StringPool) clear() {
24+
sp.Lock()
25+
defer sp.Unlock()
26+
clear(sp.strings)
27+
}
28+
2229
// NewStringPool creates a new string pool.
2330
func NewStringPool(length int) *StringPool {
2431
return &StringPool{
@@ -70,6 +77,12 @@ var matchResultPool = sync.Pool{
7077
},
7178
}
7279

80+
// clearFileContent releases the file content to free memory after processing.
81+
func (mp *matchProcessor) clearFileContent() {
82+
mp.fc = nil
83+
mp.pool.clear()
84+
}
85+
7386
// process performantly handles the conversion of matched data to strings.
7487
// yara-x does not expose the rendered string via the API due to performance overhead.
7588
func (mp *matchProcessor) process() []string {
@@ -115,20 +128,24 @@ func (mp *matchProcessor) process() []string {
115128
if l <= cap(buffer) {
116129
buffer = buffer[:l]
117130
copy(buffer, matchBytes)
118-
*result = append(*result, mp.pool.Intern(string(buffer)))
131+
matchStr := string(buffer)
132+
*result = append(*result, mp.pool.Intern(string([]byte(matchStr))))
119133
} else {
120-
*result = append(*result, mp.pool.Intern(string(matchBytes)))
134+
matchStr := string(matchBytes)
135+
*result = append(*result, mp.pool.Intern(string([]byte(matchStr))))
121136
}
122137
} else {
123138
if patterns == nil || cap(patterns) < patternsCap {
124139
patterns = make([]string, 0, patternsCap)
125140
} else {
141+
clear(patterns)
126142
patterns = patterns[:0]
127143
}
128144
for _, p := range mp.patterns {
129145
patterns = append(patterns, p.Identifier())
130146
}
131-
*result = append(*result, slices.Compact(patterns)...)
147+
compacted := slices.Compact(patterns)
148+
*result = append(*result, compacted...)
132149
}
133150
}
134151

0 commit comments

Comments
 (0)