Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions agent/server/modes/host/command/command_docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,28 @@ import (
"github.com/shellhub-io/shellhub/agent/pkg/osauth"
)

// statFn is a seam for os.Stat used when probing /proc/1/ns/* entries.
// It can be replaced in tests to avoid filesystem access.
var statFn = os.Stat

// nsenterArgs builds the nsenter flag slice from the present namespace map.
// Docker always shares the host time namespace, so -T is never passed.
func nsenterArgs(present map[string]string) []string {
args := []string{}

for _, flag := range present {
args = append(args, flag)
}

return args
}

// CheckCredentialSwitch is a no-op in Docker mode: the agent relies on
// nsenter+setpriv for credential switching, so this check is not applicable.
func CheckCredentialSwitch() error {
return nil
}

func NewCmd(u *osauth.User, shell, term, host string, envs []string, command ...string) *exec.Cmd {
groups, err := osauth.ListGroups(u.Username)
if err != nil {
Expand Down Expand Up @@ -69,7 +91,8 @@ func getWrappedCommand(nsArgs []string, uid, gid uint32, groups []uint32, home s
"1",
}, nsArgs...)

nsenterCmd = append(nsenterCmd,
nsenterCmd = append(
nsenterCmd,
[]string{
"-S",
strconv.Itoa(int(uid)),
Expand All @@ -80,30 +103,36 @@ func getWrappedCommand(nsArgs []string, uid, gid uint32, groups []uint32, home s
return append(setPrivCmd, nsenterCmd...)
}

// nsenterCommandWrapper builds the full nsenter+setpriv command slice.
// It probes /proc/1/ns/* for each namespace using statFn, then delegates
// flag assembly to nsenterArgs. The time namespace is never joined because
// Docker always shares the host time namespace.
func nsenterCommandWrapper(uid, gid uint32, groups []uint32, home string, command ...string) ([]string, error) {
if _, err := os.Stat("/usr/bin/nsenter"); err != nil && !os.IsNotExist(err) {
if _, err := statFn("/usr/bin/nsenter"); err != nil && !os.IsNotExist(err) {
return nil, err
}

paths := map[string]string{
namespaces := map[string]string{
"mnt": "-m",
"uts": "-u",
"ipc": "-i",
"net": "-n",
"pid": "-p",
"cgroup": "-C",
"time": "-T",
}

args := []string{}
for path, params := range paths {
if _, err := os.Stat(fmt.Sprintf("/proc/1/ns/%s", path)); err != nil {
present := map[string]string{}

for ns, flag := range namespaces {
if _, err := statFn(fmt.Sprintf("/proc/1/ns/%s", ns)); err != nil {
continue
}

args = append(args, params)
present[ns] = flag
}

args := nsenterArgs(present)

return append(getWrappedCommand(args, uid, gid, groups, home), command...), nil
}

Expand Down
24 changes: 24 additions & 0 deletions agent/server/modes/host/command/command_docker_compile_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//go:build docker
// +build docker

package command

import "testing"

// Compile-time assertion: CheckCredentialSwitch is exported and has the expected
// signature (func() error) under the docker build tag. This mirrors the
// identical assertion in command_native_test.go for the !docker tag so that
// neither tag set can silently break the build.
var _ func() error = CheckCredentialSwitch

// TestCheckCredentialSwitchCompiles is a named test that anchors the
// compile-time assertion above so it shows up in the test run output.
// The assertion is evaluated at compile time; the test body itself is a
// simple pass-through that confirms the symbol is reachable at runtime too.
func TestCheckCredentialSwitchCompiles(t *testing.T) {
// Calling the function ensures the linker includes the symbol and that
// it actually returns nil in docker mode (documented no-op).
if err := CheckCredentialSwitch(); err != nil {
t.Errorf("CheckCredentialSwitch() returned unexpected error under -tags docker: %v", err)
}
}
176 changes: 176 additions & 0 deletions agent/server/modes/host/command/command_docker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
//go:build docker
// +build docker

package command

import (
"os"
"sort"
"testing"

"github.com/stretchr/testify/assert"
)

// TestNsenterCommandWrapper is an integration-level test for nsenterCommandWrapper.
// It injects statFn (to control which /proc/1/ns/* files appear present) and verifies:
// - -T is never present in the assembled command (time namespace is never joined)
// - namespace flags from statFn are included when present
// - absent namespace flags are excluded
// - the setpriv/nsenter prefix ordering is unchanged
// - the user command appears at the end
//
//nolint:paralleltest
func TestNsenterCommandWrapper(t *testing.T) {
origStatFn := statFn

t.Cleanup(func() {
statFn = origStatFn
})

// presentNSFiles is the set of /proc/1/ns/* files that statFn will report as present.
// We choose a deterministic subset: mnt, net, pid.
presentNSFiles := map[string]bool{
"/proc/1/ns/mnt": true,
"/proc/1/ns/net": true,
"/proc/1/ns/pid": true,
}

// statFn stub: /usr/bin/nsenter always exists; only the listed ns files exist.
statFn = func(path string) (os.FileInfo, error) {
if path == "/usr/bin/nsenter" {
return nil, nil
}

if presentNSFiles[path] {
return nil, nil
}

return nil, os.ErrNotExist
}

// nsFlags that must always appear (from the present map above).
expectedNSFlags := []string{"-m", "-n", "-p"}

t.Run("present ns flags included and -T never present", func(t *testing.T) {
cmd, err := nsenterCommandWrapper(1000, 1000, []uint32{1000}, "/home/user", "/bin/sh")
assert.NoError(t, err)

// -T must never be present.
assert.NotContains(t, cmd, "-T")

// All namespace flags from statFn must be present.
for _, flag := range expectedNSFlags {
assert.Contains(t, cmd, flag)
}

// Flags for absent namespaces must NOT be present.
for _, absent := range []string{"-u", "-i", "-C"} {
assert.NotContains(t, cmd, absent)
}

// setpriv/nsenter prefix ordering: /bin/setpriv must come first,
// /usr/bin/nsenter must follow.
setprivIdx := indexOf(cmd, "/bin/setpriv")
nsenterIdx := indexOf(cmd, "/usr/bin/nsenter")
assert.NotEqual(t, -1, setprivIdx, "/bin/setpriv must be present")
assert.NotEqual(t, -1, nsenterIdx, "/usr/bin/nsenter must be present")
assert.Less(t, setprivIdx, nsenterIdx, "/bin/setpriv must precede /usr/bin/nsenter")

// The user command must appear at the end.
assert.Equal(t, "/bin/sh", cmd[len(cmd)-1])
})

t.Run("statFn controls which ns flags appear", func(t *testing.T) {
// Override statFn so only net is present.
statFn = func(path string) (os.FileInfo, error) {
if path == "/usr/bin/nsenter" || path == "/proc/1/ns/net" {
return nil, nil
}

return nil, os.ErrNotExist
}

cmd, err := nsenterCommandWrapper(1000, 1000, []uint32{1000}, "/home/user", "/bin/bash")
assert.NoError(t, err)

assert.Contains(t, cmd, "-n")
assert.NotContains(t, cmd, "-T")

for _, absent := range []string{"-m", "-u", "-i", "-p", "-C"} {
assert.NotContains(t, cmd, absent)
}
})

t.Run("multiple ns flags from statFn — no -T", func(t *testing.T) {
// Override statFn so mnt and uts are present.
statFn = func(path string) (os.FileInfo, error) {
if path == "/usr/bin/nsenter" || path == "/proc/1/ns/mnt" || path == "/proc/1/ns/uts" {
return nil, nil
}

return nil, os.ErrNotExist
}

cmd, err := nsenterCommandWrapper(1000, 1000, []uint32{1000}, "/home/user", "/bin/bash")
assert.NoError(t, err)

assert.Contains(t, cmd, "-m")
assert.Contains(t, cmd, "-u")
assert.NotContains(t, cmd, "-T")

for _, absent := range []string{"-i", "-n", "-p", "-C"} {
assert.NotContains(t, cmd, absent)
}
})
}

// indexOf returns the index of target in slice, or -1 if not found.
func indexOf(slice []string, target string) int {
for i, s := range slice {
if s == target {
return i
}
}

return -1
}

func TestNsenterArgs(t *testing.T) {
t.Run("all flags in present are forwarded unchanged", func(t *testing.T) {
present := map[string]string{
"mnt": "-m",
"uts": "-u",
"ipc": "-i",
"net": "-n",
"pid": "-p",
"cgroup": "-C",
}
args := nsenterArgs(present)

got := make([]string, len(args))
copy(got, args)
sort.Strings(got)

expected := []string{"-C", "-i", "-m", "-n", "-p", "-u"}
assert.Equal(t, expected, got)
})

t.Run("-T never appears regardless of input", func(t *testing.T) {
present := map[string]string{
"mnt": "-m",
"net": "-n",
}
args := nsenterArgs(present)

assert.NotContains(t, args, "-T")
assert.Contains(t, args, "-m")
assert.Contains(t, args, "-n")
})

t.Run("empty present returns empty slice", func(t *testing.T) {
present := map[string]string{}
args := nsenterArgs(present)

assert.Empty(t, args)
})
}
56 changes: 55 additions & 1 deletion agent/server/modes/host/command/command_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,68 @@
package command

import (
"errors"
"os"
"os/exec"
"strings"
"syscall"

"github.com/shellhub-io/shellhub/agent/pkg/osauth"
log "github.com/sirupsen/logrus"
)

// geteuidFn is a seam for os.Geteuid used in setgroupsDenied and NewCmd.
// Tests can replace it to simulate running as root or non-root.
var geteuidFn = os.Geteuid

// readSetgroupsPolicyFn is a seam for reading /proc/self/setgroups.
// Tests can replace it to control the kernel policy value without filesystem access.
var readSetgroupsPolicyFn = func() ([]byte, error) {
return os.ReadFile("/proc/self/setgroups")
}

// setgroupsDenied reports whether the kernel has denied setgroups(2) for this
// process by checking /proc/self/setgroups.
//
// Return values:
// - true: the policy file trims to "deny".
// - false: the file does not exist (kernel too old or not in a user-ns); silent.
// - false: any other read error; a warning is emitted via the logger.
func setgroupsDenied() bool {
data, err := readSetgroupsPolicyFn()
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.WithError(err).Warn("failed to read /proc/self/setgroups; assuming setgroups is allowed")
}

return false
}

return strings.TrimSpace(string(data)) == "deny"
}

// CheckCredentialSwitch reports whether the process can switch credentials via
// setgroups(2). It is a pre-flight check that must be called before attempting
// to execute a command as a different user.
//
// Short-circuit: when the effective UID is not root (euid != 0), credential
// switching is a no-op and the check always succeeds (nil).
//
// When euid == 0, the kernel may still forbid setgroups inside an unprivileged
// user namespace (Linux ≥ 3.19). In that case the function returns a sentinel
// error whose message contains "setgroups denied in unprivileged user namespace".
func CheckCredentialSwitch() error {
if geteuidFn() != 0 {
return nil
}

if setgroupsDenied() {
return errors.New("setgroups denied in unprivileged user namespace")
}

return nil
}

func NewCmd(u *osauth.User, shell, term, host string, envs []string, command ...string) *exec.Cmd {
groups, err := osauth.ListGroups(u.Username)
if err != nil {
Expand Down Expand Up @@ -44,7 +98,7 @@ func NewCmd(u *osauth.User, shell, term, host string, envs []string, command ...
cmd.Dir = u.HomeDir
}

if os.Geteuid() == 0 {
if geteuidFn() == 0 {
cmd.SysProcAttr = &syscall.SysProcAttr{}
cmd.SysProcAttr.Credential = &syscall.Credential{Uid: u.UID, Gid: u.GID, Groups: groups}
}
Expand Down
Loading
Loading