diff --git a/pkg/action/archive_test.go b/pkg/action/archive_test.go index 03dd7619e..8f06951ae 100644 --- a/pkg/action/archive_test.go +++ b/pkg/action/archive_test.go @@ -83,6 +83,7 @@ func TestExtractionMultiple(t *testing.T) { if err != nil { t.Fatal(err) } + defer os.RemoveAll(dir) dirFiles, err := os.ReadDir(dir) if err != nil { t.Fatal(err) @@ -110,6 +111,7 @@ func TestExtractTar(t *testing.T) { if err != nil { t.Fatal(err) } + defer os.RemoveAll(dir) want := []string{ "apko_0.13.2_linux_arm64", } @@ -138,6 +140,7 @@ func TestExtractGzip(t *testing.T) { if err != nil { t.Fatal(err) } + defer os.RemoveAll(dir) want := []string{ "apko", } @@ -166,6 +169,7 @@ func TestExtractZip(t *testing.T) { if err != nil { t.Fatal(err) } + defer os.RemoveAll(dir) want := []string{ "apko_0.13.2_linux_arm64", } @@ -194,6 +198,7 @@ func TestExtractNestedArchive(t *testing.T) { if err != nil { t.Fatal(err) } + defer os.RemoveAll(dir) want := []string{ "apko_0.13.2_linux_arm64", } diff --git a/pkg/action/scan.go b/pkg/action/scan.go index c9621b8ab..8a4090cb8 100644 --- a/pkg/action/scan.go +++ b/pkg/action/scan.go @@ -62,12 +62,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF isArchive := archiveRoot != "" - f, err := os.Open(path) - if err != nil { - return nil, err - } - - fi, err := f.Stat() + fi, err := os.Stat(path) if err != nil { return nil, err } @@ -128,6 +123,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF scannerPool = pool.NewScannerPool(yrs, getMaxConcurrency(runtime.GOMAXPROCS(0))) }) scanner := scannerPool.Get(yrs) + defer scannerPool.Put(scanner) mrs, err := scanner.ScanFile(path) if err != nil { @@ -150,8 +146,14 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF // create a buffer sized to the minimum of the file's size or the default ReadBuffer // only do so if we actually need to retrieve the file's contents buf := readPool.Get(min(size, file.ReadBuffer)) //nolint:nilaway // the buffer pool is initialized in init() + defer readPool.Put(buf) + + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() - // Only retrieve the file's contents and calculate its checksum if we need to generate a report fc, err := file.GetContents(f, buf) if err != nil { return nil, err @@ -169,12 +171,6 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF return nil, NewFileReportError(err, path, TypeGenerateError) } - defer func() { - f.Close() - readPool.Put(buf) - scannerPool.Put(scanner) - }() - // Clean up the path if scanning an archive var clean string if isArchive || c.OCI { diff --git a/pkg/action/scan_test.go b/pkg/action/scan_test.go new file mode 100644 index 000000000..dc54304e4 --- /dev/null +++ b/pkg/action/scan_test.go @@ -0,0 +1,162 @@ +// Copyright 2026 Chainguard, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package action + +import ( + "context" + "io/fs" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/chainguard-dev/malcontent/pkg/malcontent" + "github.com/chainguard-dev/malcontent/rules" + thirdparty "github.com/chainguard-dev/malcontent/third_party" +) + +// countOpenFDs returns the number of open file descriptors for the current process. +// Returns -1 if unable to count (e.g., on unsupported platforms). +func countOpenFDs(t *testing.T) int { + t.Helper() + + // Linux: count entries in /proc/self/fd + if entries, err := os.ReadDir("/proc/self/fd"); err == nil { + return len(entries) + } + + // macOS: count entries in /dev/fd + if entries, err := os.ReadDir("/dev/fd"); err == nil { + return len(entries) + } + + return -1 +} + +// TestScanSinglePathNoFDLeak verifies that early return paths in scanSinglePath +// properly close file handles and don't leak file descriptors. +func TestScanSinglePathNoFDLeak(t *testing.T) { + ctx := context.Background() + + fdsBefore := countOpenFDs(t) + if fdsBefore == -1 { + t.Skip("cannot count file descriptors on this platform") + } + + rfs := []fs.FS{rules.FS, thirdparty.FS} + yrs, err := CachedRules(ctx, rfs) + if err != nil { + t.Fatalf("rules: %v", err) + } + + cfg := malcontent.Config{ + Concurrency: runtime.NumCPU(), + IgnoreSelf: false, + IncludeDataFiles: false, + MinFileRisk: 0, + MinRisk: 0, + Rules: yrs, + RuleFS: rfs, + } + + testFiles := []string{ + filepath.Join("testdata", "empty"), + filepath.Join("testdata", "rando"), + filepath.Join("testdata", "short"), + } + + iterations := runtime.GOMAXPROCS(0) * 10 + for range iterations { + for _, tf := range testFiles { + _, _ = scanSinglePath(ctx, cfg, tf, rfs, tf, "", nil) + } + } + + runtime.GC() + + fdsAfter := countOpenFDs(t) + + maxAllowedGrowth := 0 + leaked := fdsAfter - fdsBefore + if leaked > maxAllowedGrowth { + t.Errorf("file descriptor leak detected: before=%d after=%d leaked=%d (ran %d iterations)", + fdsBefore, fdsAfter, leaked, iterations*len(testFiles)) + } +} + +// TestScanSinglePathNonExistentFile verifies that scanning a non-existent file +// returns an error without leaking resources. +func TestScanSinglePathNonExistentFile(t *testing.T) { + ctx := context.Background() + + fdsBefore := countOpenFDs(t) + if fdsBefore == -1 { + t.Skip("cannot count file descriptors on this platform") + } + + rfs := []fs.FS{rules.FS, thirdparty.FS} + yrs, err := CachedRules(ctx, rfs) + if err != nil { + t.Fatalf("rules: %v", err) + } + + cfg := malcontent.Config{ + Rules: yrs, + RuleFS: rfs, + } + + iterations := runtime.GOMAXPROCS(0) * 10 + for range iterations { + _, err := scanSinglePath(ctx, cfg, "/nonexistent/path/to/file", rfs, "", "", nil) + if err == nil { + t.Error("expected error for non-existent file") + } + } + + runtime.GC() + + fdsAfter := countOpenFDs(t) + maxAllowedGrowth := 0 + leaked := fdsAfter - fdsBefore + if leaked > maxAllowedGrowth { + t.Errorf("file descriptor leak on error path: before=%d after=%d leaked=%d", + fdsBefore, fdsAfter, leaked) + } +} + +// TestScanRepeatedScansNoResourceExhaustion verifies that repeated scans +// don't exhaust scanner pool or buffer pool resources. +func TestScanRepeatedScansNoResourceExhaustion(t *testing.T) { + ctx := context.Background() + + rfs := []fs.FS{rules.FS, thirdparty.FS} + yrs, err := CachedRules(ctx, rfs) + if err != nil { + t.Fatalf("rules: %v", err) + } + + cfg := malcontent.Config{ + Concurrency: runtime.NumCPU(), + IgnoreSelf: false, + IncludeDataFiles: false, + MinFileRisk: 0, + MinRisk: 0, + Rules: yrs, + RuleFS: rfs, + } + + testFiles := []string{ + filepath.Join("testdata", "empty"), // zero-sized, early return before scanner + filepath.Join("testdata", "rando"), // data file, early return before scanner + filepath.Join("testdata", "shell"), // actual script, full scan path + } + + iterations := runtime.GOMAXPROCS(0) * 10 + + for range iterations { + for _, tf := range testFiles { + _, _ = scanSinglePath(ctx, cfg, tf, rfs, tf, "", nil) + } + } +} diff --git a/pkg/archive/rpm.go b/pkg/archive/rpm.go index b9aeadcc5..a3b48ef78 100644 --- a/pkg/archive/rpm.go +++ b/pkg/archive/rpm.go @@ -126,6 +126,7 @@ func ExtractRPM(ctx context.Context, d, f string) error { if err != nil { return fmt.Errorf("failed to create zstd reader: %w", err) } + defer zstdStream.Close() cr = cpio.NewReader(zstdStream) default: return fmt.Errorf("unsupported compression format: %s", compression)