55 "strconv"
66 "testing"
77
8+ "github.com/docker/model-distribution/types"
89 "github.com/docker/model-runner/pkg/inference"
910)
1011
@@ -72,12 +73,17 @@ func TestGetArgs(t *testing.T) {
7273
7374 tests := []struct {
7475 name string
76+ model types.Model
7577 mode inference.BackendMode
78+ cfg * inference.BackendConfiguration
7679 expected []string
7780 }{
7881 {
7982 name : "completion mode" ,
8083 mode : inference .BackendModeCompletion ,
84+ model : & fakeModel {
85+ ggufPath : modelPath ,
86+ },
8187 expected : []string {
8288 "--jinja" ,
8389 "-ngl" , "100" ,
@@ -89,20 +95,86 @@ func TestGetArgs(t *testing.T) {
8995 {
9096 name : "embedding mode" ,
9197 mode : inference .BackendModeEmbedding ,
98+ model : & fakeModel {
99+ ggufPath : modelPath ,
100+ },
101+ expected : []string {
102+ "--jinja" ,
103+ "-ngl" , "100" ,
104+ "--metrics" ,
105+ "--model" , modelPath ,
106+ "--host" , socket ,
107+ "--embeddings" ,
108+ },
109+ },
110+ {
111+ name : "context size from backend config" ,
112+ mode : inference .BackendModeEmbedding ,
113+ model : & fakeModel {
114+ ggufPath : modelPath ,
115+ },
116+ cfg : & inference.BackendConfiguration {
117+ ContextSize : 1234 ,
118+ },
92119 expected : []string {
93120 "--jinja" ,
94121 "-ngl" , "100" ,
95122 "--metrics" ,
96123 "--model" , modelPath ,
97124 "--host" , socket ,
98125 "--embeddings" ,
126+ "--ctx-size" , "1234" , // should add this flag
127+ },
128+ },
129+ {
130+ name : "context size from model config" ,
131+ mode : inference .BackendModeEmbedding ,
132+ model : & fakeModel {
133+ ggufPath : modelPath ,
134+ config : types.Config {
135+ ContextSize : uint64ptr (2096 ),
136+ },
137+ },
138+ cfg : & inference.BackendConfiguration {
139+ ContextSize : 1234 ,
140+ },
141+ expected : []string {
142+ "--jinja" ,
143+ "-ngl" , "100" ,
144+ "--metrics" ,
145+ "--model" , modelPath ,
146+ "--host" , socket ,
147+ "--embeddings" ,
148+ "--ctx-size" , "2096" , // model config takes precedence
149+ },
150+ },
151+ {
152+ name : "raw flags from backend config" ,
153+ mode : inference .BackendModeEmbedding ,
154+ model : & fakeModel {
155+ ggufPath : modelPath ,
156+ },
157+ cfg : & inference.BackendConfiguration {
158+ RawFlags : []string {"--some" , "flag" },
159+ },
160+ expected : []string {
161+ "--jinja" ,
162+ "-ngl" , "100" ,
163+ "--metrics" ,
164+ "--model" , modelPath ,
165+ "--host" , socket ,
166+ "--embeddings" ,
167+ "--some" , "flag" , // model config takes precedence
99168 },
100169 },
101170 }
102171
103172 for _ , tt := range tests {
104173 t .Run (tt .name , func (t * testing.T ) {
105- args := config .GetArgs (modelPath , socket , tt .mode )
174+ args , err := config .GetArgs (tt .model , socket , tt .mode , tt .cfg )
175+ if err != nil {
176+ t .Errorf ("GetArgs() error = %v" , err )
177+ }
106178
107179 // Check that all expected arguments are present and in the correct order
108180 expectedIndex := 0
@@ -171,3 +243,34 @@ func TestContainsArg(t *testing.T) {
171243 })
172244 }
173245}
246+
247+ var _ types.Model = & fakeModel {}
248+
249+ type fakeModel struct {
250+ ggufPath string
251+ config types.Config
252+ }
253+
254+ func (f * fakeModel ) ID () (string , error ) {
255+ panic ("shouldn't be called" )
256+ }
257+
258+ func (f * fakeModel ) GGUFPath () (string , error ) {
259+ return f .ggufPath , nil
260+ }
261+
262+ func (f * fakeModel ) Config () (types.Config , error ) {
263+ return f .config , nil
264+ }
265+
266+ func (f * fakeModel ) Tags () []string {
267+ panic ("shouldn't be called" )
268+ }
269+
270+ func (f fakeModel ) Descriptor () (types.Descriptor , error ) {
271+ panic ("shouldn't be called" )
272+ }
273+
274+ func uint64ptr (n uint64 ) * uint64 {
275+ return & n
276+ }
0 commit comments