Skip to content

Commit b8e3312

Browse files
fix: use atomic writes for state.json to prevent corruption
Signed-off-by: mayanksharmaCSE <mayanksharmacse1@gmail.com>
1 parent 032325e commit b8e3312

3 files changed

Lines changed: 87 additions & 13 deletions

File tree

pkg/unikontainers/unikontainers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ func (u *Unikontainer) saveContainerState() error {
706706
}
707707

708708
stateName := filepath.Join(u.BaseDir, stateFilename)
709-
return os.WriteFile(stateName, data, 0o644) //nolint: gosec
709+
return atomicWriteFile(stateName, data, 0o644) //nolint: gosec
710710
}
711711

712712
// getHooksByName returns the hooks for a given lifecycle stage

pkg/unikontainers/utils.go

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,44 @@ func loadSpec(bundleDir string) (*specs.Spec, error) {
106106
return &spec, nil
107107
}
108108

109-
// writePidFile writes the content of pid to the file defined by path
110-
func writePidFile(path string, pid int) error {
111-
var (
112-
tmpDir = filepath.Dir(path)
113-
tmpName = filepath.Join(tmpDir, "."+filepath.Base(path))
114-
)
115-
f, err := os.OpenFile(tmpName, os.O_RDWR|os.O_CREATE|os.O_EXCL|os.O_SYNC, 0o666)
109+
// atomicWriteFile writes data to a file atomically by first writing to a
110+
// temporary file in the same directory, syncing it, and then renaming it
111+
// to the target path. This prevents partial/corrupt files if the process
112+
// is killed mid-write.
113+
func atomicWriteFile(path string, data []byte, perm os.FileMode) error {
114+
dir := filepath.Dir(path)
115+
tmpName := filepath.Join(dir, "."+filepath.Base(path)+".tmp")
116+
117+
f, err := os.OpenFile(tmpName, os.O_RDWR|os.O_CREATE|os.O_TRUNC|os.O_SYNC, perm)
116118
if err != nil {
117-
return err
119+
return fmt.Errorf("failed to create temp file: %w", err)
118120
}
119-
_, err = f.WriteString(strconv.Itoa(pid))
120-
f.Close()
121-
if err != nil {
122-
return err
121+
122+
_, writeErr := f.Write(data)
123+
syncErr := f.Sync()
124+
closeErr := f.Close()
125+
126+
if writeErr != nil {
127+
os.Remove(tmpName)
128+
return fmt.Errorf("failed to write temp file: %w", writeErr)
123129
}
130+
if syncErr != nil {
131+
os.Remove(tmpName)
132+
return fmt.Errorf("failed to sync temp file: %w", syncErr)
133+
}
134+
if closeErr != nil {
135+
os.Remove(tmpName)
136+
return fmt.Errorf("failed to close temp file: %w", closeErr)
137+
}
138+
124139
return os.Rename(tmpName, path)
125140
}
126141

142+
// writePidFile writes the content of pid to the file defined by path
143+
func writePidFile(path string, pid int) error {
144+
return atomicWriteFile(path, []byte(strconv.Itoa(pid)), 0o666)
145+
}
146+
127147
// handleQueueProxy adds a hardcoded IP to the process's environment.
128148
// Then, the container is identified as a non-bima container
129149
// is spawned using runc.

pkg/unikontainers/utils_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,60 @@ import (
2525
"github.com/stretchr/testify/assert"
2626
)
2727

28+
func TestAtomicWriteFile(t *testing.T) {
29+
t.Run("writes file atomically", func(t *testing.T) {
30+
t.Parallel()
31+
tmpDir := t.TempDir()
32+
target := filepath.Join(tmpDir, "state.json")
33+
data := []byte(`{"status":"running"}`)
34+
35+
err := atomicWriteFile(target, data, 0o644)
36+
assert.NoError(t, err)
37+
38+
content, err := os.ReadFile(target)
39+
assert.NoError(t, err)
40+
assert.Equal(t, data, content)
41+
})
42+
43+
t.Run("overwrites existing file", func(t *testing.T) {
44+
t.Parallel()
45+
tmpDir := t.TempDir()
46+
target := filepath.Join(tmpDir, "state.json")
47+
48+
err := os.WriteFile(target, []byte("old"), 0o600)
49+
assert.NoError(t, err)
50+
51+
newData := []byte("new content")
52+
err = atomicWriteFile(target, newData, 0o644)
53+
assert.NoError(t, err)
54+
55+
content, err := os.ReadFile(target)
56+
assert.NoError(t, err)
57+
assert.Equal(t, newData, content)
58+
})
59+
60+
t.Run("no temp file left on success", func(t *testing.T) {
61+
t.Parallel()
62+
tmpDir := t.TempDir()
63+
target := filepath.Join(tmpDir, "state.json")
64+
65+
err := atomicWriteFile(target, []byte("data"), 0o644)
66+
assert.NoError(t, err)
67+
68+
tmpFile := filepath.Join(tmpDir, ".state.json.tmp")
69+
_, err = os.Stat(tmpFile)
70+
assert.True(t, os.IsNotExist(err), "Temp file should not exist after successful write")
71+
})
72+
73+
t.Run("fails on invalid directory", func(t *testing.T) {
74+
t.Parallel()
75+
target := filepath.Join("/nonexistent/dir", "state.json")
76+
77+
err := atomicWriteFile(target, []byte("data"), 0o644)
78+
assert.Error(t, err)
79+
})
80+
}
81+
2882
func TestWritePidFile(t *testing.T) {
2983
tmpDir := t.TempDir() // Create a temporary directory for the test
3084
pidFilePath := filepath.Join(tmpDir, "test.pid")

0 commit comments

Comments
 (0)