Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 5feda51

Browse files
committed
add some notes. support embeddings. support some extra vLLM args
1 parent b5d8315 commit 5feda51

2 files changed

Lines changed: 65 additions & 4 deletions

File tree

engine/extensions/python-engines/vllm_engine.cc

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
// Note on subprocess lifecycle
2+
// In LoadModel(), we will wait until /health returns 200. Thus, in subsequent
3+
// calls to the subprocess, if the server is working normally, /health is
4+
// guaranteed to return 200. If it doesn't, it either means the subprocess has
5+
// died or the server hangs (for whatever reason).
6+
17
#include "vllm_engine.h"
28
#include <fstream>
39
#include "services/engine_service.h"
@@ -82,6 +88,7 @@ std::vector<EngineVariantResponse> VllmEngine::GetVariants() {
8288
return variants;
8389
}
8490

91+
// TODO: once llama-server is merged, check if checking 'v' is still needed
8592
void VllmEngine::Load(EngineLoadOption opts) {
8693
version_ = opts.engine_path; // engine path actually contains version info
8794
if (version_[0] == 'v')
@@ -95,7 +102,7 @@ void VllmEngine::HandleChatCompletion(
95102
std::shared_ptr<Json::Value> json_body,
96103
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
97104

98-
// request validation should be in controller
105+
// NOTE: request validation should be in controller
99106
if (!json_body->isMember("model")) {
100107
auto [status, error] =
101108
CreateResponse("Missing required fields: model", 400);
@@ -188,11 +195,49 @@ void VllmEngine::HandleChatCompletion(
188195
}
189196
};
190197

198+
// NOTE: we don't have an option to pass --task embed to vLLM spawn yet
191199
void VllmEngine::HandleEmbedding(
192200
std::shared_ptr<Json::Value> json_body,
193201
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
194-
auto [status, res] = CreateResponse("embedding is not yet supported", 400);
195-
callback(std::move(status), std::move(res));
202+
203+
if (!json_body->isMember("model")) {
204+
auto [status, error] =
205+
CreateResponse("Missing required fields: model", 400);
206+
callback(std::move(status), std::move(error));
207+
return;
208+
}
209+
210+
const std::string model = (*json_body)["model"].asString();
211+
int port;
212+
// check if model has started
213+
{
214+
std::shared_lock read_lock(mutex_);
215+
if (model_process_map_.find(model) == model_process_map_.end()) {
216+
const std::string msg = "Model " + model + " has not been loaded yet.";
217+
auto [status, error] = CreateResponse(msg, 400);
218+
callback(std::move(status), std::move(error));
219+
return;
220+
}
221+
port = model_process_map_[model].port;
222+
}
223+
224+
const std::string url =
225+
"http://127.0.0.1:" + std::to_string(port) + "/v1/embeddings";
226+
const std::string json_str = json_body->toStyledString();
227+
228+
auto result = curl_utils::SimplePostJson(url, json_str);
229+
230+
if (result.has_error()) {
231+
auto [status, res] = CreateResponse(result.error(), 400);
232+
callback(std::move(status), std::move(res));
233+
}
234+
235+
Json::Value status;
236+
status["is_done"] = true;
237+
status["has_error"] = false;
238+
status["is_stream"] = false;
239+
status["status_code"] = 200;
240+
callback(std::move(status), std::move(result.value()));
196241
};
197242

198243
void VllmEngine::LoadModel(
@@ -213,6 +258,10 @@ void VllmEngine::LoadModel(
213258
if (model_process_map_.find(model) != model_process_map_.end()) {
214259
auto proc = model_process_map_[model];
215260

261+
// NOTE: each vLLM instance can only serve 1 task. It means that the
262+
// following logic will not allow serving the same model for 2 different
263+
// tasks at the same time.
264+
// To support it, we also need to know how vLLM decides the default task.
216265
if (proc.IsAlive()) {
217266
auto [status, error] = CreateResponse("Model already loaded!", 409);
218267
callback(std::move(status), std::move(error));
@@ -263,6 +312,16 @@ void VllmEngine::LoadModel(
263312
cmd.push_back("--served-model-name");
264313
cmd.push_back(model);
265314

315+
// NOTE: we might want to adjust max-model-len automatically, since vLLM
316+
// may OOM for large models as it tries to allocate full context length.
317+
const std::string EXTRA_ARGS[] = {"task", "max-model-len"};
318+
for (const auto arg : EXTRA_ARGS) {
319+
if (json_body->isMember(arg)) {
320+
cmd.push_back("--" + arg);
321+
cmd.push_back((*json_body)[arg].asString());
322+
}
323+
}
324+
266325
const auto stdout_file = env_dir / "stdout.log";
267326
const auto stderr_file = env_dir / "stderr.log";
268327

engine/services/inference_service.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ cpp::result<void, InferResult> InferenceService::HandleEmbedding(
119119
std::shared_ptr<SyncQueue> q, std::shared_ptr<Json::Value> json_body) {
120120
std::string engine_type;
121121
if (!HasFieldInReq(json_body, "engine")) {
122-
engine_type = kLlamaRepo;
122+
auto engine_type_maybe =
123+
GetEngineByModelId((*json_body)["model"].asString());
124+
engine_type = engine_type_maybe.empty() ? kLlamaRepo : engine_type_maybe;
123125
} else {
124126
engine_type = (*(json_body)).get("engine", kLlamaRepo).asString();
125127
}

0 commit comments

Comments
 (0)