Skip to content

Commit 2f71e8f

Browse files
authored
Merge pull request #207 from thaJeztah/reexec_cleanup_tests
reexec: improve test-coverage, and use blackbox testing
2 parents bdb4ad6 + ad92e49 commit 2f71e8f

File tree

3 files changed

+147
-32
lines changed

3 files changed

+147
-32
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Package reexecoverride provides test utilities for overriding argv0 as
2+
// observed by reexec.Self within the current process.
3+
4+
package reexecoverride
5+
6+
import "sync/atomic"
7+
8+
// argv0Override holds an optional override for os.Args[0] used by reexec.Self.
9+
var argv0Override atomic.Pointer[string]
10+
11+
// Argv0 returns the overridden argv0 if set.
12+
func Argv0() (string, bool) {
13+
p := argv0Override.Load()
14+
if p == nil {
15+
return "", false
16+
}
17+
return *p, true
18+
}
19+
20+
// TestingTB is the minimal subset of [testing.TB] used by this package.
21+
type TestingTB interface {
22+
Helper()
23+
Cleanup(func())
24+
}
25+
26+
// OverrideArgv0 overrides the argv0 value observed by reexec.Self for the
27+
// lifetime of the calling test and restores it via [testing.TB.Cleanup].
28+
//
29+
// The override is process-global. Tests using OverrideArgv0 must not run in
30+
// parallel with other tests that call reexec.Self. OverrideArgv0 panics if an
31+
// override is already active.
32+
func OverrideArgv0(t TestingTB, argv0 string) {
33+
t.Helper()
34+
35+
s := argv0
36+
if !argv0Override.CompareAndSwap(nil, &s) {
37+
panic("testing: test using reexecoverride.OverrideArgv0 cannot use t.Parallel")
38+
}
39+
40+
t.Cleanup(func() {
41+
if !argv0Override.CompareAndSwap(&s, nil) {
42+
panic("testing: cleanup for reexecoverride.OverrideArgv0 detected parallel use of reexec.Self")
43+
}
44+
})
45+
}

reexec/reexec.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"os/exec"
1414
"path/filepath"
1515
"runtime"
16+
17+
"github.com/moby/sys/reexec/internal/reexecoverride"
1618
)
1719

1820
var registeredInitializers = make(map[string]func())
@@ -78,14 +80,18 @@ func CommandContext(ctx context.Context, args ...string) *exec.Cmd {
7880
// "my-binary" at "/usr/bin/" (or "my-binary.exe" at "C:\" on Windows),
7981
// then it returns "/usr/bin/my-binary" and "C:\my-binary.exe" respectively.
8082
func Self() string {
83+
if argv0, ok := reexecoverride.Argv0(); ok {
84+
return naiveSelf(argv0)
85+
}
8186
if runtime.GOOS == "linux" {
8287
return "/proc/self/exe"
8388
}
84-
return naiveSelf()
89+
return naiveSelf(os.Args[0])
8590
}
8691

87-
func naiveSelf() string {
88-
name := os.Args[0]
92+
// naiveSelf is a separate function to allow testing in isolation on Linux.
93+
func naiveSelf(argv0 string) string {
94+
name := argv0
8995
if filepath.Base(name) == name {
9096
if lp, err := exec.LookPath(name); err == nil {
9197
return lp

reexec/reexec_test.go

Lines changed: 93 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
package reexec
1+
package reexec_test
22

33
import (
44
"context"
55
"errors"
66
"fmt"
77
"os"
8-
"os/exec"
98
"path/filepath"
109
"reflect"
10+
"runtime"
1111
"strings"
1212
"testing"
1313
"time"
14+
15+
"github.com/moby/sys/reexec"
16+
"github.com/moby/sys/reexec/internal/reexecoverride"
1417
)
1518

1619
const (
@@ -20,23 +23,26 @@ const (
2023
)
2124

2225
func init() {
23-
Register(testReExec, func() {
26+
reexec.Register(testReExec, func() {
2427
panic("Return Error")
2528
})
26-
Register(testReExec2, func() {
29+
reexec.Register(testReExec2, func() {
2730
var args string
2831
if len(os.Args) > 1 {
2932
args = fmt.Sprintf("(args: %#v)", os.Args[1:])
3033
}
3134
fmt.Println("Hello", testReExec2, args)
3235
os.Exit(0)
3336
})
34-
Register(testReExec3, func() {
37+
reexec.Register(testReExec3, func() {
3538
fmt.Println("Hello " + testReExec3)
3639
time.Sleep(1 * time.Second)
3740
os.Exit(0)
3841
})
39-
Init()
42+
if reexec.Init() {
43+
// Make sure we exit in case re-exec didn't os.Exit on its own.
44+
os.Exit(0)
45+
}
4046
}
4147

4248
func TestRegister(t *testing.T) {
@@ -69,7 +75,7 @@ func TestRegister(t *testing.T) {
6975
t.Errorf("got %q, want %q", r, tc.expectedErr)
7076
}
7177
}()
72-
Register(tc.name, func() {})
78+
reexec.Register(tc.name, func() {})
7379
})
7480
}
7581
}
@@ -98,7 +104,7 @@ func TestCommand(t *testing.T) {
98104
}
99105
for _, tc := range tests {
100106
t.Run(tc.doc, func(t *testing.T) {
101-
cmd := Command(tc.cmdAndArgs...)
107+
cmd := reexec.Command(tc.cmdAndArgs...)
102108
if !reflect.DeepEqual(cmd.Args, tc.cmdAndArgs) {
103109
t.Fatalf("got %+v, want %+v", cmd.Args, tc.cmdAndArgs)
104110
}
@@ -165,7 +171,7 @@ func TestCommandContext(t *testing.T) {
165171
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
166172
defer cancel()
167173

168-
cmd := CommandContext(ctx, tc.cmdAndArgs...)
174+
cmd := reexec.CommandContext(ctx, tc.cmdAndArgs...)
169175
if !reflect.DeepEqual(cmd.Args, tc.cmdAndArgs) {
170176
t.Fatalf("got %+v, want %+v", cmd.Args, tc.cmdAndArgs)
171177
}
@@ -194,30 +200,88 @@ func TestCommandContext(t *testing.T) {
194200
}
195201
}
196202

197-
func TestNaiveSelf(t *testing.T) {
198-
if os.Getenv("TEST_CHECK") == "1" {
199-
os.Exit(2)
200-
}
201-
cmd := exec.Command(naiveSelf(), "-test.run=TestNaiveSelf")
202-
cmd.Env = append(os.Environ(), "TEST_CHECK=1")
203-
err := cmd.Start()
203+
// TestRunNaiveSelf verifies that reexec.Self() (and thus CommandContext)
204+
// can resolve a path that can be used to re-execute the current test binary
205+
// when it falls back to the argv[0]-based implementation.
206+
//
207+
// It forces Self() to bypass the Linux /proc/self/exe fast-path via
208+
// [reexecoverride.OverrideArgv0] so that the fallback logic is exercised
209+
// consistently across platforms.
210+
func TestRunNaiveSelf(t *testing.T) {
211+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
212+
defer cancel()
213+
214+
// Force Self() to use naiveSelf(os.Args[0]), instead of "/proc/self/exe" on Linux.
215+
reexecoverride.OverrideArgv0(t, os.Args[0])
216+
217+
cmd := reexec.CommandContext(ctx, testReExec2)
218+
out, err := cmd.CombinedOutput()
204219
if err != nil {
205220
t.Fatalf("Unable to start command: %v", err)
206221
}
207-
err = cmd.Wait()
208222

209-
var expError *exec.ExitError
210-
if !errors.As(err, &expError) {
211-
t.Fatalf("got %T, want %T", err, expError)
212-
}
213-
214-
const expected = "exit status 2"
215-
if err.Error() != expected {
216-
t.Fatalf("got %v, want %v", err, expected)
223+
expOut := "Hello test-reexec2"
224+
actual := strings.TrimSpace(string(out))
225+
if actual != expOut {
226+
t.Errorf("got %v, want %v", actual, expOut)
217227
}
228+
}
218229

219-
os.Args[0] = "mkdir"
220-
if naiveSelf() == os.Args[0] {
221-
t.Fatalf("Expected naiveSelf to resolve the location of mkdir")
222-
}
230+
func TestNaiveSelfResolve(t *testing.T) {
231+
t.Run("fast path on Linux", func(t *testing.T) {
232+
if runtime.GOOS != "linux" {
233+
t.Skip("only supported on Linux")
234+
}
235+
resolved := reexec.Self()
236+
expected := "/proc/self/exe"
237+
if resolved != expected {
238+
t.Errorf("got %v, want %v", resolved, expected)
239+
}
240+
})
241+
t.Run("resolve in PATH", func(t *testing.T) {
242+
executable := "sh"
243+
if runtime.GOOS == "windows" {
244+
executable = "cmd"
245+
}
246+
reexecoverride.OverrideArgv0(t, executable)
247+
resolved := reexec.Self()
248+
if resolved == executable {
249+
t.Errorf("did not resolve via PATH; got %q", resolved)
250+
}
251+
if !filepath.IsAbs(resolved) {
252+
t.Errorf("expected absolute path; got %q", resolved)
253+
}
254+
})
255+
t.Run("not in PATH", func(t *testing.T) {
256+
const executable = "some-nonexistent-executable"
257+
want, err := filepath.Abs(executable)
258+
if err != nil {
259+
t.Fatal(err)
260+
}
261+
reexecoverride.OverrideArgv0(t, executable)
262+
resolved := reexec.Self()
263+
if resolved != want {
264+
t.Errorf("expected absolute path; got %q, want %q", resolved, want)
265+
}
266+
})
267+
t.Run("relative path", func(t *testing.T) {
268+
executable := filepath.Join(".", "some-executable")
269+
want, err := filepath.Abs(executable)
270+
if err != nil {
271+
t.Fatal(err)
272+
}
273+
reexecoverride.OverrideArgv0(t, executable)
274+
resolved := reexec.Self()
275+
if resolved != want {
276+
t.Errorf("expected absolute path; got %q, want %q", resolved, want)
277+
}
278+
})
279+
t.Run("absolute path unchanged", func(t *testing.T) {
280+
executable := filepath.Join(os.TempDir(), "some-executable")
281+
reexecoverride.OverrideArgv0(t, executable)
282+
resolved := reexec.Self()
283+
if resolved != executable {
284+
t.Errorf("should not modify absolute paths; got %q, want %q", resolved, executable)
285+
}
286+
})
223287
}

0 commit comments

Comments
 (0)