Skip to content

Commit e0e2fe8

Browse files
committed
Fix context size configuration precedence
User configured context should always be applied over the packaged context size Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 9a742ad commit e0e2fe8

File tree

8 files changed

+146
-33
lines changed

8 files changed

+146
-33
lines changed

cmd/cli/commands/compose.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,17 @@ func newUpCommand() *cobra.Command {
8585
}
8686

8787
for _, model := range models {
88-
size := int32(ctxSize)
88+
backendConfig := inference.BackendConfiguration{
89+
Speculative: speculativeConfig,
90+
}
91+
if cmd.Flags().Changed("context-size") {
92+
size := int32(ctxSize)
93+
backendConfig.ContextSize = &size
94+
}
8995
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
90-
Model: model,
91-
BackendConfiguration: inference.BackendConfiguration{
92-
ContextSize: &size,
93-
Speculative: speculativeConfig,
94-
},
95-
RawRuntimeFlags: rawRuntimeFlags,
96+
Model: model,
97+
BackendConfiguration: backendConfig,
98+
RawRuntimeFlags: rawRuntimeFlags,
9699
}); err != nil {
97100
configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s"
98101
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err)

cmd/cli/commands/compose_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,67 @@ import (
88
"github.com/stretchr/testify/require"
99
)
1010

11+
// TestUpCommandContextSizeFlagBehavior verifies that the --context-size flag on
12+
// the compose up command is not "changed" by default (i.e. nil ContextSize
13+
// should be sent when the flag is absent) and is marked as changed after an
14+
// explicit value is provided.
15+
func TestUpCommandContextSizeFlagBehavior(t *testing.T) {
16+
t.Run("context-size flag not changed by default", func(t *testing.T) {
17+
cmd := newUpCommand()
18+
// Parse with just the required --model flag — no --context-size.
19+
err := cmd.ParseFlags([]string{"--model", "mymodel"})
20+
require.NoError(t, err)
21+
// The flag must NOT be marked as changed so that ContextSize is omitted
22+
// from the configure request (i.e. remains nil).
23+
assert.False(t, cmd.Flags().Changed("context-size"),
24+
"context-size must not be Changed when the flag is absent")
25+
})
26+
27+
t.Run("context-size flag changed after explicit value", func(t *testing.T) {
28+
cmd := newUpCommand()
29+
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "4096"})
30+
require.NoError(t, err)
31+
assert.True(t, cmd.Flags().Changed("context-size"),
32+
"context-size must be Changed when explicitly provided")
33+
})
34+
35+
t.Run("context-size flag changed with unlimited value -1", func(t *testing.T) {
36+
cmd := newUpCommand()
37+
err := cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "-1"})
38+
require.NoError(t, err)
39+
assert.True(t, cmd.Flags().Changed("context-size"),
40+
"context-size must be Changed when explicitly set to -1 (unlimited)")
41+
})
42+
43+
t.Run("ContextSize is nil in BackendConfiguration when flag not set", func(t *testing.T) {
44+
cmd := newUpCommand()
45+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel"}))
46+
// Simulate the logic in compose.go RunE: only add ContextSize when Changed.
47+
backendConfig := inference.BackendConfiguration{}
48+
if cmd.Flags().Changed("context-size") {
49+
size := int32(-1) // default value
50+
backendConfig.ContextSize = &size
51+
}
52+
assert.Nil(t, backendConfig.ContextSize,
53+
"ContextSize must be nil in BackendConfiguration when --context-size is not provided")
54+
})
55+
56+
t.Run("ContextSize is non-nil in BackendConfiguration when flag is set", func(t *testing.T) {
57+
cmd := newUpCommand()
58+
require.NoError(t, cmd.ParseFlags([]string{"--model", "mymodel", "--context-size", "64000"}))
59+
ctxSize, err := cmd.Flags().GetInt64("context-size")
60+
require.NoError(t, err)
61+
backendConfig := inference.BackendConfiguration{}
62+
if cmd.Flags().Changed("context-size") {
63+
size := int32(ctxSize)
64+
backendConfig.ContextSize = &size
65+
}
66+
require.NotNil(t, backendConfig.ContextSize,
67+
"ContextSize must be non-nil when --context-size is provided")
68+
assert.Equal(t, int32(64000), *backendConfig.ContextSize)
69+
})
70+
}
71+
1172
func TestParseBackendMode(t *testing.T) {
1273
tests := []struct {
1374
name string

pkg/inference/backends/llamacpp/llamacpp_config.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,16 +95,16 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
9595
}
9696

9797
func GetContextSize(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
98-
// Model config takes precedence
98+
// Backend (runtime) config takes precedence — the user explicitly requested this value
99+
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
100+
return backendCfg.ContextSize
101+
}
102+
// Fallback to model config
99103
if modelCfg != nil {
100104
if ctxSize := modelCfg.GetContextSize(); ctxSize != nil && (*ctxSize == UnlimitedContextSize || *ctxSize > 0) {
101105
return ctxSize
102106
}
103107
}
104-
// Fallback to backend config
105-
if backendCfg != nil && backendCfg.ContextSize != nil && (*backendCfg.ContextSize == UnlimitedContextSize || *backendCfg.ContextSize > 0) {
106-
return backendCfg.ContextSize
107-
}
108108
return nil
109109
}
110110

pkg/inference/backends/llamacpp/llamacpp_config_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,25 @@ func TestGetArgs(t *testing.T) {
206206
"--model", modelPath,
207207
"--host", socket,
208208
"--embeddings",
209-
"--ctx-size", "2096", // model config takes precedence
209+
"--ctx-size", "1234", // backend config takes precedence
210+
"--jinja",
211+
),
212+
},
213+
{
214+
name: "context size from model config when no backend config",
215+
mode: inference.BackendModeEmbedding,
216+
bundle: &fakeBundle{
217+
ggufPath: modelPath,
218+
config: &types.Config{
219+
ContextSize: int32ptr(2096),
220+
},
221+
},
222+
config: nil,
223+
expected: append(slices.Clone(baseArgs),
224+
"--model", modelPath,
225+
"--host", socket,
226+
"--embeddings",
227+
"--ctx-size", "2096",
210228
"--jinja",
211229
),
212230
},

pkg/inference/backends/sglang/sglang_config.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,21 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
6363
return args, nil
6464
}
6565

66-
// GetContextLength returns the context length (context size) from model config or backend config.
67-
// Model config takes precedence over backend config.
66+
// GetContextLength returns the context length (context size) from backend config or model config.
67+
// Backend (runtime) config takes precedence over model config.
6868
// Returns nil if neither is specified (SGLang will auto-derive from model).
6969
func GetContextLength(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
70-
// Model config takes precedence
71-
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
72-
return cs
73-
}
74-
// Fallback to backend config
70+
// Backend (runtime) config takes precedence
7571
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
7672
return backendCfg.ContextSize
7773
}
74+
// Fallback to model config
75+
if modelCfg == nil {
76+
return nil
77+
}
78+
if cs := modelCfg.GetContextSize(); cs != nil && *cs > 0 {
79+
return cs
80+
}
7881
// Return nil to let SGLang auto-derive from model config
7982
return nil
8083
}

pkg/inference/backends/sglang/sglang_config_test.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func TestGetArgs(t *testing.T) {
103103
},
104104
},
105105
{
106-
name: "with model context size (takes precedence)",
106+
name: "backend config takes precedence over model config",
107107
bundle: &mockModelBundle{
108108
safetensorsPath: "/path/to/model/model.safetensors",
109109
runtimeConfig: &types.Config{
@@ -124,7 +124,7 @@ func TestGetArgs(t *testing.T) {
124124
"--port",
125125
"30000",
126126
"--context-length",
127-
"16384",
127+
"8192",
128128
},
129129
},
130130
{
@@ -225,13 +225,21 @@ func TestGetContextLength(t *testing.T) {
225225
expectedValue: int32ptr(8192),
226226
},
227227
{
228-
name: "model config takes precedence",
228+
name: "backend config takes precedence",
229229
modelCfg: &types.Config{
230230
ContextSize: int32ptr(16384),
231231
},
232232
backendCfg: &inference.BackendConfiguration{
233233
ContextSize: int32ptr(4096),
234234
},
235+
expectedValue: int32ptr(4096),
236+
},
237+
{
238+
name: "model config used as fallback",
239+
modelCfg: &types.Config{
240+
ContextSize: int32ptr(16384),
241+
},
242+
backendCfg: nil,
235243
expectedValue: int32ptr(16384),
236244
},
237245
{
@@ -242,6 +250,18 @@ func TestGetContextLength(t *testing.T) {
242250
},
243251
expectedValue: nil,
244252
},
253+
{
254+
name: "nil model config with backend config",
255+
modelCfg: nil,
256+
backendCfg: &inference.BackendConfiguration{ContextSize: int32ptr(4096)},
257+
expectedValue: int32ptr(4096),
258+
},
259+
{
260+
name: "nil model config without backend config",
261+
modelCfg: nil,
262+
backendCfg: nil,
263+
expectedValue: nil,
264+
},
245265
}
246266

247267
for _, tt := range tests {

pkg/inference/backends/vllm/vllm_config.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,20 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
8787
return args, nil
8888
}
8989

90-
// GetMaxModelLen returns the max model length (context size) from model config or backend config.
91-
// Model config takes precedence over backend config.
90+
// GetMaxModelLen returns the max model length (context size) from backend config or model config.
91+
// Backend (runtime) config takes precedence over model config.
9292
// Returns nil if neither is specified (vLLM will auto-derive from model).
9393
func GetMaxModelLen(modelCfg types.ModelConfig, backendCfg *inference.BackendConfiguration) *int32 {
94-
// Model config takes precedence
94+
// Backend (runtime) config takes precedence
95+
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
96+
return backendCfg.ContextSize
97+
}
98+
// Fallback to model config
9599
if modelCfg != nil {
96100
if ctxSize := modelCfg.GetContextSize(); ctxSize != nil {
97101
return ctxSize
98102
}
99103
}
100-
// Fallback to backend config
101-
if backendCfg != nil && backendCfg.ContextSize != nil && *backendCfg.ContextSize > 0 {
102-
return backendCfg.ContextSize
103-
}
104104
// Return nil to let vLLM auto-derive from model config
105105
return nil
106106
}

pkg/inference/backends/vllm/vllm_config_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestGetArgs(t *testing.T) {
109109
},
110110
},
111111
{
112-
name: "with model context size (takes precedence)",
112+
name: "backend config takes precedence over model config",
113113
bundle: &mockModelBundle{
114114
safetensorsPath: "/path/to/model",
115115
runtimeConfig: &types.Config{
@@ -125,7 +125,7 @@ func TestGetArgs(t *testing.T) {
125125
"--uds",
126126
"/tmp/socket",
127127
"--max-model-len",
128-
"16384",
128+
"8192",
129129
},
130130
},
131131
{
@@ -458,13 +458,21 @@ func TestGetMaxModelLen(t *testing.T) {
458458
expectedValue: int32ptr(8192),
459459
},
460460
{
461-
name: "model config takes precedence",
461+
name: "backend config takes precedence",
462462
modelCfg: &types.Config{
463463
ContextSize: int32ptr(16384),
464464
},
465465
backendCfg: &inference.BackendConfiguration{
466466
ContextSize: int32ptr(4096),
467467
},
468+
expectedValue: int32ptr(4096),
469+
},
470+
{
471+
name: "model config used as fallback",
472+
modelCfg: &types.Config{
473+
ContextSize: int32ptr(16384),
474+
},
475+
backendCfg: nil,
468476
expectedValue: int32ptr(16384),
469477
},
470478
}

0 commit comments

Comments
 (0)