Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"net"
"slices"
"strconv"
Expand Down Expand Up @@ -85,14 +86,20 @@ func newUpCommand() *cobra.Command {
}

for _, model := range models {
size := int32(ctxSize)
backendConfig := inference.BackendConfiguration{
Speculative: speculativeConfig,
}
if cmd.Flags().Changed("context-size") {
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
return fmt.Errorf("context-size %d is out of range (must be between %d and %d)", ctxSize, math.MinInt32, math.MaxInt32)
}
size := int32(ctxSize)
backendConfig.ContextSize = &size
}
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
Model: model,
BackendConfiguration: inference.BackendConfiguration{
ContextSize: &size,
Speculative: speculativeConfig,
},
RawRuntimeFlags: rawRuntimeFlags,
Model: model,
BackendConfiguration: backendConfig,
RawRuntimeFlags: rawRuntimeFlags,
}); err != nil {
configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s"
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err)
Expand Down
91 changes: 91 additions & 0 deletions cmd/cli/commands/compose_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,104 @@
package commands

import (
"fmt"
"math"
"testing"

"github.com/docker/model-runner/pkg/inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// TestUpCommandContextSizeFlagBehavior verifies that the --context-size flag on
// the compose up command is not "changed" by default (i.e. nil ContextSize
// should be sent when the flag is absent) and is marked as changed after an
// explicit value is provided.
func TestUpCommandContextSizeFlagBehavior(t *testing.T) {
t.Run("context-size flag not changed by default", func(t *testing.T) {
cmd := newUpCommand()
// Parse with just the required --model flag — no --context-size.
err := cmd.ParseFlags([]string{"--model", "mymodel"})
require.NoError(t, err)
// The flag must NOT be marked as changed so that ContextSize is omitted
// from the configure request (i.e. remains nil).
assert.False(t, cmd.Flags().Changed("context-size"),
"context-size must not be Changed when the flag is absent")
})

t.Run("context-size flag changed after explicit value", func(t *testing.T) {
cmd := newUpCommand()
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "4096"})
require.NoError(t, err)
assert.True(t, cmd.Flags().Changed("context-size"),
"context-size must be Changed when explicitly provided")
})

t.Run("context-size flag changed with unlimited value -1", func(t *testing.T) {
cmd := newUpCommand()
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "-1"})
require.NoError(t, err)
assert.True(t, cmd.Flags().Changed("context-size"),
"context-size must be Changed when explicitly set to -1 (unlimited)")
})

t.Run("ContextSize is nil in BackendConfiguration when flag not set", func(t *testing.T) {
cmd := newUpCommand()
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel"}))
// Simulate the logic in compose.go RunE: only add ContextSize when Changed.
backendConfig := inference.BackendConfiguration{}
if cmd.Flags().Changed("context-size") {
size := int32(-1) // default value
backendConfig.ContextSize = &size
}
assert.Nil(t, backendConfig.ContextSize,
"ContextSize must be nil in BackendConfiguration when --context-size is not provided")
})

t.Run("ContextSize is non-nil in BackendConfiguration when flag is set", func(t *testing.T) {
cmd := newUpCommand()
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "64000"}))
ctxSize, err := cmd.Flags().GetInt64("context-size")
require.NoError(t, err)
backendConfig := inference.BackendConfiguration{}
if cmd.Flags().Changed("context-size") {
size := int32(ctxSize)
backendConfig.ContextSize = &size
}
require.NotNil(t, backendConfig.ContextSize,
"ContextSize must be non-nil when --context-size is provided")
assert.Equal(t, int32(64000), *backendConfig.ContextSize)
})

t.Run("context-size above int32 max is out of range", func(t *testing.T) {
tooBig := int64(math.MaxInt32) + 1
cmd := newUpCommand()
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooBig)}))
ctxSize, err := cmd.Flags().GetInt64("context-size")
require.NoError(t, err)
require.True(t, cmd.Flags().Changed("context-size"))
// Simulate the range check from compose.go RunE.
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
// Expected: would return an error in RunE.
return
}
t.Fatal("expected out-of-range check to trigger for value above MaxInt32")
})

t.Run("context-size below int32 min is out of range", func(t *testing.T) {
tooSmall := int64(math.MinInt32) - 1
cmd := newUpCommand()
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooSmall)}))
ctxSize, err := cmd.Flags().GetInt64("context-size")
require.NoError(t, err)
require.True(t, cmd.Flags().Changed("context-size"))
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
return
}
t.Fatal("expected out-of-range check to trigger for value below MinInt32")
})
}

func TestParseBackendMode(t *testing.T) {
tests := []struct {
name string
Expand Down
10 changes: 5 additions & 5 deletions pkg/inference/backends/llamacpp/llamacpp_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
}

func GetContextSize(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
// Model config takes precedence
// Backend (runtime) config takes precedence — the user explicitly requested this value
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
return backendCfg.ContextSize
}
// Fallback to model config
if modelCfg != nil {
if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && (*ctxSize == UnlimitedContextSize || *ctxSize > 0) {
return ctxSize
}
}
// Fallback to backend config
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
return backendCfg.ContextSize
}
return nil
}

Expand Down
20 changes: 19 additions & 1 deletion pkg/inference/backends/llamacpp/llamacpp_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,25 @@ func TestGetArgs(t *testing.T) {
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "2096", // model config takes precedence
"--ctx-size", "1234", // backend config takes precedence
"--jinja",
),
},
{
name: "context size from model config when no backend config",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
config: &types.Config{
ContextSize: int32ptr(2096),
},
},
config: nil,
expected: append(slices.Clone(baseArgs),
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "2096",
"--jinja",
),
},
Expand Down
17 changes: 10 additions & 7 deletions pkg/inference/backends/sglang/sglang_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,21 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
return args, nil
}

// GetContextLength returns the context length (context size) from model config or backend config.
// Model config takes precedence over backend config.
// GetContextLength returns the context length (context size) from backend config or model config.
// Backend (runtime) config takes precedence over model config.
// Returns nil if neither is specified (SGLang will auto-derive from model).
func GetContextLength(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
// Model config takes precedence
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
return cs
}
// Fallback to backend config
// Backend (runtime) config takes precedence
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
return backendCfg.ContextSize
}
// Fallback to model config
if modelCfg == nil {
return nil
}
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
return cs
}
// Return nil to let SGLang auto-derive from model config
return nil
}
26 changes: 23 additions & 3 deletions pkg/inference/backends/sglang/sglang_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestGetArgs(t *testing.T) {
},
},
{
name: "with model context size (takes precedence)",
name: "backend config takes precedence over model config",
bundle: &mockModelBundle{
safetensorsPath: "/path/to/model/model.safetensors",
runtimeConfig: &types.Config{
Expand All @@ -124,7 +124,7 @@ func TestGetArgs(t *testing.T) {
"--port",
"30000",
"--context-length",
"16384",
"8192",
},
},
{
Expand Down Expand Up @@ -225,13 +225,21 @@ func TestGetContextLength(t *testing.T) {
expectedValue: int32ptr(8192),
},
{
name: "model config takes precedence",
name: "backend config takes precedence",
modelCfg: &types.Config{
ContextSize: int32ptr(16384),
},
backendCfg: &inference.BackendConfiguration{
ContextSize: int32ptr(4096),
},
expectedValue: int32ptr(4096),
},
{
name: "model config used as fallback",
modelCfg: &types.Config{
ContextSize: int32ptr(16384),
},
backendCfg: nil,
expectedValue: int32ptr(16384),
},
{
Expand All @@ -242,6 +250,18 @@ func TestGetContextLength(t *testing.T) {
},
expectedValue: nil,
},
{
name: "nil model config with backend config",
modelCfg: nil,
backendCfg: &inference.BackendConfiguration{ContextSize: int32ptr(4096)},
expectedValue: int32ptr(4096),
},
{
name: "nil model config without backend config",
modelCfg: nil,
backendCfg: nil,
expectedValue: nil,
},
}

for _, tt := range tests {
Expand Down
14 changes: 7 additions & 7 deletions pkg/inference/backends/vllm/vllm_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,20 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
return args, nil
}

// GetMaxModelLen returns the max model length (context size) from model config or backend config.
// Model config takes precedence over backend config.
// GetMaxModelLen returns the max model length (context size) from backend config or model config.
// Backend (runtime) config takes precedence over model config.
// Returns nil if neither is specified (vLLM will auto-derive from model).
func GetMaxModelLen(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
// Model config takes precedence
// Backend (runtime) config takes precedence
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
return backendCfg.ContextSize
}
// Fallback to model config
if modelCfg != nil {
if ctxSize := modelCfg.GetContextSize(); ctxSize != nil {
return ctxSize
}
}
// Fallback to backend config
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
return backendCfg.ContextSize
}
// Return nil to let vLLM auto-derive from model config
return nil
}
14 changes: 11 additions & 3 deletions pkg/inference/backends/vllm/vllm_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestGetArgs(t *testing.T) {
},
},
{
name: "with model context size (takes precedence)",
name: "backend config takes precedence over model config",
bundle: &mockModelBundle{
safetensorsPath: "/path/to/model",
runtimeConfig: &types.Config{
Expand All @@ -125,7 +125,7 @@ func TestGetArgs(t *testing.T) {
"--uds",
"/tmp/socket",
"--max-model-len",
"16384",
"8192",
},
},
{
Expand Down Expand Up @@ -458,13 +458,21 @@ func TestGetMaxModelLen(t *testing.T) {
expectedValue: int32ptr(8192),
},
{
name: "model config takes precedence",
name: "backend config takes precedence",
modelCfg: &types.Config{
ContextSize: int32ptr(16384),
},
backendCfg: &inference.BackendConfiguration{
ContextSize: int32ptr(4096),
},
expectedValue: int32ptr(4096),
},
{
name: "model config used as fallback",
modelCfg: &types.Config{
ContextSize: int32ptr(16384),
},
backendCfg: nil,
expectedValue: int32ptr(16384),
},
}
Expand Down
Loading