Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 86 additions & 10 deletions extract/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,34 @@ package extract
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"sync"

"github.com/gofrs/flock"
)

// manifestName is the per-extract manifest file. It records the bundle
// hash plus a per-file SHA-256 of every extracted artifact so Ensure can
// re-verify the cache contents on subsequent calls and re-extract if
// any file has been tampered with or replaced.
const manifestName = ".manifest.json"

// manifest is the on-disk format written alongside extracted files.
type manifest struct {
// Version is the bundle version string.
Version string `json:"version"`
// Hash is the aggregate bundle hash (same value used in the cache
// directory name). Kept for quick mismatch detection before walking
// individual files.
Hash string `json:"hash"`
// Files maps the bundle-relative file name to its SHA-256 hex digest.
Files map[string]string `json:"files"`
}

// File describes a single file to extract.
type File struct {
Name string
Expand Down Expand Up @@ -60,7 +80,11 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) {

// Acquire cross-process file lock.
lockPath := filepath.Join(cacheDir, ".extract.lock")
if err := os.MkdirAll(cacheDir, 0o755); err != nil {
// 0o700 keeps the extracted runner binaries and libraries readable only
// by the invoking user. The cache holds an executable binary that will
// later be spawned — a world-writable or even group-writable cache dir
// is a local-code-execution vector.
if err := os.MkdirAll(cacheDir, 0o700); err != nil {
return "", fmt.Errorf("create cache dir: %w", err)
}
fl := flock.New(lockPath)
Expand All @@ -86,12 +110,19 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) {
}
}()

// Extract all files.
// Extract all files and collect per-file hashes for the manifest.
m := manifest{
Version: b.version,
Hash: hash,
Files: make(map[string]string, len(b.files)),
}
for _, f := range b.files {
if extractErr := b.extractFile(tmpDir, f); extractErr != nil {
err = extractErr
return "", fmt.Errorf("extract %s: %w", f.Name, extractErr)
}
fileHash := sha256.Sum256(f.Content)
m.Files[f.Name] = hex.EncodeToString(fileHash[:])
}

// Create symlinks.
Expand All @@ -102,11 +133,27 @@ func (b *Bundle) Ensure(cacheDir string) (string, error) {
}
}

// Write version file.
versionPath := filepath.Join(tmpDir, ".version")
if writeErr := os.WriteFile(versionPath, []byte(hash), 0o644); writeErr != nil {
// Write manifest atomically last so a partial extraction never looks
// valid to isValid.
manifestData, mErr := json.Marshal(m)
if mErr != nil {
err = mErr
return "", fmt.Errorf("marshal manifest: %w", mErr)
}
manifestPath := filepath.Join(tmpDir, manifestName)
if writeErr := os.WriteFile(manifestPath, manifestData, 0o600); writeErr != nil {
err = writeErr
return "", fmt.Errorf("write version file: %w", writeErr)
return "", fmt.Errorf("write manifest: %w", writeErr)
}

// If a previous invalid extraction exists at targetDir (e.g. tampered
// content detected by isValid), remove it before rename. The lock
// held above serializes this against other callers.
if _, statErr := os.Lstat(targetDir); statErr == nil {
if rmErr := os.RemoveAll(targetDir); rmErr != nil {
err = rmErr
return "", fmt.Errorf("remove stale target: %w", rmErr)
}
}

// Atomic rename to target.
Expand All @@ -130,14 +177,43 @@ func (b *Bundle) computeHash() string {
return hex.EncodeToString(h.Sum(nil))
}

// isValid checks whether targetDir exists and contains a .version file
// matching the expected hash.
// isValid checks whether targetDir is a complete, unaltered extraction
// for the expected bundle hash. It requires a manifest with a matching
// bundle hash and every listed file's SHA-256 matching its current on-
// disk content. Any mismatch causes a re-extract.
func (b *Bundle) isValid(targetDir, hash string) bool {
data, err := os.ReadFile(filepath.Join(targetDir, ".version"))
data, err := os.ReadFile(filepath.Join(targetDir, manifestName))
if err != nil {
return false
}
return string(data) == hash
var m manifest
if err := json.Unmarshal(data, &m); err != nil {
return false
}
if m.Hash != hash {
return false
}
for name, want := range m.Files {
got, err := hashFileOnDisk(filepath.Join(targetDir, name))
if err != nil || got != want {
return false
}
}
return true
}

// hashFileOnDisk returns the SHA-256 hex digest of the file at path.
func hashFileOnDisk(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer func() { _ = f.Close() }()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}

// extractFile writes a single file atomically via a temp file and rename.
Expand Down
74 changes: 60 additions & 14 deletions extract/bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ import (
"github.com/stretchr/testify/require"
)

func TestBundle_Ensure_CacheDirIsPrivate(t *testing.T) {
t.Parallel()

// The cache holds an executable binary that will later be spawned. A
// world- or group-writable cache permits local code injection; the
// directory must be 0o700.
cacheDir := filepath.Join(t.TempDir(), "cache")
b := NewBundle("v-private", []File{
{Name: "f", Content: []byte("x"), Mode: 0o644},
})

_, err := b.Ensure(cacheDir)
require.NoError(t, err)

info, err := os.Stat(cacheDir)
require.NoError(t, err)
assert.Equal(t, os.FileMode(0o700), info.Mode().Perm(),
"cache dir must be 0o700; got %o", info.Mode().Perm())
}

func TestBundle_Ensure_ExtractsFiles(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -132,9 +152,9 @@ func TestBundle_Ensure_EmptyBundle(t *testing.T) {
require.NoError(t, err)
assert.True(t, info.IsDir())

versionData, err := os.ReadFile(filepath.Join(dir, ".version"))
manifestData, err := os.ReadFile(filepath.Join(dir, manifestName))
require.NoError(t, err)
assert.NotEmpty(t, versionData)
assert.NotEmpty(t, manifestData)
}

func TestBundle_Ensure_ConcurrentAccess(t *testing.T) {
Expand Down Expand Up @@ -261,15 +281,15 @@ func TestBundle_ComputeHash_EmptyBundle(t *testing.T) {
func TestBundle_IsValid_MatchingHash(t *testing.T) {
t.Parallel()

// End-to-end through Ensure: the manifest it wrote must satisfy
// isValid on the same inputs.
b := NewBundle("v1", []File{
{Name: "x.txt", Content: []byte("data"), Mode: 0o644},
})
hash := b.computeHash()

dir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(dir, ".version"), []byte(hash), 0o644))
dir, err := b.Ensure(t.TempDir())
require.NoError(t, err)

assert.True(t, b.isValid(dir, hash))
assert.True(t, b.isValid(dir, b.computeHash()))
}

func TestBundle_IsValid_WrongHash(t *testing.T) {
Expand All @@ -278,25 +298,51 @@ func TestBundle_IsValid_WrongHash(t *testing.T) {
b := NewBundle("v1", []File{
{Name: "x.txt", Content: []byte("data"), Mode: 0o644},
})
hash := b.computeHash()

dir := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(dir, ".version"), []byte("wronghash"), 0o644))
dir, err := b.Ensure(t.TempDir())
require.NoError(t, err)

assert.False(t, b.isValid(dir, hash))
assert.False(t, b.isValid(dir, "wronghash"))
}

func TestBundle_IsValid_MissingVersionFile(t *testing.T) {
func TestBundle_IsValid_MissingManifest(t *testing.T) {
t.Parallel()

b := NewBundle("v1", nil)
hash := b.computeHash()

dir := t.TempDir()
// dir exists but has no .version file.
// dir exists but has no manifest file.
assert.False(t, b.isValid(dir, hash))
}

func TestBundle_IsValid_TamperedFileTriggersReextract(t *testing.T) {
t.Parallel()

// If a cached file has been modified after extraction, isValid must
// return false so Ensure re-extracts rather than spawning a tampered
// binary.
cacheDir := t.TempDir()
b := NewBundle("v-tamper", []File{
{Name: "binary", Content: []byte("original-content"), Mode: 0o755},
})
dir, err := b.Ensure(cacheDir)
require.NoError(t, err)

// Overwrite the cached file with different content.
require.NoError(t, os.WriteFile(filepath.Join(dir, "binary"), []byte("tampered"), 0o755))

assert.False(t, b.isValid(dir, b.computeHash()),
"tampered cached file must not be treated as valid")

// A subsequent Ensure should re-extract the original content.
dir2, err := b.Ensure(cacheDir)
require.NoError(t, err)
got, err := os.ReadFile(filepath.Join(dir2, "binary"))
require.NoError(t, err)
assert.Equal(t, "original-content", string(got),
"Ensure must re-extract the original content after tamper detection")
}

func TestBundle_IsValid_NonexistentDir(t *testing.T) {
t.Parallel()

Expand Down
Loading