Skip to content

Commit ce7d3b9

Browse files
authored
Merge pull request #57 from stacklok/fix/empty-layer-context-threading
Fix empty OCI layer extraction and add context cancellation
2 parents 7c27de6 + 64f2048 commit ce7d3b9

2 files changed

Lines changed: 159 additions & 38 deletions

File tree

image/pull.go

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,20 @@ func PullWithFetcher(ctx context.Context, imageRef string, cache *Cache, fetcher
111111
// extraction if layered extraction fails.
112112
if cache != nil {
113113
lc := cache.LayerCache()
114-
if err := extractImageLayered(img, tmpDir, lc); err != nil {
114+
if err := extractImageLayered(ctx, img, tmpDir, lc); err != nil {
115115
slog.Warn("layered extraction failed, falling back to flat extraction", "err", err)
116116
// Clean tmpDir contents before retrying with flat extraction.
117117
_ = os.RemoveAll(tmpDir)
118118
if tmpDir, err = cache.TempDir(); err != nil {
119119
return nil, fmt.Errorf("create temp dir for rootfs: %w", err)
120120
}
121-
if err := extractImage(img, tmpDir); err != nil {
121+
if err := extractImage(ctx, img, tmpDir); err != nil {
122122
_ = os.RemoveAll(tmpDir)
123123
return nil, fmt.Errorf("extract image layers: %w", err)
124124
}
125125
}
126126
} else {
127-
if err := extractImage(img, tmpDir); err != nil {
127+
if err := extractImage(ctx, img, tmpDir); err != nil {
128128
_ = os.RemoveAll(tmpDir)
129129
return nil, fmt.Errorf("extract image layers: %w", err)
130130
}
@@ -180,20 +180,35 @@ func extractOCIConfig(img v1.Image) (*OCIConfig, error) {
180180
}, nil
181181
}
182182

183+
// contextReader wraps an io.Reader with context cancellation support.
184+
// It checks the context before each Read call, enabling cancellation of
185+
// long-running I/O operations (e.g. slow registry downloads).
186+
type contextReader struct {
187+
ctx context.Context
188+
r io.Reader
189+
}
190+
191+
func (cr *contextReader) Read(p []byte) (int, error) {
192+
if err := cr.ctx.Err(); err != nil {
193+
return 0, err
194+
}
195+
return cr.r.Read(p)
196+
}
197+
183198
// extractImage flattens all image layers into a single tar stream and extracts
184199
// it to the destination directory. It includes security measures against path
185200
// traversal, symlink attacks, and decompression bombs.
186-
func extractImage(img v1.Image, dst string) error {
201+
func extractImage(ctx context.Context, img v1.Image, dst string) error {
187202
reader := mutate.Extract(img)
188203
defer func() { _ = reader.Close() }()
189204

190-
return extractTar(reader, dst)
205+
return extractTar(ctx, &contextReader{ctx: ctx, r: reader}, dst)
191206
}
192207

193208
// extractImageLayered extracts each image layer individually into the layer
194209
// cache, then composes them bottom-to-top into dst. Shared layers across
195210
// images are extracted only once.
196-
func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
211+
func extractImageLayered(ctx context.Context, img v1.Image, dst string, lc *LayerCache) error {
197212
layers, err := img.Layers()
198213
if err != nil {
199214
return fmt.Errorf("get image layers: %w", err)
@@ -220,7 +235,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
220235
remaining := &atomic.Int64{}
221236
remaining.Store(maxExtractSize)
222237

223-
g := new(errgroup.Group)
238+
g, gCtx := errgroup.WithContext(ctx)
224239
g.SetLimit(concurrency)
225240

226241
for i, layer := range layers {
@@ -232,7 +247,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
232247
}
233248

234249
g.Go(func() error {
235-
return extractLayerToCache(layer, diffID, lc, remaining)
250+
return extractLayerToCache(gCtx, layer, diffID, lc, remaining)
236251
})
237252
}
238253

@@ -258,7 +273,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
258273
// extractLayerToCache extracts a single layer into the layer cache.
259274
// The remaining counter is shared across concurrent layer extractions to
260275
// enforce a global size budget (prevents decompression bombs across layers).
261-
func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaining *atomic.Int64) error {
276+
func extractLayerToCache(ctx context.Context, layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaining *atomic.Int64) error {
262277
tmpDir, err := lc.TempDir()
263278
if err != nil {
264279
return fmt.Errorf("create temp dir for layer %s: %w", diffID.String(), err)
@@ -271,7 +286,7 @@ func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaini
271286
}
272287
defer func() { _ = rc.Close() }()
273288

274-
if err := extractTarSharedLimit(rc, tmpDir, remaining); err != nil {
289+
if err := extractTarSharedLimit(ctx, &contextReader{ctx: ctx, r: rc}, tmpDir, remaining); err != nil {
275290
_ = os.RemoveAll(tmpDir)
276291
return fmt.Errorf("extract layer %s: %w", diffID.String(), err)
277292
}
@@ -434,14 +449,19 @@ func copyFileToDir(src, target, rootDir string, mode os.FileMode) error {
434449
}
435450

436451
// extractTar reads a tar stream and extracts it to dst with security checks.
437-
func extractTar(r io.Reader, dst string) error {
452+
// The context is checked on each tar entry to support cancellation.
453+
func extractTar(ctx context.Context, r io.Reader, dst string) error {
438454
// Wrap in a LimitedReader to prevent decompression bombs.
439455
lr := &io.LimitedReader{R: r, N: maxExtractSize}
440456
tr := tar.NewReader(lr)
441457

442458
var entryCount int
443459

444460
for {
461+
if err := ctx.Err(); err != nil {
462+
return err
463+
}
464+
445465
hdr, err := tr.Next()
446466
if errors.Is(err, io.EOF) {
447467
break
@@ -469,7 +489,7 @@ func extractTar(r io.Reader, dst string) error {
469489
}
470490

471491
if entryCount == 0 {
472-
return fmt.Errorf("tar archive is empty or contains no valid entries")
492+
slog.Debug("tar archive has no entries, treating as empty layer")
473493
}
474494

475495
return nil
@@ -478,14 +498,19 @@ func extractTar(r io.Reader, dst string) error {
478498
// extractTarSharedLimit is like extractTar but uses a shared atomic counter
479499
// for the size budget. This enforces a global maxExtractSize across all layers
480500
// in a layered extraction, preventing decompression bombs via many layers.
481-
func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) error {
501+
// The context is checked on each tar entry to support cancellation.
502+
func extractTarSharedLimit(ctx context.Context, r io.Reader, dst string, remaining *atomic.Int64) error {
482503
// Use an atomicLimitReader that decrements the shared counter.
483504
alr := &atomicLimitReader{R: r, Remaining: remaining}
484505
tr := tar.NewReader(alr)
485506

486507
var entryCount int
487508

488509
for {
510+
if err := ctx.Err(); err != nil {
511+
return err
512+
}
513+
489514
hdr, err := tr.Next()
490515
if errors.Is(err, io.EOF) {
491516
break
@@ -512,7 +537,7 @@ func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) err
512537
}
513538

514539
if entryCount == 0 {
515-
return fmt.Errorf("tar archive is empty or contains no valid entries")
540+
slog.Debug("tar archive has no entries, treating as empty layer")
516541
}
517542

518543
return nil

0 commit comments

Comments
 (0)