Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"net"
"slices"
"strconv"
Expand Down Expand Up @@ -85,14 +86,20 @@ func newUpCommand() *cobra.Command {
}

for _, model := range models {
size := int32(ctxSize)
backendConfig := inference.BackendConfiguration{
Speculative: speculativeConfig,
}
if cmd.Flags().Changed("context-size") {
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
return fmt.Errorf("context-size %d is out of range (must be between %d and %d)", ctxSize, math.MinInt32, math.MaxInt32)
}
size := int32(ctxSize)
backendConfig.ContextSize = &size
}
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
Model: model,
BackendConfiguration: inference.BackendConfiguration{
ContextSize: &size,
Speculative: speculativeConfig,
},
RawRuntimeFlags: rawRuntimeFlags,
Model: model,
BackendConfiguration: backendConfig,
RawRuntimeFlags: rawRuntimeFlags,
}); err != nil {
configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s"
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err)
Expand Down
91 changes: 91 additions & 0 deletions cmd/cli/commands/compose_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,104 @@
package commands

import (
"fmt"
"math"
"testing"

"github.com/docker/model-runner/pkg/inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// TestUpCommandContextSizeFlagBehavior verifies that the --context-size flag on
// the compose up command is not "changed" by default (i.e. nil ContextSize
// should be sent when the flag is absent) and is marked as changed after an
// explicit value is provided.
func TestUpCommandContextSizeFlagBehavior(t *testing.T) {
t.Run("context-size flag not changed by default", func(t *testing.T) {
cmd := newUpCommand()
// Parse with just the required --model flag — no --context-size.
err := cmd.ParseFlags([]string{"--model", "mymodel"})
require.NoError(t, err)
// The flag must NOT be marked as changed so that ContextSize is omitted
// from the configure request (i.e. remains nil).
assert.False(t, cmd.Flags().Changed("context-size"),
"context-size must not be Changed when the flag is absent")
})

t.Run("context-size flag changed after explicit value", func(t *testing.T) {
cmd := newUpCommand()
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "4096"})
require.NoError(t, err)
assert.True(t, cmd.Flags().Changed("context-size"),
"context-size must be Changed when explicitly provided")
})

t.Run("context-size flag changed with unlimited value -1", func(t *testing.T) {
cmd := newUpCommand()
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "-1"})
require.NoError(t, err)
assert.True(t, cmd.Flags().Changed("context-size"),
"context-size must be Changed when explicitly set to -1 (unlimited)")
})

t.Run("ContextSize is nil in BackendConfiguration when flag not set", func(t *testing.T) {
cmd := newUpCommand()
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel"}))
// Simulate the logic in compose.go RunE: only add ContextSize when Changed.
backendConfig := inference.BackendConfiguration{}
if cmd.Flags().Changed("context-size") {
size := int32(-1) // default value
backendConfig.ContextSize = &size
}
assert.Nil(t, backendConfig.ContextSize,
"ContextSize must be nil in BackendConfiguration when --context-size is not provided")
})

t.Run("ContextSize is non-nil in BackendConfiguration when flag is set", func(t *testing.T) {
cmd := newUpCommand()
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "64000"}))
ctxSize, err := cmd.Flags().GetInt64("context-size")
require.NoError(t, err)
backendConfig := inference.BackendConfiguration{}
if cmd.Flags().Changed("context-size") {
size := int32(ctxSize)
backendConfig.ContextSize = &size
}
require.NotNil(t, backendConfig.ContextSize,
"ContextSize must be non-nil when --context-size is provided")
assert.Equal(t, int32(64000), *backendConfig.ContextSize)
})

t.Run("context-size above int32 max is out of range", func(t *testing.T) {
tooBig := int64(math.MaxInt32) + 1
cmd := newUpCommand()
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooBig)}))
ctxSize, err := cmd.Flags().GetInt64("context-size")
require.NoError(t, err)
require.True(t, cmd.Flags().Changed("context-size"))
// Simulate the range check from compose.go RunE.
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
// Expected: would return an error in RunE.
return
}
t.Fatal("expected out-of-range check to trigger for value above MaxInt32")
})

t.Run("context-size below int32 min is out of range", func(t *testing.T) {
tooSmall := int64(math.MinInt32) - 1
cmd := newUpCommand()
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooSmall)}))
ctxSize, err := cmd.Flags().GetInt64("context-size")
require.NoError(t, err)
require.True(t, cmd.Flags().Changed("context-size"))
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
return
}
t.Fatal("expected out-of-range check to trigger for value below MinInt32")
})
}

func TestParseBackendMode(t *testing.T) {
tests := []struct {
name string
Expand Down
10 changes: 5 additions & 5 deletions pkg/inference/backends/llamacpp/llamacpp_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
}

func GetContextSize(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
// Model config takes precedence
// Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx)
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
return backendCfg.ContextSize
}
// Fallback to model config (set at packaging time via docker model package --context-size)
if modelCfg != nil {
if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && (*ctxSize == UnlimitedContextSize || *ctxSize > 0) {
return ctxSize
}
}
// Fallback to backend config
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
return backendCfg.ContextSize
}
return nil
}

Expand Down
22 changes: 20 additions & 2 deletions pkg/inference/backends/llamacpp/llamacpp_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func TestGetArgs(t *testing.T) {
),
},
{
name: "context size from model config",
name: "backend config takes precedence over model config",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
Expand All @@ -206,7 +206,25 @@ func TestGetArgs(t *testing.T) {
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "2096", // model config takes precedence
"--ctx-size", "1234", // backend config takes precedence
"--jinja",
),
},
{
name: "model config used when no backend config",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
config: &types.Config{
ContextSize: int32ptr(2096),
},
},
config: nil,
expected: append(slices.Clone(baseArgs),
"--model", modelPath,
"--host", socket,
"--embeddings",
"--ctx-size", "2096", // model config used as fallback
"--jinja",
),
},
Expand Down
17 changes: 15 additions & 2 deletions pkg/inference/backends/mlx/mlx_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,22 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
return args, nil
}

// GetMaxTokens returns the max tokens (context size) from model config or backend config.
// Model config takes precedence over backend config.
// GetMaxTokens returns the max tokens (context size) from backend config or model config.
// Backend config takes precedence over model config (runtime configuration).
// Returns nil if neither is specified (MLX will use model defaults).
func GetMaxTokens(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *uint64 {
// Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx)
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
v := uint64(*backendCfg.ContextSize)
return &v
}
// Fallback to model config (set at packaging time via docker model package --context-size)
if modelCfg != nil {
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
v := uint64(*cs)
return &v
}
}
// Return nil to let MLX use model defaults
return nil
}
97 changes: 97 additions & 0 deletions pkg/inference/backends/mlx/mlx_config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package mlx

import (
"testing"

"github.com/docker/model-runner/pkg/distribution/types"
"github.com/docker/model-runner/pkg/inference"
)

func TestGetMaxTokens(t *testing.T) {
tests := []struct {
name string
modelCfg types.ModelConfig
backendCfg *inference.BackendConfiguration
expectedValue *uint64
}{
{
name: "no config",
modelCfg: &types.Config{},
backendCfg: nil,
expectedValue: nil,
},
{
name: "backend config only",
modelCfg: &types.Config{},
backendCfg: &inference.BackendConfiguration{
ContextSize: int32ptr(4096),
},
expectedValue: uint64ptr(4096),
},
{
name: "model config only",
modelCfg: &types.Config{
ContextSize: int32ptr(8192),
},
backendCfg: nil,
expectedValue: uint64ptr(8192),
},
{
name: "backend config takes precedence",
modelCfg: &types.Config{
ContextSize: int32ptr(16384),
},
backendCfg: &inference.BackendConfiguration{
ContextSize: int32ptr(4096),
},
expectedValue: uint64ptr(4096),
},
{
name: "model config used as fallback",
modelCfg: &types.Config{
ContextSize: int32ptr(16384),
},
backendCfg: nil,
expectedValue: uint64ptr(16384),
},
{
name: "zero context size in backend config returns nil",
modelCfg: &types.Config{},
backendCfg: &inference.BackendConfiguration{
ContextSize: int32ptr(0),
},
expectedValue: nil,
},
{
name: "nil model config with backend config",
modelCfg: nil,
backendCfg: &inference.BackendConfiguration{ContextSize: int32ptr(4096)},
expectedValue: uint64ptr(4096),
},
{
name: "nil model config without backend config",
modelCfg: nil,
backendCfg: nil,
expectedValue: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetMaxTokens(tt.modelCfg, tt.backendCfg)
if (result == nil) != (tt.expectedValue == nil) {
t.Errorf("expected nil=%v, got nil=%v", tt.expectedValue == nil, result == nil)
} else if result != nil && *result != *tt.expectedValue {
t.Errorf("expected %d, got %d", *tt.expectedValue, *result)
}
})
}
}

func int32ptr(n int32) *int32 {
return &n
}

func uint64ptr(n uint64) *uint64 {
return &n
}
12 changes: 12 additions & 0 deletions pkg/inference/backends/mlx/testmain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package mlx

import (
"testing"

"go.uber.org/goleak"
)

// TestMain runs goleak after the test suite to detect goroutine leaks.
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
17 changes: 10 additions & 7 deletions pkg/inference/backends/sglang/sglang_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,21 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
return args, nil
}

// GetContextLength returns the context length (context size) from model config or backend config.
// Model config takes precedence over backend config.
// GetContextLength returns the context length (context size) from backend config or model config.
// Backend config takes precedence over model config (runtime configuration).
// Returns nil if neither is specified (SGLang will auto-derive from model).
func GetContextLength(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
// Model config takes precedence
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
return cs
}
// Fallback to backend config
// Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx)
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
return backendCfg.ContextSize
}
// Fallback to model config (set at packaging time via docker model package --context-size)
if modelCfg == nil {
return nil
}
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
return cs
}
// Return nil to let SGLang auto-derive from model config
return nil
}
Loading
Loading