Skip to content

Commit de1ef38

Browse files
authored
Merge pull request #93 from docker/context-size
Respect context size from model config
2 parents 26a0a73 + cbdbb83 commit de1ef38

6 files changed

Lines changed: 141 additions & 18 deletions

File tree

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ go 1.23.7
55
require (
66
github.com/containerd/containerd/v2 v2.0.4
77
github.com/containerd/platforms v1.0.0-rc.1
8-
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857
8+
github.com/docker/model-distribution v0.0.0-20250627163720-aff34abcf3e0
99
github.com/google/go-containerregistry v0.20.3
1010
github.com/jaypipes/ghw v0.16.0
1111
github.com/mattn/go-shellwords v1.0.12

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBi
3838
github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
3939
github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo=
4040
github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M=
41-
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857 h1:2IvvpdPZvpNn06+RUh5DC5O64dnrKjdsBKCMrzR5QTk=
42-
github.com/docker/model-distribution v0.0.0-20250618082521-fb5c8332c857/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
41+
github.com/docker/model-distribution v0.0.0-20250627163720-aff34abcf3e0 h1:bve4JZI06Admw+NewtPfrpJXsvRnGKTQvBOEICNC1C0=
42+
github.com/docker/model-distribution v0.0.0-20250627163720-aff34abcf3e0/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
4343
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
4444
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
4545
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=

pkg/inference/backends/llamacpp/llamacpp.go

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"os/exec"
1212
"path/filepath"
1313
"runtime"
14-
"strconv"
1514
"strings"
1615

1716
"github.com/docker/model-runner/pkg/diskusage"
@@ -122,10 +121,9 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
122121

123122
// Run implements inference.Backend.Run.
124123
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
125-
modelPath, err := l.modelManager.GetModelPath(model)
126-
l.log.Infof("Model path: %s", modelPath)
124+
mdl, err := l.modelManager.GetModel(model)
127125
if err != nil {
128-
return fmt.Errorf("failed to get model path: %w", err)
126+
return fmt.Errorf("failed to get model: %w", err)
129127
}
130128

131129
if err := os.RemoveAll(socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
@@ -138,13 +136,9 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
138136
binPath = l.updatedServerStoragePath
139137
}
140138

141-
args := l.config.GetArgs(modelPath, socket, mode)
142-
143-
if config != nil {
144-
if config.ContextSize >= 0 {
145-
args = append(args, "--ctx-size", strconv.Itoa(int(config.ContextSize)))
146-
}
147-
args = append(args, config.RuntimeFlags...)
139+
args, err := l.config.GetArgs(mdl, socket, mode, config)
140+
if err != nil {
141+
return fmt.Errorf("failed to get args for llama.cpp: %w", err)
148142
}
149143

150144
l.log.Infof("llamaCppArgs: %v", args)

pkg/inference/backends/llamacpp/llamacpp_config.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package llamacpp
22

33
import (
4+
"fmt"
45
"runtime"
56
"strconv"
67

8+
"github.com/docker/model-distribution/types"
79
"github.com/docker/model-runner/pkg/inference"
810
)
911

@@ -33,10 +35,20 @@ func NewDefaultLlamaCppConfig() *Config {
3335
}
3436

3537
// GetArgs implements BackendConfig.GetArgs.
36-
func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) []string {
38+
func (c *Config) GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
3739
// Start with the arguments from LlamaCppConfig
3840
args := append([]string{}, c.Args...)
3941

42+
modelPath, err := model.GGUFPath()
43+
if err != nil {
44+
return nil, fmt.Errorf("get gguf path: %w", err)
45+
}
46+
47+
modelCfg, err := model.Config()
48+
if err != nil {
49+
return nil, fmt.Errorf("get model config: %w", err)
50+
}
51+
4052
// Add model and socket arguments
4153
args = append(args, "--model", modelPath, "--host", socket)
4254

@@ -45,7 +57,20 @@ func (c *Config) GetArgs(modelPath, socket string, mode inference.BackendMode) [
4557
args = append(args, "--embeddings")
4658
}
4759

48-
return args
60+
// Add arguments from model config
61+
if modelCfg.ContextSize != nil {
62+
args = append(args, "--ctx-size", strconv.FormatUint(*modelCfg.ContextSize, 10))
63+
}
64+
65+
// Add arguments from backend config
66+
if config != nil {
67+
if config.ContextSize > 0 && !containsArg(args, "--ctx-size") {
68+
args = append(args, "--ctx-size", strconv.FormatInt(config.ContextSize, 10))
69+
}
70+
args = append(args, config.RuntimeFlags...)
71+
}
72+
73+
return args, nil
4974
}
5075

5176
// containsArg checks if the given argument is already in the args slice.

pkg/inference/backends/llamacpp/llamacpp_config_test.go

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"strconv"
66
"testing"
77

8+
"github.com/docker/model-distribution/types"
89
"github.com/docker/model-runner/pkg/inference"
910
)
1011

@@ -72,12 +73,17 @@ func TestGetArgs(t *testing.T) {
7273

7374
tests := []struct {
7475
name string
76+
model types.Model
7577
mode inference.BackendMode
78+
config *inference.BackendConfiguration
7679
expected []string
7780
}{
7881
{
7982
name: "completion mode",
8083
mode: inference.BackendModeCompletion,
84+
model: &fakeModel{
85+
ggufPath: modelPath,
86+
},
8187
expected: []string{
8288
"--jinja",
8389
"-ngl", "100",
@@ -89,20 +95,86 @@ func TestGetArgs(t *testing.T) {
8995
{
9096
name: "embedding mode",
9197
mode: inference.BackendModeEmbedding,
98+
model: &fakeModel{
99+
ggufPath: modelPath,
100+
},
101+
expected: []string{
102+
"--jinja",
103+
"-ngl", "100",
104+
"--metrics",
105+
"--model", modelPath,
106+
"--host", socket,
107+
"--embeddings",
108+
},
109+
},
110+
{
111+
name: "context size from backend config",
112+
mode: inference.BackendModeEmbedding,
113+
model: &fakeModel{
114+
ggufPath: modelPath,
115+
},
116+
config: &inference.BackendConfiguration{
117+
ContextSize: 1234,
118+
},
92119
expected: []string{
93120
"--jinja",
94121
"-ngl", "100",
95122
"--metrics",
96123
"--model", modelPath,
97124
"--host", socket,
98125
"--embeddings",
126+
"--ctx-size", "1234", // should add this flag
127+
},
128+
},
129+
{
130+
name: "context size from model config",
131+
mode: inference.BackendModeEmbedding,
132+
model: &fakeModel{
133+
ggufPath: modelPath,
134+
config: types.Config{
135+
ContextSize: uint64ptr(2096),
136+
},
137+
},
138+
config: &inference.BackendConfiguration{
139+
ContextSize: 1234,
140+
},
141+
expected: []string{
142+
"--jinja",
143+
"-ngl", "100",
144+
"--metrics",
145+
"--model", modelPath,
146+
"--host", socket,
147+
"--embeddings",
148+
"--ctx-size", "2096", // model config takes precedence
149+
},
150+
},
151+
{
152+
name: "raw flags from backend config",
153+
mode: inference.BackendModeEmbedding,
154+
model: &fakeModel{
155+
ggufPath: modelPath,
156+
},
157+
config: &inference.BackendConfiguration{
158+
RuntimeFlags: []string{"--some", "flag"},
159+
},
160+
expected: []string{
161+
"--jinja",
162+
"-ngl", "100",
163+
"--metrics",
164+
"--model", modelPath,
165+
"--host", socket,
166+
"--embeddings",
167+
"--some", "flag", // model config takes precedence
99168
},
100169
},
101170
}
102171

103172
for _, tt := range tests {
104173
t.Run(tt.name, func(t *testing.T) {
105-
args := config.GetArgs(modelPath, socket, tt.mode)
174+
args, err := config.GetArgs(tt.model, socket, tt.mode, tt.config)
175+
if err != nil {
176+
t.Errorf("GetArgs() error = %v", err)
177+
}
106178

107179
// Check that all expected arguments are present and in the correct order
108180
expectedIndex := 0
@@ -171,3 +243,34 @@ func TestContainsArg(t *testing.T) {
171243
})
172244
}
173245
}
246+
247+
var _ types.Model = &fakeModel{}
248+
249+
type fakeModel struct {
250+
ggufPath string
251+
config types.Config
252+
}
253+
254+
func (f *fakeModel) ID() (string, error) {
255+
panic("shouldn't be called")
256+
}
257+
258+
func (f *fakeModel) GGUFPath() (string, error) {
259+
return f.ggufPath, nil
260+
}
261+
262+
func (f *fakeModel) Config() (types.Config, error) {
263+
return f.config, nil
264+
}
265+
266+
func (f *fakeModel) Tags() []string {
267+
panic("shouldn't be called")
268+
}
269+
270+
func (f fakeModel) Descriptor() (types.Descriptor, error) {
271+
panic("shouldn't be called")
272+
}
273+
274+
func uint64ptr(n uint64) *uint64 {
275+
return &n
276+
}

pkg/inference/config/config.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package config
22

33
import (
4+
"github.com/docker/model-distribution/types"
45
"github.com/docker/model-runner/pkg/inference"
56
)
67

@@ -11,5 +12,5 @@ type BackendConfig interface {
1112
// GetArgs returns the command-line arguments for the backend.
1213
// It takes the model path, socket, and mode as input and returns
1314
// the appropriate arguments for the backend.
14-
GetArgs(modelPath, socket string, mode inference.BackendMode) []string
15+
GetArgs(model types.Model, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error)
1516
}

0 commit comments

Comments
 (0)