diff --git a/image/pull.go b/image/pull.go index 572ca66..00e3cb6 100644 --- a/image/pull.go +++ b/image/pull.go @@ -111,20 +111,20 @@ func PullWithFetcher(ctx context.Context, imageRef string, cache *Cache, fetcher // extraction if layered extraction fails. if cache != nil { lc := cache.LayerCache() - if err := extractImageLayered(img, tmpDir, lc); err != nil { + if err := extractImageLayered(ctx, img, tmpDir, lc); err != nil { slog.Warn("layered extraction failed, falling back to flat extraction", "err", err) // Clean tmpDir contents before retrying with flat extraction. _ = os.RemoveAll(tmpDir) if tmpDir, err = cache.TempDir(); err != nil { return nil, fmt.Errorf("create temp dir for rootfs: %w", err) } - if err := extractImage(img, tmpDir); err != nil { + if err := extractImage(ctx, img, tmpDir); err != nil { _ = os.RemoveAll(tmpDir) return nil, fmt.Errorf("extract image layers: %w", err) } } } else { - if err := extractImage(img, tmpDir); err != nil { + if err := extractImage(ctx, img, tmpDir); err != nil { _ = os.RemoveAll(tmpDir) return nil, fmt.Errorf("extract image layers: %w", err) } @@ -180,20 +180,35 @@ func extractOCIConfig(img v1.Image) (*OCIConfig, error) { }, nil } +// contextReader wraps an io.Reader with context cancellation support. +// It checks the context before each Read call, enabling cancellation of +// long-running I/O operations (e.g. slow registry downloads). +type contextReader struct { + ctx context.Context + r io.Reader +} + +func (cr *contextReader) Read(p []byte) (int, error) { + if err := cr.ctx.Err(); err != nil { + return 0, err + } + return cr.r.Read(p) +} + // extractImage flattens all image layers into a single tar stream and extracts // it to the destination directory. It includes security measures against path // traversal, symlink attacks, and decompression bombs. -func extractImage(img v1.Image, dst string) error { +func extractImage(ctx context.Context, img v1.Image, dst string) error { reader := mutate.Extract(img) defer func() { _ = reader.Close() }() - return extractTar(reader, dst) + return extractTar(ctx, &contextReader{ctx: ctx, r: reader}, dst) } // extractImageLayered extracts each image layer individually into the layer // cache, then composes them bottom-to-top into dst. Shared layers across // images are extracted only once. -func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error { +func extractImageLayered(ctx context.Context, img v1.Image, dst string, lc *LayerCache) error { layers, err := img.Layers() if err != nil { return fmt.Errorf("get image layers: %w", err) @@ -220,7 +235,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error { remaining := &atomic.Int64{} remaining.Store(maxExtractSize) - g := new(errgroup.Group) + g, gCtx := errgroup.WithContext(ctx) g.SetLimit(concurrency) for i, layer := range layers { @@ -232,7 +247,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error { } g.Go(func() error { - return extractLayerToCache(layer, diffID, lc, remaining) + return extractLayerToCache(gCtx, layer, diffID, lc, remaining) }) } @@ -258,7 +273,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error { // extractLayerToCache extracts a single layer into the layer cache. // The remaining counter is shared across concurrent layer extractions to // enforce a global size budget (prevents decompression bombs across layers). -func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaining *atomic.Int64) error { +func extractLayerToCache(ctx context.Context, layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaining *atomic.Int64) error { tmpDir, err := lc.TempDir() if err != nil { return fmt.Errorf("create temp dir for layer %s: %w", diffID.String(), err) @@ -271,7 +286,7 @@ func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaini } defer func() { _ = rc.Close() }() - if err := extractTarSharedLimit(rc, tmpDir, remaining); err != nil { + if err := extractTarSharedLimit(ctx, &contextReader{ctx: ctx, r: rc}, tmpDir, remaining); err != nil { _ = os.RemoveAll(tmpDir) return fmt.Errorf("extract layer %s: %w", diffID.String(), err) } @@ -434,7 +449,8 @@ func copyFileToDir(src, target, rootDir string, mode os.FileMode) error { } // extractTar reads a tar stream and extracts it to dst with security checks. -func extractTar(r io.Reader, dst string) error { +// The context is checked on each tar entry to support cancellation. +func extractTar(ctx context.Context, r io.Reader, dst string) error { // Wrap in a LimitedReader to prevent decompression bombs. lr := &io.LimitedReader{R: r, N: maxExtractSize} tr := tar.NewReader(lr) @@ -442,6 +458,10 @@ func extractTar(r io.Reader, dst string) error { var entryCount int for { + if err := ctx.Err(); err != nil { + return err + } + hdr, err := tr.Next() if errors.Is(err, io.EOF) { break @@ -469,7 +489,7 @@ func extractTar(r io.Reader, dst string) error { } if entryCount == 0 { - return fmt.Errorf("tar archive is empty or contains no valid entries") + slog.Debug("tar archive has no entries, treating as empty layer") } return nil @@ -478,7 +498,8 @@ func extractTar(r io.Reader, dst string) error { // extractTarSharedLimit is like extractTar but uses a shared atomic counter // for the size budget. This enforces a global maxExtractSize across all layers // in a layered extraction, preventing decompression bombs via many layers. -func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) error { +// The context is checked on each tar entry to support cancellation. +func extractTarSharedLimit(ctx context.Context, r io.Reader, dst string, remaining *atomic.Int64) error { // Use an atomicLimitReader that decrements the shared counter. alr := &atomicLimitReader{R: r, Remaining: remaining} tr := tar.NewReader(alr) @@ -486,6 +507,10 @@ func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) err var entryCount int for { + if err := ctx.Err(); err != nil { + return err + } + hdr, err := tr.Next() if errors.Is(err, io.EOF) { break @@ -512,7 +537,7 @@ func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) err } if entryCount == 0 { - return fmt.Errorf("tar archive is empty or contains no valid entries") + slog.Debug("tar archive has no entries, treating as empty layer") } return nil diff --git a/image/pull_test.go b/image/pull_test.go index 0e72513..06975b8 100644 --- a/image/pull_test.go +++ b/image/pull_test.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "syscall" "testing" @@ -167,7 +168,7 @@ func TestExtractTar_DirectoriesAndFiles(t *testing.T) { buf := createTarBuffer(t, entries) dst := t.TempDir() - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // Verify directory was created. @@ -212,7 +213,7 @@ func TestExtractTar_Symlinks(t *testing.T) { buf := createTarBuffer(t, entries) dst := t.TempDir() - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // Verify the symlink exists and points to the right target. @@ -304,7 +305,7 @@ func TestExtractTar_SkipsPathTraversal(t *testing.T) { require.NoError(t, err) dst := t.TempDir() - err = extractTar(&buf, dst) + err = extractTar(context.Background(), &buf, dst) require.NoError(t, err) // The malicious entry should not have been extracted. @@ -356,7 +357,7 @@ func TestExtractTar_PreservesOwnershipBestEffort(t *testing.T) { buf := createTarBuffer(t, entries) dst := t.TempDir() - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // Verify files were extracted regardless of uid/gid. @@ -544,7 +545,7 @@ func TestExtractHardlink_ValidLink(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) origInfo, err := os.Lstat(filepath.Join(dst, "original.txt")) @@ -567,7 +568,7 @@ func TestExtractHardlink_SourceOutsideRootfs(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "hardlink") assert.Contains(t, err.Error(), "outside rootfs") @@ -587,7 +588,7 @@ func TestExtractHardlink_SourceIsSymlink(t *testing.T) { } buf := createTarBuffer(t, entries) - err = extractTar(buf, dst) + err = extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "refusing hardlink to symlink") } @@ -606,7 +607,7 @@ func TestExtractHardlink_SourceIsDirectory(t *testing.T) { } buf := createTarBuffer(t, entries) - err = extractTar(buf, dst) + err = extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "refusing hardlink to non-regular file") } @@ -621,7 +622,7 @@ func TestExtractHardlink_SourceNotExtracted(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "stat hardlink source") } @@ -644,7 +645,7 @@ func TestExtractHardlink_TargetIsExistingSymlink(t *testing.T) { } buf := createTarBuffer(t, entries) - err = extractTar(buf, dst) + err = extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "refusing to write through symlink") } @@ -663,7 +664,7 @@ func TestExtractSymlink_AbsoluteEscapeAttempt(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "points outside rootfs") } @@ -678,7 +679,7 @@ func TestExtractSymlink_RelativeEscapeAttempt(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "points outside rootfs") } @@ -696,7 +697,7 @@ func TestExtractSymlink_ReplacesExistingFile(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // "overwrite" should now be a symlink to "target.txt". @@ -721,7 +722,7 @@ func TestExtractSymlink_RefusesToReplaceDirectory(t *testing.T) { } buf := createTarBuffer(t, entries) - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.Error(t, err) assert.Contains(t, err.Error(), "refusing to replace directory with symlink") } @@ -837,15 +838,20 @@ func TestExtractTar_EmptyArchive(t *testing.T) { t.Parallel() // Create an empty tar archive (no entries). + // Empty layers are legitimate in OCI images (produced by ENV, LABEL, CMD, etc.). var buf bytes.Buffer tw := tar.NewWriter(&buf) err := tw.Close() require.NoError(t, err) dst := t.TempDir() - err = extractTar(&buf, dst) - require.Error(t, err) - assert.Contains(t, err.Error(), "empty or contains no valid entries") + err = extractTar(context.Background(), &buf, dst) + require.NoError(t, err) + + // Destination directory should be empty (no files extracted). + entries, err := os.ReadDir(dst) + require.NoError(t, err) + assert.Empty(t, entries) } func TestExtractTar_UnsupportedEntryType(t *testing.T) { @@ -860,7 +866,7 @@ func TestExtractTar_UnsupportedEntryType(t *testing.T) { buf := createTarBuffer(t, entries) dst := t.TempDir() - err := extractTar(buf, dst) + err := extractTar(context.Background(), buf, dst) require.NoError(t, err) // FIFO should not exist. @@ -919,7 +925,7 @@ func TestExtractImageLayered_BasicExtraction(t *testing.T) { lc := NewLayerCache(cacheDir) dst := t.TempDir() - err := extractImageLayered(img, dst, lc) + err := extractImageLayered(context.Background(), img, dst, lc) require.NoError(t, err) // Verify files from both layers are present. @@ -956,7 +962,7 @@ func TestExtractImageLayered_SharedLayers(t *testing.T) { // Extract first image. dst1 := t.TempDir() - err := extractImageLayered(img1, dst1, lc) + err := extractImageLayered(context.Background(), img1, dst1, lc) require.NoError(t, err) // Get the base layer DiffID to check cache state. @@ -969,7 +975,7 @@ func TestExtractImageLayered_SharedLayers(t *testing.T) { // Extract second image — base layer should be a cache hit. dst2 := t.TempDir() - err = extractImageLayered(img2, dst2, lc) + err = extractImageLayered(context.Background(), img2, dst2, lc) require.NoError(t, err) // Verify both images have the shared base content. @@ -1017,7 +1023,7 @@ func TestExtractImageLayered_Whiteouts(t *testing.T) { lc := NewLayerCache(cacheDir) dst := t.TempDir() - err := extractImageLayered(img, dst, lc) + err := extractImageLayered(context.Background(), img, dst, lc) require.NoError(t, err) // keep.txt should still exist. @@ -1059,7 +1065,7 @@ func TestExtractImageLayered_UpperLayerOverridesFile(t *testing.T) { lc := NewLayerCache(cacheDir) dst := t.TempDir() - err := extractImageLayered(img, dst, lc) + err := extractImageLayered(context.Background(), img, dst, lc) require.NoError(t, err) // Upper layer should win. @@ -1083,7 +1089,7 @@ func TestExtractImageLayered_Symlinks(t *testing.T) { lc := NewLayerCache(cacheDir) dst := t.TempDir() - err := extractImageLayered(img, dst, lc) + err := extractImageLayered(context.Background(), img, dst, lc) require.NoError(t, err) target, err := os.Readlink(filepath.Join(dst, "usr", "bin", "link")) @@ -1128,3 +1134,93 @@ func TestPullWithFetcher_LayeredExtraction(t *testing.T) { } assert.Equal(t, 2, len(layerEntries), "should have 2 cached layers") } + +// --------------------------------------------------------------------------- +// Empty layer tests (Bug 1: empty OCI layers are valid no-ops) +// --------------------------------------------------------------------------- + +func TestExtractTarSharedLimit_EmptyArchive(t *testing.T) { + t.Parallel() + + // Empty tar archive — should succeed (empty layers are valid OCI artifacts). + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + err := tw.Close() + require.NoError(t, err) + + remaining := &atomic.Int64{} + remaining.Store(maxExtractSize) + + dst := t.TempDir() + err = extractTarSharedLimit(context.Background(), &buf, dst, remaining) + require.NoError(t, err) + + entries, err := os.ReadDir(dst) + require.NoError(t, err) + assert.Empty(t, entries) +} + +func TestExtractImageLayered_EmptyLayers(t *testing.T) { + t.Parallel() + + // Create a real layer with content + an empty layer (simulates ENV/LABEL/CMD). + realLayer := createLayerFromEntries(t, []tarEntry{ + {name: "app/", typeflag: tar.TypeDir, mode: 0o755}, + {name: "app/main", typeflag: tar.TypeReg, mode: 0o755, content: "#!/bin/sh\necho hello"}, + }) + + emptyLayer := createLayerFromEntries(t, nil) + + img := createImageFromLayers(t, realLayer, emptyLayer) + + cacheDir := t.TempDir() + lc := NewLayerCache(cacheDir) + dst := t.TempDir() + + err := extractImageLayered(context.Background(), img, dst, lc) + require.NoError(t, err) + + // Verify the real layer's content is present. + data, err := os.ReadFile(filepath.Join(dst, "app", "main")) + require.NoError(t, err) + assert.Equal(t, "#!/bin/sh\necho hello", string(data)) +} + +// --------------------------------------------------------------------------- +// Context cancellation tests (Bug 2: extraction respects context) +// --------------------------------------------------------------------------- + +func TestExtractTar_ContextCancellation(t *testing.T) { + t.Parallel() + + entries := []tarEntry{ + {name: "file.txt", typeflag: tar.TypeReg, mode: 0o644, content: "data"}, + } + buf := createTarBuffer(t, entries) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately. + + dst := t.TempDir() + err := extractTar(ctx, buf, dst) + require.ErrorIs(t, err, context.Canceled) +} + +func TestExtractTarSharedLimit_ContextCancellation(t *testing.T) { + t.Parallel() + + entries := []tarEntry{ + {name: "file.txt", typeflag: tar.TypeReg, mode: 0o644, content: "data"}, + } + buf := createTarBuffer(t, entries) + + remaining := &atomic.Int64{} + remaining.Store(maxExtractSize) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately. + + dst := t.TempDir() + err := extractTarSharedLimit(ctx, buf, dst, remaining) + require.ErrorIs(t, err, context.Canceled) +}