Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xllm_service/common/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ namespace xllm_service {

#define REQUIRES(...) std::enable_if_t<(__VA_ARGS__)>* = nullptr

// brpc/butil/macros.h may define the same macro; avoid redefinition warnings.
#ifndef DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#endif

// Define a macro to simplify adding elements from a vector to a repeated field
#define ADD_VECTOR_TO_PROTO(proto_field, vec) \
Expand Down
25 changes: 19 additions & 6 deletions xllm_service/common/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ inline const char* runtime_state_name(InstanceRuntimeState state) {
}

struct LoadMetrics {
LoadMetrics() : waiting_requests_num(0), gpu_cache_usage_perc(0) {};
LoadMetrics() : waiting_requests_num(0), gpu_cache_usage_perc(0){};
LoadMetrics(const uint64_t& waiting_reqs_num, const float& usage)
: waiting_requests_num(waiting_reqs_num), gpu_cache_usage_perc(usage) {};
: waiting_requests_num(waiting_reqs_num), gpu_cache_usage_perc(usage){};

uint64_t waiting_requests_num;
float gpu_cache_usage_perc;
Expand Down Expand Up @@ -399,18 +399,18 @@ struct OverlapScores {
};

struct LoadBalanceInfos {
OverlapScores overlap_scores;
std::unordered_map<std::string, LoadMetrics> prefill_load_metrics;
std::unordered_map<std::string, LoadMetrics> decode_load_metrics;
uint64_t prefill_max_waiting_requests_num = 0;
uint64_t decode_max_waiting_requests_num = 0;
std::unordered_map<std::string, RequestMetrics> request_metrics;
/// Topology snapshot for candidate instances (filled by
/// InstanceMgr::prepare_load_balance_candidates together with metrics below).
std::unordered_map<std::string, InstanceMetaInfo> instance_infos;

std::string debug_string() {
nlohmann::json json_val;

json_val["overlap_scores"] =
nlohmann::json::parse(overlap_scores.debug_string());

nlohmann::json prefill_json;
for (auto& [key, metrics] : prefill_load_metrics) {
prefill_json[key] = nlohmann::json::parse(metrics.debug_string());
Expand All @@ -432,6 +432,19 @@ struct LoadBalanceInfos {
}
};

struct LoadBalanceCandidates {
std::vector<std::string> prefill_candidates;
std::vector<std::string> decode_candidates;
LoadBalanceInfos load_balance_infos;
};

struct LoadBalanceResult {
std::string prefill_name;
std::string decode_name;
std::string prefill_incarnation_id;
std::string decode_incarnation_id;
};

// Function call related types
struct JsonFunction {
std::string name;
Expand Down
16 changes: 8 additions & 8 deletions xllm_service/rpc_service/service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@ InstanceMetaInfo XllmRpcServiceImpl::get_instance_info(
return scheduler_->get_instance_info(instance_name);
}

std::vector<std::string> XllmRpcServiceImpl::get_static_decode_list(
const std::string& instance_name) {
return scheduler_->get_static_decode_list(instance_name);
std::vector<std::string> XllmRpcServiceImpl::get_static_decode_list() {
return scheduler_->get_static_decode_list();
}

std::vector<std::string> XllmRpcServiceImpl::get_static_prefill_list(
const std::string& instance_name) {
return scheduler_->get_static_prefill_list(instance_name);
std::vector<std::string> XllmRpcServiceImpl::get_static_prefill_list() {
return scheduler_->get_static_prefill_list();
}

bool XllmRpcServiceImpl::handle_generation(
Expand Down Expand Up @@ -126,8 +124,9 @@ void XllmRpcService::GetStaticDecodeList(
proto::InstanceIDs* resp,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
(void)req;
std::vector<std::string> decode_list =
xllm_rpc_service_impl_->get_static_decode_list(req->name());
xllm_rpc_service_impl_->get_static_decode_list();
for (auto& d : decode_list) {
*(resp->mutable_names()->Add()) = std::move(d);
}
Expand All @@ -139,8 +138,9 @@ void XllmRpcService::GetStaticPrefillList(
proto::InstanceIDs* resp,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
(void)req;
std::vector<std::string> prefill_list =
xllm_rpc_service_impl_->get_static_prefill_list(req->name());
xllm_rpc_service_impl_->get_static_prefill_list();
for (auto& p : prefill_list) {
*(resp->mutable_names()->Add()) = std::move(p);
}
Expand Down
6 changes: 2 additions & 4 deletions xllm_service/rpc_service/service.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@ class XllmRpcServiceImpl final {

InstanceMetaInfo get_instance_info(const std::string& instance_name);

std::vector<std::string> get_static_decode_list(
const std::string& prefill_name);
std::vector<std::string> get_static_decode_list();

std::vector<std::string> get_static_prefill_list(
const std::string& decode_name);
std::vector<std::string> get_static_prefill_list();

public:
// handle generations from prefill/decode instance
Expand Down
1 change: 1 addition & 0 deletions xllm_service/scheduler/loadbalance_policy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cc_library(
round_robin.cpp
cache_aware_routing.cpp
slo_aware_policy.cpp
loadbalance_policy.cpp
DEPS
:chat_template
:common
Expand Down
45 changes: 18 additions & 27 deletions xllm_service/scheduler/loadbalance_policy/cache_aware_routing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,30 @@ namespace xllm_service {

constexpr float MIN_SCORE = -2.0;

bool CacheAwareRouting::select_instances_pair(
std::shared_ptr<Request> request) {
LoadBalanceInfos lb_infos;
if (!request->token_ids.empty()) {
Slice<int32_t> token_ids(request->token_ids.data(),
request->token_ids.size());
global_kvcache_mgr_->match(token_ids, &lb_infos.overlap_scores);
DLOG(INFO) << lb_infos.debug_string();
}

instance_mgr_->get_load_metrics(&lb_infos);
DLOG(INFO) << lb_infos.debug_string();

if (lb_infos.prefill_load_metrics.size() == 0) {
LOG(INFO) << "No node available!";
bool CacheAwareRouting::load_balance(
const std::shared_ptr<const Request>& request,
const LoadBalanceCandidates* candidates,
LoadBalanceResult* result) {
OverlapScores overlap_scores;
instance_mgr_->kvcache_match(request->token_ids, &overlap_scores);
if (overlap_scores.instances.empty()) {
return false;
}

// find preifll
cost_function(lb_infos.overlap_scores.hbm_instance_score,
lb_infos.overlap_scores.max_block_num,
lb_infos.prefill_load_metrics,
lb_infos.prefill_max_waiting_requests_num,
&request->routing.prefill_name);
const auto& infos = candidates->load_balance_infos;
cost_function(overlap_scores.hbm_instance_score,
overlap_scores.max_block_num,
infos.prefill_load_metrics,
infos.prefill_max_waiting_requests_num,
&result->prefill_name);

// find decode
if (lb_infos.decode_load_metrics.size()) {
cost_function(lb_infos.overlap_scores.hbm_instance_score,
lb_infos.overlap_scores.max_block_num,
lb_infos.decode_load_metrics,
lb_infos.decode_max_waiting_requests_num,
&request->routing.decode_name);
}
cost_function(overlap_scores.hbm_instance_score,
overlap_scores.max_block_num,
infos.decode_load_metrics,
infos.decode_max_waiting_requests_num,
&result->decode_name);

return true;
}
Expand Down
13 changes: 5 additions & 8 deletions xllm_service/scheduler/loadbalance_policy/cache_aware_routing.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,19 @@ limitations under the License.

#include "common/macros.h"
#include "loadbalance_policy.h"
#include "scheduler/managers/global_kvcache_mgr.h"

namespace xllm_service {

class CacheAwareRouting final : public LoadBalancePolicy {
public:
CacheAwareRouting(std::shared_ptr<InstanceMgr> instance_mgr,
std::shared_ptr<GlobalKVCacheMgr> global_kvcache_mgr)
: global_kvcache_mgr_(global_kvcache_mgr),
LoadBalancePolicy(instance_mgr) {};
explicit CacheAwareRouting(std::shared_ptr<InstanceMgr> instance_mgr)
: LoadBalancePolicy(instance_mgr) {}

virtual ~CacheAwareRouting() = default;

bool select_instances_pair(std::shared_ptr<Request> request) override;
bool load_balance(const std::shared_ptr<const Request>& request,
const LoadBalanceCandidates* candidates,
LoadBalanceResult* result) override;

private:
DISALLOW_COPY_AND_ASSIGN(CacheAwareRouting);
Expand All @@ -41,8 +40,6 @@ class CacheAwareRouting final : public LoadBalancePolicy {
const std::unordered_map<std::string, LoadMetrics>& load_metrics,
const int64_t& max_waiting_requests_num,
std::string* best_choice);

std::shared_ptr<GlobalKVCacheMgr> global_kvcache_mgr_;
};

} // namespace xllm_service
141 changes: 141 additions & 0 deletions xllm_service/scheduler/loadbalance_policy/loadbalance_policy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm-service/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "loadbalance_policy.h"

#include <glog/logging.h>

namespace xllm_service {

bool LoadBalancePolicy::select_instances_pair(
std::shared_ptr<Request> request) {
constexpr int kMaxAttempts = 2;
for (int attempt = 0; attempt < kMaxAttempts; ++attempt) {
// 1. prepare load balance candidates, filter out non-schedulable
// instances,and take a snapshot of the "candidate list + instance_infos +
// load_balance_infos" as a single generation within a lock-holding
// sequence, subsequently, we should strive to perform load balancing based
// on this snapshot information
LoadBalanceCandidates candidates;
if (!load_balance_pre_process(request, &candidates)) {
return false;
}

// 2. load balance, select the best instances pair, each loadbalance policy
// covers the implementation of this method, if failed, use round robin
// policy instead.
LoadBalanceResult result;
if (!load_balance(request, &candidates, &result)) {
LOG(ERROR) << "Failed to load balance!, use round robin policy instead";
// the original round robin policy will not reach this point.
pick_round_robin_candidates(candidates, &result);
}

// 3. post process, update the request with the selected instances pair.
if (!load_balance_post_process(request, &candidates, &result)) {
return false;
}

// 4. validate the selected instances pair, the snapshot obtained in the
// first step may have undergone changes,if failed, retry.
if (instance_mgr_->validate_scheduled_routing(*request)) {
return true;
}

if (attempt + 1 < kMaxAttempts) {
LOG(WARNING)
<< "select_instances_pair: validate_scheduled_routing failed, "
"retrying once";
}
}

LOG(ERROR)
<< "select_instances_pair: validate_scheduled_routing failed after "
<< kMaxAttempts << " attempt(s)";
return false;
}

bool LoadBalancePolicy::load_balance(
const std::shared_ptr<const Request>& request,
const LoadBalanceCandidates* candidates,
LoadBalanceResult* result) {
return true;
}

bool LoadBalancePolicy::should_instance_schedulable(
const std::shared_ptr<const Request>& request,
const InstanceMetaInfo& info) const {
(void)request;
return info.runtime_state != InstanceRuntimeState::SUSPECT;
}

bool LoadBalancePolicy::load_balance_pre_process(
const std::shared_ptr<const Request>& request,
LoadBalanceCandidates* candidates) {
if (!instance_mgr_->prepare_load_balance_candidates(
[this, request](const InstanceMetaInfo& info) {
return should_instance_schedulable(request, info);
},
candidates)) {
LOG(ERROR) << "No schedulable instances found!";
return false;
}
return true;
}

void LoadBalancePolicy::pick_round_robin_candidates(
const LoadBalanceCandidates& candidates,
LoadBalanceResult* result) {
if (candidates.prefill_candidates.empty() ||
candidates.decode_candidates.empty()) {
return;
}
const uint64_t prefill_idx =
next_prefill_index_ % candidates.prefill_candidates.size();
const uint64_t decode_idx =
next_decode_index_ % candidates.decode_candidates.size();
result->prefill_name = candidates.prefill_candidates[prefill_idx];
result->decode_name = candidates.decode_candidates[decode_idx];
next_prefill_index_++;
next_decode_index_++;
}

bool LoadBalancePolicy::load_balance_post_process(
std::shared_ptr<Request> request,
const LoadBalanceCandidates* candidates,
LoadBalanceResult* result) {
if (result->prefill_name.empty() || result->decode_name.empty()) {
return false;
}

request->routing.prefill_name = result->prefill_name;
request->routing.decode_name = result->decode_name;

auto pre_it =
candidates->load_balance_infos.instance_infos.find(result->prefill_name);
if (pre_it != candidates->load_balance_infos.instance_infos.end()) {
request->prefill_incarnation_id = pre_it->second.incarnation_id;
}

auto dec_it =
candidates->load_balance_infos.instance_infos.find(result->decode_name);
if (dec_it != candidates->load_balance_infos.instance_infos.end()) {
request->decode_incarnation_id = dec_it->second.incarnation_id;
}

return true;
}

} // namespace xllm_service
22 changes: 21 additions & 1 deletion xllm_service/scheduler/loadbalance_policy/loadbalance_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,29 @@ class LoadBalancePolicy {

virtual ~LoadBalancePolicy() = default;

virtual bool select_instances_pair(std::shared_ptr<Request> request) = 0;
bool select_instances_pair(std::shared_ptr<Request> request);

protected:
virtual bool load_balance(const std::shared_ptr<const Request>& request,
const LoadBalanceCandidates* candidates,
LoadBalanceResult* result) = 0;

bool load_balance_pre_process(const std::shared_ptr<const Request>& request,
LoadBalanceCandidates* candidates);

virtual bool should_instance_schedulable(
const std::shared_ptr<const Request>& request,
const InstanceMetaInfo& info) const;

bool load_balance_post_process(std::shared_ptr<Request> request,
const LoadBalanceCandidates* candidates,
LoadBalanceResult* result);

void pick_round_robin_candidates(const LoadBalanceCandidates& candidates,
LoadBalanceResult* result);

uint64_t next_prefill_index_ = 0;
uint64_t next_decode_index_ = 0;
std::shared_ptr<InstanceMgr> instance_mgr_;
};

Expand Down
Loading
Loading