Skip to content

Commit 9b06e02

Browse files
committed
Add cleanup for VM state
1 parent 97bf26b commit 9b06e02

8 files changed

Lines changed: 265 additions & 2 deletions

File tree

image/cache.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ func NewCache(baseDir string) *Cache {
2323
return &Cache{baseDir: baseDir}
2424
}
2525

26+
// BaseDir returns the root directory used for cached rootfs entries.
27+
func (c *Cache) BaseDir() string {
28+
if c == nil {
29+
return ""
30+
}
31+
return c.baseDir
32+
}
33+
2634
// Get returns the path to a cached rootfs for the given digest, and true
2735
// if it exists and appears valid. Returns ("", false) on a cache miss.
2836
func (c *Cache) Get(digest string) (string, bool) {

image/cache_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ func TestNewCache(t *testing.T) {
2222
assert.Equal(t, baseDir, c.baseDir)
2323
}
2424

25+
func TestCache_BaseDir(t *testing.T) {
26+
t.Parallel()
27+
28+
baseDir := "/some/cache/dir"
29+
c := NewCache(baseDir)
30+
31+
assert.Equal(t, baseDir, c.BaseDir())
32+
33+
var nilCache *Cache
34+
assert.Equal(t, "", nilCache.BaseDir())
35+
}
36+
2537
func TestCache_Has_EmptyCache(t *testing.T) {
2638
t.Parallel()
2739

options.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ type config struct {
7979
imageCache *image.Cache
8080
imageFetcher image.ImageFetcher // nil = default local-then-remote fallback
8181
spawner runner.Spawner // nil = default runner.Spawn
82+
cleanDataDir bool
83+
removeAll func(string) error
84+
stat func(string) (os.FileInfo, error)
8285
}
8386

8487
func defaultConfig() *config {
@@ -92,6 +95,8 @@ func defaultConfig() *config {
9295
preflight: preflight.Default(),
9396
imageCache: image.NewCache(filepath.Join(dataDir, "cache")),
9497
dataDir: dataDir,
98+
removeAll: os.RemoveAll,
99+
stat: os.Stat,
95100
}
96101
}
97102

@@ -234,6 +239,13 @@ func WithDataDir(path string) Option {
234239
})
235240
}
236241

242+
// WithCleanDataDir removes any existing data directory contents before boot.
243+
// Use only when the data dir is VM-scoped; this will also clear the image cache
244+
// if it lives under the data dir.
245+
func WithCleanDataDir() Option {
246+
return optionFunc(func(c *config) { c.cleanDataDir = true })
247+
}
248+
237249
// WithRunnerPath sets the path to the propolis-runner binary.
238250
// When empty, the runner is found via $PATH or alongside the calling binary.
239251
func WithRunnerPath(path string) Option {

options_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ func TestWithDataDir(t *testing.T) {
122122
assert.NotNil(t, cfg.imageCache) // imageCache should be recreated
123123
}
124124

125+
func TestWithCleanDataDir(t *testing.T) {
126+
t.Parallel()
127+
128+
cfg := defaultConfig()
129+
assert.False(t, cfg.cleanDataDir)
130+
131+
WithCleanDataDir().apply(cfg)
132+
assert.True(t, cfg.cleanDataDir)
133+
}
134+
125135
func TestWithRootFSPath(t *testing.T) {
126136
t.Parallel()
127137

propolis.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ func Run(ctx context.Context, imageRef string, opts ...Option) (*VM, error) {
3838
for _, opt := range opts {
3939
opt.apply(cfg)
4040
}
41+
if cfg.cleanDataDir {
42+
if err := cleanDataDir(cfg); err != nil {
43+
return nil, err
44+
}
45+
}
4146

4247
if err := os.MkdirAll(cfg.dataDir, 0o700); err != nil {
4348
return nil, fmt.Errorf("create data dir: %w", err)
@@ -158,6 +163,8 @@ func Run(ctx context.Context, imageRef string, opts ...Option) (*VM, error) {
158163
dataDir: cfg.dataDir,
159164
rootfsPath: rootfs.Path,
160165
ports: cfg.ports,
166+
cacheDir: cacheDir(cfg),
167+
removeAll: cfg.removeAll,
161168
}
162169

163170
// Best-effort state persistence for crash recovery.
@@ -192,6 +199,46 @@ func Run(ctx context.Context, imageRef string, opts ...Option) (*VM, error) {
192199
return vm, nil
193200
}
194201

202+
func cleanDataDir(cfg *config) error {
203+
if cfg.dataDir == "" {
204+
return nil
205+
}
206+
_, err := cfg.stat(cfg.dataDir)
207+
if err != nil {
208+
if os.IsNotExist(err) {
209+
return nil
210+
}
211+
return fmt.Errorf("check data dir: %w", err)
212+
}
213+
214+
var keep []string
215+
cache := cacheDir(cfg)
216+
if cache != "" && isWithin(cfg.dataDir, cache) {
217+
keep = append(keep, cache)
218+
}
219+
if cfg.rootfsPath != "" && isWithin(cfg.dataDir, cfg.rootfsPath) {
220+
keep = append(keep, cfg.rootfsPath)
221+
}
222+
if len(keep) > 0 {
223+
if err := removeDataDirContentsExcept(cfg.removeAll, cfg.dataDir, keep); err != nil {
224+
return fmt.Errorf("clean data dir contents: %w", err)
225+
}
226+
return nil
227+
}
228+
229+
if err := cfg.removeAll(cfg.dataDir); err != nil {
230+
return fmt.Errorf("clean data dir: %w", err)
231+
}
232+
return nil
233+
}
234+
235+
func cacheDir(cfg *config) string {
236+
if cfg == nil || cfg.imageCache == nil {
237+
return ""
238+
}
239+
return cfg.imageCache.BaseDir()
240+
}
241+
195242
func toRunnerPortForwards(ports []PortForward) []runner.PortForward {
196243
out := make([]runner.PortForward, len(ports))
197244
for i, p := range ports {

propolis_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,44 @@ func TestRun_SpawnFailure(t *testing.T) {
192192
assert.True(t, netProv.stopped)
193193
}
194194

195+
func TestRun_WithCleanDataDir_RemovesStaleState(t *testing.T) {
196+
t.Parallel()
197+
198+
dataDir := t.TempDir()
199+
cacheDir := filepath.Join(dataDir, "cache")
200+
stalePath := filepath.Join(dataDir, "stale.sock")
201+
202+
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
203+
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "marker"), []byte("cache"), 0o644))
204+
require.NoError(t, os.WriteFile(stalePath, []byte("stale"), 0o644))
205+
206+
rootfsDir := filepath.Join(dataDir, "rootfs")
207+
require.NoError(t, os.MkdirAll(rootfsDir, 0o755))
208+
209+
proc := &mockProcessHandle{pid: 1234, alive: true}
210+
netProv := &mockNetProvider{sockPath: "/tmp/fake.sock"}
211+
212+
vm, err := Run(context.Background(), "test:latest",
213+
WithDataDir(dataDir),
214+
WithCleanDataDir(),
215+
WithPreflightChecker(preflight.NewEmpty()),
216+
WithRootFSPath(rootfsDir),
217+
WithNetProvider(netProv),
218+
WithSpawner(&mockSpawner{proc: proc}),
219+
)
220+
require.NoError(t, err)
221+
require.NotNil(t, vm)
222+
223+
_, err = os.Stat(stalePath)
224+
assert.True(t, os.IsNotExist(err))
225+
226+
_, err = os.Stat(cacheDir)
227+
require.NoError(t, err)
228+
229+
_, err = os.Stat(rootfsDir)
230+
require.NoError(t, err)
231+
}
232+
195233
func TestRun_Success(t *testing.T) {
196234
t.Parallel()
197235

vm.go

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ import (
77
"context"
88
"fmt"
99
"log/slog"
10+
"os"
11+
"path/filepath"
12+
"strings"
1013
"time"
1114

1215
"github.com/stacklok/propolis/net"
@@ -22,6 +25,8 @@ type VM struct {
2225
dataDir string
2326
rootfsPath string
2427
ports []PortForward
28+
cacheDir string
29+
removeAll func(string) error
2530
}
2631

2732
// VMInfo contains status information about a VM.
@@ -80,14 +85,39 @@ func (vm *VM) Status(_ context.Context) (*VMInfo, error) {
8085
}
8186

8287
// Remove stops the VM and cleans up its rootfs and state.
88+
// If the image cache lives under the data dir, its contents are preserved.
8389
func (vm *VM) Remove(ctx context.Context) error {
8490
if vm.proc.IsAlive() {
8591
if err := vm.Stop(ctx); err != nil {
8692
return fmt.Errorf("stop before remove: %w", err)
8793
}
8894
}
89-
// Note: we intentionally do NOT remove the image cache —
90-
// only the VM-specific state and rootfs extraction.
95+
if vm.removeAll == nil {
96+
vm.removeAll = os.RemoveAll
97+
}
98+
99+
if vm.rootfsPath != "" && !isWithin(vm.cacheDir, vm.rootfsPath) {
100+
if err := vm.removeAll(vm.rootfsPath); err != nil {
101+
return fmt.Errorf("remove rootfs: %w", err)
102+
}
103+
}
104+
105+
if vm.dataDir != "" {
106+
var keep []string
107+
if vm.cacheDir != "" && isWithin(vm.dataDir, vm.cacheDir) {
108+
keep = append(keep, vm.cacheDir)
109+
}
110+
if len(keep) > 0 {
111+
if err := removeDataDirContentsExcept(vm.removeAll, vm.dataDir, keep); err != nil {
112+
return fmt.Errorf("remove data dir contents: %w", err)
113+
}
114+
} else {
115+
if err := vm.removeAll(vm.dataDir); err != nil {
116+
return fmt.Errorf("remove data dir: %w", err)
117+
}
118+
}
119+
}
120+
91121
return nil
92122
}
93123

@@ -105,3 +135,44 @@ func (vm *VM) RootFSPath() string { return vm.rootfsPath }
105135

106136
// Ports returns the configured port forwards.
107137
func (vm *VM) Ports() []PortForward { return vm.ports }
138+
139+
func isWithin(base string, target string) bool {
140+
if base == "" || target == "" {
141+
return false
142+
}
143+
rel, err := filepath.Rel(base, target)
144+
if err != nil {
145+
return false
146+
}
147+
if rel == "." {
148+
return true
149+
}
150+
if rel == ".." {
151+
return false
152+
}
153+
return !strings.HasPrefix(rel, ".."+string(filepath.Separator))
154+
}
155+
156+
func removeDataDirContentsExcept(removeAll func(string) error, dataDir string, keepPaths []string) error {
157+
entries, err := os.ReadDir(dataDir)
158+
if err != nil {
159+
return fmt.Errorf("read data dir: %w", err)
160+
}
161+
keep := make(map[string]struct{}, len(keepPaths))
162+
for _, path := range keepPaths {
163+
if path == "" {
164+
continue
165+
}
166+
keep[filepath.Clean(path)] = struct{}{}
167+
}
168+
for _, entry := range entries {
169+
entryPath := filepath.Join(dataDir, entry.Name())
170+
if _, ok := keep[filepath.Clean(entryPath)]; ok {
171+
continue
172+
}
173+
if err := removeAll(entryPath); err != nil {
174+
return fmt.Errorf("remove data dir entry %s: %w", entryPath, err)
175+
}
176+
}
177+
return nil
178+
}

vm_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ package propolis
66
import (
77
"context"
88
"fmt"
9+
"os"
10+
"path/filepath"
911
"testing"
1012

1113
"github.com/stretchr/testify/assert"
@@ -160,6 +162,69 @@ func TestVM_Remove_StillRunning(t *testing.T) {
160162
assert.True(t, netProv.stopped)
161163
}
162164

165+
func TestVM_Remove_PreservesCacheContents(t *testing.T) {
166+
t.Parallel()
167+
168+
dataDir := t.TempDir()
169+
cacheDir := filepath.Join(dataDir, "cache")
170+
rootfsDir := filepath.Join(cacheDir, "rootfs")
171+
stalePath := filepath.Join(dataDir, "stale.sock")
172+
173+
require.NoError(t, os.MkdirAll(rootfsDir, 0o755))
174+
require.NoError(t, os.WriteFile(filepath.Join(rootfsDir, "marker"), []byte("rootfs"), 0o644))
175+
require.NoError(t, os.WriteFile(stalePath, []byte("stale"), 0o644))
176+
177+
vm := &VM{
178+
name: "test-vm",
179+
proc: &mockProcessHandle{pid: 42, alive: false},
180+
dataDir: dataDir,
181+
rootfsPath: rootfsDir,
182+
cacheDir: cacheDir,
183+
removeAll: os.RemoveAll,
184+
}
185+
186+
err := vm.Remove(context.Background())
187+
require.NoError(t, err)
188+
189+
_, err = os.Stat(stalePath)
190+
assert.True(t, os.IsNotExist(err))
191+
192+
_, err = os.Stat(rootfsDir)
193+
require.NoError(t, err)
194+
195+
_, err = os.Stat(cacheDir)
196+
require.NoError(t, err)
197+
}
198+
199+
func TestVM_Remove_RemovesRootfsOutsideCache(t *testing.T) {
200+
t.Parallel()
201+
202+
dataDir := t.TempDir()
203+
cacheDir := filepath.Join(dataDir, "cache")
204+
rootfsDir := t.TempDir()
205+
206+
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
207+
require.NoError(t, os.WriteFile(filepath.Join(rootfsDir, "marker"), []byte("rootfs"), 0o644))
208+
209+
vm := &VM{
210+
name: "test-vm",
211+
proc: &mockProcessHandle{pid: 42, alive: false},
212+
dataDir: dataDir,
213+
rootfsPath: rootfsDir,
214+
cacheDir: cacheDir,
215+
removeAll: os.RemoveAll,
216+
}
217+
218+
err := vm.Remove(context.Background())
219+
require.NoError(t, err)
220+
221+
_, err = os.Stat(rootfsDir)
222+
assert.True(t, os.IsNotExist(err))
223+
224+
_, err = os.Stat(cacheDir)
225+
require.NoError(t, err)
226+
}
227+
163228
func TestVM_Accessors(t *testing.T) {
164229
t.Parallel()
165230

0 commit comments

Comments
 (0)