Skip to content

Commit 465f334

Browse files
committed
refactor: implement fused copy operations and enhance tensor handling in CUDA graph
1 parent 02acb2f commit 465f334

27 files changed

Lines changed: 577 additions & 185 deletions

rtp_llm/cpp/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ cc_library(
184184
":event",
185185
"//rtp_llm/cpp/config:config_modules",
186186
"//rtp_llm/cpp/models:stats",
187+
"//rtp_llm/models_py/bindings/common/kernels:fuse_copy_util",
187188
] + torch_deps() + select({
188189
"@//:using_rocm": ["@local_config_rocm//rocm:rocm_headers"],
189190
"//conditions:default": [],

rtp_llm/cpp/core/ExecOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "rtp_llm/cpp/core/Event.h"
66
#include "rtp_llm/cpp/config/ConfigModules.h"
77
#include "rtp_llm/cpp/models/eplb/stats/ExpertStats.h"
8+
#include "rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h"
89

910
#include <memory>
1011
#include <atomic>
@@ -92,6 +93,9 @@ void execNoBlockCopy(const CopyParams& params);
9293
void execBatchCopy(const BatchCopyParams& params);
9394
void execMultiMergeCopy(const MultiMergeCopyParams& params);
9495

96+
void fusedCopy(const FusedD2DCopyParams& params);
97+
void fusedStridedCopy(const FusedStridedCopyParams& params);
98+
9599
// ===================================================================
96100
// Sample ops
97101
// ===================================================================

rtp_llm/cpp/cuda/ops/CudaFlashInfer.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ FlashInferAttnParams::allocateManyBuffer(const std::vector<std::vector<int64_t>>
7676
auto buf_options = torch::TensorOptions(torch::kInt32);
7777
if (atype == AllocationType::DEVICE) {
7878
buf_options = buf_options.device(torch::kCUDA);
79+
} else {
80+
buf_options = buf_options.device(torch::kCPU).pinned_memory(true);
7981
}
82+
8083
auto buf = torch::empty({(int64_t)total_size}, buf_options);
8184

8285
size_t offset = 0;
@@ -104,7 +107,8 @@ FlashInferAttnParams* FlashInferAttnParams::create(int batch_size, int input_tok
104107
params->float_workspace_ =
105108
torch::empty({128 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCUDA));
106109
params->int_workspace_ = torch::empty({8 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCUDA));
107-
params->int_host_workspace_ = torch::empty({8 * 1024 * 1024}, torch::kInt8);
110+
params->int_host_workspace_ =
111+
torch::empty({8 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCPU).pinned_memory(true));
108112

109113
params->float_workspace_d = params->float_workspace_;
110114
params->int_workspace_d = params->int_workspace_;

rtp_llm/cpp/cuda_graph/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,12 @@ cc_library(
6666
deps = torch_deps() + [
6767
":cuda_graph_base",
6868
":cuda_graph_hdrs_lib",
69+
"//rtp_llm/cpp/core:exec_ops_hdr",
6970
"//rtp_llm/cpp/utils:core_utils",
7071
"//rtp_llm/cpp/utils:profiling_scope",
7172
"//rtp_llm/models_py/bindings:op_defs",
73+
"//rtp_llm/models_py/bindings/common/kernels:fuse_copy_util",
74+
"//rtp_llm/models_py/bindings/common:fuse_copy_op",
7275
] + select({
7376
"//:using_cuda": [
7477
"//rtp_llm/cpp/cuda:cuda",

rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,21 @@ void CudaGraphRunner::capturePrefill() {
2323
inputs.attention_inputs.prefix_lengths.fill_(0);
2424
// Must set cu_seqlens/cu_kv_seqlens/input_lengths to match actual seq_len,
2525
// otherwise FlashInfer plans for max_seq_len tokens but q/k/v only have seq_len tokens
26-
inputs.attention_inputs.cu_seqlens.data_ptr<int>()[0] = 0;
27-
inputs.attention_inputs.cu_seqlens.data_ptr<int>()[1] = seq_len;
28-
inputs.attention_inputs.input_lengths.data_ptr<int>()[0] = seq_len;
26+
inputs.attention_inputs.cu_seqlens_host[0] = 0;
27+
inputs.attention_inputs.cu_seqlens_host[1] = seq_len;
28+
inputs.attention_inputs.cu_seqlens.copy_(inputs.attention_inputs.cu_seqlens_host, false);
29+
inputs.attention_inputs.input_lengths[0] = seq_len;
2930
} else {
30-
inputs.attention_inputs.cu_seqlens.fill_(seq_len);
31+
inputs.attention_inputs.cu_seqlens_host.fill_(seq_len);
32+
inputs.attention_inputs.cu_seqlens_host[0] = 0;
33+
inputs.attention_inputs.cu_seqlens.copy_(inputs.attention_inputs.cu_seqlens_host, false);
3134
inputs.attention_inputs.input_lengths.fill_(0);
3235
int kv_len = max_seq_len_ + seq_len;
3336
int prefix_len = kv_len;
3437
inputs.attention_inputs.cu_kv_seqlens.fill_(kv_len);
38+
inputs.attention_inputs.cu_kv_seqlens[0] = 0;
3539
inputs.attention_inputs.prefix_lengths.fill_(prefix_len);
36-
inputs.attention_inputs.cu_seqlens.data_ptr<int>()[0] = 0;
37-
inputs.attention_inputs.cu_kv_seqlens.data_ptr<int>()[0] = 0;
38-
inputs.attention_inputs.input_lengths.data_ptr<int>()[0] = seq_len;
40+
inputs.attention_inputs.input_lengths[0] = seq_len;
3941
}
4042

4143
inputs.attention_inputs.context_total_kv_length = seq_len;

0 commit comments

Comments
 (0)