Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion rtp_llm/config/engine_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import torch
Expand All @@ -10,6 +10,7 @@
from rtp_llm.config.py_config_modules import (
MIN_WORKER_INFO_PORT_NUM,
WORKER_INFO_PORT_NUM,
EmbeddingConfig,
LoadConfig,
PyEnvConfigs,
ServerConfig,
Expand Down Expand Up @@ -69,6 +70,7 @@ class EngineConfig:
arpc_config: ArpcConfig
grpc_config: GrpcConfig
load_config: LoadConfig
embedding_config: EmbeddingConfig = field(default_factory=EmbeddingConfig)

def to_string(self) -> str:
"""Return a formatted string representation of EngineConfig for debugging.
Expand Down Expand Up @@ -179,6 +181,12 @@ def to_string(self) -> str:
else:
lines.append(str(self.load_config))

lines.append("\n[EmbeddingConfig]")
if hasattr(self.embedding_config, "to_string"):
lines.append(self.embedding_config.to_string())
else:
lines.append(str(self.embedding_config))

lines.append("\n" + "=" * 80)
return "\n".join(lines)

Expand Down Expand Up @@ -227,6 +235,7 @@ def create(
arpc_config = py_env_configs.arpc_config
grpc_config = py_env_configs.grpc_config
load_config = py_env_configs.load_config
embedding_config = py_env_configs.embedding_config

# Setup pd_sep_config role_type based on vit_separation
if (
Expand Down Expand Up @@ -267,6 +276,7 @@ def create(
arpc_config=arpc_config,
grpc_config=grpc_config,
load_config=load_config,
embedding_config=embedding_config,
)

runtime_config.max_generate_batch_size = concurrency_config.concurrency_limit
Expand Down
4 changes: 3 additions & 1 deletion rtp_llm/config/py_config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,13 @@ class EmbeddingConfig:
def __init__(self):
self.embedding_model: int = 0
self.extra_input_in_mm_embedding = ""
self.embedding_arpc_rdma_mode: bool = False

def to_string(self):
return (
f"embedding_model: {self.embedding_model}\n"
f"extra_input_in_mm_embedding: {self.extra_input_in_mm_embedding}"
f"extra_input_in_mm_embedding: {self.extra_input_in_mm_embedding}\n"
f"embedding_arpc_rdma_mode: {self.embedding_arpc_rdma_mode}"
)


Expand Down
5 changes: 3 additions & 2 deletions rtp_llm/cpp/embedding_engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ cc_library(
name = "embedding_engine_arpc_server_header",
hdrs = glob([
"arpc/ArpcServerWrapper.h",
"arpc/ArpcServiceCreator.h"
"arpc/ArpcServiceCreator.h",
"arpc/AnetArpcServerWrapper.h",
]),
srcs = glob([
"arpc/ArpcServerWrapper.cc"
"arpc/AnetArpcServerWrapper.cc",
]),
deps = [
":embedding_engine"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h"
#include "rtp_llm/cpp/embedding_engine/arpc/AnetArpcServerWrapper.h"
#include "rtp_llm/cpp/utils/Logger.h"
#include "rtp_llm/cpp/utils/AssertUtils.h"
#include "aios/network/arpc/arpc/metric/KMonitorANetMetricReporterConfig.h"
#include "aios/network/arpc/arpc/metric/KMonitorANetServerMetricReporter.h"

namespace rtp_llm {

void ArpcServerWrapper::start() {
void AnetArpcServerWrapper::start() {
RTP_LLM_LOG_INFO("start arpc server with thread=%d, queue=%d, ioThreadNum=%d", threadNum_, queueNum_, ioThreadNum_);
arpc_server_transport_.reset(new anet::Transport(ioThreadNum_, anet::SHARE_THREAD));
arpc_server_.reset(new arpc::ANetRPCServer(arpc_server_transport_.get(), threadNum_, queueNum_));
Expand All @@ -30,7 +30,7 @@ void ArpcServerWrapper::start() {
RTP_LLM_LOG_INFO("ARPC Server listening on %s", spec.c_str());
}

void ArpcServerWrapper::stop() {
void AnetArpcServerWrapper::stop() {
if (arpc_server_) {
arpc_server_->Close();
arpc_server_->StopPrivateTransport();
Expand All @@ -42,4 +42,4 @@ void ArpcServerWrapper::stop() {
RTP_LLM_LOG_INFO("ARPC Server stopped");
}

} // namespace rtp_llm
} // namespace rtp_llm
20 changes: 20 additions & 0 deletions rtp_llm/cpp/embedding_engine/arpc/AnetArpcServerWrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h"
#include "aios/network/arpc/arpc/ANetRPCServer.h"

namespace rtp_llm {
class AnetArpcServerWrapper: public ArpcServerWrapper {
public:
AnetArpcServerWrapper(
std::unique_ptr<::google::protobuf::Service> service, int threadNum, int queueNum, int ioThreadNum, int port):
ArpcServerWrapper(std::move(service), threadNum, queueNum, ioThreadNum, port) {}
virtual void start() override;
virtual void stop() override;

private:
std::unique_ptr<arpc::ANetRPCServer> arpc_server_;
std::unique_ptr<anet::Transport> arpc_server_transport_;
};

} // namespace rtp_llm
14 changes: 7 additions & 7 deletions rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "aios/network/arpc/arpc/ANetRPCServer.h"
#include <memory>
#include <google/protobuf/service.h>

namespace rtp_llm {
class ArpcServerWrapper {
Expand All @@ -12,17 +13,16 @@ class ArpcServerWrapper {
threadNum_(threadNum),
queueNum_(queueNum),
ioThreadNum_(ioThreadNum) {}
void start();
void stop();
virtual ~ArpcServerWrapper() = default;
virtual void start() = 0;
virtual void stop() = 0;

private:
protected:
std::unique_ptr<::google::protobuf::Service> service_;
int port_;
int threadNum_;
int queueNum_;
int ioThreadNum_;
std::unique_ptr<arpc::ANetRPCServer> arpc_server_;
std::unique_ptr<anet::Transport> arpc_server_transport_;
};

} // namespace rtp_llm
} // namespace rtp_llm
26 changes: 21 additions & 5 deletions rtp_llm/cpp/embedding_engine/arpc/ArpcServiceCreator.cc
Original file line number Diff line number Diff line change
@@ -1,21 +1,37 @@
#include <vector>
#include <stdexcept>
#include <google/protobuf/service.h>
#include "rtp_llm/cpp/embedding_engine/EmbeddingEngine.h"
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServiceCreator.h"
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h"
#include "rtp_llm/cpp/embedding_engine/arpc/AnetArpcServerWrapper.h"

namespace rtp_llm {

std::unique_ptr<::google::protobuf::Service>
createEmbeddingArpcService(int64_t model_rpc_port,
int64_t arpc_thread_num,
int64_t arpc_queue_num,
int64_t arpc_io_thread_num,
createEmbeddingArpcService(int64_t model_rpc_port,
int64_t arpc_thread_num,
int64_t arpc_queue_num,
int64_t arpc_io_thread_num,
py::object py_render,
py::object py_tokenizer,
std::shared_ptr<rtp_llm::MultimodalProcessor> mm_processor,
std::shared_ptr<rtp_llm::EmbeddingEngine> engine,
kmonitor::MetricsReporterPtr reporter) {
kmonitor::MetricsReporterPtr reporter,
bool arpc_rdma_mode) {
return nullptr;
}

std::unique_ptr<ArpcServerWrapper> createArpcServerWrapper(bool arpc_rdma_mode,
std::unique_ptr<::google::protobuf::Service> service,
int threadNum,
int queueNum,
int ioThreadNum,
int port) {
if (arpc_rdma_mode) {
throw std::runtime_error("rdma arpc mode not supported in open-source build");
}
return std::make_unique<AnetArpcServerWrapper>(std::move(service), threadNum, queueNum, ioThreadNum, port);
}

} // namespace rtp_llm
20 changes: 15 additions & 5 deletions rtp_llm/cpp/embedding_engine/arpc/ArpcServiceCreator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,29 @@
#include <google/protobuf/service.h>
#include "rtp_llm/cpp/config/ConfigModules.h"
#include "rtp_llm/cpp/embedding_engine/EmbeddingEngine.h"
#include "rtp_llm/cpp/embedding_engine/arpc/ArpcServerWrapper.h"
#include "rtp_llm/cpp/multimodal_processor/MultimodalProcessor.h"

namespace rtp_llm {

std::unique_ptr<::google::protobuf::Service>
createEmbeddingArpcService(int64_t model_rpc_port,
int64_t arpc_thread_num,
int64_t arpc_queue_num,
int64_t arpc_io_thread_num,
createEmbeddingArpcService(int64_t model_rpc_port,
int64_t arpc_thread_num,
int64_t arpc_queue_num,
int64_t arpc_io_thread_num,
py::object py_render,
py::object py_tokenizer,
std::shared_ptr<rtp_llm::MultimodalProcessor> mm_processor,
std::shared_ptr<rtp_llm::EmbeddingEngine> engine,
kmonitor::MetricsReporterPtr reporter);
kmonitor::MetricsReporterPtr reporter,
bool arpc_rdma_mode = false);

// Factory: open-source stub throws for RDMA; internal_source provides real RDMA impl.
std::unique_ptr<ArpcServerWrapper> createArpcServerWrapper(bool arpc_rdma_mode,
std::unique_ptr<::google::protobuf::Service> service,
int threadNum,
int queueNum,
int ioThreadNum,
int port);

} // namespace rtp_llm
16 changes: 13 additions & 3 deletions rtp_llm/cpp/pybind/multi_gpu_gpt/RtpEmbeddingOp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,14 @@ void RtpEmbeddingOp::init(py::object model,

int64_t model_rpc_port = params.server_config.attr("rpc_server_port").cast<int64_t>();
int64_t embedding_rpc_port = params.server_config.attr("embedding_rpc_server_port").cast<int64_t>();
auto embedding_config = engine_config.attr("embedding_config");
bool arpc_rdma_mode = embedding_config.attr("embedding_arpc_rdma_mode").cast<bool>();

startRpcServer(model_rpc_port,
arpc_config.threadNum,
arpc_config.queueNum,
arpc_config.ioThreadNum,
arpc_rdma_mode,
py_render,
py_tokenizer,
params.metrics_reporter,
Expand Down Expand Up @@ -203,6 +207,7 @@ void RtpEmbeddingOp::startRpcServer(int64_t model_r
int64_t arpc_thread_num,
int64_t arpc_queue_num,
int64_t arpc_io_thread_num,
bool arpc_rdma_mode,
py::object py_render,
py::object py_tokenizer,
kmonitor::MetricsReporterPtr reporter,
Expand All @@ -215,11 +220,16 @@ void RtpEmbeddingOp::startRpcServer(int64_t model_r
py_tokenizer,
mm_processor,
embedding_engine_,
reporter));
reporter,
arpc_rdma_mode));
if (arpc_service) {
RTP_LLM_LOG_INFO("creating arpc service");
embedding_rpc_service_.reset(new ArpcServerWrapper(
std::move(arpc_service), arpc_thread_num, arpc_queue_num, arpc_io_thread_num, model_rpc_port));
embedding_rpc_service_ = createArpcServerWrapper(arpc_rdma_mode,
std::move(arpc_service),
arpc_thread_num,
arpc_queue_num,
arpc_io_thread_num,
model_rpc_port);
embedding_rpc_service_->start();
} else {
RTP_LLM_LOG_INFO("Embedding RPC not supported, skip");
Expand Down
16 changes: 8 additions & 8 deletions rtp_llm/cpp/pybind/multi_gpu_gpt/RtpEmbeddingOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ class RtpEmbeddingOp: public th::jit::CustomClassHolder {
std::vector<MultimodalInput> multimodal_inputs = {});

private:

void startRpcServer(int64_t model_rpc_port,
int64_t arpc_thread_num,
int64_t arpc_queue_num,
int64_t arpc_io_thread_num,
py::object py_render,
py::object py_tokenizer,
kmonitor::MetricsReporterPtr reporter,
void startRpcServer(int64_t model_rpc_port,
int64_t arpc_thread_num,
int64_t arpc_queue_num,
int64_t arpc_io_thread_num,
bool arpc_rdma_mode,
py::object py_render,
py::object py_tokenizer,
kmonitor::MetricsReporterPtr reporter,
std::shared_ptr<MultimodalProcessor> mm_processor);

void startHttpServer(std::shared_ptr<EmbeddingEngine> embedding_engine,
Expand Down
12 changes: 12 additions & 0 deletions rtp_llm/server/server_args/embedding_group_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from rtp_llm.server.server_args.util import str2bool


def init_embedding_group_args(parser, embedding_config):
##############################################################################################################
# Embedding Configuration
Expand All @@ -20,3 +23,12 @@ def init_embedding_group_args(parser, embedding_config):
default=None,
help='在多模态嵌入中使用额外的输入,可选值"INDEX"',
)

embedding_group.add_argument(
"--embedding_arpc_rdma_mode",
env_name="EMBEDDING_ARPC_RDMA_MODE",
bind_to=(embedding_config, 'embedding_arpc_rdma_mode'),
type=str2bool,
default=False,
help="控制 embedding ARPC 是否使用 RDMA 模式。",
)
Loading