diff --git a/image/pull.go b/image/pull.go index 6ceb6e7..b007009 100644 --- a/image/pull.go +++ b/image/pull.go @@ -33,6 +33,16 @@ import ( // 30 GB should be generous enough for any legitimate container image. const maxExtractSize int64 = 30 << 30 // 30 GiB +// maxExtractEntries caps the number of tar entries accepted during +// extraction. Bounds inode exhaustion / tar-bomb variants where millions +// of small files fit easily under maxExtractSize but exhaust host inodes +// and the dentry cache. One million entries is far above the largest +// legitimate container images seen in practice (a few hundred thousand). +// +// Declared as var so tests can override it; no production caller should +// mutate it. +var maxExtractEntries = 1_000_000 + // Pull fetches an OCI image, flattens its layers, and extracts to a directory. // If a Cache is provided, results are cached by image digest. The returned // RootFS contains the extraction path and parsed OCI config. @@ -509,6 +519,9 @@ func extractTar(ctx context.Context, r io.Reader, dst string) error { } entryCount++ + if entryCount > maxExtractEntries { + return fmt.Errorf("tar archive exceeds maximum entry count of %d", maxExtractEntries) + } target, err := sanitizeTarPath(dst, hdr.Name) if err != nil { @@ -558,6 +571,9 @@ func extractTarSharedLimit(ctx context.Context, r io.Reader, dst string, remaini } entryCount++ + if entryCount > maxExtractEntries { + return fmt.Errorf("tar archive exceeds maximum entry count of %d", maxExtractEntries) + } target, err := sanitizeTarPath(dst, hdr.Name) if err != nil { diff --git a/image/pull_test.go b/image/pull_test.go index 06975b8..64a17ea 100644 --- a/image/pull_test.go +++ b/image/pull_test.go @@ -228,6 +228,46 @@ func TestExtractTar_Symlinks(t *testing.T) { assert.Equal(t, "real binary", string(data)) } +func TestExtractTar_RejectsExcessiveEntryCount(t *testing.T) { + // Not parallel: mutates the package-level maxExtractEntries. + orig := maxExtractEntries + t.Cleanup(func() { maxExtractEntries = orig }) + maxExtractEntries = 3 + + entries := []tarEntry{ + {name: "a", typeflag: tar.TypeReg, mode: 0o644, content: ""}, + {name: "b", typeflag: tar.TypeReg, mode: 0o644, content: ""}, + {name: "c", typeflag: tar.TypeReg, mode: 0o644, content: ""}, + {name: "d", typeflag: tar.TypeReg, mode: 0o644, content: ""}, + } + buf := createTarBuffer(t, entries) + + err := extractTar(context.Background(), buf, t.TempDir()) + require.Error(t, err) + assert.Contains(t, err.Error(), "maximum entry count") +} + +func TestExtractTarSharedLimit_RejectsExcessiveEntryCount(t *testing.T) { + // Not parallel: mutates the package-level maxExtractEntries. + orig := maxExtractEntries + t.Cleanup(func() { maxExtractEntries = orig }) + maxExtractEntries = 3 + + entries := []tarEntry{ + {name: "a", typeflag: tar.TypeReg, mode: 0o644, content: ""}, + {name: "b", typeflag: tar.TypeReg, mode: 0o644, content: ""}, + {name: "c", typeflag: tar.TypeReg, mode: 0o644, content: ""}, + {name: "d", typeflag: tar.TypeReg, mode: 0o644, content: ""}, + } + buf := createTarBuffer(t, entries) + + var remaining atomic.Int64 + remaining.Store(maxExtractSize) + err := extractTarSharedLimit(context.Background(), buf, t.TempDir(), &remaining) + require.Error(t, err) + assert.Contains(t, err.Error(), "maximum entry count") +} + func TestExtractTar_RejectsOversizedPayload(t *testing.T) { t.Parallel() diff --git a/image/tar_security.go b/image/tar_security.go index 8019e0f..ee3a874 100644 --- a/image/tar_security.go +++ b/image/tar_security.go @@ -4,7 +4,10 @@ package image import ( + "fmt" "os" + "path/filepath" + "strings" "github.com/stacklok/go-microvm/internal/pathutil" ) @@ -26,3 +29,56 @@ func MkdirAllNoSymlink(destDir, targetDir string, mode os.FileMode) error { func ValidateNoSymlinkLeaf(target string) error { return validateNoSymlinkLeaf(target) } + +// SafeWalk resolves rel under root and verifies that no parent directory +// component along the way is a symlink. Returns the cleaned absolute path. +// +// The leaf itself is not inspected — callers that need to restrict leaf type +// (for example before ReadDir or Open) should Lstat the returned path and +// check ModeSymlink / IsDir explicitly. Callers that will call RemoveAll on +// the leaf may pass the returned path directly, since RemoveAll does not +// follow a symlink leaf. +// +// Use this before host-side operations on paths derived from untrusted tar +// metadata or similar sources. A malicious layer planting a symlink as an +// intermediate directory component would otherwise cause the subsequent +// host-side operation to redirect outside root. +func SafeWalk(root, rel string) (string, error) { + absPath, err := SanitizeTarPath(root, rel) + if err != nil { + return "", err + } + cleanRoot := filepath.Clean(root) + if absPath == cleanRoot { + return absPath, nil + } + parent := filepath.Dir(absPath) + relParent, err := filepath.Rel(cleanRoot, parent) + if err != nil { + return "", fmt.Errorf("compute relative parent: %w", err) + } + if relParent == "." { + return absPath, nil + } + cur := cleanRoot + for _, p := range strings.Split(relParent, string(os.PathSeparator)) { + if p == "" || p == "." { + continue + } + cur = filepath.Join(cur, p) + info, err := os.Lstat(cur) + if err != nil { + if os.IsNotExist(err) { + return "", fmt.Errorf("missing parent directory %s: %w", cur, err) + } + return "", fmt.Errorf("stat %s: %w", cur, err) + } + if info.Mode()&os.ModeSymlink != 0 { + return "", fmt.Errorf("refusing to traverse symlink: %s", cur) + } + if !info.IsDir() { + return "", fmt.Errorf("non-directory in path: %s", cur) + } + } + return absPath, nil +} diff --git a/image/tar_security_test.go b/image/tar_security_test.go new file mode 100644 index 0000000..461678b --- /dev/null +++ b/image/tar_security_test.go @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package image + +import ( + "errors" + "io/fs" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSafeWalk(t *testing.T) { + t.Parallel() + + t.Run("root itself resolves", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + got, err := SafeWalk(root, ".") + require.NoError(t, err) + assert.Equal(t, filepath.Clean(root), got) + }) + + t.Run("normal path resolves", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, "a", "b"), 0o755)) + got, err := SafeWalk(root, "a/b/file") + require.NoError(t, err) + assert.Equal(t, filepath.Join(root, "a", "b", "file"), got) + }) + + t.Run("nonexistent leaf is allowed", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, "a"), 0o755)) + got, err := SafeWalk(root, "a/missing") + require.NoError(t, err) + assert.Equal(t, filepath.Join(root, "a", "missing"), got) + }) + + t.Run("leaf is a symlink — allowed, not inspected", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, "a"), 0o755)) + outside := t.TempDir() + require.NoError(t, os.Symlink(outside, filepath.Join(root, "a", "link"))) + got, err := SafeWalk(root, "a/link") + require.NoError(t, err) + assert.Equal(t, filepath.Join(root, "a", "link"), got) + }) + + t.Run("mid-path symlink is refused", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + outside := t.TempDir() + require.NoError(t, os.Symlink(outside, filepath.Join(root, "a"))) + _, err := SafeWalk(root, "a/b/file") + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") + }) + + t.Run("parent symlink is refused even when leaf exists through it", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + outside := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(outside, "existing"), 0o755)) + require.NoError(t, os.Symlink(outside, filepath.Join(root, "a"))) + _, err := SafeWalk(root, "a/existing") + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") + }) + + t.Run("missing parent directory is refused", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + _, err := SafeWalk(root, "nonexistent/child") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing parent") + assert.True(t, errors.Is(err, fs.ErrNotExist), + "missing-parent error should unwrap to fs.ErrNotExist") + }) + + t.Run("non-directory parent is refused", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(root, "file"), []byte("x"), 0o644)) + _, err := SafeWalk(root, "file/child") + require.Error(t, err) + assert.Contains(t, err.Error(), "non-directory") + }) + + t.Run("path escapes root", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + _, err := SafeWalk(root, "../escape") + require.Error(t, err) + assert.Contains(t, err.Error(), "path traversal") + }) +} diff --git a/image/whiteout.go b/image/whiteout.go index 04cb6ff..d8b1cdc 100644 --- a/image/whiteout.go +++ b/image/whiteout.go @@ -4,7 +4,9 @@ package image import ( + "errors" "fmt" + "io/fs" "log/slog" "os" "path/filepath" @@ -29,16 +31,25 @@ func isOpaqueWhiteout(name string) bool { // applyWhiteout removes the file or directory targeted by a whiteout entry. // The name parameter is a relative path within the rootfs (e.g., "usr/lib/.wh.oldlib"). +// +// Parent directory components are validated via SafeWalk, so a symlink +// planted as an intermediate component cannot redirect the RemoveAll onto +// the host filesystem. RemoveAll itself does not follow a symlink leaf, so +// whiteout-on-symlink (a legitimate OCI pattern) still works correctly. func applyWhiteout(rootDir, name string) error { dirPart := filepath.Dir(name) base := filepath.Base(name) targetName := strings.TrimPrefix(base, whiteoutPrefix) + relPath := filepath.Join(dirPart, targetName) - fullPath := filepath.Clean(filepath.Join(rootDir, dirPart, targetName)) - - // Validate the resolved path stays within rootDir. - if rel, err := filepath.Rel(rootDir, fullPath); err != nil || strings.HasPrefix(rel, "..") { - return fmt.Errorf("whiteout target escapes rootfs: %s", name) + fullPath, err := SafeWalk(rootDir, relPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + // A parent directory does not exist — the whiteout target + // cannot exist either, so this is a no-op. + return nil + } + return fmt.Errorf("applying whiteout for %s: %w", name, err) } slog.Debug("applying whiteout", "target", fullPath) @@ -51,19 +62,36 @@ func applyWhiteout(rootDir, name string) error { // applyOpaqueWhiteout removes all entries inside a directory, keeping the directory itself. // The dirPath parameter is relative to rootDir (e.g., "usr/lib"). +// +// Parent directory components are validated via SafeWalk, and the leaf is +// refused if it is a symlink — os.ReadDir follows symlinks, so a leaf +// symlink pointing outside the rootfs would otherwise enumerate and remove +// host files. func applyOpaqueWhiteout(rootDir, dirPath string) error { - fullDir := filepath.Clean(filepath.Join(rootDir, dirPath)) - - // Validate the resolved path stays within rootDir. - if rel, err := filepath.Rel(rootDir, fullDir); err != nil || strings.HasPrefix(rel, "..") { - return fmt.Errorf("opaque whiteout target escapes rootfs: %s", dirPath) + fullDir, err := SafeWalk(rootDir, dirPath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return fmt.Errorf("opaque whiteout for %s: %w", dirPath, err) } - entries, err := os.ReadDir(fullDir) + info, err := os.Lstat(fullDir) if err != nil { - if os.IsNotExist(err) { + if errors.Is(err, fs.ErrNotExist) { return nil } + return fmt.Errorf("stat opaque whiteout target %s: %w", dirPath, err) + } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to follow symlink for opaque whiteout: %s", dirPath) + } + if !info.IsDir() { + return fmt.Errorf("opaque whiteout target is not a directory: %s", dirPath) + } + + entries, err := os.ReadDir(fullDir) + if err != nil { return fmt.Errorf("reading directory for opaque whiteout %s: %w", dirPath, err) } diff --git a/image/whiteout_test.go b/image/whiteout_test.go index 1b230b6..257c876 100644 --- a/image/whiteout_test.go +++ b/image/whiteout_test.go @@ -103,7 +103,29 @@ func TestApplyWhiteout(t *testing.T) { err := applyWhiteout(root, "../../etc/.wh.passwd") require.Error(t, err) - assert.Contains(t, err.Error(), "escapes rootfs") + assert.Contains(t, err.Error(), "path traversal") + }) + + t.Run("refuses to walk through a symlink parent", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + + // An attacker-planted layer substitutes etc with a symlink to a + // host-owned directory. Before SafeWalk, a subsequent whiteout on + // etc/.wh.passwd would RemoveAll through the symlink. + outside := t.TempDir() + victim := filepath.Join(outside, "passwd") + require.NoError(t, os.WriteFile(victim, []byte("original"), 0o600)) + require.NoError(t, os.Symlink(outside, filepath.Join(root, "etc"))) + + err := applyWhiteout(root, "etc/.wh.passwd") + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") + + // The host-side file must be untouched. + got, readErr := os.ReadFile(victim) + require.NoError(t, readErr) + assert.Equal(t, "original", string(got)) }) } @@ -148,6 +170,42 @@ func TestApplyOpaqueWhiteout(t *testing.T) { err := applyOpaqueWhiteout(root, "../../etc") require.Error(t, err) - assert.Contains(t, err.Error(), "escapes rootfs") + assert.Contains(t, err.Error(), "path traversal") + }) + + t.Run("refuses to walk through a symlink parent", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + + outside := t.TempDir() + require.NoError(t, os.Symlink(outside, filepath.Join(root, "usr"))) + + err := applyOpaqueWhiteout(root, "usr/lib") + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") + }) + + t.Run("refuses to follow a symlink leaf", func(t *testing.T) { + t.Parallel() + root := t.TempDir() + + // The parent is a real dir but the target itself is a symlink + // pointing outside the rootfs. os.ReadDir would follow and + // enumerate host files without this guard. + outside := t.TempDir() + victim := filepath.Join(outside, "victim") + require.NoError(t, os.WriteFile(victim, []byte("original"), 0o600)) + + usrDir := filepath.Join(root, "usr") + require.NoError(t, os.MkdirAll(usrDir, 0o755)) + require.NoError(t, os.Symlink(outside, filepath.Join(usrDir, "lib"))) + + err := applyOpaqueWhiteout(root, "usr/lib") + require.Error(t, err) + assert.Contains(t, err.Error(), "symlink") + + got, readErr := os.ReadFile(victim) + require.NoError(t, readErr) + assert.Equal(t, "original", string(got)) }) }