Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ GLOBAL OPTIONS:
--ignore-tags string Rule tags to ignore (default: "false_positive,ignore")
--include-data-files Include files that are detected as non-program (binary or source) files
--jobs int, -j int Concurrently scan files within target scan paths (default: 12)
--max-depth int Maximum depth for archive extraction (-1 for unlimited) (default: 32)
--max-files int Maximum number of files to scan (-1 for unlimited) (default: 2097152)
--max-depth int Maximum depth for archive extraction (0 or -1 for unlimited) (default: 32)
--max-files int Maximum number of files to scan (0 or -1 for unlimited) (default: 2097152)
--max-image-size int Maximum OCI image size in bytes (0 or -1 for unlimited) (default: 17179869184)
--min-file-level int Obsoleted by --min-file-risk (default: -1)
--min-file-risk string Only show results for files which meet the given risk level (any, low, medium, high, critical) (default: "low")
--min-level int Obsoleted by --min-risk (default: -1)
Expand Down
13 changes: 11 additions & 2 deletions cmd/mal/mal.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ var (
ignoreTagsFlag string
includeDataFilesFlag bool
maxDepthFlag int
maxImageSizeFlag int64
maxScanFilesFlag int
minFileLevelFlag int
minFileRiskFlag string
Expand Down Expand Up @@ -261,6 +262,7 @@ func main() {
IgnoreTags: ignoreTags,
IncludeDataFiles: includeDataFiles,
MaxDepth: maxDepthFlag,
MaxImageSize: maxImageSizeFlag,
MaxScanFiles: maxScanFilesFlag,
MinFileRisk: minFileRisk,
MinRisk: minRisk,
Expand Down Expand Up @@ -348,17 +350,24 @@ func main() {
&cli.IntFlag{
Name: "max-depth",
Value: 32,
Usage: "Maximum depth for archive extraction (-1 for unlimited)",
Usage: "Maximum depth for archive extraction (0 or -1 for unlimited)",
Destination: &maxDepthFlag,
Local: false,
},
&cli.IntFlag{
Name: "max-files",
Value: 1 << 21, // ~2 million files
Usage: "Maximum number of files to scan (-1 for unlimited)",
Usage: "Maximum number of files to scan (0 or -1 for unlimited)",
Destination: &maxScanFilesFlag,
Local: false,
},
&cli.Int64Flag{
Name: "max-image-size",
Value: 1 << 34, // ~16 GB
Usage: "Maximum OCI image size in bytes (0 or -1 for unlimited)",
Destination: &maxImageSizeFlag,
Local: false,
},
&cli.IntFlag{
Name: "min-file-level",
Value: -1,
Expand Down
47 changes: 23 additions & 24 deletions pkg/action/archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,30 @@ import (

"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/malcontent/pkg/archive"
"github.com/chainguard-dev/malcontent/pkg/file"
"github.com/chainguard-dev/malcontent/pkg/malcontent"
"github.com/chainguard-dev/malcontent/pkg/render"
"github.com/chainguard-dev/malcontent/rules"
thirdparty "github.com/chainguard-dev/malcontent/third_party"
"github.com/google/go-cmp/cmp"
)

// readTestFile reads a file using file.GetContents for consistency with production code.
func readTestFile(t *testing.T, path string) []byte {
t.Helper()
f, err := os.Open(path)
if err != nil {
t.Fatalf("failed to open test file %s: %v", path, err)
}
defer f.Close()
buf := make([]byte, file.ExtractBuffer)
data, err := file.GetContents(f, buf)
if err != nil {
t.Fatalf("failed to read test file %s: %v", path, err)
}
return data
}

func TestExtractionMethod(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -256,10 +273,7 @@ func TestScanArchive(t *testing.T) {

got := out.String()

td, err := os.ReadFile("testdata/scan_archive")
if err != nil {
t.Fatalf("testdata read failed: %v", err)
}
td := readTestFile(t, "testdata/scan_archive")
want := string(td)

if diff := cmp.Diff(want, got); diff != "" {
Expand Down Expand Up @@ -303,10 +317,7 @@ func TestScanDeb(t *testing.T) {

got := out.String()

td, err := os.ReadFile("testdata/scan_deb")
if err != nil {
t.Fatalf("testdata read failed: %v", err)
}
td := readTestFile(t, "testdata/scan_deb")
want := string(td)

if diff := cmp.Diff(want, got); diff != "" {
Expand Down Expand Up @@ -350,10 +361,7 @@ func TestScanRPM(t *testing.T) {

got := out.String()

td, err := os.ReadFile("testdata/scan_rpm")
if err != nil {
t.Fatalf("testdata read failed: %v", err)
}
td := readTestFile(t, "testdata/scan_rpm")
want := string(td)

if diff := cmp.Diff(want, got); diff != "" {
Expand Down Expand Up @@ -397,10 +405,7 @@ func TestScanZlib(t *testing.T) {

got := out.String()

td, err := os.ReadFile("testdata/scan_zlib")
if err != nil {
t.Fatalf("testdata read failed: %v", err)
}
td := readTestFile(t, "testdata/scan_zlib")
want := string(td)

if diff := cmp.Diff(want, got); diff != "" {
Expand Down Expand Up @@ -444,10 +449,7 @@ func TestScanZstd(t *testing.T) {

got := out.String()

td, err := os.ReadFile("testdata/scan_zstd")
if err != nil {
t.Fatalf("testdata read failed: %v", err)
}
td := readTestFile(t, "testdata/scan_zstd")
want := string(td)

if diff := cmp.Diff(want, got); diff != "" {
Expand Down Expand Up @@ -582,10 +584,7 @@ func TestScanConflictingArchiveFiles(t *testing.T) {
}

got := out.String()
td, err := os.ReadFile("testdata/scan_conflict")
if err != nil {
t.Fatalf("testdata read failed: %v", err)
}
td := readTestFile(t, "testdata/scan_conflict")
want := string(td)

if diff := cmp.Diff(want, got); diff != "" {
Expand Down
4 changes: 2 additions & 2 deletions pkg/action/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@ func Diff(ctx context.Context, c malcontent.Config, _ *clog.Logger) (*malcontent
)

if c.OCI {
srcPath, err = archive.OCI(ctx, srcPath, c.OCIAuth)
srcPath, err = archive.OCI(ctx, srcPath, c.OCIAuth, c.MaxImageSize)
if err != nil {
return nil, fmt.Errorf("failed to prepare scan path: %w", err)
}
destPath, err = archive.OCI(ctx, destPath, c.OCIAuth)
destPath, err = archive.OCI(ctx, destPath, c.OCIAuth, c.MaxImageSize)
if err != nil {
return nil, fmt.Errorf("failed to prepare scan path: %w", err)
}
Expand Down
6 changes: 1 addition & 5 deletions pkg/action/oci_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"bytes"
"context"
"io/fs"
"os"
"runtime"
"testing"

Expand Down Expand Up @@ -55,10 +54,7 @@ func TestOCI(t *testing.T) {

got := out.String()

td, err := os.ReadFile("testdata/scan_oci")
if err != nil {
t.Fatalf("testdata read failed: %v", err)
}
td := readTestFile(t, "testdata/scan_oci")
want := string(td)

if diff := cmp.Diff(want, got); diff != "" {
Expand Down
6 changes: 3 additions & 3 deletions pkg/action/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func handleScanPath(ctx context.Context, scanPath string, c malcontent.Config, r
c.Renderer.Scanning(ctx, scanPath)
}

scanInfo, err := prepareScanPath(ctx, scanPath, c.OCI, c.OCIAuth, logger)
scanInfo, err := prepareScanPath(ctx, scanPath, c.OCI, c.OCIAuth, c.MaxImageSize, logger)
if err != nil {
return fmt.Errorf("failed to prepare scan path: %w", err)
}
Expand All @@ -376,7 +376,7 @@ func handleScanPath(ctx context.Context, scanPath string, c malcontent.Config, r
return processPaths(ctx, paths, scanInfo, c, r, matchChan, matchOnce, logger)
}

func prepareScanPath(ctx context.Context, scanPath string, isOCI, useAuth bool, logger *clog.Logger) (scanPathInfo, error) {
func prepareScanPath(ctx context.Context, scanPath string, isOCI, useAuth bool, maxImageSize int64, logger *clog.Logger) (scanPathInfo, error) {
if ctx.Err() != nil {
return scanPathInfo{}, ctx.Err()
}
Expand All @@ -391,7 +391,7 @@ func prepareScanPath(ctx context.Context, scanPath string, isOCI, useAuth bool,
}

info.imageURI = scanPath
ociPath, err := archive.OCI(ctx, info.imageURI, useAuth)
ociPath, err := archive.OCI(ctx, info.imageURI, useAuth, maxImageSize)
if err != nil {
return info, fmt.Errorf("failed to prepare OCI image for scanning: %w", err)
}
Expand Down
60 changes: 50 additions & 10 deletions pkg/archive/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"runtime"
"strings"
"sync"
"time"

"github.com/chainguard-dev/clog"
"github.com/chainguard-dev/malcontent/pkg/file"
Expand All @@ -33,6 +32,34 @@ func init() {
zipPool = pool.NewBufferPool(runtime.GOMAXPROCS(0) * 2)
}

// ValidateResolvedPath checks that the target path still resides within the extraction directory
// after resolving symlinks in its parent directory.
func ValidateResolvedPath(target, dir, clean string) error {
resolvedParent, ok := evalSymlinks(filepath.Dir(target))
if !ok {
return nil
}
resolvedDir, ok := evalSymlinks(dir)
if !ok {
return nil
}
resolvedTarget := filepath.Join(resolvedParent, filepath.Base(target))
if !IsValidPath(resolvedTarget, resolvedDir) {
return fmt.Errorf("path traversal via symlink in parent directory: %s", clean)
}
return nil
}

// evalSymlinks resolves symlinks in the given path, returning the resolved path
// and true on success, or an empty string and false if resolution fails.
func evalSymlinks(path string) (string, bool) {
resolved, err := filepath.EvalSymlinks(path)
if err != nil {
return "", false
}
return resolved, true
}

// isValidPath checks if the target file is within the given directory.
func IsValidPath(target, dir string) bool {
if strings.Contains(target, "\x00") || strings.Contains(dir, "\x00") {
Expand Down Expand Up @@ -137,13 +164,15 @@ func extractNestedArchive(ctx context.Context, c malcontent.Config, d string, f
// Some packages may have archives and files with colliding names
// e.g., demo_page.css and demo_page.css.gz
// the former is the uncompressed version of the latter
// if we encounter this, replace the name with something that won't collide
// if we encounter this, use os.MkdirTemp to create a unique directory
if _, err := os.Stat(archivePath); err == nil {
logger.Debugf("duplicate file name already exists, modifying directory name for %s", archivePath)
archivePath = fmt.Sprintf("%s%d", archivePath, time.Now().UnixNano())
}

if err := os.MkdirAll(archivePath, 0o700); err != nil {
var mkErr error
archivePath, mkErr = os.MkdirTemp(filepath.Dir(archivePath), filepath.Base(archivePath)+"_*")
if mkErr != nil {
return fmt.Errorf("failed to create unique extraction directory: %w", mkErr)
}
} else if err := os.MkdirAll(archivePath, 0o700); err != nil {
return fmt.Errorf("failed to create extraction directory: %w", err)
}

Expand Down Expand Up @@ -331,9 +360,19 @@ func handleSymlink(dir, linkPath, linkTarget string) error {
return nil
}

parentDir := filepath.Dir(fullPath)
resolvedDir := dir
if rp, err := filepath.EvalSymlinks(parentDir); err == nil {
parentDir = rp
if rd, err := filepath.EvalSymlinks(dir); err == nil {
resolvedDir = rd
}
}

// Validate relative symlink target resolves within extraction directory
resolvedTarget := filepath.Clean(filepath.Join(filepath.Dir(fullPath), linkTarget))
if !IsValidPath(resolvedTarget, dir) {
// using the actual (resolved) parent directory
resolvedTarget := filepath.Clean(filepath.Join(parentDir, linkTarget))
if !IsValidPath(resolvedTarget, resolvedDir) {
return fmt.Errorf("symlink target escapes extraction directory: %s -> %s", linkPath, linkTarget)
}

Expand Down Expand Up @@ -363,8 +402,9 @@ func handleSymlink(dir, linkPath, linkTarget string) error {
return fmt.Errorf("symlink target mismatch: expected %s, got %s", linkTarget, actualTarget)
}

actualResolved := filepath.Clean(filepath.Join(filepath.Dir(fullPath), actualTarget))
if !IsValidPath(actualResolved, dir) {
// Post-creation validation using the resolved parent directory
actualResolved := filepath.Clean(filepath.Join(parentDir, actualTarget))
if !IsValidPath(actualResolved, resolvedDir) {
os.Remove(fullPath)
return fmt.Errorf("symlink target escapes extraction directory after creation: %s -> %s", linkPath, actualTarget)
}
Expand Down
2 changes: 0 additions & 2 deletions pkg/archive/bz2.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ func ExtractBz2(ctx context.Context, d, f string) error {
return fmt.Errorf("failed to create directory for file: %w", err)
}

// #nosec G115 // ignore Type conversion which leads to integer overflow
// header.Mode is int64 and FileMode is uint32
out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
Expand Down
12 changes: 11 additions & 1 deletion pkg/archive/deb.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ import (
)

// ExtractDeb extracts .deb packages.
func ExtractDeb(ctx context.Context, d, f string) error {
func ExtractDeb(ctx context.Context, d, f string) (retErr error) {
// Recover from panics in third-party deb parsing library.
defer func() {
if r := recover(); r != nil {
retErr = fmt.Errorf("recovered from panic during deb extraction: %v", r)
}
}()
if ctx.Err() != nil {
return ctx.Err()
}
Expand Down Expand Up @@ -57,6 +63,10 @@ func ExtractDeb(ctx context.Context, d, f string) error {
return fmt.Errorf("invalid file path: %s", target)
}

if err := ValidateResolvedPath(target, d, clean); err != nil {
return err
}

switch header.Typeflag {
case tar.TypeDir:
if err := handleDirectory(target); err != nil {
Expand Down
Loading