@@ -111,20 +111,20 @@ func PullWithFetcher(ctx context.Context, imageRef string, cache *Cache, fetcher
111111 // extraction if layered extraction fails.
112112 if cache != nil {
113113 lc := cache .LayerCache ()
114- if err := extractImageLayered (img , tmpDir , lc ); err != nil {
114+ if err := extractImageLayered (ctx , img , tmpDir , lc ); err != nil {
115115 slog .Warn ("layered extraction failed, falling back to flat extraction" , "err" , err )
116116 // Clean tmpDir contents before retrying with flat extraction.
117117 _ = os .RemoveAll (tmpDir )
118118 if tmpDir , err = cache .TempDir (); err != nil {
119119 return nil , fmt .Errorf ("create temp dir for rootfs: %w" , err )
120120 }
121- if err := extractImage (img , tmpDir ); err != nil {
121+ if err := extractImage (ctx , img , tmpDir ); err != nil {
122122 _ = os .RemoveAll (tmpDir )
123123 return nil , fmt .Errorf ("extract image layers: %w" , err )
124124 }
125125 }
126126 } else {
127- if err := extractImage (img , tmpDir ); err != nil {
127+ if err := extractImage (ctx , img , tmpDir ); err != nil {
128128 _ = os .RemoveAll (tmpDir )
129129 return nil , fmt .Errorf ("extract image layers: %w" , err )
130130 }
@@ -180,20 +180,35 @@ func extractOCIConfig(img v1.Image) (*OCIConfig, error) {
180180 }, nil
181181}
182182
183+ // contextReader wraps an io.Reader with context cancellation support.
184+ // It checks the context before each Read call, enabling cancellation of
185+ // long-running I/O operations (e.g. slow registry downloads).
186+ type contextReader struct {
187+ ctx context.Context
188+ r io.Reader
189+ }
190+
191+ func (cr * contextReader ) Read (p []byte ) (int , error ) {
192+ if err := cr .ctx .Err (); err != nil {
193+ return 0 , err
194+ }
195+ return cr .r .Read (p )
196+ }
197+
183198// extractImage flattens all image layers into a single tar stream and extracts
184199// it to the destination directory. It includes security measures against path
185200// traversal, symlink attacks, and decompression bombs.
186- func extractImage (img v1.Image , dst string ) error {
201+ func extractImage (ctx context. Context , img v1.Image , dst string ) error {
187202 reader := mutate .Extract (img )
188203 defer func () { _ = reader .Close () }()
189204
190- return extractTar (reader , dst )
205+ return extractTar (ctx , & contextReader { ctx : ctx , r : reader } , dst )
191206}
192207
193208// extractImageLayered extracts each image layer individually into the layer
194209// cache, then composes them bottom-to-top into dst. Shared layers across
195210// images are extracted only once.
196- func extractImageLayered (img v1.Image , dst string , lc * LayerCache ) error {
211+ func extractImageLayered (ctx context. Context , img v1.Image , dst string , lc * LayerCache ) error {
197212 layers , err := img .Layers ()
198213 if err != nil {
199214 return fmt .Errorf ("get image layers: %w" , err )
@@ -220,7 +235,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
220235 remaining := & atomic.Int64 {}
221236 remaining .Store (maxExtractSize )
222237
223- g := new ( errgroup.Group )
238+ g , gCtx := errgroup .WithContext ( ctx )
224239 g .SetLimit (concurrency )
225240
226241 for i , layer := range layers {
@@ -232,7 +247,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
232247 }
233248
234249 g .Go (func () error {
235- return extractLayerToCache (layer , diffID , lc , remaining )
250+ return extractLayerToCache (gCtx , layer , diffID , lc , remaining )
236251 })
237252 }
238253
@@ -258,7 +273,7 @@ func extractImageLayered(img v1.Image, dst string, lc *LayerCache) error {
258273// extractLayerToCache extracts a single layer into the layer cache.
259274// The remaining counter is shared across concurrent layer extractions to
260275// enforce a global size budget (prevents decompression bombs across layers).
261- func extractLayerToCache (layer v1.Layer , diffID v1.Hash , lc * LayerCache , remaining * atomic.Int64 ) error {
276+ func extractLayerToCache (ctx context. Context , layer v1.Layer , diffID v1.Hash , lc * LayerCache , remaining * atomic.Int64 ) error {
262277 tmpDir , err := lc .TempDir ()
263278 if err != nil {
264279 return fmt .Errorf ("create temp dir for layer %s: %w" , diffID .String (), err )
@@ -271,7 +286,7 @@ func extractLayerToCache(layer v1.Layer, diffID v1.Hash, lc *LayerCache, remaini
271286 }
272287 defer func () { _ = rc .Close () }()
273288
274- if err := extractTarSharedLimit (rc , tmpDir , remaining ); err != nil {
289+ if err := extractTarSharedLimit (ctx , & contextReader { ctx : ctx , r : rc } , tmpDir , remaining ); err != nil {
275290 _ = os .RemoveAll (tmpDir )
276291 return fmt .Errorf ("extract layer %s: %w" , diffID .String (), err )
277292 }
@@ -434,14 +449,19 @@ func copyFileToDir(src, target, rootDir string, mode os.FileMode) error {
434449}
435450
436451// extractTar reads a tar stream and extracts it to dst with security checks.
437- func extractTar (r io.Reader , dst string ) error {
452+ // The context is checked on each tar entry to support cancellation.
453+ func extractTar (ctx context.Context , r io.Reader , dst string ) error {
438454 // Wrap in a LimitedReader to prevent decompression bombs.
439455 lr := & io.LimitedReader {R : r , N : maxExtractSize }
440456 tr := tar .NewReader (lr )
441457
442458 var entryCount int
443459
444460 for {
461+ if err := ctx .Err (); err != nil {
462+ return err
463+ }
464+
445465 hdr , err := tr .Next ()
446466 if errors .Is (err , io .EOF ) {
447467 break
@@ -469,7 +489,7 @@ func extractTar(r io.Reader, dst string) error {
469489 }
470490
471491 if entryCount == 0 {
472- return fmt . Errorf ("tar archive is empty or contains no valid entries " )
492+ slog . Debug ("tar archive has no entries, treating as empty layer " )
473493 }
474494
475495 return nil
@@ -478,14 +498,19 @@ func extractTar(r io.Reader, dst string) error {
478498// extractTarSharedLimit is like extractTar but uses a shared atomic counter
479499// for the size budget. This enforces a global maxExtractSize across all layers
480500// in a layered extraction, preventing decompression bombs via many layers.
481- func extractTarSharedLimit (r io.Reader , dst string , remaining * atomic.Int64 ) error {
501+ // The context is checked on each tar entry to support cancellation.
502+ func extractTarSharedLimit (ctx context.Context , r io.Reader , dst string , remaining * atomic.Int64 ) error {
482503 // Use an atomicLimitReader that decrements the shared counter.
483504 alr := & atomicLimitReader {R : r , Remaining : remaining }
484505 tr := tar .NewReader (alr )
485506
486507 var entryCount int
487508
488509 for {
510+ if err := ctx .Err (); err != nil {
511+ return err
512+ }
513+
489514 hdr , err := tr .Next ()
490515 if errors .Is (err , io .EOF ) {
491516 break
@@ -512,7 +537,7 @@ func extractTarSharedLimit(r io.Reader, dst string, remaining *atomic.Int64) err
512537 }
513538
514539 if entryCount == 0 {
515- return fmt . Errorf ("tar archive is empty or contains no valid entries " )
540+ slog . Debug ("tar archive has no entries, treating as empty layer " )
516541 }
517542
518543 return nil
0 commit comments