diff --git a/cmd/cli/commands/compose.go b/cmd/cli/commands/compose.go index 966bf844e..7d10a6d27 100644 --- a/cmd/cli/commands/compose.go +++ b/cmd/cli/commands/compose.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "net" "slices" "strconv" @@ -85,14 +86,20 @@ func newUpCommand() *cobra.Command { } for _, model := range models { - size := int32(ctxSize) + backendConfig := inference.BackendConfiguration{ + Speculative: speculativeConfig, + } + if cmd.Flags().Changed("context-size") { + if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 { + return fmt.Errorf("context-size %d is out of range (must be between %d and %d)", ctxSize, math.MinInt32, math.MaxInt32) + } + size := int32(ctxSize) + backendConfig.ContextSize = &size + } if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{ - Model: model, - BackendConfiguration: inference.BackendConfiguration{ - ContextSize: &size, - Speculative: speculativeConfig, - }, - RawRuntimeFlags: rawRuntimeFlags, + Model: model, + BackendConfiguration: backendConfig, + RawRuntimeFlags: rawRuntimeFlags, }); err != nil { configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s" _ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err) diff --git a/cmd/cli/commands/compose_test.go b/cmd/cli/commands/compose_test.go index f0960503f..df9ad37f6 100644 --- a/cmd/cli/commands/compose_test.go +++ b/cmd/cli/commands/compose_test.go @@ -1,6 +1,8 @@ package commands import ( + "fmt" + "math" "testing" "github.com/docker/model-runner/pkg/inference" @@ -8,6 +10,95 @@ import ( "github.com/stretchr/testify/require" ) +// TestUpCommandContextSizeFlagBehavior verifies that the --context-size flag on +// the compose up command is not "changed" by default (i.e. nil ContextSize +// should be sent when the flag is absent) and is marked as changed after an +// explicit value is provided. +func TestUpCommandContextSizeFlagBehavior(t *testing.T) { + t.Run("context-size flag not changed by default", func(t *testing.T) { + cmd := newUpCommand() + // Parse with just the required --model flag — no --context-size. + err := cmd.ParseFlags([]string{"--model", "mymodel"}) + require.NoError(t, err) + // The flag must NOT be marked as changed so that ContextSize is omitted + // from the configure request (i.e. remains nil). + assert.False(t, cmd.Flags().Changed("context-size"), + "context-size must not be Changed when the flag is absent") + }) + + t.Run("context-size flag changed after explicit value", func(t *testing.T) { + cmd := newUpCommand() + err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "4096"}) + require.NoError(t, err) + assert.True(t, cmd.Flags().Changed("context-size"), + "context-size must be Changed when explicitly provided") + }) + + t.Run("context-size flag changed with unlimited value -1", func(t *testing.T) { + cmd := newUpCommand() + err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "-1"}) + require.NoError(t, err) + assert.True(t, cmd.Flags().Changed("context-size"), + "context-size must be Changed when explicitly set to -1 (unlimited)") + }) + + t.Run("ContextSize is nil in BackendConfiguration when flag not set", func(t *testing.T) { + cmd := newUpCommand() + require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel"})) + // Simulate the logic in compose.go RunE: only add ContextSize when Changed. + backendConfig := inference.BackendConfiguration{} + if cmd.Flags().Changed("context-size") { + size := int32(-1) // default value + backendConfig.ContextSize = &size + } + assert.Nil(t, backendConfig.ContextSize, + "ContextSize must be nil in BackendConfiguration when --context-size is not provided") + }) + + t.Run("ContextSize is non-nil in BackendConfiguration when flag is set", func(t *testing.T) { + cmd := newUpCommand() + require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "64000"})) + ctxSize, err := cmd.Flags().GetInt64("context-size") + require.NoError(t, err) + backendConfig := inference.BackendConfiguration{} + if cmd.Flags().Changed("context-size") { + size := int32(ctxSize) + backendConfig.ContextSize = &size + } + require.NotNil(t, backendConfig.ContextSize, + "ContextSize must be non-nil when --context-size is provided") + assert.Equal(t, int32(64000), *backendConfig.ContextSize) + }) + + t.Run("context-size above int32 max is out of range", func(t *testing.T) { + tooBig := int64(math.MaxInt32) + 1 + cmd := newUpCommand() + require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooBig)})) + ctxSize, err := cmd.Flags().GetInt64("context-size") + require.NoError(t, err) + require.True(t, cmd.Flags().Changed("context-size")) + // Simulate the range check from compose.go RunE. + if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 { + // Expected: would return an error in RunE. + return + } + t.Fatal("expected out-of-range check to trigger for value above MaxInt32") + }) + + t.Run("context-size below int32 min is out of range", func(t *testing.T) { + tooSmall := int64(math.MinInt32) - 1 + cmd := newUpCommand() + require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooSmall)})) + ctxSize, err := cmd.Flags().GetInt64("context-size") + require.NoError(t, err) + require.True(t, cmd.Flags().Changed("context-size")) + if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 { + return + } + t.Fatal("expected out-of-range check to trigger for value below MinInt32") + }) +} + func TestParseBackendMode(t *testing.T) { tests := []struct { name string diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index 87816410c..5f7ae1519 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -95,16 +95,16 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference } func GetContextSize(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { - // Model config takes precedence + // Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx) + if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) { + return backendCfg.ContextSize + } + // Fallback to model config (set at packaging time via docker model package --context-size) if modelCfg != nil { if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && (*ctxSize == UnlimitedContextSize || *ctxSize > 0) { return ctxSize } } - // Fallback to backend config - if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) { - return backendCfg.ContextSize - } return nil } diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index ee8223c15..1ef1262e1 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -191,7 +191,7 @@ func TestGetArgs(t *testing.T) { ), }, { - name: "context size from model config", + name: "backend config takes precedence over model config", mode: inference.BackendModeEmbedding, bundle: &fakeBundle{ ggufPath: modelPath, @@ -206,7 +206,25 @@ func TestGetArgs(t *testing.T) { "--model", modelPath, "--host", socket, "--embeddings", - "--ctx-size", "2096", // model config takes precedence + "--ctx-size", "1234", // backend config takes precedence + "--jinja", + ), + }, + { + name: "model config used when no backend config", + mode: inference.BackendModeEmbedding, + bundle: &fakeBundle{ + ggufPath: modelPath, + config: &types.Config{ + ContextSize: int32ptr(2096), + }, + }, + config: nil, + expected: append(slices.Clone(baseArgs), + "--model", modelPath, + "--host", socket, + "--embeddings", + "--ctx-size", "2096", // model config used as fallback "--jinja", ), }, diff --git a/pkg/inference/backends/mlx/mlx_config.go b/pkg/inference/backends/mlx/mlx_config.go index 29f98638f..e302e2b13 100644 --- a/pkg/inference/backends/mlx/mlx_config.go +++ b/pkg/inference/backends/mlx/mlx_config.go @@ -61,9 +61,22 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference return args, nil } -// GetMaxTokens returns the max tokens (context size) from model config or backend config. -// Model config takes precedence over backend config. +// GetMaxTokens returns the max tokens (context size) from backend config or model config. +// Backend config takes precedence over model config (runtime configuration). // Returns nil if neither is specified (MLX will use model defaults). func GetMaxTokens(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *uint64 { + // Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx) + if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { + v := uint64(*backendCfg.ContextSize) + return &v + } + // Fallback to model config (set at packaging time via docker model package --context-size) + if modelCfg != nil { + if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 { + v := uint64(*cs) + return &v + } + } + // Return nil to let MLX use model defaults return nil } diff --git a/pkg/inference/backends/mlx/mlx_config_test.go b/pkg/inference/backends/mlx/mlx_config_test.go new file mode 100644 index 000000000..93d4088dd --- /dev/null +++ b/pkg/inference/backends/mlx/mlx_config_test.go @@ -0,0 +1,97 @@ +package mlx + +import ( + "testing" + + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference" +) + +func TestGetMaxTokens(t *testing.T) { + tests := []struct { + name string + modelCfg types.ModelConfig + backendCfg *inference.BackendConfiguration + expectedValue *uint64 + }{ + { + name: "no config", + modelCfg: &types.Config{}, + backendCfg: nil, + expectedValue: nil, + }, + { + name: "backend config only", + modelCfg: &types.Config{}, + backendCfg: &inference.BackendConfiguration{ + ContextSize: int32ptr(4096), + }, + expectedValue: uint64ptr(4096), + }, + { + name: "model config only", + modelCfg: &types.Config{ + ContextSize: int32ptr(8192), + }, + backendCfg: nil, + expectedValue: uint64ptr(8192), + }, + { + name: "backend config takes precedence", + modelCfg: &types.Config{ + ContextSize: int32ptr(16384), + }, + backendCfg: &inference.BackendConfiguration{ + ContextSize: int32ptr(4096), + }, + expectedValue: uint64ptr(4096), + }, + { + name: "model config used as fallback", + modelCfg: &types.Config{ + ContextSize: int32ptr(16384), + }, + backendCfg: nil, + expectedValue: uint64ptr(16384), + }, + { + name: "zero context size in backend config returns nil", + modelCfg: &types.Config{}, + backendCfg: &inference.BackendConfiguration{ + ContextSize: int32ptr(0), + }, + expectedValue: nil, + }, + { + name: "nil model config with backend config", + modelCfg: nil, + backendCfg: &inference.BackendConfiguration{ContextSize: int32ptr(4096)}, + expectedValue: uint64ptr(4096), + }, + { + name: "nil model config without backend config", + modelCfg: nil, + backendCfg: nil, + expectedValue: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetMaxTokens(tt.modelCfg, tt.backendCfg) + if (result == nil) != (tt.expectedValue == nil) { + t.Errorf("expected nil=%v, got nil=%v", tt.expectedValue == nil, result == nil) + } else if result != nil && *result != *tt.expectedValue { + t.Errorf("expected %d, got %d", *tt.expectedValue, *result) + } + }) + } +} + +func int32ptr(n int32) *int32 { + return &n +} + +func uint64ptr(n uint64) *uint64 { + return &n +} diff --git a/pkg/inference/backends/mlx/testmain_test.go b/pkg/inference/backends/mlx/testmain_test.go new file mode 100644 index 000000000..0bd95b620 --- /dev/null +++ b/pkg/inference/backends/mlx/testmain_test.go @@ -0,0 +1,12 @@ +package mlx + +import ( + "testing" + + "go.uber.org/goleak" +) + +// TestMain runs goleak after the test suite to detect goroutine leaks. +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/pkg/inference/backends/sglang/sglang_config.go b/pkg/inference/backends/sglang/sglang_config.go index 814a516f2..ee2f6588d 100644 --- a/pkg/inference/backends/sglang/sglang_config.go +++ b/pkg/inference/backends/sglang/sglang_config.go @@ -63,18 +63,21 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference return args, nil } -// GetContextLength returns the context length (context size) from model config or backend config. -// Model config takes precedence over backend config. +// GetContextLength returns the context length (context size) from backend config or model config. +// Backend config takes precedence over model config (runtime configuration). // Returns nil if neither is specified (SGLang will auto-derive from model). func GetContextLength(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { - // Model config takes precedence - if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 { - return cs - } - // Fallback to backend config + // Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx) if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { return backendCfg.ContextSize } + // Fallback to model config (set at packaging time via docker model package --context-size) + if modelCfg == nil { + return nil + } + if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 { + return cs + } // Return nil to let SGLang auto-derive from model config return nil } diff --git a/pkg/inference/backends/sglang/sglang_config_test.go b/pkg/inference/backends/sglang/sglang_config_test.go index 2a96b0bc8..8231779ea 100644 --- a/pkg/inference/backends/sglang/sglang_config_test.go +++ b/pkg/inference/backends/sglang/sglang_config_test.go @@ -103,7 +103,7 @@ func TestGetArgs(t *testing.T) { }, }, { - name: "with model context size (takes precedence)", + name: "backend config takes precedence over model config", bundle: &mockModelBundle{ safetensorsPath: "/path/to/model/model.safetensors", runtimeConfig: &types.Config{ @@ -114,6 +114,29 @@ func TestGetArgs(t *testing.T) { config: &inference.BackendConfiguration{ ContextSize: int32ptr(8192), }, + expected: []string{ + "-m", + "sglang.launch_server", + "--model-path", + "/path/to/model", + "--host", + "127.0.0.1", + "--port", + "30000", + "--context-length", + "8192", + }, + }, + { + name: "model config used when no backend config", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model/model.safetensors", + runtimeConfig: &types.Config{ + ContextSize: int32ptr(16384), + }, + }, + mode: inference.BackendModeCompletion, + config: nil, expected: []string{ "-m", "sglang.launch_server", @@ -225,14 +248,14 @@ func TestGetContextLength(t *testing.T) { expectedValue: int32ptr(8192), }, { - name: "model config takes precedence", + name: "backend config takes precedence", modelCfg: &types.Config{ ContextSize: int32ptr(16384), }, backendCfg: &inference.BackendConfiguration{ ContextSize: int32ptr(4096), }, - expectedValue: int32ptr(16384), + expectedValue: int32ptr(4096), }, { name: "zero context size in backend config returns nil", @@ -242,6 +265,18 @@ func TestGetContextLength(t *testing.T) { }, expectedValue: nil, }, + { + name: "nil model config with backend config", + modelCfg: nil, + backendCfg: &inference.BackendConfiguration{ContextSize: int32ptr(4096)}, + expectedValue: int32ptr(4096), + }, + { + name: "nil model config without backend config", + modelCfg: nil, + backendCfg: nil, + expectedValue: nil, + }, } for _, tt := range tests { diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index 3ad9e230d..2f525c664 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -87,20 +87,20 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference return args, nil } -// GetMaxModelLen returns the max model length (context size) from model config or backend config. -// Model config takes precedence over backend config. +// GetMaxModelLen returns the max model length (context size) from backend config or model config. +// Backend config takes precedence over model config (runtime configuration). // Returns nil if neither is specified (vLLM will auto-derive from model). func GetMaxModelLen(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { - // Model config takes precedence + // Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx) + if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { + return backendCfg.ContextSize + } + // Fallback to model config (set at packaging time via docker model package --context-size) if modelCfg != nil { - if ctxSize := modelCfg.GetContextSize(); ctxSize != nil { + if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && *ctxSize > 0 { return ctxSize } } - // Fallback to backend config - if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { - return backendCfg.ContextSize - } // Return nil to let vLLM auto-derive from model config return nil } diff --git a/pkg/inference/backends/vllm/vllm_config_test.go b/pkg/inference/backends/vllm/vllm_config_test.go index 6c33fd4c5..44f7f43c8 100644 --- a/pkg/inference/backends/vllm/vllm_config_test.go +++ b/pkg/inference/backends/vllm/vllm_config_test.go @@ -109,7 +109,7 @@ func TestGetArgs(t *testing.T) { }, }, { - name: "with model context size (takes precedence)", + name: "backend config takes precedence over model config", bundle: &mockModelBundle{ safetensorsPath: "/path/to/model", runtimeConfig: &types.Config{ @@ -119,6 +119,24 @@ func TestGetArgs(t *testing.T) { config: &inference.BackendConfiguration{ ContextSize: int32ptr(8192), }, + expected: []string{ + "serve", + "/path/to", + "--uds", + "/tmp/socket", + "--max-model-len", + "8192", + }, + }, + { + name: "model config used when no backend config", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model", + runtimeConfig: &types.Config{ + ContextSize: int32ptr(16384), + }, + }, + config: nil, expected: []string{ "serve", "/path/to", @@ -458,14 +476,14 @@ func TestGetMaxModelLen(t *testing.T) { expectedValue: int32ptr(8192), }, { - name: "model config takes precedence", + name: "backend config takes precedence", modelCfg: &types.Config{ ContextSize: int32ptr(16384), }, backendCfg: &inference.BackendConfiguration{ ContextSize: int32ptr(4096), }, - expectedValue: int32ptr(16384), + expectedValue: int32ptr(4096), }, }