Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions image/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,20 @@ func PullWithFetcher(ctx context.Context, imageRef string, cache *Cache, fetcher
// extraction if layered extraction fails.
if cache != nil {
lc := cache.LayerCache()
if err := extractImageLayered(img, tmpDir, lc); err != nil {
if err := extractImageLayered(ctx, img, tmpDir, lc); err != nil {
slog.Warn("layered extraction failed, falling back to flat extraction", "err", err)
// Clean tmpDir contents before retrying with flat extraction.
_ = os.RemoveAll(tmpDir)
if tmpDir, err = cache.TempDir(); err != nil {
return nil, fmt.Errorf("create temp dir for rootfs: %w", err)
}
if err := extractImage(img, tmpDir); err != nil {
if err := extractImage(ctx, img, tmpDir); err != nil {
_ = os.RemoveAll(tmpDir)
return nil, fmt.Errorf("extract image layers: %w", err)
}
}
} else {
if err := extractImage(img, tmpDir); err != nil {
if err := extractImage(ctx, img, tmpDir); err != nil {
_ = os.RemoveAll(tmpDir)
return nil, fmt.Errorf("extract image layers: %w", err)
}
Expand Down Expand Up @@ -180,20 +180,35 @@ func extractOCIConfig(img v1.Image) (*OCIConfig, error) {
}, nil
}

// contextReader wraps an io.Reader with context cancellation support.
// It checks the context before each Read call, enabling cancellation of
// long-running I/O operations (e.g. slow registry downloads).
type contextReader struct {
ctx context.Context
r io.Reader
}

func (cr *contextReader) Read(p []byte) (int, error) {
if err := cr.ctx.Err(); err != nil {
return 0, err
}
return cr.r.Read(p)
}

// extractImage flattens all image layers into a single tar stream and extracts
// it to the destination directory. It includes security measures against path
// traversal, symlink attacks, and decompression bombs.
func extractImage(img v1.Image, dst string) error {
func extractImage(ctx context.Context, img v1.Image, dst string) error {
reader := mutate.Extract(img)
defer func() { _ = reader.Close() }()

return extractTar(reader, dst)
return extractTar(ctx, &contextReader{ctx: ctx, r: reader}, dst)
}

// extractImageLayered extracts each image layer individually into the layer
// cache, then composes them bottom-to-top into dst. Shared layers across
// images are extracted only once.
func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
func extractImageLayered(ctx context.Context, img v1.Image, dst string, lc *LayerCache) error {
layers, err := img.Layers()
if err != nil {
return fmt.Errorf("get image layers: %w", err)
Expand All @@ -220,7 +235,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
remaining := &atomic.Int64{}
remaining.Store(maxExtractSize)

g := new(errgroup.Group)
g, gCtx := errgroup.WithContext(ctx)
g.SetLimit(concurrency)

for i, layer := range layers {
Expand All @@ -232,7 +247,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
}

g.Go(func() error {
return extractLayerToCache(layer, diffID, lc, remaining)
return extractLayerToCache(gCtx, layer, diffID, lc, remaining)
})
}

Expand All @@ -258,7 +273,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
// extractLayerToCache extracts a single layer into the layer cache.
// The remaining counter is shared across concurrent layer extractions to
// enforce a global size budget (prevents decompression bombs across layers).
func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaining *atomic.Int64) error {
func extractLayerToCache(ctx context.Context, layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaining *atomic.Int64) error {
tmpDir, err := lc.TempDir()
if err != nil {
return fmt.Errorf("create temp dir for layer %s: %w", diffID.String(), err)
Expand All @@ -271,7 +286,7 @@ func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaini
}
defer func() { _ = rc.Close() }()

if err := extractTarSharedLimit(rc, tmpDir, remaining); err != nil {
if err := extractTarSharedLimit(ctx, &contextReader{ctx: ctx, r: rc}, tmpDir, remaining); err != nil {
_ = os.RemoveAll(tmpDir)
return fmt.Errorf("extract layer %s: %w", diffID.String(), err)
}
Expand Down Expand Up @@ -434,14 +449,19 @@ func copyFileToDir(src, target, rootDir string, mode os.FileMode) error {
}

// extractTar reads a tar stream and extracts it to dst with security checks.
func extractTar(r io.Reader, dst string) error {
// The context is checked on each tar entry to support cancellation.
func extractTar(ctx context.Context, r io.Reader, dst string) error {
// Wrap in a LimitedReader to prevent decompression bombs.
lr := &io.LimitedReader{R: r, N: maxExtractSize}
tr := tar.NewReader(lr)

var entryCount int

for {
if err := ctx.Err(); err != nil {
return err
}

hdr, err := tr.Next()
if errors.Is(err, io.EOF) {
break
Expand Down Expand Up @@ -469,7 +489,7 @@ func extractTar(r io.Reader, dst string) error {
}

if entryCount == 0 {
return fmt.Errorf("tar archive is empty or contains no valid entries")
slog.Debug("tar archive has no entries, treating as empty layer")
}

return nil
Expand All @@ -478,14 +498,19 @@ func extractTar(r io.Reader, dst string) error {
// extractTarSharedLimit is like extractTar but uses a shared atomic counter
// for the size budget. This enforces a global maxExtractSize across all layers
// in a layered extraction, preventing decompression bombs via many layers.
func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) error {
// The context is checked on each tar entry to support cancellation.
func extractTarSharedLimit(ctx context.Context, r io.Reader, dst string, remaining *atomic.Int64) error {
// Use an atomicLimitReader that decrements the shared counter.
alr := &atomicLimitReader{R: r, Remaining: remaining}
tr := tar.NewReader(alr)

var entryCount int

for {
if err := ctx.Err(); err != nil {
return err
}

hdr, err := tr.Next()
if errors.Is(err, io.EOF) {
break
Expand All @@ -512,7 +537,7 @@ func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) err
}

if entryCount == 0 {
return fmt.Errorf("tar archive is empty or contains no valid entries")
slog.Debug("tar archive has no entries, treating as empty layer")
}

return nil
Expand Down
Loading
Loading