Skip to content

Commit f98a5da

Browse files
authored
Merge pull request #74 from stacklok/jaosorior/extract-cache-integrity
extract: private cache dir and per-file integrity manifest
2 parents ddbbc01 + 0474d61 commit f98a5da

File tree

2 files changed

+146
-24
lines changed

2 files changed

+146
-24
lines changed

extract/bundle.go

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,34 @@ package extract
66
import (
77
"crypto/sha256"
88
"encoding/hex"
9+
"encoding/json"
910
"fmt"
11+
"io"
1012
"os"
1113
"path/filepath"
1214
"sync"
1315

1416
"github.com/gofrs/flock"
1517
)
1618

19+
// manifestName is the per-extract manifest file. It records the bundle
20+
// hash plus a per-file SHA-256 of every extracted artifact so Ensure can
21+
// re-verify the cache contents on subsequent calls and re-extract if
22+
// any file has been tampered with or replaced.
23+
const manifestName = ".manifest.json"
24+
25+
// manifest is the on-disk format written alongside extracted files.
26+
type manifest struct {
27+
// Version is the bundle version string.
28+
Version string `json:"version"`
29+
// Hash is the aggregate bundle hash (same value used in the cache
30+
// directory name). Kept for quick mismatch detection before walking
31+
// individual files.
32+
Hash string `json:"hash"`
33+
// Files maps the bundle-relative file name to its SHA-256 hex digest.
34+
Files map[string]string `json:"files"`
35+
}
36+
1737
// File describes a single file to extract.
1838
type File struct {
1939
Name string
@@ -60,7 +80,11 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) {
6080

6181
// Acquire cross-process file lock.
6282
lockPath := filepath.Join(cacheDir, ".extract.lock")
63-
if err := os.MkdirAll(cacheDir, 0o755); err != nil {
83+
// 0o700 keeps the extracted runner binaries and libraries readable only
84+
// by the invoking user. The cache holds an executable binary that will
85+
// later be spawned — a world-writable or even group-writable cache dir
86+
// is a local-code-execution vector.
87+
if err := os.MkdirAll(cacheDir, 0o700); err != nil {
6488
return "", fmt.Errorf("create cache dir: %w", err)
6589
}
6690
fl := flock.New(lockPath)
@@ -86,12 +110,19 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) {
86110
}
87111
}()
88112

89-
// Extract all files.
113+
// Extract all files and collect per-file hashes for the manifest.
114+
m := manifest{
115+
Version: b.version,
116+
Hash: hash,
117+
Files: make(map[string]string, len(b.files)),
118+
}
90119
for _, f := range b.files {
91120
if extractErr := b.extractFile(tmpDir, f); extractErr != nil {
92121
err = extractErr
93122
return "", fmt.Errorf("extract %s: %w", f.Name, extractErr)
94123
}
124+
fileHash := sha256.Sum256(f.Content)
125+
m.Files[f.Name] = hex.EncodeToString(fileHash[:])
95126
}
96127

97128
// Create symlinks.
@@ -102,11 +133,27 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) {
102133
}
103134
}
104135

105-
// Write version file.
106-
versionPath := filepath.Join(tmpDir, ".version")
107-
if writeErr := os.WriteFile(versionPath, []byte(hash), 0o644); writeErr != nil {
136+
// Write manifest atomically last so a partial extraction never looks
137+
// valid to isValid.
138+
manifestData, mErr := json.Marshal(m)
139+
if mErr != nil {
140+
err = mErr
141+
return "", fmt.Errorf("marshal manifest: %w", mErr)
142+
}
143+
manifestPath := filepath.Join(tmpDir, manifestName)
144+
if writeErr := os.WriteFile(manifestPath, manifestData, 0o600); writeErr != nil {
108145
err = writeErr
109-
return "", fmt.Errorf("write version file: %w", writeErr)
146+
return "", fmt.Errorf("write manifest: %w", writeErr)
147+
}
148+
149+
// If a previous invalid extraction exists at targetDir (e.g. tampered
150+
// content detected by isValid), remove it before rename. The lock
151+
// held above serializes this against other callers.
152+
if _, statErr := os.Lstat(targetDir); statErr == nil {
153+
if rmErr := os.RemoveAll(targetDir); rmErr != nil {
154+
err = rmErr
155+
return "", fmt.Errorf("remove stale target: %w", rmErr)
156+
}
110157
}
111158

112159
// Atomic rename to target.
@@ -130,14 +177,43 @@ func (b *Bundle) computeHash() string {
130177
return hex.EncodeToString(h.Sum(nil))
131178
}
132179

133-
// isValid checks whether targetDir exists and contains a .version file
134-
// matching the expected hash.
180+
// isValid checks whether targetDir is a complete, unaltered extraction
181+
// for the expected bundle hash. It requires a manifest with a matching
182+
// bundle hash and every listed file's SHA-256 matching its current on-
183+
// disk content. Any mismatch causes a re-extract.
135184
func (b *Bundle) isValid(targetDir, hash string) bool {
136-
data, err := os.ReadFile(filepath.Join(targetDir, ".version"))
185+
data, err := os.ReadFile(filepath.Join(targetDir, manifestName))
137186
if err != nil {
138187
return false
139188
}
140-
return string(data) == hash
189+
var m manifest
190+
if err := json.Unmarshal(data, &m); err != nil {
191+
return false
192+
}
193+
if m.Hash != hash {
194+
return false
195+
}
196+
for name, want := range m.Files {
197+
got, err := hashFileOnDisk(filepath.Join(targetDir, name))
198+
if err != nil || got != want {
199+
return false
200+
}
201+
}
202+
return true
203+
}
204+
205+
// hashFileOnDisk returns the SHA-256 hex digest of the file at path.
206+
func hashFileOnDisk(path string) (string, error) {
207+
f, err := os.Open(path)
208+
if err != nil {
209+
return "", err
210+
}
211+
defer func() { _ = f.Close() }()
212+
h := sha256.New()
213+
if _, err := io.Copy(h, f); err != nil {
214+
return "", err
215+
}
216+
return hex.EncodeToString(h.Sum(nil)), nil
141217
}
142218

143219
// extractFile writes a single file atomically via a temp file and rename.

extract/bundle_test.go

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,26 @@ import (
1414
"github.com/stretchr/testify/require"
1515
)
1616

17+
func TestBundle_Ensure_CacheDirIsPrivate(t *testing.T) {
18+
t.Parallel()
19+
20+
// The cache holds an executable binary that will later be spawned. A
21+
// world- or group-writable cache permits local code injection; the
22+
// directory must be 0o700.
23+
cacheDir := filepath.Join(t.TempDir(), "cache")
24+
b := NewBundle("v-private", []File{
25+
{Name: "f", Content: []byte("x"), Mode: 0o644},
26+
})
27+
28+
_, err := b.Ensure(cacheDir)
29+
require.NoError(t, err)
30+
31+
info, err := os.Stat(cacheDir)
32+
require.NoError(t, err)
33+
assert.Equal(t, os.FileMode(0o700), info.Mode().Perm(),
34+
"cache dir must be 0o700; got %o", info.Mode().Perm())
35+
}
36+
1737
func TestBundle_Ensure_ExtractsFiles(t *testing.T) {
1838
t.Parallel()
1939

@@ -132,9 +152,9 @@ func TestBundle_Ensure_EmptyBundle(t *testing.T) {
132152
require.NoError(t, err)
133153
assert.True(t, info.IsDir())
134154

135-
versionData, err := os.ReadFile(filepath.Join(dir, ".version"))
155+
manifestData, err := os.ReadFile(filepath.Join(dir, manifestName))
136156
require.NoError(t, err)
137-
assert.NotEmpty(t, versionData)
157+
assert.NotEmpty(t, manifestData)
138158
}
139159

140160
func TestBundle_Ensure_ConcurrentAccess(t *testing.T) {
@@ -261,15 +281,15 @@ func TestBundle_ComputeHash_EmptyBundle(t *testing.T) {
261281
func TestBundle_IsValid_MatchingHash(t *testing.T) {
262282
t.Parallel()
263283

284+
// End-to-end through Ensure: the manifest it wrote must satisfy
285+
// isValid on the same inputs.
264286
b := NewBundle("v1", []File{
265287
{Name: "x.txt", Content: []byte("data"), Mode: 0o644},
266288
})
267-
hash := b.computeHash()
268-
269-
dir := t.TempDir()
270-
require.NoError(t, os.WriteFile(filepath.Join(dir, ".version"), []byte(hash), 0o644))
289+
dir, err := b.Ensure(t.TempDir())
290+
require.NoError(t, err)
271291

272-
assert.True(t, b.isValid(dir, hash))
292+
assert.True(t, b.isValid(dir, b.computeHash()))
273293
}
274294

275295
func TestBundle_IsValid_WrongHash(t *testing.T) {
@@ -278,25 +298,51 @@ func TestBundle_IsValid_WrongHash(t *testing.T) {
278298
b := NewBundle("v1", []File{
279299
{Name: "x.txt", Content: []byte("data"), Mode: 0o644},
280300
})
281-
hash := b.computeHash()
282-
283-
dir := t.TempDir()
284-
require.NoError(t, os.WriteFile(filepath.Join(dir, ".version"), []byte("wronghash"), 0o644))
301+
dir, err := b.Ensure(t.TempDir())
302+
require.NoError(t, err)
285303

286-
assert.False(t, b.isValid(dir, hash))
304+
assert.False(t, b.isValid(dir, "wronghash"))
287305
}
288306

289-
func TestBundle_IsValid_MissingVersionFile(t *testing.T) {
307+
func TestBundle_IsValid_MissingManifest(t *testing.T) {
290308
t.Parallel()
291309

292310
b := NewBundle("v1", nil)
293311
hash := b.computeHash()
294312

295313
dir := t.TempDir()
296-
// dir exists but has no .version file.
314+
// dir exists but has no manifest file.
297315
assert.False(t, b.isValid(dir, hash))
298316
}
299317

318+
func TestBundle_IsValid_TamperedFileTriggersReextract(t *testing.T) {
319+
t.Parallel()
320+
321+
// If a cached file has been modified after extraction, isValid must
322+
// return false so Ensure re-extracts rather than spawning a tampered
323+
// binary.
324+
cacheDir := t.TempDir()
325+
b := NewBundle("v-tamper", []File{
326+
{Name: "binary", Content: []byte("original-content"), Mode: 0o755},
327+
})
328+
dir, err := b.Ensure(cacheDir)
329+
require.NoError(t, err)
330+
331+
// Overwrite the cached file with different content.
332+
require.NoError(t, os.WriteFile(filepath.Join(dir, "binary"), []byte("tampered"), 0o755))
333+
334+
assert.False(t, b.isValid(dir, b.computeHash()),
335+
"tampered cached file must not be treated as valid")
336+
337+
// A subsequent Ensure should re-extract the original content.
338+
dir2, err := b.Ensure(cacheDir)
339+
require.NoError(t, err)
340+
got, err := os.ReadFile(filepath.Join(dir2, "binary"))
341+
require.NoError(t, err)
342+
assert.Equal(t, "original-content", string(got),
343+
"Ensure must re-extract the original content after tamper detection")
344+
}
345+
300346
func TestBundle_IsValid_NonexistentDir(t *testing.T) {
301347
t.Parallel()
302348

0 commit comments

Comments
 (0)