Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions pkg/action/archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package action

import (
"archive/tar"
"bytes"
"compress/gzip"
"context"
"io/fs"
"os"
Expand Down Expand Up @@ -592,6 +594,114 @@ func TestScanConflictingArchiveFiles(t *testing.T) {
}
}

// createBrokenNestedArchive creates a tar.gz file containing a nested
// file with an archive extension whose content is valid gzip but invalid tar.
func createBrokenNestedArchive(t *testing.T, dir string) string {
t.Helper()

outPath := filepath.Join(dir, "outer.tar.gz")
f, err := os.Create(outPath)
if err != nil {
t.Fatalf("failed to create outer archive: %v", err)
}
defer f.Close()

gw := gzip.NewWriter(f)
defer gw.Close()
tw := tar.NewWriter(gw)
defer tw.Close()

var innerBuf bytes.Buffer
innerGw := gzip.NewWriter(&innerBuf)
if _, err := innerGw.Write(bytes.Repeat([]byte("A"), 1024)); err != nil {
t.Fatalf("failed to write inner gzip data: %v", err)
}
if err := innerGw.Close(); err != nil {
t.Fatalf("failed to close inner gzip writer: %v", err)
}

innerData := innerBuf.Bytes()
if err := tw.WriteHeader(&tar.Header{
Name: "bad_nested.tar.gz",
Mode: 0o600,
Size: int64(len(innerData)),
}); err != nil {
t.Fatalf("failed to write tar header: %v", err)
}
if _, err := tw.Write(innerData); err != nil {
t.Fatalf("failed to write tar data: %v", err)
}

return outPath
}

// TestNestedFailureRetention verifies that when a nested archive
// extraction fails with ExitExtraction=false (default), the original nested archive
// file is retained in the extraction directory for scanning rather than being deleted.
func TestNestedFailureRetention(t *testing.T) {
t.Parallel()

tmpDir, err := os.MkdirTemp("", "nested-fail-retain-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

outerArchive := createBrokenNestedArchive(t, tmpDir)

ctx := context.Background()
cfg := malcontent.Config{ExitExtraction: false}

extractDir, err := archive.ExtractArchiveToTempDir(ctx, cfg, outerArchive)
if err != nil {
t.Fatalf("ExtractArchiveToTempDir should not fail with ExitExtraction=false, got: %v", err)
}
defer os.RemoveAll(extractDir)

// The nested archive file must still exist so it can be scanned as a regular file
found := false
err = filepath.WalkDir(extractDir, func(_ string, d os.DirEntry, err error) error {
if err != nil {
return err
}
if d.Name() == "bad_nested.tar.gz" {
found = true
}
return nil
})
if err != nil {
t.Fatalf("failed to walk extraction directory: %v", err)
}
if !found {
t.Fatal("nested archive file was deleted after extraction failure but should be retained for scanning")
}
}

// TestNestedFailureRetentionError verifies that when ExitExtraction=true,
// a nested archive extraction failure propagates as an error.
func TestNestedFailureRetentionError(t *testing.T) {
t.Parallel()

tmpDir, err := os.MkdirTemp("", "nested-fail-exit-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)

outerArchive := createBrokenNestedArchive(t, tmpDir)

ctx := context.Background()
cfg := malcontent.Config{ExitExtraction: true}

extractDir, err := archive.ExtractArchiveToTempDir(ctx, cfg, outerArchive)
if extractDir != "" {
defer os.RemoveAll(extractDir)
}
if err == nil {
t.Fatal("ExtractArchiveToTempDir should return error with ExitExtraction=true for nested archives which cannot be extracted")
}
}

func TestIsValidPath(t *testing.T) {
tmpRoot, err := os.MkdirTemp("", "isValidPath-*")
if err != nil {
Expand Down
10 changes: 7 additions & 3 deletions pkg/archive/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,17 @@ func extractNestedArchive(ctx context.Context, c malcontent.Config, d string, f
if c.ExitExtraction {
return fmt.Errorf("failed to extract archive: %w", err)
}
logger.Debugf("ignoring extraction error for %s: %s", f, err.Error())
logger.Warnf("extraction failed for %s, retaining archive for scanning: %s", f, err.Error())
}

extracted.Store(f, true)

if err := os.Remove(fullPath); err != nil {
return fmt.Errorf("failed to remove archive file: %w", err)
// only attempt to remove the archive file if we don't encounter an extraction error
// any archives which cannot be extracted will be scanned like non-archive files
if err == nil {
if err := os.Remove(fullPath); err != nil {
return fmt.Errorf("failed to remove archive file: %w", err)
}
}

entries, err := os.ReadDir(d)
Expand Down