Skip to content

Commit 9ae3f1a

Browse files
committed
Fix context size configuration precedence
User configured context should always be applied over the packaged context size Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 9a742ad commit 9ae3f1a

File tree

8 files changed

+180
-33
lines changed

8 files changed

+180
-33
lines changed

cmd/cli/commands/compose.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7+
"math"
78
"net"
89
"slices"
910
"strconv"
@@ -85,14 +86,20 @@ func newUpCommand() *cobra.Command {
8586
}
8687

8788
for _, model := range models {
88-
size := int32(ctxSize)
89+
backendConfig := inference.BackendConfiguration{
90+
Speculative: speculativeConfig,
91+
}
92+
if cmd.Flags().Changed("context-size") {
93+
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
94+
return fmt.Errorf("context-size %d is out of range (must be between %d and %d)", ctxSize, math.MinInt32, math.MaxInt32)
95+
}
96+
size := int32(ctxSize)
97+
backendConfig.ContextSize = &size
98+
}
8999
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
90-
Model: model,
91-
BackendConfiguration: inference.BackendConfiguration{
92-
ContextSize: &size,
93-
Speculative: speculativeConfig,
94-
},
95-
RawRuntimeFlags: rawRuntimeFlags,
100+
Model: model,
101+
BackendConfiguration: backendConfig,
102+
RawRuntimeFlags: rawRuntimeFlags,
96103
}); err != nil {
97104
configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s"
98105
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err)

cmd/cli/commands/compose_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,104 @@
11
package commands
22

33
import (
4+
"fmt"
5+
"math"
46
"testing"
57

68
"github.com/docker/model-runner/pkg/inference"
79
"github.com/stretchr/testify/assert"
810
"github.com/stretchr/testify/require"
911
)
1012

13+
// TestUpCommandContextSizeFlagBehavior verifies that the --context-size flag on
14+
// the compose up command is not "changed" by default (i.e. nil ContextSize
15+
// should be sent when the flag is absent) and is marked as changed after an
16+
// explicit value is provided.
17+
func TestUpCommandContextSizeFlagBehavior(t *testing.T) {
18+
t.Run("context-size flag not changed by default", func(t *testing.T) {
19+
cmd := newUpCommand()
20+
// Parse with just the required --model flag — no --context-size.
21+
err := cmd.ParseFlags([]string{"--model", "mymodel"})
22+
require.NoError(t, err)
23+
// The flag must NOT be marked as changed so that ContextSize is omitted
24+
// from the configure request (i.e. remains nil).
25+
assert.False(t, cmd.Flags().Changed("context-size"),
26+
"context-size must not be Changed when the flag is absent")
27+
})
28+
29+
t.Run("context-size flag changed after explicit value", func(t *testing.T) {
30+
cmd := newUpCommand()
31+
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "4096"})
32+
require.NoError(t, err)
33+
assert.True(t, cmd.Flags().Changed("context-size"),
34+
"context-size must be Changed when explicitly provided")
35+
})
36+
37+
t.Run("context-size flag changed with unlimited value -1", func(t *testing.T) {
38+
cmd := newUpCommand()
39+
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "-1"})
40+
require.NoError(t, err)
41+
assert.True(t, cmd.Flags().Changed("context-size"),
42+
"context-size must be Changed when explicitly set to -1 (unlimited)")
43+
})
44+
45+
t.Run("ContextSize is nil in BackendConfiguration when flag not set", func(t *testing.T) {
46+
cmd := newUpCommand()
47+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel"}))
48+
// Simulate the logic in compose.go RunE: only add ContextSize when Changed.
49+
backendConfig := inference.BackendConfiguration{}
50+
if cmd.Flags().Changed("context-size") {
51+
size := int32(-1) // default value
52+
backendConfig.ContextSize = &size
53+
}
54+
assert.Nil(t, backendConfig.ContextSize,
55+
"ContextSize must be nil in BackendConfiguration when --context-size is not provided")
56+
})
57+
58+
t.Run("ContextSize is non-nil in BackendConfiguration when flag is set", func(t *testing.T) {
59+
cmd := newUpCommand()
60+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "64000"}))
61+
ctxSize, err := cmd.Flags().GetInt64("context-size")
62+
require.NoError(t, err)
63+
backendConfig := inference.BackendConfiguration{}
64+
if cmd.Flags().Changed("context-size") {
65+
size := int32(ctxSize)
66+
backendConfig.ContextSize = &size
67+
}
68+
require.NotNil(t, backendConfig.ContextSize,
69+
"ContextSize must be non-nil when --context-size is provided")
70+
assert.Equal(t, int32(64000), *backendConfig.ContextSize)
71+
})
72+
73+
t.Run("context-size above int32 max is out of range", func(t *testing.T) {
74+
tooBig := int64(math.MaxInt32) + 1
75+
cmd := newUpCommand()
76+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooBig)}))
77+
ctxSize, err := cmd.Flags().GetInt64("context-size")
78+
require.NoError(t, err)
79+
require.True(t, cmd.Flags().Changed("context-size"))
80+
// Simulate the range check from compose.go RunE.
81+
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
82+
// Expected: would return an error in RunE.
83+
return
84+
}
85+
t.Fatal("expected out-of-range check to trigger for value above MaxInt32")
86+
})
87+
88+
t.Run("context-size below int32 min is out of range", func(t *testing.T) {
89+
tooSmall := int64(math.MinInt32) - 1
90+
cmd := newUpCommand()
91+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", fmt.Sprintf("%d", tooSmall)}))
92+
ctxSize, err := cmd.Flags().GetInt64("context-size")
93+
require.NoError(t, err)
94+
require.True(t, cmd.Flags().Changed("context-size"))
95+
if ctxSize > math.MaxInt32 || ctxSize < math.MinInt32 {
96+
return
97+
}
98+
t.Fatal("expected out-of-range check to trigger for value below MinInt32")
99+
})
100+
}
101+
11102
func TestParseBackendMode(t *testing.T) {
12103
tests := []struct {
13104
name string

pkg/inference/backends/llamacpp/llamacpp_config.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,16 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
9595
}
9696

9797
func GetContextSize(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
98-
// Model config takes precedence
98+
// Backend (runtime) config takes precedence — the user explicitly requested this value
99+
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
100+
return backendCfg.ContextSize
101+
}
102+
// Fallback to model config
99103
if modelCfg != nil {
100104
if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && (*ctxSize == UnlimitedContextSize || *ctxSize > 0) {
101105
return ctxSize
102106
}
103107
}
104-
// Fallback to backend config
105-
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
106-
return backendCfg.ContextSize
107-
}
108108
return nil
109109
}
110110

pkg/inference/backends/llamacpp/llamacpp_config_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,25 @@ func TestGetArgs(t *testing.T) {
206206
"--model", modelPath,
207207
"--host", socket,
208208
"--embeddings",
209-
"--ctx-size", "2096", // model config takes precedence
209+
"--ctx-size", "1234", // backend config takes precedence
210+
"--jinja",
211+
),
212+
},
213+
{
214+
name: "context size from model config when no backend config",
215+
mode: inference.BackendModeEmbedding,
216+
bundle: &fakeBundle{
217+
ggufPath: modelPath,
218+
config: &types.Config{
219+
ContextSize: int32ptr(2096),
220+
},
221+
},
222+
config: nil,
223+
expected: append(slices.Clone(baseArgs),
224+
"--model", modelPath,
225+
"--host", socket,
226+
"--embeddings",
227+
"--ctx-size", "2096",
210228
"--jinja",
211229
),
212230
},

pkg/inference/backends/sglang/sglang_config.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,21 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
6363
return args, nil
6464
}
6565

66-
// GetContextLength returns the context length (context size) from model config or backend config.
67-
// Model config takes precedence over backend config.
66+
// GetContextLength returns the context length (context size) from backend config or model config.
67+
// Backend (runtime) config takes precedence over model config.
6868
// Returns nil if neither is specified (SGLang will auto-derive from model).
6969
func GetContextLength(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
70-
// Model config takes precedence
71-
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
72-
return cs
73-
}
74-
// Fallback to backend config
70+
// Backend (runtime) config takes precedence
7571
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
7672
return backendCfg.ContextSize
7773
}
74+
// Fallback to model config
75+
if modelCfg == nil {
76+
return nil
77+
}
78+
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
79+
return cs
80+
}
7881
// Return nil to let SGLang auto-derive from model config
7982
return nil
8083
}

pkg/inference/backends/sglang/sglang_config_test.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func TestGetArgs(t *testing.T) {
103103
},
104104
},
105105
{
106-
name: "with model context size (takes precedence)",
106+
name: "backend config takes precedence over model config",
107107
bundle: &mockModelBundle{
108108
safetensorsPath: "/path/to/model/model.safetensors",
109109
runtimeConfig: &types.Config{
@@ -124,7 +124,7 @@ func TestGetArgs(t *testing.T) {
124124
"--port",
125125
"30000",
126126
"--context-length",
127-
"16384",
127+
"8192",
128128
},
129129
},
130130
{
@@ -225,13 +225,21 @@ func TestGetContextLength(t *testing.T) {
225225
expectedValue: int32ptr(8192),
226226
},
227227
{
228-
name: "model config takes precedence",
228+
name: "backend config takes precedence",
229229
modelCfg: &types.Config{
230230
ContextSize: int32ptr(16384),
231231
},
232232
backendCfg: &inference.BackendConfiguration{
233233
ContextSize: int32ptr(4096),
234234
},
235+
expectedValue: int32ptr(4096),
236+
},
237+
{
238+
name: "model config used as fallback",
239+
modelCfg: &types.Config{
240+
ContextSize: int32ptr(16384),
241+
},
242+
backendCfg: nil,
235243
expectedValue: int32ptr(16384),
236244
},
237245
{
@@ -242,6 +250,18 @@ func TestGetContextLength(t *testing.T) {
242250
},
243251
expectedValue: nil,
244252
},
253+
{
254+
name: "nil model config with backend config",
255+
modelCfg: nil,
256+
backendCfg: &inference.BackendConfiguration{ContextSize: int32ptr(4096)},
257+
expectedValue: int32ptr(4096),
258+
},
259+
{
260+
name: "nil model config without backend config",
261+
modelCfg: nil,
262+
backendCfg: nil,
263+
expectedValue: nil,
264+
},
245265
}
246266

247267
for _, tt := range tests {

pkg/inference/backends/vllm/vllm_config.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,20 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
8787
return args, nil
8888
}
8989

90-
// GetMaxModelLen returns the max model length (context size) from model config or backend config.
91-
// Model config takes precedence over backend config.
90+
// GetMaxModelLen returns the max model length (context size) from backend config or model config.
91+
// Backend (runtime) config takes precedence over model config.
9292
// Returns nil if neither is specified (vLLM will auto-derive from model).
9393
func GetMaxModelLen(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
94-
// Model config takes precedence
94+
// Backend (runtime) config takes precedence
95+
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
96+
return backendCfg.ContextSize
97+
}
98+
// Fallback to model config
9599
if modelCfg != nil {
96100
if ctxSize := modelCfg.GetContextSize(); ctxSize != nil {
97101
return ctxSize
98102
}
99103
}
100-
// Fallback to backend config
101-
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
102-
return backendCfg.ContextSize
103-
}
104104
// Return nil to let vLLM auto-derive from model config
105105
return nil
106106
}

pkg/inference/backends/vllm/vllm_config_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestGetArgs(t *testing.T) {
109109
},
110110
},
111111
{
112-
name: "with model context size (takes precedence)",
112+
name: "backend config takes precedence over model config",
113113
bundle: &mockModelBundle{
114114
safetensorsPath: "/path/to/model",
115115
runtimeConfig: &types.Config{
@@ -125,7 +125,7 @@ func TestGetArgs(t *testing.T) {
125125
"--uds",
126126
"/tmp/socket",
127127
"--max-model-len",
128-
"16384",
128+
"8192",
129129
},
130130
},
131131
{
@@ -458,13 +458,21 @@ func TestGetMaxModelLen(t *testing.T) {
458458
expectedValue: int32ptr(8192),
459459
},
460460
{
461-
name: "model config takes precedence",
461+
name: "backend config takes precedence",
462462
modelCfg: &types.Config{
463463
ContextSize: int32ptr(16384),
464464
},
465465
backendCfg: &inference.BackendConfiguration{
466466
ContextSize: int32ptr(4096),
467467
},
468+
expectedValue: int32ptr(4096),
469+
},
470+
{
471+
name: "model config used as fallback",
472+
modelCfg: &types.Config{
473+
ContextSize: int32ptr(16384),
474+
},
475+
backendCfg: nil,
468476
expectedValue: int32ptr(16384),
469477
},
470478
}

0 commit comments

Comments
 (0)