88)
99
1010type mockModelBundle struct {
11- safetensorsPath string
12- runtimeConfig * types.Config
11+ safetensorsPath string
12+ chatTemplatePath string
13+ runtimeConfig * types.Config
1314}
1415
1516func (m * mockModelBundle ) GGUFPath () string {
@@ -21,7 +22,7 @@ func (m *mockModelBundle) SafetensorsPath() string {
2122}
2223
2324func (m * mockModelBundle ) ChatTemplatePath () string {
24- return ""
25+ return m . chatTemplatePath
2526}
2627
2728func (m * mockModelBundle ) MMPROJPath () string {
@@ -74,6 +75,36 @@ func TestGetArgs(t *testing.T) {
7475 "/tmp/socket" ,
7576 },
7677 },
78+ {
79+ name : "with chat template" ,
80+ bundle : & mockModelBundle {
81+ safetensorsPath : "/path/to/model" ,
82+ chatTemplatePath : "/path/to/bundle/template.jinja" ,
83+ },
84+ config : nil ,
85+ expected : []string {
86+ "serve" ,
87+ "/path/to" ,
88+ "--uds" ,
89+ "/tmp/socket" ,
90+ "--chat-template" ,
91+ "/path/to/bundle/template.jinja" ,
92+ },
93+ },
94+ {
95+ name : "without chat template omits flag" ,
96+ bundle : & mockModelBundle {
97+ safetensorsPath : "/path/to/model" ,
98+ chatTemplatePath : "" ,
99+ },
100+ config : nil ,
101+ expected : []string {
102+ "serve" ,
103+ "/path/to" ,
104+ "--uds" ,
105+ "/tmp/socket" ,
106+ },
107+ },
77108 {
78109 name : "with backend context size" ,
79110 bundle : & mockModelBundle {
@@ -499,6 +530,158 @@ func TestGetMaxModelLen(t *testing.T) {
499530 }
500531}
501532
533+ func TestBuildArgs (t * testing.T ) {
534+ tests := []struct {
535+ name string
536+ bundle * mockModelBundle
537+ socket string
538+ model string
539+ modelRef string
540+ mode inference.BackendMode
541+ config * inference.BackendConfiguration
542+ expected []string
543+ expectError bool
544+ }{
545+ {
546+ name : "basic completion mode" ,
547+ bundle : & mockModelBundle {
548+ safetensorsPath : "/models/bundle/model/safetensors" ,
549+ },
550+ socket : "127.0.0.1:30000" ,
551+ model : "sha256:abc123" ,
552+ modelRef : "ai/test-model:latest" ,
553+ mode : inference .BackendModeCompletion ,
554+ expected : []string {
555+ "-m" , "vllm.entrypoints.openai.api_server" ,
556+ "--model" , "/models/bundle/model" ,
557+ "--host" , "127.0.0.1" ,
558+ "--port" , "30000" ,
559+ "--enable-auto-tool-choice" , "--tool-call-parser" , "hermes" ,
560+ "--served-model-name" , "sha256:abc123" , "ai/test-model:latest" ,
561+ },
562+ },
563+ {
564+ name : "with chat template" ,
565+ bundle : & mockModelBundle {
566+ safetensorsPath : "/models/bundle/model/safetensors" ,
567+ chatTemplatePath : "/models/bundle/template.jinja" ,
568+ },
569+ socket : "127.0.0.1:30000" ,
570+ model : "sha256:abc123" ,
571+ modelRef : "ai/test-model:latest" ,
572+ mode : inference .BackendModeCompletion ,
573+ expected : []string {
574+ "-m" , "vllm.entrypoints.openai.api_server" ,
575+ "--model" , "/models/bundle/model" ,
576+ "--host" , "127.0.0.1" ,
577+ "--port" , "30000" ,
578+ "--enable-auto-tool-choice" , "--tool-call-parser" , "hermes" ,
579+ "--chat-template" , "/models/bundle/template.jinja" ,
580+ "--served-model-name" , "sha256:abc123" , "ai/test-model:latest" ,
581+ },
582+ },
583+ {
584+ name : "without chat template" ,
585+ bundle : & mockModelBundle {
586+ safetensorsPath : "/models/bundle/model/safetensors" ,
587+ chatTemplatePath : "" ,
588+ },
589+ socket : "127.0.0.1:30000" ,
590+ model : "sha256:abc123" ,
591+ modelRef : "ai/test-model:latest" ,
592+ mode : inference .BackendModeCompletion ,
593+ expected : []string {
594+ "-m" , "vllm.entrypoints.openai.api_server" ,
595+ "--model" , "/models/bundle/model" ,
596+ "--host" , "127.0.0.1" ,
597+ "--port" , "30000" ,
598+ "--enable-auto-tool-choice" , "--tool-call-parser" , "hermes" ,
599+ "--served-model-name" , "sha256:abc123" , "ai/test-model:latest" ,
600+ },
601+ },
602+ {
603+ name : "empty safetensors path should error" ,
604+ bundle : & mockModelBundle {
605+ safetensorsPath : "" ,
606+ },
607+ socket : "127.0.0.1:30000" ,
608+ model : "sha256:abc123" ,
609+ modelRef : "ai/test-model:latest" ,
610+ mode : inference .BackendModeCompletion ,
611+ expectError : true ,
612+ },
613+ {
614+ name : "embedding mode" ,
615+ bundle : & mockModelBundle {
616+ safetensorsPath : "/models/bundle/model/safetensors" ,
617+ },
618+ socket : "127.0.0.1:30000" ,
619+ model : "sha256:abc123" ,
620+ modelRef : "ai/test-model:latest" ,
621+ mode : inference .BackendModeEmbedding ,
622+ expected : []string {
623+ "-m" , "vllm.entrypoints.openai.api_server" ,
624+ "--model" , "/models/bundle/model" ,
625+ "--host" , "127.0.0.1" ,
626+ "--port" , "30000" ,
627+ "--enable-auto-tool-choice" , "--tool-call-parser" , "hermes" ,
628+ "--runner" , "pooling" ,
629+ "--served-model-name" , "sha256:abc123" , "ai/test-model:latest" ,
630+ },
631+ },
632+ {
633+ name : "with context size" ,
634+ bundle : & mockModelBundle {
635+ safetensorsPath : "/models/bundle/model/safetensors" ,
636+ },
637+ socket : "127.0.0.1:30000" ,
638+ model : "sha256:abc123" ,
639+ modelRef : "ai/test-model:latest" ,
640+ mode : inference .BackendModeCompletion ,
641+ config : & inference.BackendConfiguration {
642+ ContextSize : int32ptr (4096 ),
643+ },
644+ expected : []string {
645+ "-m" , "vllm.entrypoints.openai.api_server" ,
646+ "--model" , "/models/bundle/model" ,
647+ "--host" , "127.0.0.1" ,
648+ "--port" , "30000" ,
649+ "--enable-auto-tool-choice" , "--tool-call-parser" , "hermes" ,
650+ "--served-model-name" , "sha256:abc123" , "ai/test-model:latest" ,
651+ "--max-model-len" , "4096" ,
652+ },
653+ },
654+ }
655+
656+ for _ , tt := range tests {
657+ t .Run (tt .name , func (t * testing.T ) {
658+ v := & vllmMetal {}
659+ args , err := v .buildArgs (tt .bundle , tt .socket , tt .model , tt .modelRef , tt .mode , tt .config )
660+
661+ if tt .expectError {
662+ if err == nil {
663+ t .Fatalf ("expected error but got none" )
664+ }
665+ return
666+ }
667+
668+ if err != nil {
669+ t .Fatalf ("unexpected error: %v" , err )
670+ }
671+
672+ if len (args ) != len (tt .expected ) {
673+ t .Fatalf ("expected %d args, got %d\n expected: %v\n got: %v" , len (tt .expected ), len (args ), tt .expected , args )
674+ }
675+
676+ for i , arg := range args {
677+ if arg != tt .expected [i ] {
678+ t .Errorf ("arg[%d]: expected %q, got %q" , i , tt .expected [i ], arg )
679+ }
680+ }
681+ })
682+ }
683+ }
684+
502685func int32ptr (n int32 ) * int32 {
503686 return & n
504687}
0 commit comments