From 64f2048dc3f177f20a86503b678d81d313c3d6ac Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Fri, 3 Apr 2026 17:38:31 -0400 Subject: [PATCH] Fix empty OCI layer extraction and add context cancellation Skip empty tar archives (0-entry layers from ENV/LABEL/CMD instructions) instead of treating them as errors, which was causing layered extraction to fail and fall back to flat extraction. Thread context.Context through all extraction functions so slow or stalled operations can be cancelled. Closes #56 Co-Authored-By: Claude Opus 4.6 (1M context) --- image/pull.go | 53 ++++++++++++----- image/pull_test.go | 144 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 159 insertions(+), 38 deletions(-) diff --git a/image/pull.go b/image/pull.go index 572ca66..00e3cb6 100644 --- a/image/pull.go +++ b/image/pull.go @@ -111,20 +111,20 @@ func PullWithFetcher(ctx context.Context, imageRef string, cache *Cache, fetcher // extraction if layered extraction fails. if cache != nil { lc := cache.LayerCache() - if err := extractImageLayered(img, tmpDir, lc); err != nil { + if err := extractImageLayered(ctx, img, tmpDir, lc); err != nil { slog.Warn("layered extraction failed, falling back to flat extraction", "err", err) // Clean tmpDir contents before retrying with flat extraction. _ = os.RemoveAll(tmpDir) if tmpDir, err = cache.TempDir(); err != nil { return nil, fmt.Errorf("create temp dir for rootfs: %w", err) } - if err := extractImage(img, tmpDir); err != nil { + if err := extractImage(ctx, img, tmpDir); err != nil { _ = os.RemoveAll(tmpDir) return nil, fmt.Errorf("extract image layers: %w", err) } } } else { - if err := extractImage(img, tmpDir); err != nil { + if err := extractImage(ctx, img, tmpDir); err != nil { _ = os.RemoveAll(tmpDir) return nil, fmt.Errorf("extract image layers: %w", err) } @@ -180,20 +180,35 @@ func extractOCIConfig(img v1.Image) (*OCIConfig, error) { }, nil } +// contextReader wraps an io.Reader with context cancellation support. +// It checks the context before each Read call, enabling cancellation of +// long-running I/O operations (e.g. slow registry downloads). +type contextReader struct { + ctx context.Context + r io.Reader +} + +func (cr *contextReader) Read(p []byte) (int, error) { + if err := cr.ctx.Err(); err != nil { + return 0, err + } + return cr.r.Read(p) +} + // extractImage flattens all image layers into a single tar stream and extracts // it to the destination directory. It includes security measures against path // traversal, symlink attacks, and decompression bombs. -func extractImage(img v1.Image, dst string) error { +func extractImage(ctx context.Context, img v1.Image, dst string) error { reader := mutate.Extract(img) defer func() { _ = reader.Close() }() - return extractTar(reader, dst) + return extractTar(ctx, &contextReader{ctx: ctx, r: reader}, dst) } // extractImageLayered extracts each image layer individually into the layer // cache, then composes them bottom-to-top into dst. Shared layers across // images are extracted only once. -func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error { +func extractImageLayered(ctx context.Context, img v1.Image, dst string, lc *LayerCache) error { layers, err := img.Layers() if err != nil { return fmt.Errorf("get image layers: %w", err) @@ -220,7 +235,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error { remaining := &atomic.Int64{} remaining.Store(maxExtractSize) - g := new(errgroup.Group) + g, gCtx := errgroup.WithContext(ctx) g.SetLimit(concurrency) for i, layer := range layers { @@ -232,7 +247,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error { } g.Go(func() error { - return extractLayerToCache(layer, diffID, lc, remaining) + return extractLayerToCache(gCtx, layer, diffID, lc, remaining) }) } @@ -258,7 +273,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error { // extractLayerToCache extracts a single layer into the layer cache. // The remaining counter is shared across concurrent layer extractions to // enforce a global size budget (prevents decompression bombs across layers). -func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaining *atomic.Int64) error { +func extractLayerToCache(ctx context.Context, layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaining *atomic.Int64) error { tmpDir, err := lc.TempDir() if err != nil { return fmt.Errorf("create temp dir for layer %s: %w", diffID.String(), err) @@ -271,7 +286,7 @@ func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaini } defer func() { _ = rc.Close() }() - if err := extractTarSharedLimit(rc, tmpDir, remaining); err != nil { + if err := extractTarSharedLimit(ctx, &contextReader{ctx: ctx, r: rc}, tmpDir, remaining); err != nil { _ = os.RemoveAll(tmpDir) return fmt.Errorf("extract layer %s: %w", diffID.String(), err) } @@ -434,7 +449,8 @@ func copyFileToDir(src, target, rootDir string, mode os.FileMode) error { } // extractTar reads a tar stream and extracts it to dst with security checks. -func extractTar(r io.Reader, dst string) error { +// The context is checked on each tar entry to support cancellation. +func extractTar(ctx context.Context, r io.Reader, dst string) error { // Wrap in a LimitedReader to prevent decompression bombs. lr := &io.LimitedReader{R: r, N: maxExtractSize} tr := tar.NewReader(lr) @@ -442,6 +458,10 @@ func extractTar(r io.Reader, dst string) error { var entryCount int for { + if err := ctx.Err(); err != nil { + return err + } + hdr, err := tr.Next() if errors.Is(err, io.EOF) { break @@ -469,7 +489,7 @@ func extractTar(r io.Reader, dst string) error { } if entryCount == 0 { - return fmt.Errorf("tar archive is empty or contains no valid entries") + slog.Debug("tar archive has no entries, treating as empty layer") } return nil @@ -478,7 +498,8 @@ func extractTar(r io.Reader, dst string) error { // extractTarSharedLimit is like extractTar but uses a shared atomic counter // for the size budget. This enforces a global maxExtractSize across all layers // in a layered extraction, preventing decompression bombs via many layers. -func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) error { +// The context is checked on each tar entry to support cancellation. +func extractTarSharedLimit(ctx context.Context, r io.Reader, dst string, remaining *atomic.Int64) error { // Use an atomicLimitReader that decrements the shared counter. alr := &atomicLimitReader{R: r, Remaining: remaining} tr := tar.NewReader(alr) @@ -486,6 +507,10 @@ func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) err var entryCount int for { + if err := ctx.Err(); err != nil { + return err + } + hdr, err := tr.Next() if errors.Is(err, io.EOF) { break @@ -512,7 +537,7 @@ func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) err } if entryCount == 0 { - return fmt.Errorf("tar archive is empty or contains no valid entries") + slog.Debug("tar archive has no entries, treating as empty layer") } return nil diff --git a/image/pull_test.go b/image/pull_test.go index 0e72513..06975b8 100644 --- a/image/pull_test.go +++ b/image/pull_test.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "syscall" "testing" @@ -167,7 +168,7 @@ func TestExtractTar_DirectoriesAndFiles(t *testing.T) { buf := createTarBuffer(t, entries) dst := t.TempDir() - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // Verify directory was created. @@ -212,7 +213,7 @@ func TestExtractTar_Symlinks(t *testing.T) { buf := createTarBuffer(t, entries) dst := t.TempDir() - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // Verify the symlink exists and points to the right target. @@ -304,7 +305,7 @@ func TestExtractTar_SkipsPathTraversal(t *testing.T) { require.NoError(t, err) dst := t.TempDir() - err = extractTar(&buf, dst) + err = extractTar(context.Background(), &buf, dst) require.NoError(t, err) // The malicious entry should not have been extracted. @@ -356,7 +357,7 @@ func TestExtractTar_PreservesOwnershipBestEffort(t *testing.T) { buf := createTarBuffer(t, entries) dst := t.TempDir() - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // Verify files were extracted regardless of uid/gid. @@ -544,7 +545,7 @@ func TestExtractHardlink_ValidLink(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) origInfo, err := os.Lstat(filepath.Join(dst, "original.txt")) @@ -567,7 +568,7 @@ func TestExtractHardlink_SourceOutsideRootfs(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "hardlink") assert.Contains(t, err.Error(), "outside rootfs") @@ -587,7 +588,7 @@ func TestExtractHardlink_SourceIsSymlink(t *testing.T) { } buf := createTarBuffer(t, entries) - err = extractTar(buf, dst) + err = extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "refusing hardlink to symlink") } @@ -606,7 +607,7 @@ func TestExtractHardlink_SourceIsDirectory(t *testing.T) { } buf := createTarBuffer(t, entries) - err = extractTar(buf, dst) + err = extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "refusing hardlink to non-regular file") } @@ -621,7 +622,7 @@ func TestExtractHardlink_SourceNotExtracted(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "stat hardlink source") } @@ -644,7 +645,7 @@ func TestExtractHardlink_TargetIsExistingSymlink(t *testing.T) { } buf := createTarBuffer(t, entries) - err = extractTar(buf, dst) + err = extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "refusing to write through symlink") } @@ -663,7 +664,7 @@ func TestExtractSymlink_AbsoluteEscapeAttempt(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "points outside rootfs") } @@ -678,7 +679,7 @@ func TestExtractSymlink_RelativeEscapeAttempt(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "points outside rootfs") } @@ -696,7 +697,7 @@ func TestExtractSymlink_ReplacesExistingFile(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // "overwrite" should now be a symlink to "target.txt". @@ -721,7 +722,7 @@ func TestExtractSymlink_RefusesToReplaceDirectory(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "refusing to replace directory with symlink") } @@ -837,15 +838,20 @@ func TestExtractTar_EmptyArchive(t *testing.T) { t.Parallel() // Create an empty tar archive (no entries). + // Empty layers are legitimate in OCI images (produced by ENV, LABEL, CMD, etc.). var buf bytes.Buffer tw := tar.NewWriter(&buf) err := tw.Close() require.NoError(t, err) dst := t.TempDir() - err = extractTar(&buf, dst) - require.Error(t, err) - assert.Contains(t, err.Error(), "empty or contains no valid entries") + err = extractTar(context.Background(), &buf, dst) + require.NoError(t, err) + + // Destination directory should be empty (no files extracted). + entries, err := os.ReadDir(dst) + require.NoError(t, err) + assert.Empty(t, entries) } func TestExtractTar_UnsupportedEntryType(t *testing.T) { @@ -860,7 +866,7 @@ func TestExtractTar_UnsupportedEntryType(t *testing.T) { buf := createTarBuffer(t, entries) dst := t.TempDir() - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // FIFO should not exist. @@ -919,7 +925,7 @@ func TestExtractImageLayered_BasicExtraction(t *testing.T) { lc := NewLayerCache(cacheDir) dst := t.TempDir() - err := extractImageLayered(img, dst, lc) + err := extractImageLayered(context.Background(), img, dst, lc) require.NoError(t, err) // Verify files from both layers are present. @@ -956,7 +962,7 @@ func TestExtractImageLayered_SharedLayers(t *testing.T) { // Extract first image. dst1 := t.TempDir() - err := extractImageLayered(img1, dst1, lc) + err := extractImageLayered(context.Background(), img1, dst1, lc) require.NoError(t, err) // Get the base layer DiffID to check cache state. @@ -969,7 +975,7 @@ func TestExtractImageLayered_SharedLayers(t *testing.T) { // Extract second image — base layer should be a cache hit. dst2 := t.TempDir() - err = extractImageLayered(img2, dst2, lc) + err = extractImageLayered(context.Background(), img2, dst2, lc) require.NoError(t, err) // Verify both images have the shared base content. @@ -1017,7 +1023,7 @@ func TestExtractImageLayered_Whiteouts(t *testing.T) { lc := NewLayerCache(cacheDir) dst := t.TempDir() - err := extractImageLayered(img, dst, lc) + err := extractImageLayered(context.Background(), img, dst, lc) require.NoError(t, err) // keep.txt should still exist. @@ -1059,7 +1065,7 @@ func TestExtractImageLayered_UpperLayerOverridesFile(t *testing.T) { lc := NewLayerCache(cacheDir) dst := t.TempDir() - err := extractImageLayered(img, dst, lc) + err := extractImageLayered(context.Background(), img, dst, lc) require.NoError(t, err) // Upper layer should win. @@ -1083,7 +1089,7 @@ func TestExtractImageLayered_Symlinks(t *testing.T) { lc := NewLayerCache(cacheDir) dst := t.TempDir() - err := extractImageLayered(img, dst, lc) + err := extractImageLayered(context.Background(), img, dst, lc) require.NoError(t, err) target, err := os.Readlink(filepath.Join(dst, "usr", "bin", "link")) @@ -1128,3 +1134,93 @@ func TestPullWithFetcher_LayeredExtraction(t *testing.T) { } assert.Equal(t, 2, len(layerEntries), "should have 2 cached layers") } + +// --------------------------------------------------------------------------- +// Empty layer tests (Bug 1: empty OCI layers are valid no-ops) +// --------------------------------------------------------------------------- + +func TestExtractTarSharedLimit_EmptyArchive(t *testing.T) { + t.Parallel() + + // Empty tar archive — should succeed (empty layers are valid OCI artifacts). + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + err := tw.Close() + require.NoError(t, err) + + remaining := &atomic.Int64{} + remaining.Store(maxExtractSize) + + dst := t.TempDir() + err = extractTarSharedLimit(context.Background(), &buf, dst, remaining) + require.NoError(t, err) + + entries, err := os.ReadDir(dst) + require.NoError(t, err) + assert.Empty(t, entries) +} + +func TestExtractImageLayered_EmptyLayers(t *testing.T) { + t.Parallel() + + // Create a real layer with content + an empty layer (simulates ENV/LABEL/CMD). + realLayer := createLayerFromEntries(t, []tarEntry{ + {name: "app/", typeflag: tar.TypeDir, mode: 0o755}, + {name: "app/main", typeflag: tar.TypeReg, mode: 0o755, content: "#!/bin/sh\necho hello"}, + }) + + emptyLayer := createLayerFromEntries(t, nil) + + img := createImageFromLayers(t, realLayer, emptyLayer) + + cacheDir := t.TempDir() + lc := NewLayerCache(cacheDir) + dst := t.TempDir() + + err := extractImageLayered(context.Background(), img, dst, lc) + require.NoError(t, err) + + // Verify the real layer's content is present. + data, err := os.ReadFile(filepath.Join(dst, "app", "main")) + require.NoError(t, err) + assert.Equal(t, "#!/bin/sh\necho hello", string(data)) +} + +// --------------------------------------------------------------------------- +// Context cancellation tests (Bug 2: extraction respects context) +// --------------------------------------------------------------------------- + +func TestExtractTar_ContextCancellation(t *testing.T) { + t.Parallel() + + entries := []tarEntry{ + {name: "file.txt", typeflag: tar.TypeReg, mode: 0o644, content: "data"}, + } + buf := createTarBuffer(t, entries) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately. + + dst := t.TempDir() + err := extractTar(ctx, buf, dst) + require.ErrorIs(t, err, context.Canceled) +} + +func TestExtractTarSharedLimit_ContextCancellation(t *testing.T) { + t.Parallel() + + entries := []tarEntry{ + {name: "file.txt", typeflag: tar.TypeReg, mode: 0o644, content: "data"}, + } + buf := createTarBuffer(t, entries) + + remaining := &atomic.Int64{} + remaining.Store(maxExtractSize) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately. + + dst := t.TempDir() + err := extractTarSharedLimit(ctx, buf, dst, remaining) + require.ErrorIs(t, err, context.Canceled) +}