Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmd/cli/commands/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ func handleClientError(err error, message string) error {
var buf bytes.Buffer
printNextSteps(&buf, []string{enableVLLM})
return fmt.Errorf("%w\n%s", err, strings.TrimRight(buf.String(), "\n"))
} else if strings.Contains(err.Error(), "try upgrading") {
// The model uses a newer config format than this client supports.
var buf bytes.Buffer
printNextSteps(&buf, []string{
"Upgrade Docker Desktop to the latest version to support this model",
})
return fmt.Errorf("%s: %w\n%s", message, err, strings.TrimRight(buf.String(), "\n"))
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
}
return fmt.Errorf("%s: %w", message, err)
}
Expand Down
12 changes: 8 additions & 4 deletions cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,13 @@ func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, b

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
err := fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body))
// Only retry on server errors (5xx), not client errors (4xx)
shouldRetry := resp.StatusCode >= 500 && resp.StatusCode < 600
err := fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, strings.TrimSpace(string(body)))
Comment thread
ilopezluna marked this conversation as resolved.
Outdated
// Only retry on gateway/proxy errors (502, 503, 504).
// Do not retry 500 (usually deterministic server errors) or
// 4xx (client errors including 422 for unsupported media type).
shouldRetry := resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable ||
resp.StatusCode == http.StatusGatewayTimeout
return "", false, err, shouldRetry
}

Expand Down Expand Up @@ -235,7 +239,7 @@ func (c *Client) withRetries(
}
}

return "", progressShown, fmt.Errorf("failed to %s after %d retries: %w", operationName, maxRetries, lastErr)
return "", progressShown, fmt.Errorf("%w (failed after %d retries)", lastErr, maxRetries)
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
}

func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) {
Expand Down
76 changes: 71 additions & 5 deletions cmd/cli/desktop/desktop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"io"
"net/http"
"strings"
"testing"

mockdesktop "github.com/docker/model-runner/cmd/cli/mocks"
Expand Down Expand Up @@ -59,7 +60,7 @@ func TestPullNoRetryOn4xxError(t *testing.T) {
assert.Contains(t, err.Error(), "Model not found")
}

func TestPullRetryOn5xxError(t *testing.T) {
func TestPullNoRetryOn500Error(t *testing.T) {
Comment thread
ilopezluna marked this conversation as resolved.
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand All @@ -68,11 +69,55 @@ func TestPullRetryOn5xxError(t *testing.T) {
mockContext := NewContextForMock(mockClient)
client := New(mockContext)

// First attempt fails with 500, second succeeds
// 500 is not retried (deterministic server error), so only 1 call.
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewBufferString("Internal server error")),
}, nil).Times(1)

printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull(modelName, printer)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Internal server error")
}

func TestPullNoRetryOn422Error(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

modelName := "test-model"
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
mockContext := NewContextForMock(mockClient)
client := New(mockContext)

// 422 (unsupported media type) must not be retried.
unsupportedMsg := `error while pulling model: config type "v0.3" is not supported` +
` - try upgrading`
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewBufferString(unsupportedMsg)),
}, nil).Times(1)

printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull(modelName, printer)
assert.Error(t, err)
assert.Contains(t, err.Error(), "try upgrading")
}

func TestPullRetryOn502Error(t *testing.T) {
Comment thread
ilopezluna marked this conversation as resolved.
Outdated
ctrl := gomock.NewController(t)
defer ctrl.Finish()

modelName := "test-model"
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
mockContext := NewContextForMock(mockClient)
client := New(mockContext)

// 502 Bad Gateway is a transient proxy error and should be retried.
gomock.InOrder(
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(bytes.NewBufferString("Internal server error")),
StatusCode: http.StatusBadGateway,
Body: io.NopCloser(bytes.NewBufferString("Bad Gateway")),
}, nil),
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
StatusCode: http.StatusOK,
Expand Down Expand Up @@ -127,7 +172,7 @@ func TestPullMaxRetriesExhausted(t *testing.T) {
printer := NewSimplePrinter(func(s string) {})
_, _, err := client.Pull(modelName, printer)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to download after 3 retries")
assert.Contains(t, err.Error(), "(failed after 3 retries)")
}

func TestPushRetryOnNetworkError(t *testing.T) {
Expand Down Expand Up @@ -341,3 +386,24 @@ func TestIsTemplateIncompatibleError(t *testing.T) {
})
}
}

func TestDisplayProgressNonJSONLines(t *testing.T) {
// Simulate a proxy returning an HTML error page instead of a progress stream.
htmlBody := "<html><body><h1>502 Bad Gateway</h1></body></html>\n"
printer := NewSimplePrinter(func(string) {})
_, _, err := DisplayProgress(strings.NewReader(htmlBody), printer)
require.Error(t, err)
assert.Contains(t, err.Error(), "unexpected response from server")
assert.Contains(t, err.Error(), "502 Bad Gateway")
}

func TestDisplayProgressMixedContent(t *testing.T) {
// Valid progress followed by some unparseable lines: the valid progress
// should be honoured and no error returned for the stray lines.
body := `{"type":"success","message":"Model pulled successfully"}` + "\n" +
"<html>some extra garbage</html>\n"
printer := NewSimplePrinter(func(string) {})
msg, _, err := DisplayProgress(strings.NewReader(body), printer)
require.NoError(t, err)
assert.Equal(t, "Model pulled successfully", msg)
}
53 changes: 52 additions & 1 deletion cmd/cli/desktop/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ func DisplayProgress(body io.Reader, printer standalone.StatusPrinter) (string,
scanner := bufio.NewScanner(body)
var finalMessage string
progressShown := false // Track if we actually showed any progress bars
// nonJSONBytes collects raw unparseable lines for error reporting,
// capped at maxNonJSONBytes to avoid large allocations.
var nonJSONBytes []byte

for scanner.Scan() {
progressLine := scanner.Text()
Expand All @@ -53,7 +56,14 @@ func DisplayProgress(body io.Reader, printer standalone.StatusPrinter) (string,

var progressMsg oci.ProgressMessage
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
// If we can't parse, just skip
// Collect unparseable lines (e.g. HTML error pages from proxies)
// so we can surface them if no valid progress arrives.
if len(nonJSONBytes) < maxNonJSONBytes {
if len(nonJSONBytes) > 0 {
nonJSONBytes = append(nonJSONBytes, '\n')
}
nonJSONBytes = append(nonJSONBytes, progressLine...)
}
continue
Comment thread
ilopezluna marked this conversation as resolved.
}

Expand Down Expand Up @@ -85,6 +95,17 @@ func DisplayProgress(body io.Reader, printer standalone.StatusPrinter) (string,
return "", false, err
}

// If we received only unparseable lines and no valid progress or success,
// surface the raw content as an error. This catches HTML error pages
// returned by proxies or CDNs in place of a proper progress stream.
if finalMessage == "" && !progressShown && len(nonJSONBytes) > 0 {
pw.Close()
return "", false, fmt.Errorf(
"unexpected response from server (not valid progress data): %s",
truncateBytes(nonJSONBytes, maxNonJSONBytes),
)
}

pw.Close()

// Wait for display to finish
Expand All @@ -102,6 +123,8 @@ func displayProgressSimple(body io.Reader, printer standalone.StatusPrinter) (st
layerProgress := make(map[string]uint64)
var finalMessage string
progressShown := false // Track if we actually showed any progress
// nonJSONBytes collects raw unparseable lines for error reporting.
var nonJSONBytes []byte

for scanner.Scan() {
progressLine := scanner.Text()
Expand All @@ -111,6 +134,13 @@ func displayProgressSimple(body io.Reader, printer standalone.StatusPrinter) (st

var progressMsg oci.ProgressMessage
if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil {
// Collect unparseable lines for error reporting.
if len(nonJSONBytes) < maxNonJSONBytes {
if len(nonJSONBytes) > 0 {
nonJSONBytes = append(nonJSONBytes, '\n')
}
nonJSONBytes = append(nonJSONBytes, progressLine...)
}
continue
}

Expand Down Expand Up @@ -146,6 +176,14 @@ func displayProgressSimple(body io.Reader, printer standalone.StatusPrinter) (st
return "", false, err
}

// Surface unparseable content if no valid progress was received.
if finalMessage == "" && !progressShown && len(nonJSONBytes) > 0 {
return "", false, fmt.Errorf(
"unexpected response from server (not valid progress data): %s",
truncateBytes(nonJSONBytes, maxNonJSONBytes),
)
}

return finalMessage, progressShown, nil
}

Expand Down Expand Up @@ -257,3 +295,16 @@ func NewSimplePrinter(printFunc func(string)) standalone.StatusPrinter {
printFunc: printFunc,
}
}

// maxNonJSONBytes is the maximum number of bytes collected from unparseable
// non-JSON lines in the progress stream before truncation.
const maxNonJSONBytes = 4096

// truncateBytes returns b if len(b) <= n, otherwise returns b[:n] with
// "..." appended to signal truncation.
func truncateBytes(b []byte, n int) string {
if len(b) <= n {
return string(b)
}
return string(b[:n]) + "..."
}
9 changes: 8 additions & 1 deletion pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,14 @@ func checkCompat(image types.ModelArtifact, log *slog.Logger, reference string,
return err
}
if manifest.Config.MediaType != types.MediaTypeModelConfigV01 && manifest.Config.MediaType != types.MediaTypeModelConfigV02 {
return fmt.Errorf("config type %q is unsupported: %w", manifest.Config.MediaType, ErrUnsupportedMediaType)
return fmt.Errorf(
"config type %q is not supported (supported: %q, %q)"+
" - try upgrading: %w",
manifest.Config.MediaType,
types.MediaTypeModelConfigV01,
types.MediaTypeModelConfigV02,
ErrUnsupportedMediaType,
)
}

// Check if the model format is supported
Expand Down
16 changes: 7 additions & 9 deletions pkg/distribution/distribution/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@ package distribution

import (
"errors"
"fmt"

"github.com/docker/model-runner/pkg/distribution/internal/store"
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/types"
)

var (
ErrInvalidReference = registry.ErrInvalidReference
ErrModelNotFound = store.ErrModelNotFound // model not found in store
ErrUnsupportedMediaType = fmt.Errorf(
"client supports only models of type %q and older - try upgrading",
types.MediaTypeModelConfigV01,
)
ErrConflict = errors.New("resource conflict")
ErrInvalidReference = registry.ErrInvalidReference
ErrModelNotFound = store.ErrModelNotFound // model not found in store
// ErrUnsupportedMediaType is returned when a model's config media type is
// not supported by this client. The caller should wrap this with a dynamic
// message that includes the actual and supported media types.
ErrUnsupportedMediaType = errors.New("unsupported model config media type")
ErrConflict = errors.New("resource conflict")
)

const warnUnsupportedFormat = "vLLM backend currently only implemented for x86_64 NVIDIA platforms"
5 changes: 5 additions & 0 deletions pkg/inference/models/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ func (h *HTTPHandler) handleCreateModel(w http.ResponseWriter, r *http.Request)
http.Error(w, "Model not found", http.StatusNotFound)
return
}
if errors.Is(err, distribution.ErrUnsupportedMediaType) {
h.log.Warn("Unsupported model config type", "model", sanitizedFrom, "error", err)
http.Error(w, err.Error(), http.StatusUnprocessableEntity)
return
}
// Note: ErrUnsupportedFormat is no longer treated as an error - it's a warning
// that's sent to the client via the progress stream
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
Loading