diff --git a/extract/bundle.go b/extract/bundle.go index aff9c64..a623721 100644 --- a/extract/bundle.go +++ b/extract/bundle.go @@ -6,7 +6,9 @@ package extract import ( "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" + "io" "os" "path/filepath" "sync" @@ -14,6 +16,24 @@ import ( "github.com/gofrs/flock" ) +// manifestName is the per-extract manifest file. It records the bundle +// hash plus a per-file SHA-256 of every extracted artifact so Ensure can +// re-verify the cache contents on subsequent calls and re-extract if +// any file has been tampered with or replaced. +const manifestName = ".manifest.json" + +// manifest is the on-disk format written alongside extracted files. +type manifest struct { + // Version is the bundle version string. + Version string `json:"version"` + // Hash is the aggregate bundle hash (same value used in the cache + // directory name). Kept for quick mismatch detection before walking + // individual files. + Hash string `json:"hash"` + // Files maps the bundle-relative file name to its SHA-256 hex digest. + Files map[string]string `json:"files"` +} + // File describes a single file to extract. type File struct { Name string @@ -60,7 +80,11 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) { // Acquire cross-process file lock. lockPath := filepath.Join(cacheDir, ".extract.lock") - if err := os.MkdirAll(cacheDir, 0o755); err != nil { + // 0o700 keeps the extracted runner binaries and libraries readable only + // by the invoking user. The cache holds an executable binary that will + // later be spawned — a world-writable or even group-writable cache dir + // is a local-code-execution vector. + if err := os.MkdirAll(cacheDir, 0o700); err != nil { return "", fmt.Errorf("create cache dir: %w", err) } fl := flock.New(lockPath) @@ -86,12 +110,19 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) { } }() - // Extract all files. + // Extract all files and collect per-file hashes for the manifest. + m := manifest{ + Version: b.version, + Hash: hash, + Files: make(map[string]string, len(b.files)), + } for _, f := range b.files { if extractErr := b.extractFile(tmpDir, f); extractErr != nil { err = extractErr return "", fmt.Errorf("extract %s: %w", f.Name, extractErr) } + fileHash := sha256.Sum256(f.Content) + m.Files[f.Name] = hex.EncodeToString(fileHash[:]) } // Create symlinks. @@ -102,11 +133,27 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) { } } - // Write version file. - versionPath := filepath.Join(tmpDir, ".version") - if writeErr := os.WriteFile(versionPath, []byte(hash), 0o644); writeErr != nil { + // Write manifest atomically last so a partial extraction never looks + // valid to isValid. + manifestData, mErr := json.Marshal(m) + if mErr != nil { + err = mErr + return "", fmt.Errorf("marshal manifest: %w", mErr) + } + manifestPath := filepath.Join(tmpDir, manifestName) + if writeErr := os.WriteFile(manifestPath, manifestData, 0o600); writeErr != nil { err = writeErr - return "", fmt.Errorf("write version file: %w", writeErr) + return "", fmt.Errorf("write manifest: %w", writeErr) + } + + // If a previous invalid extraction exists at targetDir (e.g. tampered + // content detected by isValid), remove it before rename. The lock + // held above serializes this against other callers. + if _, statErr := os.Lstat(targetDir); statErr == nil { + if rmErr := os.RemoveAll(targetDir); rmErr != nil { + err = rmErr + return "", fmt.Errorf("remove stale target: %w", rmErr) + } } // Atomic rename to target. @@ -130,14 +177,43 @@ func (b *Bundle) computeHash() string { return hex.EncodeToString(h.Sum(nil)) } -// isValid checks whether targetDir exists and contains a .version file -// matching the expected hash. +// isValid checks whether targetDir is a complete, unaltered extraction +// for the expected bundle hash. It requires a manifest with a matching +// bundle hash and every listed file's SHA-256 matching its current on- +// disk content. Any mismatch causes a re-extract. func (b *Bundle) isValid(targetDir, hash string) bool { - data, err := os.ReadFile(filepath.Join(targetDir, ".version")) + data, err := os.ReadFile(filepath.Join(targetDir, manifestName)) if err != nil { return false } - return string(data) == hash + var m manifest + if err := json.Unmarshal(data, &m); err != nil { + return false + } + if m.Hash != hash { + return false + } + for name, want := range m.Files { + got, err := hashFileOnDisk(filepath.Join(targetDir, name)) + if err != nil || got != want { + return false + } + } + return true +} + +// hashFileOnDisk returns the SHA-256 hex digest of the file at path. +func hashFileOnDisk(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer func() { _ = f.Close() }() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil } // extractFile writes a single file atomically via a temp file and rename. diff --git a/extract/bundle_test.go b/extract/bundle_test.go index e8080da..903199c 100644 --- a/extract/bundle_test.go +++ b/extract/bundle_test.go @@ -14,6 +14,26 @@ import ( "github.com/stretchr/testify/require" ) +func TestBundle_Ensure_CacheDirIsPrivate(t *testing.T) { + t.Parallel() + + // The cache holds an executable binary that will later be spawned. A + // world- or group-writable cache permits local code injection; the + // directory must be 0o700. + cacheDir := filepath.Join(t.TempDir(), "cache") + b := NewBundle("v-private", []File{ + {Name: "f", Content: []byte("x"), Mode: 0o644}, + }) + + _, err := b.Ensure(cacheDir) + require.NoError(t, err) + + info, err := os.Stat(cacheDir) + require.NoError(t, err) + assert.Equal(t, os.FileMode(0o700), info.Mode().Perm(), + "cache dir must be 0o700; got %o", info.Mode().Perm()) +} + func TestBundle_Ensure_ExtractsFiles(t *testing.T) { t.Parallel() @@ -132,9 +152,9 @@ func TestBundle_Ensure_EmptyBundle(t *testing.T) { require.NoError(t, err) assert.True(t, info.IsDir()) - versionData, err := os.ReadFile(filepath.Join(dir, ".version")) + manifestData, err := os.ReadFile(filepath.Join(dir, manifestName)) require.NoError(t, err) - assert.NotEmpty(t, versionData) + assert.NotEmpty(t, manifestData) } func TestBundle_Ensure_ConcurrentAccess(t *testing.T) { @@ -261,15 +281,15 @@ func TestBundle_ComputeHash_EmptyBundle(t *testing.T) { func TestBundle_IsValid_MatchingHash(t *testing.T) { t.Parallel() + // End-to-end through Ensure: the manifest it wrote must satisfy + // isValid on the same inputs. b := NewBundle("v1", []File{ {Name: "x.txt", Content: []byte("data"), Mode: 0o644}, }) - hash := b.computeHash() - - dir := t.TempDir() - require.NoError(t, os.WriteFile(filepath.Join(dir, ".version"), []byte(hash), 0o644)) + dir, err := b.Ensure(t.TempDir()) + require.NoError(t, err) - assert.True(t, b.isValid(dir, hash)) + assert.True(t, b.isValid(dir, b.computeHash())) } func TestBundle_IsValid_WrongHash(t *testing.T) { @@ -278,25 +298,51 @@ func TestBundle_IsValid_WrongHash(t *testing.T) { b := NewBundle("v1", []File{ {Name: "x.txt", Content: []byte("data"), Mode: 0o644}, }) - hash := b.computeHash() - - dir := t.TempDir() - require.NoError(t, os.WriteFile(filepath.Join(dir, ".version"), []byte("wronghash"), 0o644)) + dir, err := b.Ensure(t.TempDir()) + require.NoError(t, err) - assert.False(t, b.isValid(dir, hash)) + assert.False(t, b.isValid(dir, "wronghash")) } -func TestBundle_IsValid_MissingVersionFile(t *testing.T) { +func TestBundle_IsValid_MissingManifest(t *testing.T) { t.Parallel() b := NewBundle("v1", nil) hash := b.computeHash() dir := t.TempDir() - // dir exists but has no .version file. + // dir exists but has no manifest file. assert.False(t, b.isValid(dir, hash)) } +func TestBundle_IsValid_TamperedFileTriggersReextract(t *testing.T) { + t.Parallel() + + // If a cached file has been modified after extraction, isValid must + // return false so Ensure re-extracts rather than spawning a tampered + // binary. + cacheDir := t.TempDir() + b := NewBundle("v-tamper", []File{ + {Name: "binary", Content: []byte("original-content"), Mode: 0o755}, + }) + dir, err := b.Ensure(cacheDir) + require.NoError(t, err) + + // Overwrite the cached file with different content. + require.NoError(t, os.WriteFile(filepath.Join(dir, "binary"), []byte("tampered"), 0o755)) + + assert.False(t, b.isValid(dir, b.computeHash()), + "tampered cached file must not be treated as valid") + + // A subsequent Ensure should re-extract the original content. + dir2, err := b.Ensure(cacheDir) + require.NoError(t, err) + got, err := os.ReadFile(filepath.Join(dir2, "binary")) + require.NoError(t, err) + assert.Equal(t, "original-content", string(got), + "Ensure must re-extract the original content after tamper detection") +} + func TestBundle_IsValid_NonexistentDir(t *testing.T) { t.Parallel()