Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/action/archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func TestExtractionMultiple(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
dirFiles, err := os.ReadDir(dir)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -110,6 +111,7 @@ func TestExtractTar(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
want := []string{
"apko_0.13.2_linux_arm64",
}
Expand Down Expand Up @@ -138,6 +140,7 @@ func TestExtractGzip(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
want := []string{
"apko",
}
Expand Down Expand Up @@ -166,6 +169,7 @@ func TestExtractZip(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
want := []string{
"apko_0.13.2_linux_arm64",
}
Expand Down Expand Up @@ -194,6 +198,7 @@ func TestExtractNestedArchive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
want := []string{
"apko_0.13.2_linux_arm64",
}
Expand Down
22 changes: 9 additions & 13 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF

isArchive := archiveRoot != ""

f, err := os.Open(path)
if err != nil {
return nil, err
}

fi, err := f.Stat()
fi, err := os.Stat(path)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -128,6 +123,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
scannerPool = pool.NewScannerPool(yrs, getMaxConcurrency(runtime.GOMAXPROCS(0)))
})
scanner := scannerPool.Get(yrs)
defer scannerPool.Put(scanner)

mrs, err := scanner.ScanFile(path)
if err != nil {
Expand All @@ -150,8 +146,14 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
// create a buffer sized to the minimum of the file's size or the default ReadBuffer
// only do so if we actually need to retrieve the file's contents
buf := readPool.Get(min(size, file.ReadBuffer)) //nolint:nilaway // the buffer pool is initialized in init()
defer readPool.Put(buf)

f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()

// Only retrieve the file's contents and calculate its checksum if we need to generate a report
fc, err := file.GetContents(f, buf)
if err != nil {
return nil, err
Expand All @@ -169,12 +171,6 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
return nil, NewFileReportError(err, path, TypeGenerateError)
}

defer func() {
f.Close()
readPool.Put(buf)
scannerPool.Put(scanner)
}()

// Clean up the path if scanning an archive
var clean string
if isArchive || c.OCI {
Expand Down
162 changes: 162 additions & 0 deletions pkg/action/scan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright 2026 Chainguard, Inc.
// SPDX-License-Identifier: Apache-2.0

package action

import (
"context"
"io/fs"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/chainguard-dev/malcontent/pkg/malcontent"
"github.com/chainguard-dev/malcontent/rules"
thirdparty "github.com/chainguard-dev/malcontent/third_party"
)

// countOpenFDs returns the number of open file descriptors for the current process.
// Returns -1 if unable to count (e.g., on unsupported platforms).
func countOpenFDs(t *testing.T) int {
t.Helper()

// Linux: count entries in /proc/self/fd
if entries, err := os.ReadDir("/proc/self/fd"); err == nil {
return len(entries)
}

// macOS: count entries in /dev/fd
if entries, err := os.ReadDir("/dev/fd"); err == nil {
return len(entries)
}

return -1
}

// TestScanSinglePathNoFDLeak verifies that early return paths in scanSinglePath
// properly close file handles and don't leak file descriptors.
func TestScanSinglePathNoFDLeak(t *testing.T) {
ctx := context.Background()

fdsBefore := countOpenFDs(t)
if fdsBefore == -1 {
t.Skip("cannot count file descriptors on this platform")
}

rfs := []fs.FS{rules.FS, thirdparty.FS}
yrs, err := CachedRules(ctx, rfs)
if err != nil {
t.Fatalf("rules: %v", err)
}

cfg := malcontent.Config{
Concurrency: runtime.NumCPU(),
IgnoreSelf: false,
IncludeDataFiles: false,
MinFileRisk: 0,
MinRisk: 0,
Rules: yrs,
RuleFS: rfs,
}

testFiles := []string{
filepath.Join("testdata", "empty"),
filepath.Join("testdata", "rando"),
filepath.Join("testdata", "short"),
}

iterations := runtime.GOMAXPROCS(0) * 10
for range iterations {
for _, tf := range testFiles {
_, _ = scanSinglePath(ctx, cfg, tf, rfs, tf, "", nil)
}
}

runtime.GC()

fdsAfter := countOpenFDs(t)

maxAllowedGrowth := 0
leaked := fdsAfter - fdsBefore
if leaked > maxAllowedGrowth {
t.Errorf("file descriptor leak detected: before=%d after=%d leaked=%d (ran %d iterations)",
fdsBefore, fdsAfter, leaked, iterations*len(testFiles))
}
}

// TestScanSinglePathNonExistentFile verifies that scanning a non-existent file
// returns an error without leaking resources.
func TestScanSinglePathNonExistentFile(t *testing.T) {
ctx := context.Background()

fdsBefore := countOpenFDs(t)
if fdsBefore == -1 {
t.Skip("cannot count file descriptors on this platform")
}

rfs := []fs.FS{rules.FS, thirdparty.FS}
yrs, err := CachedRules(ctx, rfs)
if err != nil {
t.Fatalf("rules: %v", err)
}

cfg := malcontent.Config{
Rules: yrs,
RuleFS: rfs,
}

iterations := runtime.GOMAXPROCS(0) * 10
for range iterations {
_, err := scanSinglePath(ctx, cfg, "/nonexistent/path/to/file", rfs, "", "", nil)
if err == nil {
t.Error("expected error for non-existent file")
}
}

runtime.GC()

fdsAfter := countOpenFDs(t)
maxAllowedGrowth := 0
leaked := fdsAfter - fdsBefore
if leaked > maxAllowedGrowth {
t.Errorf("file descriptor leak on error path: before=%d after=%d leaked=%d",
fdsBefore, fdsAfter, leaked)
}
}

// TestScanRepeatedScansNoResourceExhaustion verifies that repeated scans
// don't exhaust scanner pool or buffer pool resources.
func TestScanRepeatedScansNoResourceExhaustion(t *testing.T) {
ctx := context.Background()

rfs := []fs.FS{rules.FS, thirdparty.FS}
yrs, err := CachedRules(ctx, rfs)
if err != nil {
t.Fatalf("rules: %v", err)
}

cfg := malcontent.Config{
Concurrency: runtime.NumCPU(),
IgnoreSelf: false,
IncludeDataFiles: false,
MinFileRisk: 0,
MinRisk: 0,
Rules: yrs,
RuleFS: rfs,
}

testFiles := []string{
filepath.Join("testdata", "empty"), // zero-sized, early return before scanner
filepath.Join("testdata", "rando"), // data file, early return before scanner
filepath.Join("testdata", "shell"), // actual script, full scan path
}

iterations := runtime.GOMAXPROCS(0) * 10

for range iterations {
for _, tf := range testFiles {
_, _ = scanSinglePath(ctx, cfg, tf, rfs, tf, "", nil)
}
}
}
1 change: 1 addition & 0 deletions pkg/archive/rpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ func ExtractRPM(ctx context.Context, d, f string) error {
if err != nil {
return fmt.Errorf("failed to create zstd reader: %w", err)
}
defer zstdStream.Close()
cr = cpio.NewReader(zstdStream)
default:
return fmt.Errorf("unsupported compression format: %s", compression)
Expand Down