Skip to content

Commit 3899aba

Browse files
committed
feat: embedding service support rdma arpc
1 parent 0fb6c88 commit 3899aba

11 files changed

Lines changed: 110 additions & 32 deletions

File tree

rtp_llm/config/engine_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from rtp_llm.config.py_config_modules import (
1111
MIN_WORKER_INFO_PORT_NUM,
1212
WORKER_INFO_PORT_NUM,
13+
EmbeddingConfig,
1314
LoadConfig,
1415
PyEnvConfigs,
1516
ServerConfig,
@@ -69,6 +70,7 @@ class EngineConfig:
6970
arpc_config: ArpcConfig
7071
grpc_config: GrpcConfig
7172
load_config: LoadConfig
73+
embedding_config: EmbeddingConfig
7274

7375
def to_string(self) -> str:
7476
"""Return a formatted string representation of EngineConfig for debugging.
@@ -179,6 +181,12 @@ def to_string(self) -> str:
179181
else:
180182
lines.append(str(self.load_config))
181183

184+
lines.append("\n[EmbeddingConfig]")
185+
if hasattr(self.embedding_config, "to_string"):
186+
lines.append(self.embedding_config.to_string())
187+
else:
188+
lines.append(str(self.embedding_config))
189+
182190
lines.append("\n" + "=" * 80)
183191
return "\n".join(lines)
184192

@@ -227,6 +235,7 @@ def create(
227235
arpc_config = py_env_configs.arpc_config
228236
grpc_config = py_env_configs.grpc_config
229237
load_config = py_env_configs.load_config
238+
embedding_config = py_env_configs.embedding_config
230239

231240
# Setup pd_sep_config role_type based on vit_separation
232241
if (
@@ -267,6 +276,7 @@ def create(
267276
arpc_config=arpc_config,
268277
grpc_config=grpc_config,
269278
load_config=load_config,
279+
embedding_config=embedding_config,
270280
)
271281

272282
runtime_config.max_generate_batch_size = concurrency_config.concurrency_limit

rtp_llm/config/py_config_modules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,13 @@ class EmbeddingConfig:
312312
def __init__(self):
313313
self.embedding_model: int = 0
314314
self.extra_input_in_mm_embedding = ""
315+
self.embedding_arpc_rdma_mode: bool = False
315316

316317
def to_string(self):
317318
return (
318319
f"embedding_model: {self.embedding_model}\n"
319-
f"extra_input_in_mm_embedding: {self.extra_input_in_mm_embedding}"
320+
f"extra_input_in_mm_embedding: {self.extra_input_in_mm_embedding}\n"
321+
f"embedding_arpc_rdma_mode: {self.embedding_arpc_rdma_mode}"
320322
)
321323

322324

rtp_llm/cpp/embedding_engine/BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ cc_library(
2929
name = "embedding_engine_arpc_server_header",
3030
hdrs = glob([
3131
"arpc/ArpcServerWrapper.h",
32-
"arpc/ArpcServiceCreator.h"
32+
"arpc/ArpcServiceCreator.h",
33+
"arpc/AnetArpcServerWrapper.h",
3334
]),
3435
srcs = glob([
35-
"arpc/ArpcServerWrapper.cc"
36+
"arpc/AnetArpcServerWrapper.cc",
3637
]),
3738
deps = [
3839
":embedding_engine"

rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.cc renamed to rtp_llm/cpp/embedding_engine/arpc/AnetArpcServerWrapper.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h"
1+
#include "rtp_llm/cpp/embedding_engine/arpc/AnetArpcServerWrapper.h"
22
#include "rtp_llm/cpp/utils/Logger.h"
33
#include "rtp_llm/cpp/utils/AssertUtils.h"
44
#include "aios/network/arpc/arpc/metric/KMonitorANetMetricReporterConfig.h"
55
#include "aios/network/arpc/arpc/metric/KMonitorANetServerMetricReporter.h"
66

77
namespace rtp_llm {
88

9-
void ArpcServerWrapper::start() {
9+
void AnetArpcServerWrapper::start() {
1010
RTP_LLM_LOG_INFO("start arpc server with thread=%d, queue=%d, ioThreadNum=%d", threadNum_, queueNum_, ioThreadNum_);
1111
arpc_server_transport_.reset(new anet::Transport(ioThreadNum_, anet::SHARE_THREAD));
1212
arpc_server_.reset(new arpc::ANetRPCServer(arpc_server_transport_.get(), threadNum_, queueNum_));
@@ -30,7 +30,7 @@ void ArpcServerWrapper::start() {
3030
RTP_LLM_LOG_INFO("ARPC Server listening on %s", spec.c_str());
3131
}
3232

33-
void ArpcServerWrapper::stop() {
33+
void AnetArpcServerWrapper::stop() {
3434
if (arpc_server_) {
3535
arpc_server_->Close();
3636
arpc_server_->StopPrivateTransport();
@@ -42,4 +42,4 @@ void ArpcServerWrapper::stop() {
4242
RTP_LLM_LOG_INFO("ARPC Server stopped");
4343
}
4444

45-
} // namespace rtp_llm
45+
} // namespace rtp_llm
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h"
4+
#include "aios/network/arpc/arpc/ANetRPCServer.h"
5+
6+
namespace rtp_llm {
7+
class AnetArpcServerWrapper: public ArpcServerWrapper {
8+
public:
9+
AnetArpcServerWrapper(
10+
std::unique_ptr<::google::protobuf::Service> service, int threadNum, int queueNum, int ioThreadNum, int port):
11+
ArpcServerWrapper(std::move(service), threadNum, queueNum, ioThreadNum, port) {}
12+
virtual void start() override;
13+
virtual void stop() override;
14+
15+
private:
16+
std::unique_ptr<arpc::ANetRPCServer> arpc_server_;
17+
std::unique_ptr<anet::Transport> arpc_server_transport_;
18+
};
19+
20+
} // namespace rtp_llm
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

3-
#include "aios/network/arpc/arpc/ANetRPCServer.h"
3+
#include <memory>
4+
#include <google/protobuf/service.h>
45

56
namespace rtp_llm {
67
class ArpcServerWrapper {
@@ -12,17 +13,16 @@ class ArpcServerWrapper {
1213
threadNum_(threadNum),
1314
queueNum_(queueNum),
1415
ioThreadNum_(ioThreadNum) {}
15-
void start();
16-
void stop();
16+
virtual ~ArpcServerWrapper() = default;
17+
virtual void start() = 0;
18+
virtual void stop() = 0;
1719

18-
private:
20+
protected:
1921
std::unique_ptr<::google::protobuf::Service> service_;
2022
int port_;
2123
int threadNum_;
2224
int queueNum_;
2325
int ioThreadNum_;
24-
std::unique_ptr<arpc::ANetRPCServer> arpc_server_;
25-
std::unique_ptr<anet::Transport> arpc_server_transport_;
2626
};
2727

28-
} // namespace rtp_llm
28+
} // namespace rtp_llm
Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
#include <vector>
2+
#include <stdexcept>
23
#include <google/protobuf/service.h>
34
#include "rtp_llm/cpp/embedding_engine/EmbeddingEngine.h"
45
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServiceCreator.h"
6+
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h"
7+
#include "rtp_llm/cpp/embedding_engine/arpc/AnetArpcServerWrapper.h"
58

69
namespace rtp_llm {
710

811
std::unique_ptr<::google::protobuf::Service>
9-
createEmbeddingArpcService(int64_t model_rpc_port,
10-
int64_t arpc_thread_num,
11-
int64_t arpc_queue_num,
12-
int64_t arpc_io_thread_num,
12+
createEmbeddingArpcService(int64_t model_rpc_port,
13+
int64_t arpc_thread_num,
14+
int64_t arpc_queue_num,
15+
int64_t arpc_io_thread_num,
1316
py::object py_render,
1417
py::object py_tokenizer,
1518
std::shared_ptr<rtp_llm::MultimodalProcessor> mm_processor,
@@ -18,4 +21,16 @@ createEmbeddingArpcService(int64_t model_rpc_port,
1821
return nullptr;
1922
}
2023

24+
std::unique_ptr<ArpcServerWrapper> createArpcServerWrapper(bool arpc_rdma_mode,
25+
std::unique_ptr<::google::protobuf::Service> service,
26+
int threadNum,
27+
int queueNum,
28+
int ioThreadNum,
29+
int port) {
30+
if (arpc_rdma_mode) {
31+
throw std::runtime_error("rdma arpc mode not supported in open-source build");
32+
}
33+
return std::make_unique<AnetArpcServerWrapper>(std::move(service), threadNum, queueNum, ioThreadNum, port);
34+
}
35+
2136
} // namespace rtp_llm

rtp_llm/cpp/embedding_engine/arpc/ArpcServiceCreator.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,28 @@
44
#include <google/protobuf/service.h>
55
#include "rtp_llm/cpp/config/ConfigModules.h"
66
#include "rtp_llm/cpp/embedding_engine/EmbeddingEngine.h"
7+
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h"
78
#include "rtp_llm/cpp/multimodal_processor/MultimodalProcessor.h"
89

910
namespace rtp_llm {
1011

1112
std::unique_ptr<::google::protobuf::Service>
12-
createEmbeddingArpcService(int64_t model_rpc_port,
13-
int64_t arpc_thread_num,
14-
int64_t arpc_queue_num,
15-
int64_t arpc_io_thread_num,
13+
createEmbeddingArpcService(int64_t model_rpc_port,
14+
int64_t arpc_thread_num,
15+
int64_t arpc_queue_num,
16+
int64_t arpc_io_thread_num,
1617
py::object py_render,
1718
py::object py_tokenizer,
1819
std::shared_ptr<rtp_llm::MultimodalProcessor> mm_processor,
1920
std::shared_ptr<rtp_llm::EmbeddingEngine> engine,
2021
kmonitor::MetricsReporterPtr reporter);
2122

23+
// Factory: open-source stub throws for RDMA; internal_source provides real RDMA impl.
24+
std::unique_ptr<ArpcServerWrapper> createArpcServerWrapper(bool arpc_rdma_mode,
25+
std::unique_ptr<::google::protobuf::Service> service,
26+
int threadNum,
27+
int queueNum,
28+
int ioThreadNum,
29+
int port);
30+
2231
} // namespace rtp_llm

rtp_llm/cpp/pybind/multi_gpu_gpt/RtpEmbeddingOp.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,14 @@ void RtpEmbeddingOp::init(py::object model,
107107

108108
int64_t model_rpc_port = params.server_config.attr("rpc_server_port").cast<int64_t>();
109109
int64_t embedding_rpc_port = params.server_config.attr("embedding_rpc_server_port").cast<int64_t>();
110+
auto embedding_config = engine_config.attr("embedding_config");
111+
bool arpc_rdma_mode = embedding_config.attr("embedding_arpc_rdma_mode").cast<bool>();
112+
110113
startRpcServer(model_rpc_port,
111114
arpc_config.threadNum,
112115
arpc_config.queueNum,
113116
arpc_config.ioThreadNum,
117+
arpc_rdma_mode,
114118
py_render,
115119
py_tokenizer,
116120
params.metrics_reporter,
@@ -204,6 +208,7 @@ void RtpEmbeddingOp::startRpcServer(int64_t model_r
204208
int64_t arpc_thread_num,
205209
int64_t arpc_queue_num,
206210
int64_t arpc_io_thread_num,
211+
bool arpc_rdma_mode,
207212
py::object py_render,
208213
py::object py_tokenizer,
209214
kmonitor::MetricsReporterPtr reporter,
@@ -219,8 +224,12 @@ void RtpEmbeddingOp::startRpcServer(int64_t model_r
219224
reporter));
220225
if (arpc_service) {
221226
RTP_LLM_LOG_INFO("creating arpc service");
222-
embedding_rpc_service_.reset(new ArpcServerWrapper(
223-
std::move(arpc_service), arpc_thread_num, arpc_queue_num, arpc_io_thread_num, model_rpc_port));
227+
embedding_rpc_service_ = createArpcServerWrapper(arpc_rdma_mode,
228+
std::move(arpc_service),
229+
arpc_thread_num,
230+
arpc_queue_num,
231+
arpc_io_thread_num,
232+
model_rpc_port);
224233
embedding_rpc_service_->start();
225234
} else {
226235
RTP_LLM_LOG_INFO("Embedding RPC not supported, skip");

rtp_llm/cpp/pybind/multi_gpu_gpt/RtpEmbeddingOp.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ class RtpEmbeddingOp: public th::jit::CustomClassHolder {
3535
std::vector<MultimodalInput> multimodal_inputs = {});
3636

3737
private:
38-
39-
void startRpcServer(int64_t model_rpc_port,
40-
int64_t arpc_thread_num,
41-
int64_t arpc_queue_num,
42-
int64_t arpc_io_thread_num,
43-
py::object py_render,
44-
py::object py_tokenizer,
45-
kmonitor::MetricsReporterPtr reporter,
38+
void startRpcServer(int64_t model_rpc_port,
39+
int64_t arpc_thread_num,
40+
int64_t arpc_queue_num,
41+
int64_t arpc_io_thread_num,
42+
bool arpc_rdma_mode,
43+
py::object py_render,
44+
py::object py_tokenizer,
45+
kmonitor::MetricsReporterPtr reporter,
4646
std::shared_ptr<MultimodalProcessor> mm_processor);
4747

4848
void startHttpServer(std::shared_ptr<EmbeddingEngine> embedding_engine,

0 commit comments

Comments
 (0)