Skip to content

Commit b48efb3

Browse files
authored
fix: update context size precedence to prioritize backend configuration over model configuration (#847)
* fix: update context size precedence to prioritize backend configuration over model configuration * fix: ensure context size from model configuration is positive before using it * fix: validate context size range and update backend configuration handling * fix: implement GetMaxTokens function to prioritize backend context size over model configuration in mlx also added goleak detector in mlx package
1 parent a527e8c commit b48efb3

File tree

11 files changed

+331
-37
lines changed

11 files changed

+331
-37
lines changed

cmd/cli/commands/compose.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7+
"math"
78
"net"
89
"slices"
910
"strconv"
@@ -85,14 +86,20 @@ func newUpCommand() *cobra.Command {
8586
}
8687

8788
for _, model := range models {
88-
size := int32(ctxSize)
89+
backendConfig := inference.BackendConfiguration{
90+
Speculative: speculativeConfig,
91+
}
92+
if cmd.Flags().Changed("context-size") {
93+
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
94+
return fmt.Errorf("context-size %d is out of range (must be between %d and %d)", ctxSize, math.MinInt32, math.MaxInt32)
95+
}
96+
size := int32(ctxSize)
97+
backendConfig.ContextSize = &size
98+
}
8999
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
90-
Model: model,
91-
BackendConfiguration: inference.BackendConfiguration{
92-
ContextSize: &size,
93-
Speculative: speculativeConfig,
94-
},
95-
RawRuntimeFlags: rawRuntimeFlags,
100+
Model: model,
101+
BackendConfiguration: backendConfig,
102+
RawRuntimeFlags: rawRuntimeFlags,
96103
}); err != nil {
97104
configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s"
98105
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err)

cmd/cli/commands/compose_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,104 @@
11
package commands
22

33
import (
4+
"fmt"
5+
"math"
46
"testing"
57

68
"github.com/docker/model-runner/pkg/inference"
79
"github.com/stretchr/testify/assert"
810
"github.com/stretchr/testify/require"
911
)
1012

13+
// TestUpCommandContextSizeFlagBehavior verifies that the --context-size flag on
14+
// the compose up command is not "changed" by default (i.e. nil ContextSize
15+
// should be sent when the flag is absent) and is marked as changed after an
16+
// explicit value is provided.
17+
func TestUpCommandContextSizeFlagBehavior(t *testing.T) {
18+
t.Run("context-size flag not changed by default", func(t *testing.T) {
19+
cmd := newUpCommand()
20+
// Parse with just the required --model flag — no --context-size.
21+
err := cmd.ParseFlags([]string{"--model", "mymodel"})
22+
require.NoError(t, err)
23+
// The flag must NOT be marked as changed so that ContextSize is omitted
24+
// from the configure request (i.e. remains nil).
25+
assert.False(t, cmd.Flags().Changed("context-size"),
26+
"context-size must not be Changed when the flag is absent")
27+
})
28+
29+
t.Run("context-size flag changed after explicit value", func(t *testing.T) {
30+
cmd := newUpCommand()
31+
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "4096"})
32+
require.NoError(t, err)
33+
assert.True(t, cmd.Flags().Changed("context-size"),
34+
"context-size must be Changed when explicitly provided")
35+
})
36+
37+
t.Run("context-size flag changed with unlimited value -1", func(t *testing.T) {
38+
cmd := newUpCommand()
39+
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "-1"})
40+
require.NoError(t, err)
41+
assert.True(t, cmd.Flags().Changed("context-size"),
42+
"context-size must be Changed when explicitly set to -1 (unlimited)")
43+
})
44+
45+
t.Run("ContextSize is nil in BackendConfiguration when flag not set", func(t *testing.T) {
46+
cmd := newUpCommand()
47+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel"}))
48+
// Simulate the logic in compose.go RunE: only add ContextSize when Changed.
49+
backendConfig := inference.BackendConfiguration{}
50+
if cmd.Flags().Changed("context-size") {
51+
size := int32(-1) // default value
52+
backendConfig.ContextSize = &size
53+
}
54+
assert.Nil(t, backendConfig.ContextSize,
55+
"ContextSize must be nil in BackendConfiguration when --context-size is not provided")
56+
})
57+
58+
t.Run("ContextSize is non-nil in BackendConfiguration when flag is set", func(t *testing.T) {
59+
cmd := newUpCommand()
60+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "64000"}))
61+
ctxSize, err := cmd.Flags().GetInt64("context-size")
62+
require.NoError(t, err)
63+
backendConfig := inference.BackendConfiguration{}
64+
if cmd.Flags().Changed("context-size") {
65+
size := int32(ctxSize)
66+
backendConfig.ContextSize = &size
67+
}
68+
require.NotNil(t, backendConfig.ContextSize,
69+
"ContextSize must be non-nil when --context-size is provided")
70+
assert.Equal(t, int32(64000), *backendConfig.ContextSize)
71+
})
72+
73+
t.Run("context-size above int32 max is out of range", func(t *testing.T) {
74+
tooBig := int64(math.MaxInt32) + 1
75+
cmd := newUpCommand()
76+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooBig)}))
77+
ctxSize, err := cmd.Flags().GetInt64("context-size")
78+
require.NoError(t, err)
79+
require.True(t, cmd.Flags().Changed("context-size"))
80+
// Simulate the range check from compose.go RunE.
81+
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
82+
// Expected: would return an error in RunE.
83+
return
84+
}
85+
t.Fatal("expected out-of-range check to trigger for value above MaxInt32")
86+
})
87+
88+
t.Run("context-size below int32 min is out of range", func(t *testing.T) {
89+
tooSmall := int64(math.MinInt32) - 1
90+
cmd := newUpCommand()
91+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooSmall)}))
92+
ctxSize, err := cmd.Flags().GetInt64("context-size")
93+
require.NoError(t, err)
94+
require.True(t, cmd.Flags().Changed("context-size"))
95+
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
96+
return
97+
}
98+
t.Fatal("expected out-of-range check to trigger for value below MinInt32")
99+
})
100+
}
101+
11102
func TestParseBackendMode(t *testing.T) {
12103
tests := []struct {
13104
name string

pkg/inference/backends/llamacpp/llamacpp_config.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,16 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
9595
}
9696

9797
func GetContextSize(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
98-
// Model config takes precedence
98+
// Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx)
99+
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
100+
return backendCfg.ContextSize
101+
}
102+
// Fallback to model config (set at packaging time via docker model package --context-size)
99103
if modelCfg != nil {
100104
if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && (*ctxSize == UnlimitedContextSize || *ctxSize > 0) {
101105
return ctxSize
102106
}
103107
}
104-
// Fallback to backend config
105-
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
106-
return backendCfg.ContextSize
107-
}
108108
return nil
109109
}
110110

pkg/inference/backends/llamacpp/llamacpp_config_test.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ func TestGetArgs(t *testing.T) {
191191
),
192192
},
193193
{
194-
name: "context size from model config",
194+
name: "backend config takes precedence over model config",
195195
mode: inference.BackendModeEmbedding,
196196
bundle: &fakeBundle{
197197
ggufPath: modelPath,
@@ -206,7 +206,25 @@ func TestGetArgs(t *testing.T) {
206206
"--model", modelPath,
207207
"--host", socket,
208208
"--embeddings",
209-
"--ctx-size", "2096", // model config takes precedence
209+
"--ctx-size", "1234", // backend config takes precedence
210+
"--jinja",
211+
),
212+
},
213+
{
214+
name: "model config used when no backend config",
215+
mode: inference.BackendModeEmbedding,
216+
bundle: &fakeBundle{
217+
ggufPath: modelPath,
218+
config: &types.Config{
219+
ContextSize: int32ptr(2096),
220+
},
221+
},
222+
config: nil,
223+
expected: append(slices.Clone(baseArgs),
224+
"--model", modelPath,
225+
"--host", socket,
226+
"--embeddings",
227+
"--ctx-size", "2096", // model config used as fallback
210228
"--jinja",
211229
),
212230
},

pkg/inference/backends/mlx/mlx_config.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,22 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
6161
return args, nil
6262
}
6363

64-
// GetMaxTokens returns the max tokens (context size) from model config or backend config.
65-
// Model config takes precedence over backend config.
64+
// GetMaxTokens returns the max tokens (context size) from backend config or model config.
65+
// Backend config takes precedence over model config (runtime configuration).
6666
// Returns nil if neither is specified (MLX will use model defaults).
6767
func GetMaxTokens(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *uint64 {
68+
// Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx)
69+
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
70+
v := uint64(*backendCfg.ContextSize)
71+
return &v
72+
}
73+
// Fallback to model config (set at packaging time via docker model package --context-size)
74+
if modelCfg != nil {
75+
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
76+
v := uint64(*cs)
77+
return &v
78+
}
79+
}
80+
// Return nil to let MLX use model defaults
6881
return nil
6982
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package mlx
2+
3+
import (
4+
"testing"
5+
6+
"github.com/docker/model-runner/pkg/distribution/types"
7+
"github.com/docker/model-runner/pkg/inference"
8+
)
9+
10+
func TestGetMaxTokens(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
modelCfg types.ModelConfig
14+
backendCfg *inference.BackendConfiguration
15+
expectedValue *uint64
16+
}{
17+
{
18+
name: "no config",
19+
modelCfg: &types.Config{},
20+
backendCfg: nil,
21+
expectedValue: nil,
22+
},
23+
{
24+
name: "backend config only",
25+
modelCfg: &types.Config{},
26+
backendCfg: &inference.BackendConfiguration{
27+
ContextSize: int32ptr(4096),
28+
},
29+
expectedValue: uint64ptr(4096),
30+
},
31+
{
32+
name: "model config only",
33+
modelCfg: &types.Config{
34+
ContextSize: int32ptr(8192),
35+
},
36+
backendCfg: nil,
37+
expectedValue: uint64ptr(8192),
38+
},
39+
{
40+
name: "backend config takes precedence",
41+
modelCfg: &types.Config{
42+
ContextSize: int32ptr(16384),
43+
},
44+
backendCfg: &inference.BackendConfiguration{
45+
ContextSize: int32ptr(4096),
46+
},
47+
expectedValue: uint64ptr(4096),
48+
},
49+
{
50+
name: "model config used as fallback",
51+
modelCfg: &types.Config{
52+
ContextSize: int32ptr(16384),
53+
},
54+
backendCfg: nil,
55+
expectedValue: uint64ptr(16384),
56+
},
57+
{
58+
name: "zero context size in backend config returns nil",
59+
modelCfg: &types.Config{},
60+
backendCfg: &inference.BackendConfiguration{
61+
ContextSize: int32ptr(0),
62+
},
63+
expectedValue: nil,
64+
},
65+
{
66+
name: "nil model config with backend config",
67+
modelCfg: nil,
68+
backendCfg: &inference.BackendConfiguration{ContextSize: int32ptr(4096)},
69+
expectedValue: uint64ptr(4096),
70+
},
71+
{
72+
name: "nil model config without backend config",
73+
modelCfg: nil,
74+
backendCfg: nil,
75+
expectedValue: nil,
76+
},
77+
}
78+
79+
for _, tt := range tests {
80+
t.Run(tt.name, func(t *testing.T) {
81+
result := GetMaxTokens(tt.modelCfg, tt.backendCfg)
82+
if (result == nil) != (tt.expectedValue == nil) {
83+
t.Errorf("expected nil=%v, got nil=%v", tt.expectedValue == nil, result == nil)
84+
} else if result != nil && *result != *tt.expectedValue {
85+
t.Errorf("expected %d, got %d", *tt.expectedValue, *result)
86+
}
87+
})
88+
}
89+
}
90+
91+
func int32ptr(n int32) *int32 {
92+
return &n
93+
}
94+
95+
func uint64ptr(n uint64) *uint64 {
96+
return &n
97+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package mlx
2+
3+
import (
4+
"testing"
5+
6+
"go.uber.org/goleak"
7+
)
8+
9+
// TestMain runs goleak after the test suite to detect goroutine leaks.
10+
func TestMain(m *testing.M) {
11+
goleak.VerifyTestMain(m)
12+
}

pkg/inference/backends/sglang/sglang_config.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,21 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
6363
return args, nil
6464
}
6565

66-
// GetContextLength returns the context length (context size) from model config or backend config.
67-
// Model config takes precedence over backend config.
66+
// GetContextLength returns the context length (context size) from backend config or model config.
67+
// Backend config takes precedence over model config (runtime configuration).
6868
// Returns nil if neither is specified (SGLang will auto-derive from model).
6969
func GetContextLength(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
70-
// Model config takes precedence
71-
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
72-
return cs
73-
}
74-
// Fallback to backend config
70+
// Backend config takes precedence (runtime configuration via docker model configure / Ollama API num_ctx)
7571
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
7672
return backendCfg.ContextSize
7773
}
74+
// Fallback to model config (set at packaging time via docker model package --context-size)
75+
if modelCfg == nil {
76+
return nil
77+
}
78+
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
79+
return cs
80+
}
7881
// Return nil to let SGLang auto-derive from model config
7982
return nil
8083
}

0 commit comments

Comments
 (0)