Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions hooks/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"regexp"
"sort"
"strings"
"syscall"

"github.com/stacklok/go-microvm/guest/vmconfig"
"github.com/stacklok/go-microvm/image"
Expand Down Expand Up @@ -86,7 +87,7 @@ func InjectAuthorizedKeys(pubKey string, opts ...KeyOption) func(string, *image.
if err != nil {
return fmt.Errorf("validate .ssh path: %w", err)
}
if err := os.MkdirAll(sshDir, 0o700); err != nil {
if err := image.MkdirAllNoSymlink(rootfsPath, sshDir, 0o700); err != nil {
return fmt.Errorf("create .ssh dir: %w", err)
}
if err := cfg.chown(sshDir, cfg.uid, cfg.gid); err != nil {
Expand All @@ -98,7 +99,10 @@ func InjectAuthorizedKeys(pubKey string, opts ...KeyOption) func(string, *image.
if err != nil {
return fmt.Errorf("validate authorized_keys path: %w", err)
}
if err := os.WriteFile(akPath, []byte(pubKey+"\n"), 0o600); err != nil {
if err := image.ValidateNoSymlinkLeaf(akPath); err != nil {
return fmt.Errorf("validate authorized_keys: %w", err)
}
if err := writeFileNoFollow(akPath, []byte(pubKey+"\n"), 0o600); err != nil {
return fmt.Errorf("write authorized_keys: %w", err)
}
if err := cfg.chown(akPath, cfg.uid, cfg.gid); err != nil {
Expand All @@ -109,19 +113,40 @@ func InjectAuthorizedKeys(pubKey string, opts ...KeyOption) func(string, *image.
}
}

// writeFileNoFollow writes data to path with O_NOFOLLOW on the final open,
// refusing to write through a symlink leaf even if one races into place
// between the caller's validation and this open. Creates the file if absent,
// truncates if present. Parent directories must already exist.
func writeFileNoFollow(path string, data []byte, perm os.FileMode) error {
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC|syscall.O_NOFOLLOW, perm)
if err != nil {
return err
}
if _, err := f.Write(data); err != nil {
_ = f.Close()
return err
}
return f.Close()
}

// InjectFile returns a RootFSHook that writes content to the specified guest
// path inside the rootfs with the given permissions. Parent directories are
// created as needed.
// created as needed. Symlink components — whether in parent directories or at
// the leaf — are refused so that a rootfs planted with hostile symlinks cannot
// redirect the write outside the rootfs.
func InjectFile(guestPath string, content []byte, perm os.FileMode) func(string, *image.OCIConfig) error {
return func(rootfsPath string, _ *image.OCIConfig) error {
dst, err := pathutil.Contains(rootfsPath, guestPath)
if err != nil {
return fmt.Errorf("validate path %s: %w", guestPath, err)
}
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
if err := image.MkdirAllNoSymlink(rootfsPath, filepath.Dir(dst), 0o755); err != nil {
return fmt.Errorf("create parent dirs for %s: %w", guestPath, err)
}
if err := os.WriteFile(dst, content, perm); err != nil {
if err := image.ValidateNoSymlinkLeaf(dst); err != nil {
return fmt.Errorf("validate %s: %w", guestPath, err)
}
if err := writeFileNoFollow(dst, content, perm); err != nil {
return fmt.Errorf("write %s: %w", guestPath, err)
}
return nil
Expand Down Expand Up @@ -167,10 +192,13 @@ func InjectEnvFile(guestPath string, envMap map[string]string) func(string, *ima
if err != nil {
return fmt.Errorf("validate path %s: %w", guestPath, err)
}
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
if err := image.MkdirAllNoSymlink(rootfsPath, filepath.Dir(dst), 0o755); err != nil {
return fmt.Errorf("create parent dirs for %s: %w", guestPath, err)
}
if err := os.WriteFile(dst, []byte(buf.String()), 0o600); err != nil {
if err := image.ValidateNoSymlinkLeaf(dst); err != nil {
return fmt.Errorf("validate %s: %w", guestPath, err)
}
if err := writeFileNoFollow(dst, []byte(buf.String()), 0o600); err != nil {
return fmt.Errorf("write %s: %w", guestPath, err)
}
return nil
Expand All @@ -183,20 +211,21 @@ func shellEscape(s string) string {
return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'"
}

// BestEffortLchown attempts os.Lchown and silently ignores permission errors,
// returning nil. On macOS non-root users cannot chown to a different UID;
// the guest init will fix ownership at boot time. Non-permission errors are
// logged at warn level and also swallowed. Callers that need strict chown
// should call os.Lchown directly instead.
// BestEffortLchown attempts os.Lchown and swallows permission errors (EPERM
// and EACCES). On non-root Linux and on macOS the hook process cannot lchown
// to a different UID; in those cases the override_stat xattr carries the
// intended ownership to the guest, and the guest init fixes up ownership at
// boot. Errors other than permission denied (e.g. ENOENT, EROFS, EIO) are
// returned to the caller rather than silently dropped.
// Lchown is used instead of Chown to avoid following symlinks in the rootfs.
func BestEffortLchown(path string, uid, gid int) error {
if err := os.Lchown(path, uid, gid); err != nil {
if !os.IsPermission(err) {
slog.Warn("lchown failed", "path", path, "uid", uid, "gid", gid, "err", err)
return fmt.Errorf("lchown %s: %w", path, err)
}
slog.Debug("lchown permission denied; relying on xattr + guest fixup",
"path", path, "uid", uid, "gid", gid)
}
// On macOS, set the override_stat xattr so libkrun's virtiofs reports
// correct ownership to the guest (non-root cannot Lchown to a different UID).
xattr.SetOverrideStatFromPath(path, uid, gid)
return nil
}
233 changes: 233 additions & 0 deletions hooks/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ func TestInjectVMConfig(t *testing.T) {
}
}

func TestInjectVMConfig_RejectsSymlinkComponents(t *testing.T) {
t.Parallel()

// Guard the delegation chain: InjectVMConfig -> InjectFile.
// If InjectFile's symlink safety regresses, this test catches it.
rootfs := t.TempDir()
outside := t.TempDir()
stageSymlink(t, rootfs, "etc", outside)

hook := InjectVMConfig(vmconfig.Config{TmpSizeMiB: 512})
err := hook(rootfs, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "symlink")

_, statErr := os.Stat(filepath.Join(outside, "go-microvm.json"))
assert.True(t, os.IsNotExist(statErr), "must not write under symlink target")
}

func TestInjectFile_WritesContent(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -325,6 +343,61 @@ func TestInjectBinary_RejectsPathTraversal(t *testing.T) {
assert.Contains(t, err.Error(), "path traversal")
}

func TestInjectFile_RejectsSymlinkComponents(t *testing.T) {
t.Parallel()

t.Run("parent directory is a symlink", func(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
outside := t.TempDir()
stageSymlink(t, rootfs, "etc", outside)

hook := InjectFile("/etc/myconfig.txt", []byte("hello"), 0o644)
err := hook(rootfs, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "symlink")

_, statErr := os.Stat(filepath.Join(outside, "myconfig.txt"))
assert.True(t, os.IsNotExist(statErr), "must not write under symlink target")
})

t.Run("leaf is a symlink to a host file", func(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(rootfs, "etc"), 0o755))

victim := filepath.Join(t.TempDir(), "victim")
require.NoError(t, os.WriteFile(victim, []byte("original"), 0o600))
require.NoError(t, os.Symlink(victim, filepath.Join(rootfs, "etc", "myconfig.txt")))

hook := InjectFile("/etc/myconfig.txt", []byte("evil"), 0o644)
err := hook(rootfs, nil)
require.Error(t, err)

got, readErr := os.ReadFile(victim)
require.NoError(t, readErr)
assert.Equal(t, "original", string(got), "victim must not be overwritten")
})
}

func TestInjectBinary_RejectsSymlinkComponents(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
outside := t.TempDir()
stageSymlink(t, rootfs, "usr", outside)

hook := InjectBinary("/usr/bin/mytool", []byte("#!/bin/sh\necho hi"))
err := hook(rootfs, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "symlink")

_, statErr := os.Stat(filepath.Join(outside, "bin", "mytool"))
assert.True(t, os.IsNotExist(statErr), "must not write under symlink target")
}

func TestInjectEnvFile_RejectsPathTraversal(t *testing.T) {
t.Parallel()

Expand All @@ -335,6 +408,45 @@ func TestInjectEnvFile_RejectsPathTraversal(t *testing.T) {
assert.Contains(t, err.Error(), "path traversal")
}

func TestInjectEnvFile_RejectsSymlinkComponents(t *testing.T) {
t.Parallel()

t.Run("parent directory is a symlink", func(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
outside := t.TempDir()
stageSymlink(t, rootfs, "etc", outside)

hook := InjectEnvFile("/etc/env", map[string]string{"FOO": "bar"})
err := hook(rootfs, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "symlink")

_, statErr := os.Stat(filepath.Join(outside, "env"))
assert.True(t, os.IsNotExist(statErr), "must not write under symlink target")
})

t.Run("leaf is a symlink to a host file", func(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(rootfs, "etc"), 0o755))

victim := filepath.Join(t.TempDir(), "victim")
require.NoError(t, os.WriteFile(victim, []byte("original"), 0o600))
require.NoError(t, os.Symlink(victim, filepath.Join(rootfs, "etc", "env")))

hook := InjectEnvFile("/etc/env", map[string]string{"FOO": "evil"})
err := hook(rootfs, nil)
require.Error(t, err)

got, readErr := os.ReadFile(victim)
require.NoError(t, readErr)
assert.Equal(t, "original", string(got), "victim must not be overwritten")
})
}

// failingChown returns a ChownFunc that returns an error when the path
// ends with the given suffix, and succeeds otherwise.
func failingChown(pathSuffix string) ChownFunc {
Expand Down Expand Up @@ -427,6 +539,70 @@ func TestInjectAuthorizedKeys_RejectsPathTraversal(t *testing.T) {
assert.Contains(t, err.Error(), "path traversal")
}

func TestInjectAuthorizedKeys_RejectsSymlinkComponents(t *testing.T) {
t.Parallel()

chown, _ := recordingChown()
pubKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAATEST test@example.com"

t.Run("home component is an absolute symlink out of rootfs", func(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
outside := t.TempDir()
stageSymlink(t, rootfs, "home", outside)

hook := InjectAuthorizedKeys(pubKey, WithChown(chown))
err := hook(rootfs, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "symlink")

// Nothing must have been written under the symlink target.
_, statErr := os.Stat(filepath.Join(outside, "sandbox", ".ssh", "authorized_keys"))
assert.True(t, os.IsNotExist(statErr), "must not write under symlink target")
})

t.Run("dot-ssh component is a relative escaping symlink", func(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(rootfs, "home", "sandbox"), 0o755))
// .ssh points two levels up, escaping the rootfs lexically after resolution.
outside := t.TempDir()
stageSymlink(t, rootfs, "home/sandbox/.ssh", outside)

hook := InjectAuthorizedKeys(pubKey, WithChown(chown))
err := hook(rootfs, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "symlink")

_, statErr := os.Stat(filepath.Join(outside, "authorized_keys"))
assert.True(t, os.IsNotExist(statErr), "must not write under symlink target")
})

t.Run("authorized_keys leaf is a symlink to a host file", func(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
sshDir := filepath.Join(rootfs, "home", "sandbox", ".ssh")
require.NoError(t, os.MkdirAll(sshDir, 0o700))

// An attacker-planted symlink at the leaf points to an arbitrary host file.
victim := filepath.Join(t.TempDir(), "victim")
require.NoError(t, os.WriteFile(victim, []byte("original"), 0o600))
require.NoError(t, os.Symlink(victim, filepath.Join(sshDir, "authorized_keys")))

hook := InjectAuthorizedKeys(pubKey, WithChown(chown))
err := hook(rootfs, nil)
require.Error(t, err)

// The host-side file must be untouched.
got, readErr := os.ReadFile(victim)
require.NoError(t, readErr)
assert.Equal(t, "original", string(got), "victim must not be overwritten")
})
}

func TestInjectEnvFile_RejectsInvalidKeyNames(t *testing.T) {
t.Parallel()

Expand All @@ -453,6 +629,63 @@ func TestInjectEnvFile_RejectsInvalidKeyNames(t *testing.T) {
}
}

// stageSymlink places a symlink at rootfs/linkPath pointing to target.
// Parent directories of linkPath inside rootfs are created as 0o755.
// Use this to build rootfs fixtures that exercise symlink-following behavior
// in hook code — e.g. a malicious layer shipping `home/sandbox/.ssh` as a
// symlink to somewhere outside the rootfs.
func stageSymlink(t *testing.T, rootfs, linkPath, target string) {
t.Helper()
abs := filepath.Join(rootfs, linkPath)
require.NoError(t, os.MkdirAll(filepath.Dir(abs), 0o755))
require.NoError(t, os.Symlink(target, abs))
}

func TestStageSymlink(t *testing.T) {
t.Parallel()

rootfs := t.TempDir()
outside := t.TempDir()

stageSymlink(t, rootfs, "home/sandbox/.ssh", outside)

info, err := os.Lstat(filepath.Join(rootfs, "home", "sandbox", ".ssh"))
require.NoError(t, err)
assert.NotZero(t, info.Mode()&os.ModeSymlink, ".ssh should be a symlink")

dest, err := os.Readlink(filepath.Join(rootfs, "home", "sandbox", ".ssh"))
require.NoError(t, err)
assert.Equal(t, outside, dest)
}

func TestBestEffortLchown_PropagatesNonPermissionErrors(t *testing.T) {
t.Parallel()

// ENOENT from a non-existent path is not a permission error; the function
// must return it rather than silently swallowing.
missing := filepath.Join(t.TempDir(), "does-not-exist")
err := BestEffortLchown(missing, 1000, 1000)
require.Error(t, err)
assert.Contains(t, err.Error(), "lchown")
}

func TestBestEffortLchown_SwallowsPermissionErrors(t *testing.T) {
t.Parallel()

if os.Geteuid() == 0 {
t.Skip("requires non-root: root can chown to any UID, so no EPERM")
}

// Create a file we own; chowning to a UID we don't own should fail with
// EPERM on Linux and macOS as a non-root user. BestEffortLchown must
// swallow that specific error.
target := filepath.Join(t.TempDir(), "target")
require.NoError(t, os.WriteFile(target, []byte("x"), 0o600))

err := BestEffortLchown(target, 1, 1)
require.NoError(t, err, "permission error must be swallowed, got: %v", err)
}

func TestShellEscape(t *testing.T) {
t.Parallel()

Expand Down
Loading