diff --git a/cmd/cli/commands/compose.go b/cmd/cli/commands/compose.go index 966bf844e..7d10a6d27 100644 --- a/cmd/cli/commands/compose.go +++ b/cmd/cli/commands/compose.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "net" "slices" "strconv" @@ -85,14 +86,20 @@ func newUpCommand() *cobra.Command { } for _, model := range models { - size := int32(ctxSize) + backendConfig := inference.BackendConfiguration{ + Speculative: speculativeConfig, + } + if cmd.Flags().Changed("context-size") { + if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 { + return fmt.Errorf("context-size %d is out of range (must be between %d and %d)", ctxSize, math.MinInt32, math.MaxInt32) + } + size := int32(ctxSize) + backendConfig.ContextSize = &size + } if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{ - Model: model, - BackendConfiguration: inference.BackendConfiguration{ - ContextSize: &size, - Speculative: speculativeConfig, - }, - RawRuntimeFlags: rawRuntimeFlags, + Model: model, + BackendConfiguration: backendConfig, + RawRuntimeFlags: rawRuntimeFlags, }); err != nil { configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s" _ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err) diff --git a/cmd/cli/commands/compose_test.go b/cmd/cli/commands/compose_test.go index f0960503f..df9ad37f6 100644 --- a/cmd/cli/commands/compose_test.go +++ b/cmd/cli/commands/compose_test.go @@ -1,6 +1,8 @@ package commands import ( + "fmt" + "math" "testing" "github.com/docker/model-runner/pkg/inference" @@ -8,6 +10,95 @@ import ( "github.com/stretchr/testify/require" ) +// TestUpCommandContextSizeFlagBehavior verifies that the --context-size flag on +// the compose up command is not "changed" by default (i.e. nil ContextSize +// should be sent when the flag is absent) and is marked as changed after an +// explicit value is provided. +func TestUpCommandContextSizeFlagBehavior(t *testing.T) { + t.Run("context-size flag not changed by default", func(t *testing.T) { + cmd := newUpCommand() + // Parse with just the required --model flag — no --context-size. + err := cmd.ParseFlags([]string{"--model", "mymodel"}) + require.NoError(t, err) + // The flag must NOT be marked as changed so that ContextSize is omitted + // from the configure request (i.e. remains nil). + assert.False(t, cmd.Flags().Changed("context-size"), + "context-size must not be Changed when the flag is absent") + }) + + t.Run("context-size flag changed after explicit value", func(t *testing.T) { + cmd := newUpCommand() + err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "4096"}) + require.NoError(t, err) + assert.True(t, cmd.Flags().Changed("context-size"), + "context-size must be Changed when explicitly provided") + }) + + t.Run("context-size flag changed with unlimited value -1", func(t *testing.T) { + cmd := newUpCommand() + err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "-1"}) + require.NoError(t, err) + assert.True(t, cmd.Flags().Changed("context-size"), + "context-size must be Changed when explicitly set to -1 (unlimited)") + }) + + t.Run("ContextSize is nil in BackendConfiguration when flag not set", func(t *testing.T) { + cmd := newUpCommand() + require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel"})) + // Simulate the logic in compose.go RunE: only add ContextSize when Changed. + backendConfig := inference.BackendConfiguration{} + if cmd.Flags().Changed("context-size") { + size := int32(-1) // default value + backendConfig.ContextSize = &size + } + assert.Nil(t, backendConfig.ContextSize, + "ContextSize must be nil in BackendConfiguration when --context-size is not provided") + }) + + t.Run("ContextSize is non-nil in BackendConfiguration when flag is set", func(t *testing.T) { + cmd := newUpCommand() + require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "64000"})) + ctxSize, err := cmd.Flags().GetInt64("context-size") + require.NoError(t, err) + backendConfig := inference.BackendConfiguration{} + if cmd.Flags().Changed("context-size") { + size := int32(ctxSize) + backendConfig.ContextSize = &size + } + require.NotNil(t, backendConfig.ContextSize, + "ContextSize must be non-nil when --context-size is provided") + assert.Equal(t, int32(64000), *backendConfig.ContextSize) + }) + + t.Run("context-size above int32 max is out of range", func(t *testing.T) { + tooBig := int64(math.MaxInt32) + 1 + cmd := newUpCommand() + require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooBig)})) + ctxSize, err := cmd.Flags().GetInt64("context-size") + require.NoError(t, err) + require.True(t, cmd.Flags().Changed("context-size")) + // Simulate the range check from compose.go RunE. + if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 { + // Expected: would return an error in RunE. + return + } + t.Fatal("expected out-of-range check to trigger for value above MaxInt32") + }) + + t.Run("context-size below int32 min is out of range", func(t *testing.T) { + tooSmall := int64(math.MinInt32) - 1 + cmd := newUpCommand() + require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooSmall)})) + ctxSize, err := cmd.Flags().GetInt64("context-size") + require.NoError(t, err) + require.True(t, cmd.Flags().Changed("context-size")) + if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 { + return + } + t.Fatal("expected out-of-range check to trigger for value below MinInt32") + }) +} + func TestParseBackendMode(t *testing.T) { tests := []struct { name string diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index 87816410c..26c7c3efb 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -95,16 +95,16 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference } func GetContextSize(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { - // Model config takes precedence + // Backend (runtime) config takes precedence — the user explicitly requested this value + if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) { + return backendCfg.ContextSize + } + // Fallback to model config if modelCfg != nil { if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && (*ctxSize == UnlimitedContextSize || *ctxSize > 0) { return ctxSize } } - // Fallback to backend config - if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) { - return backendCfg.ContextSize - } return nil } diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index ee8223c15..318f19b31 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -206,7 +206,25 @@ func TestGetArgs(t *testing.T) { "--model", modelPath, "--host", socket, "--embeddings", - "--ctx-size", "2096", // model config takes precedence + "--ctx-size", "1234", // backend config takes precedence + "--jinja", + ), + }, + { + name: "context size from model config when no backend config", + mode: inference.BackendModeEmbedding, + bundle: &fakeBundle{ + ggufPath: modelPath, + config: &types.Config{ + ContextSize: int32ptr(2096), + }, + }, + config: nil, + expected: append(slices.Clone(baseArgs), + "--model", modelPath, + "--host", socket, + "--embeddings", + "--ctx-size", "2096", "--jinja", ), }, diff --git a/pkg/inference/backends/sglang/sglang_config.go b/pkg/inference/backends/sglang/sglang_config.go index 814a516f2..2ed032b4e 100644 --- a/pkg/inference/backends/sglang/sglang_config.go +++ b/pkg/inference/backends/sglang/sglang_config.go @@ -63,18 +63,21 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference return args, nil } -// GetContextLength returns the context length (context size) from model config or backend config. -// Model config takes precedence over backend config. +// GetContextLength returns the context length (context size) from backend config or model config. +// Backend (runtime) config takes precedence over model config. // Returns nil if neither is specified (SGLang will auto-derive from model). func GetContextLength(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { - // Model config takes precedence - if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 { - return cs - } - // Fallback to backend config + // Backend (runtime) config takes precedence if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { return backendCfg.ContextSize } + // Fallback to model config + if modelCfg == nil { + return nil + } + if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 { + return cs + } // Return nil to let SGLang auto-derive from model config return nil } diff --git a/pkg/inference/backends/sglang/sglang_config_test.go b/pkg/inference/backends/sglang/sglang_config_test.go index 2a96b0bc8..152a82f30 100644 --- a/pkg/inference/backends/sglang/sglang_config_test.go +++ b/pkg/inference/backends/sglang/sglang_config_test.go @@ -103,7 +103,7 @@ func TestGetArgs(t *testing.T) { }, }, { - name: "with model context size (takes precedence)", + name: "backend config takes precedence over model config", bundle: &mockModelBundle{ safetensorsPath: "/path/to/model/model.safetensors", runtimeConfig: &types.Config{ @@ -124,7 +124,7 @@ func TestGetArgs(t *testing.T) { "--port", "30000", "--context-length", - "16384", + "8192", }, }, { @@ -225,13 +225,21 @@ func TestGetContextLength(t *testing.T) { expectedValue: int32ptr(8192), }, { - name: "model config takes precedence", + name: "backend config takes precedence", modelCfg: &types.Config{ ContextSize: int32ptr(16384), }, backendCfg: &inference.BackendConfiguration{ ContextSize: int32ptr(4096), }, + expectedValue: int32ptr(4096), + }, + { + name: "model config used as fallback", + modelCfg: &types.Config{ + ContextSize: int32ptr(16384), + }, + backendCfg: nil, expectedValue: int32ptr(16384), }, { @@ -242,6 +250,18 @@ func TestGetContextLength(t *testing.T) { }, expectedValue: nil, }, + { + name: "nil model config with backend config", + modelCfg: nil, + backendCfg: &inference.BackendConfiguration{ContextSize: int32ptr(4096)}, + expectedValue: int32ptr(4096), + }, + { + name: "nil model config without backend config", + modelCfg: nil, + backendCfg: nil, + expectedValue: nil, + }, } for _, tt := range tests { diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index 3ad9e230d..4e91e3bff 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -87,20 +87,20 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference return args, nil } -// GetMaxModelLen returns the max model length (context size) from model config or backend config. -// Model config takes precedence over backend config. +// GetMaxModelLen returns the max model length (context size) from backend config or model config. +// Backend (runtime) config takes precedence over model config. // Returns nil if neither is specified (vLLM will auto-derive from model). func GetMaxModelLen(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 { - // Model config takes precedence + // Backend (runtime) config takes precedence + if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { + return backendCfg.ContextSize + } + // Fallback to model config if modelCfg != nil { if ctxSize := modelCfg.GetContextSize(); ctxSize != nil { return ctxSize } } - // Fallback to backend config - if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 { - return backendCfg.ContextSize - } // Return nil to let vLLM auto-derive from model config return nil } diff --git a/pkg/inference/backends/vllm/vllm_config_test.go b/pkg/inference/backends/vllm/vllm_config_test.go index 6c33fd4c5..a397f3a57 100644 --- a/pkg/inference/backends/vllm/vllm_config_test.go +++ b/pkg/inference/backends/vllm/vllm_config_test.go @@ -109,7 +109,7 @@ func TestGetArgs(t *testing.T) { }, }, { - name: "with model context size (takes precedence)", + name: "backend config takes precedence over model config", bundle: &mockModelBundle{ safetensorsPath: "/path/to/model", runtimeConfig: &types.Config{ @@ -125,7 +125,7 @@ func TestGetArgs(t *testing.T) { "--uds", "/tmp/socket", "--max-model-len", - "16384", + "8192", }, }, { @@ -458,13 +458,21 @@ func TestGetMaxModelLen(t *testing.T) { expectedValue: int32ptr(8192), }, { - name: "model config takes precedence", + name: "backend config takes precedence", modelCfg: &types.Config{ ContextSize: int32ptr(16384), }, backendCfg: &inference.BackendConfiguration{ ContextSize: int32ptr(4096), }, + expectedValue: int32ptr(4096), + }, + { + name: "model config used as fallback", + modelCfg: &types.Config{ + ContextSize: int32ptr(16384), + }, + backendCfg: nil, expectedValue: int32ptr(16384), }, }