@@ -6,12 +6,20 @@ package tools
66
77import (
88 "bytes"
9+ "context"
10+ "errors"
11+ "fmt"
12+ "io"
913 "os"
1014 "path/filepath"
15+ "strings"
1116 "testing"
1217
18+ cerrdefs "github.com/containerd/errdefs"
1319 "github.com/docker/docker/api/types/container"
20+ imagetypes "github.com/docker/docker/api/types/image"
1421 "github.com/docker/docker/api/types/mount"
22+ "github.com/docker/docker/client"
1523 "github.com/stretchr/testify/assert"
1624 "github.com/stretchr/testify/require"
1725)
@@ -317,3 +325,86 @@ func TestIsEmptyDirectory_NonExistent(t *testing.T) {
317325 _ , err := IsEmptyDirectory ("/nonexistent/path/12345" )
318326 assert .Error (t , err )
319327}
328+
329+ type inspectFn func (ctx context.Context , imageID string , opts ... client.ImageInspectOption ) (imagetypes.InspectResponse , error )
330+ type pullFn func (ctx context.Context , refStr string , options imagetypes.PullOptions ) (io.ReadCloser , error )
331+
332+ // fakePuller implements imagePuller for testing PullImage in isolation.
333+ type fakePuller struct {
334+ inspect inspectFn
335+ pull pullFn
336+ pullCalled bool
337+ }
338+
339+ func (f * fakePuller ) ImageInspect (ctx context.Context , imageID string , opts ... client.ImageInspectOption ) (imagetypes.InspectResponse , error ) {
340+ if f .inspect != nil {
341+ return f .inspect (ctx , imageID , opts ... )
342+ }
343+ return imagetypes.InspectResponse {}, errors .New ("inspect not implemented" )
344+ }
345+
346+ func (f * fakePuller ) ImagePull (ctx context.Context , refStr string , options imagetypes.PullOptions ) (io.ReadCloser , error ) {
347+ f .pullCalled = true
348+ if f .pull != nil {
349+ return f .pull (ctx , refStr , options )
350+ }
351+ return nil , errors .New ("pull not implemented" )
352+ }
353+
354+ func TestPullImage (t * testing.T ) {
355+ notFound := fmt .Errorf ("no such image: %w" , cerrdefs .ErrNotFound )
356+ wrappedNotFound := fmt .Errorf ("inspect failed: %w" , notFound )
357+
358+ localImage := func (context.Context , string , ... client.ImageInspectOption ) (imagetypes.InspectResponse , error ) {
359+ return imagetypes.InspectResponse {ID : "sha256:abc" }, nil
360+ }
361+ notFoundInspect := func (context.Context , string , ... client.ImageInspectOption ) (imagetypes.InspectResponse , error ) {
362+ return imagetypes.InspectResponse {}, notFound
363+ }
364+ wrappedNotFoundInspect := func (context.Context , string , ... client.ImageInspectOption ) (imagetypes.InspectResponse , error ) {
365+ return imagetypes.InspectResponse {}, wrappedNotFound
366+ }
367+ daemonDownInspect := func (context.Context , string , ... client.ImageInspectOption ) (imagetypes.InspectResponse , error ) {
368+ return imagetypes.InspectResponse {}, errors .New ("daemon down" )
369+ }
370+
371+ emptyStream := func (context.Context , string , imagetypes.PullOptions ) (io.ReadCloser , error ) {
372+ return io .NopCloser (strings .NewReader ("" )), nil
373+ }
374+ errorStream := func (context.Context , string , imagetypes.PullOptions ) (io.ReadCloser , error ) {
375+ body := `{"errorDetail":{"message":"manifest for postgresai/extended-postgres:99-bogus not found: manifest unknown"},"error":"manifest unknown"}`
376+ return io .NopCloser (strings .NewReader (body )), nil
377+ }
378+
379+ testCases := []struct {
380+ name string
381+ inspect inspectFn
382+ pull pullFn
383+ wantErr bool
384+ wantErrContains string
385+ wantPullCalled bool
386+ }{
387+ {name : "image already local skips pull" , inspect : localImage },
388+ {name : "not found triggers pull" , inspect : notFoundInspect , pull : emptyStream , wantPullCalled : true },
389+ {name : "wrapped not found triggers pull" , inspect : wrappedNotFoundInspect , pull : emptyStream , wantPullCalled : true },
390+ {name : "non-not-found inspect error propagates" , inspect : daemonDownInspect , wantErr : true , wantErrContains : "failed to inspect image" },
391+ {name : "pull stream error propagates" , inspect : notFoundInspect , pull : errorStream , wantErr : true , wantErrContains : "failed to pull image" , wantPullCalled : true },
392+ }
393+
394+ for _ , tc := range testCases {
395+ t .Run (tc .name , func (t * testing.T ) {
396+ puller := & fakePuller {inspect : tc .inspect , pull : tc .pull }
397+
398+ err := PullImage (context .Background (), puller , "postgresai/extended-postgres:test" )
399+
400+ if tc .wantErr {
401+ require .Error (t , err )
402+ assert .Contains (t , err .Error (), tc .wantErrContains )
403+ } else {
404+ require .NoError (t , err )
405+ }
406+
407+ assert .Equal (t , tc .wantPullCalled , puller .pullCalled )
408+ })
409+ }
410+ }
0 commit comments