1+ // Note on subprocess lifecycle
2+ // In LoadModel(), we will wait until /health returns 200. Thus, in subsequent
3+ // calls to the subprocess, if the server is working normally, /health is
4+ // guaranteed to return 200. If it doesn't, it either means the subprocess has
5+ // died or the server hangs (for whatever reason).
6+
17#include " vllm_engine.h"
28#include < fstream>
39#include " services/engine_service.h"
@@ -82,6 +88,7 @@ std::vector<EngineVariantResponse> VllmEngine::GetVariants() {
8288 return variants;
8389}
8490
91+ // TODO: once llama-server is merged, check if checking 'v' is still needed
8592void VllmEngine::Load (EngineLoadOption opts) {
8693 version_ = opts.engine_path ; // engine path actually contains version info
8794 if (version_[0 ] == ' v' )
@@ -95,7 +102,7 @@ void VllmEngine::HandleChatCompletion(
95102 std::shared_ptr<Json::Value> json_body,
96103 std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
97104
98- // request validation should be in controller
105+ // NOTE: request validation should be in controller
99106 if (!json_body->isMember (" model" )) {
100107 auto [status, error] =
101108 CreateResponse (" Missing required fields: model" , 400 );
@@ -188,11 +195,49 @@ void VllmEngine::HandleChatCompletion(
188195 }
189196};
190197
198+ // NOTE: we don't have an option to pass --task embed to vLLM spawn yet
191199void VllmEngine::HandleEmbedding (
192200 std::shared_ptr<Json::Value> json_body,
193201 std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
194- auto [status, res] = CreateResponse (" embedding is not yet supported" , 400 );
195- callback (std::move (status), std::move (res));
202+
203+ if (!json_body->isMember (" model" )) {
204+ auto [status, error] =
205+ CreateResponse (" Missing required fields: model" , 400 );
206+ callback (std::move (status), std::move (error));
207+ return ;
208+ }
209+
210+ const std::string model = (*json_body)[" model" ].asString ();
211+ int port;
212+ // check if model has started
213+ {
214+ std::shared_lock read_lock (mutex_);
215+ if (model_process_map_.find (model) == model_process_map_.end ()) {
216+ const std::string msg = " Model " + model + " has not been loaded yet." ;
217+ auto [status, error] = CreateResponse (msg, 400 );
218+ callback (std::move (status), std::move (error));
219+ return ;
220+ }
221+ port = model_process_map_[model].port ;
222+ }
223+
224+ const std::string url =
225+ " http://127.0.0.1:" + std::to_string (port) + " /v1/embeddings" ;
226+ const std::string json_str = json_body->toStyledString ();
227+
228+ auto result = curl_utils::SimplePostJson (url, json_str);
229+
230+ if (result.has_error ()) {
231+ auto [status, res] = CreateResponse (result.error (), 400 );
232+ callback (std::move (status), std::move (res));
233+ }
234+
235+ Json::Value status;
236+ status[" is_done" ] = true ;
237+ status[" has_error" ] = false ;
238+ status[" is_stream" ] = false ;
239+ status[" status_code" ] = 200 ;
240+ callback (std::move (status), std::move (result.value ()));
196241};
197242
198243void VllmEngine::LoadModel (
@@ -213,6 +258,10 @@ void VllmEngine::LoadModel(
213258 if (model_process_map_.find (model) != model_process_map_.end ()) {
214259 auto proc = model_process_map_[model];
215260
261+ // NOTE: each vLLM instance can only serve 1 task. It means that the
262+ // following logic will not allow serving the same model for 2 different
263+ // tasks at the same time.
264+ // To support it, we also need to know how vLLM decides the default task.
216265 if (proc.IsAlive ()) {
217266 auto [status, error] = CreateResponse (" Model already loaded!" , 409 );
218267 callback (std::move (status), std::move (error));
@@ -263,6 +312,16 @@ void VllmEngine::LoadModel(
263312 cmd.push_back (" --served-model-name" );
264313 cmd.push_back (model);
265314
315+ // NOTE: we might want to adjust max-model-len automatically, since vLLM
316+ // may OOM for large models as it tries to allocate full context length.
317+ const std::string EXTRA_ARGS[] = {" task" , " max-model-len" };
318+ for (const auto arg : EXTRA_ARGS) {
319+ if (json_body->isMember (arg)) {
320+ cmd.push_back (" --" + arg);
321+ cmd.push_back ((*json_body)[arg].asString ());
322+ }
323+ }
324+
266325 const auto stdout_file = env_dir / " stdout.log" ;
267326 const auto stderr_file = env_dir / " stderr.log" ;
268327
0 commit comments