Skip to content

Commit 9df7e71

Browse files
authored
Merge pull request #880 from docker/fix/vllm-metal-chat-template
fix: pass chat template to vllm-metal backend
2 parents f2c5744 + 7cf87c8 commit 9df7e71

3 files changed

Lines changed: 202 additions & 4 deletions

File tree

pkg/inference/backends/vllm/vllm_config.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
4040
// Add socket arguments
4141
args = append(args, "--uds", socket)
4242

43+
// Add chat template if available in the model bundle.
44+
// Since transformers v4.44, vLLM no longer provides a default chat
45+
// template so we must supply one when the tokenizer omits it.
46+
if path := bundle.ChatTemplatePath(); path != "" {
47+
args = append(args, "--chat-template", path)
48+
}
49+
4350
// Add mode-specific arguments
4451
switch mode {
4552
case inference.BackendModeCompletion:

pkg/inference/backends/vllm/vllm_config_test.go

Lines changed: 186 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ import (
88
)
99

1010
type mockModelBundle struct {
11-
safetensorsPath string
12-
runtimeConfig *types.Config
11+
safetensorsPath string
12+
chatTemplatePath string
13+
runtimeConfig *types.Config
1314
}
1415

1516
func (m *mockModelBundle) GGUFPath() string {
@@ -21,7 +22,7 @@ func (m *mockModelBundle) SafetensorsPath() string {
2122
}
2223

2324
func (m *mockModelBundle) ChatTemplatePath() string {
24-
return ""
25+
return m.chatTemplatePath
2526
}
2627

2728
func (m *mockModelBundle) MMPROJPath() string {
@@ -74,6 +75,36 @@ func TestGetArgs(t *testing.T) {
7475
"/tmp/socket",
7576
},
7677
},
78+
{
79+
name: "with chat template",
80+
bundle: &mockModelBundle{
81+
safetensorsPath: "/path/to/model",
82+
chatTemplatePath: "/path/to/bundle/template.jinja",
83+
},
84+
config: nil,
85+
expected: []string{
86+
"serve",
87+
"/path/to",
88+
"--uds",
89+
"/tmp/socket",
90+
"--chat-template",
91+
"/path/to/bundle/template.jinja",
92+
},
93+
},
94+
{
95+
name: "without chat template omits flag",
96+
bundle: &mockModelBundle{
97+
safetensorsPath: "/path/to/model",
98+
chatTemplatePath: "",
99+
},
100+
config: nil,
101+
expected: []string{
102+
"serve",
103+
"/path/to",
104+
"--uds",
105+
"/tmp/socket",
106+
},
107+
},
77108
{
78109
name: "with backend context size",
79110
bundle: &mockModelBundle{
@@ -499,6 +530,158 @@ func TestGetMaxModelLen(t *testing.T) {
499530
}
500531
}
501532

533+
func TestBuildArgs(t *testing.T) {
534+
tests := []struct {
535+
name string
536+
bundle *mockModelBundle
537+
socket string
538+
model string
539+
modelRef string
540+
mode inference.BackendMode
541+
config *inference.BackendConfiguration
542+
expected []string
543+
expectError bool
544+
}{
545+
{
546+
name: "basic completion mode",
547+
bundle: &mockModelBundle{
548+
safetensorsPath: "/models/bundle/model/safetensors",
549+
},
550+
socket: "127.0.0.1:30000",
551+
model: "sha256:abc123",
552+
modelRef: "ai/test-model:latest",
553+
mode: inference.BackendModeCompletion,
554+
expected: []string{
555+
"-m", "vllm.entrypoints.openai.api_server",
556+
"--model", "/models/bundle/model",
557+
"--host", "127.0.0.1",
558+
"--port", "30000",
559+
"--enable-auto-tool-choice", "--tool-call-parser", "hermes",
560+
"--served-model-name", "sha256:abc123", "ai/test-model:latest",
561+
},
562+
},
563+
{
564+
name: "with chat template",
565+
bundle: &mockModelBundle{
566+
safetensorsPath: "/models/bundle/model/safetensors",
567+
chatTemplatePath: "/models/bundle/template.jinja",
568+
},
569+
socket: "127.0.0.1:30000",
570+
model: "sha256:abc123",
571+
modelRef: "ai/test-model:latest",
572+
mode: inference.BackendModeCompletion,
573+
expected: []string{
574+
"-m", "vllm.entrypoints.openai.api_server",
575+
"--model", "/models/bundle/model",
576+
"--host", "127.0.0.1",
577+
"--port", "30000",
578+
"--enable-auto-tool-choice", "--tool-call-parser", "hermes",
579+
"--chat-template", "/models/bundle/template.jinja",
580+
"--served-model-name", "sha256:abc123", "ai/test-model:latest",
581+
},
582+
},
583+
{
584+
name: "without chat template",
585+
bundle: &mockModelBundle{
586+
safetensorsPath: "/models/bundle/model/safetensors",
587+
chatTemplatePath: "",
588+
},
589+
socket: "127.0.0.1:30000",
590+
model: "sha256:abc123",
591+
modelRef: "ai/test-model:latest",
592+
mode: inference.BackendModeCompletion,
593+
expected: []string{
594+
"-m", "vllm.entrypoints.openai.api_server",
595+
"--model", "/models/bundle/model",
596+
"--host", "127.0.0.1",
597+
"--port", "30000",
598+
"--enable-auto-tool-choice", "--tool-call-parser", "hermes",
599+
"--served-model-name", "sha256:abc123", "ai/test-model:latest",
600+
},
601+
},
602+
{
603+
name: "empty safetensors path should error",
604+
bundle: &mockModelBundle{
605+
safetensorsPath: "",
606+
},
607+
socket: "127.0.0.1:30000",
608+
model: "sha256:abc123",
609+
modelRef: "ai/test-model:latest",
610+
mode: inference.BackendModeCompletion,
611+
expectError: true,
612+
},
613+
{
614+
name: "embedding mode",
615+
bundle: &mockModelBundle{
616+
safetensorsPath: "/models/bundle/model/safetensors",
617+
},
618+
socket: "127.0.0.1:30000",
619+
model: "sha256:abc123",
620+
modelRef: "ai/test-model:latest",
621+
mode: inference.BackendModeEmbedding,
622+
expected: []string{
623+
"-m", "vllm.entrypoints.openai.api_server",
624+
"--model", "/models/bundle/model",
625+
"--host", "127.0.0.1",
626+
"--port", "30000",
627+
"--enable-auto-tool-choice", "--tool-call-parser", "hermes",
628+
"--runner", "pooling",
629+
"--served-model-name", "sha256:abc123", "ai/test-model:latest",
630+
},
631+
},
632+
{
633+
name: "with context size",
634+
bundle: &mockModelBundle{
635+
safetensorsPath: "/models/bundle/model/safetensors",
636+
},
637+
socket: "127.0.0.1:30000",
638+
model: "sha256:abc123",
639+
modelRef: "ai/test-model:latest",
640+
mode: inference.BackendModeCompletion,
641+
config: &inference.BackendConfiguration{
642+
ContextSize: int32ptr(4096),
643+
},
644+
expected: []string{
645+
"-m", "vllm.entrypoints.openai.api_server",
646+
"--model", "/models/bundle/model",
647+
"--host", "127.0.0.1",
648+
"--port", "30000",
649+
"--enable-auto-tool-choice", "--tool-call-parser", "hermes",
650+
"--served-model-name", "sha256:abc123", "ai/test-model:latest",
651+
"--max-model-len", "4096",
652+
},
653+
},
654+
}
655+
656+
for _, tt := range tests {
657+
t.Run(tt.name, func(t *testing.T) {
658+
v := &vllmMetal{}
659+
args, err := v.buildArgs(tt.bundle, tt.socket, tt.model, tt.modelRef, tt.mode, tt.config)
660+
661+
if tt.expectError {
662+
if err == nil {
663+
t.Fatalf("expected error but got none")
664+
}
665+
return
666+
}
667+
668+
if err != nil {
669+
t.Fatalf("unexpected error: %v", err)
670+
}
671+
672+
if len(args) != len(tt.expected) {
673+
t.Fatalf("expected %d args, got %d\nexpected: %v\ngot: %v", len(tt.expected), len(args), tt.expected, args)
674+
}
675+
676+
for i, arg := range args {
677+
if arg != tt.expected[i] {
678+
t.Errorf("arg[%d]: expected %q, got %q", i, tt.expected[i], arg)
679+
}
680+
}
681+
})
682+
}
683+
}
684+
502685
func int32ptr(n int32) *int32 {
503686
return &n
504687
}

pkg/inference/backends/vllm/vllm_metal.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"strconv"
1414
"strings"
1515

16+
"github.com/docker/model-runner/pkg/distribution/types"
1617
"github.com/docker/model-runner/pkg/inference"
1718
"github.com/docker/model-runner/pkg/inference/backends"
1819
"github.com/docker/model-runner/pkg/inference/models"
@@ -252,7 +253,7 @@ func (v *vllmMetal) Run(ctx context.Context, socket, model string, modelRef stri
252253
// buildArgs builds the command line arguments for vllm-metal server.
253254
// vllm-metal is a vLLM platform plugin, so we launch vLLM's OpenAI-compatible
254255
// API server directly; the Metal plugin is auto-discovered via entry points.
255-
func (v *vllmMetal) buildArgs(bundle interface{ SafetensorsPath() string }, socket, model, modelRef string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
256+
func (v *vllmMetal) buildArgs(bundle types.ModelBundle, socket, model, modelRef string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) {
256257
// Parse host:port from socket (vllm-metal uses TCP)
257258
host, port, err := net.SplitHostPort(socket)
258259
if err != nil {
@@ -274,6 +275,13 @@ func (v *vllmMetal) buildArgs(bundle interface{ SafetensorsPath() string }, sock
274275
"--enable-auto-tool-choice", "--tool-call-parser", "hermes",
275276
}
276277

278+
// Add chat template if available in the model bundle.
279+
// Since transformers v4.44, vLLM no longer provides a default chat
280+
// template so we must supply one when the tokenizer omits it.
281+
if path := bundle.ChatTemplatePath(); path != "" {
282+
args = append(args, "--chat-template", path)
283+
}
284+
277285
// Add mode-specific arguments
278286
switch mode {
279287
case inference.BackendModeCompletion:

0 commit comments

Comments
 (0)