Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rtp_llm/cpp/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ cc_library(
":event",
"//rtp_llm/cpp/config:config_modules",
"//rtp_llm/cpp/models:stats",
"//rtp_llm/models_py/bindings/common/kernels:fuse_copy_util",
] + torch_deps() + select({
"@//:using_rocm": ["@local_config_rocm//rocm:rocm_headers"],
"//conditions:default": [],
Expand Down
4 changes: 4 additions & 0 deletions rtp_llm/cpp/core/ExecOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "rtp_llm/cpp/core/Event.h"
#include "rtp_llm/cpp/config/ConfigModules.h"
#include "rtp_llm/cpp/models/eplb/stats/ExpertStats.h"
#include "rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h"

#include <memory>
#include <atomic>
Expand Down Expand Up @@ -92,6 +93,9 @@ void execNoBlockCopy(const CopyParams& params);
void execBatchCopy(const BatchCopyParams& params);
void execMultiMergeCopy(const MultiMergeCopyParams& params);

void fusedCopy(const FusedD2DCopyParams& params);
void fusedStridedCopy(const FusedStridedCopyParams& params);

// ===================================================================
// Sample ops
// ===================================================================
Expand Down
6 changes: 5 additions & 1 deletion rtp_llm/cpp/cuda/ops/CudaFlashInfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ FlashInferAttnParams::allocateManyBuffer(const std::vector<std::vector<int64_t>>
auto buf_options = torch::TensorOptions(torch::kInt32);
if (atype == AllocationType::DEVICE) {
buf_options = buf_options.device(torch::kCUDA);
} else {
buf_options = buf_options.device(torch::kCPU).pinned_memory(true);
}

auto buf = torch::empty({(int64_t)total_size}, buf_options);

size_t offset = 0;
Expand Down Expand Up @@ -104,7 +107,8 @@ FlashInferAttnParams* FlashInferAttnParams::create(int batch_size, int input_tok
params->float_workspace_ =
torch::empty({128 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCUDA));
params->int_workspace_ = torch::empty({8 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCUDA));
params->int_host_workspace_ = torch::empty({8 * 1024 * 1024}, torch::kInt8);
params->int_host_workspace_ =
torch::empty({8 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCPU).pinned_memory(true));

params->float_workspace_d = params->float_workspace_;
params->int_workspace_d = params->int_workspace_;
Expand Down
3 changes: 3 additions & 0 deletions rtp_llm/cpp/cuda_graph/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ cc_library(
deps = torch_deps() + [
":cuda_graph_base",
":cuda_graph_hdrs_lib",
"//rtp_llm/cpp/core:exec_ops_hdr",
"//rtp_llm/cpp/utils:core_utils",
"//rtp_llm/cpp/utils:profiling_scope",
"//rtp_llm/models_py/bindings:op_defs",
"//rtp_llm/models_py/bindings/common/kernels:fuse_copy_util",
"//rtp_llm/models_py/bindings/common:fuse_copy_op",
] + select({
"//:using_cuda": [
"//rtp_llm/cpp/cuda:cuda",
Expand Down
29 changes: 17 additions & 12 deletions rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ void CudaGraphRunner::capturePrefill() {
inputs.attention_inputs.prefix_lengths.fill_(0);
// Must set cu_seqlens/cu_kv_seqlens/input_lengths to match actual seq_len,
// otherwise FlashInfer plans for max_seq_len tokens but q/k/v only have seq_len tokens
inputs.attention_inputs.cu_seqlens.data_ptr<int>()[0] = 0;
inputs.attention_inputs.cu_seqlens.data_ptr<int>()[1] = seq_len;
inputs.attention_inputs.input_lengths.data_ptr<int>()[0] = seq_len;
inputs.attention_inputs.cu_seqlens_host[0] = 0;
inputs.attention_inputs.cu_seqlens_host[1] = seq_len;
inputs.attention_inputs.cu_seqlens.copy_(inputs.attention_inputs.cu_seqlens_host, false);
inputs.attention_inputs.input_lengths[0] = seq_len;
} else {
// Draft model prefill: distribute seq_len tokens across batches (max num_tokens_per_bs_ each).
// All max_bs_ batches get prefix to ensure buffer allocation covers worst-case replay.
Expand All @@ -36,22 +37,26 @@ void CudaGraphRunner::capturePrefill() {
// Active batches get real input tokens, inactive batches get 0 input tokens.
inputs.attention_inputs.input_lengths.fill_(0);
inputs.attention_inputs.prefix_lengths.fill_(prefix_len);
auto* input_lengths_ptr = inputs.attention_inputs.input_lengths.data_ptr<int>();
auto& input_lengths = inputs.attention_inputs.input_lengths;
for (int b = 0; b < active_bs; b++) {
int tokens = (b < active_bs - 1) ? num_tokens_per_bs_ : (seq_len - b * num_tokens_per_bs_);
input_lengths_ptr[b] = tokens;
input_lengths[b] = tokens;
}

// Build cu_seqlens and cu_kv_seqlens as cumulative sums
auto* prefix_lengths_ptr = inputs.attention_inputs.prefix_lengths.data_ptr<int>();
auto* cu_seqlens_ptr = inputs.attention_inputs.cu_seqlens.data_ptr<int>();
auto* cu_kv_seqlens_ptr = inputs.attention_inputs.cu_kv_seqlens.data_ptr<int>();
cu_seqlens_ptr[0] = 0;
cu_kv_seqlens_ptr[0] = 0;
auto cu_seqlens_host = inputs.attention_inputs.cu_seqlens_host;
auto cu_kv_seqlens_host = inputs.attention_inputs.cu_kv_seqlens.cpu();
auto prefix_lengths = inputs.attention_inputs.prefix_lengths;

cu_seqlens_host[0] = 0;
cu_kv_seqlens_host[0] = 0;
for (int b = 0; b < max_bs_; b++) {
cu_seqlens_ptr[b + 1] = cu_seqlens_ptr[b] + input_lengths_ptr[b];
cu_kv_seqlens_ptr[b + 1] = cu_kv_seqlens_ptr[b] + input_lengths_ptr[b] + prefix_lengths_ptr[b];
cu_seqlens_host[b + 1] = cu_seqlens_host[b].item<int>() + input_lengths[b].item<int>();
cu_kv_seqlens_host[b + 1] = cu_kv_seqlens_host[b].item<int>() + input_lengths[b].item<int>() + prefix_lengths[b].item<int>();
}

inputs.attention_inputs.cu_seqlens.copy_(cu_seqlens_host);
inputs.attention_inputs.cu_kv_seqlens.copy_(cu_kv_seqlens_host);
}

inputs.attention_inputs.context_total_kv_length = seq_len;
Expand Down
Loading
Loading