Skip to content

Commit 7d3e720

Browse files
committed
refactor: implement fused copy operations and enhance tensor handling in CUDA graph
1 parent 5a0b770 commit 7d3e720

21 files changed

Lines changed: 610 additions & 131 deletions

File tree

rtp_llm/cpp/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ cc_library(
184184
":event",
185185
"//rtp_llm/cpp/config:config_modules",
186186
"//rtp_llm/cpp/models:stats",
187+
"//rtp_llm/cpp/kernels:fuse_copy_util",
187188
] + torch_deps() + select({
188189
"@//:using_rocm": ["@local_config_rocm//rocm:rocm_headers"],
189190
"//conditions:default": [],

rtp_llm/cpp/core/ExecOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "rtp_llm/cpp/core/Event.h"
66
#include "rtp_llm/cpp/config/ConfigModules.h"
77
#include "rtp_llm/cpp/models/eplb/stats/ExpertStats.h"
8+
#include "rtp_llm/cpp/kernels/fuse_copy/util.h"
89

910
#include <memory>
1011
#include <atomic>
@@ -93,6 +94,9 @@ void execNoBlockCopy(const MultiCopyParams& params);
9394
void execBatchCopy(const BatchCopyParams& params);
9495
void execMultiMergeCopy(const MultiMergeCopyParams& params);
9596

97+
void fusedCopy(const FusedD2DCopyParams& params);
98+
void fusedStridedCopy(const FusedStridedCopyParams& params);
99+
96100
// ===================================================================
97101
// Sample ops
98102
// ===================================================================

rtp_llm/cpp/cuda/CudaGraph/CudaGraphRunner.cc

Lines changed: 164 additions & 104 deletions
Large diffs are not rendered by default.

rtp_llm/cpp/cuda/CudaGraph/CudaGraphRunner.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ class CudaGraphRunner: public GraphBase {
102102
void setInputEmbeddingScalar(float input_embedding_scalar) override;
103103

104104
private:
105-
void copySmallerIntoLarger(const torch::Tensor& source_tensor, torch::Tensor& target_tensor);
106105
std::vector<int> getDecodeBatchSizesToCapture();
107106
std::vector<int> getPrefillSequenceLengthsToCapture();
108107
/// Select graph key for decode; false if no captured graph can serve current_batch_size (e.g. lower_bound hit end).

rtp_llm/cpp/cuda/CudaGraph/CudaGraphUtils.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ void debugPrintPyModelInputs(const PyModelInputs& inputs) {
103103
printTensorInfo("kv_cache_block_id_host", inputs.attention_inputs.kv_cache_block_id_host, 40);
104104
printTensorInfo("kv_cache_block_id_device", inputs.attention_inputs.kv_cache_block_id_device, 40);
105105
printTensorInfo("cu_seqlens", inputs.attention_inputs.cu_seqlens);
106+
printTensorInfo("cu_seqlens_host", inputs.attention_inputs.cu_seqlens_host);
106107
printTensorInfo("cu_kv_seqlens", inputs.attention_inputs.cu_kv_seqlens);
107108
printTensorInfo("sequence_lengths_plus_1_d", inputs.attention_inputs.sequence_lengths_plus_1_d);
108109
printTensorInfo("decode_cu_seqlens_d", inputs.attention_inputs.decode_cu_seqlens_d);

rtp_llm/cpp/cuda/CudaGraph/CudaGraphUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class CaptureMemoryHold {
2929
CaptureMemoryHold(at::Tensor hidden_states, PyModelInputs& inputs, bool is_embedding):
3030
decoder_layer_hidden_states_(hidden_states) {
3131
py_model_inputs_.attention_inputs.input_lengths = inputs.attention_inputs.input_lengths;
32+
py_model_inputs_.attention_inputs.input_lengths_d = inputs.attention_inputs.input_lengths_d;
3233
py_model_inputs_.attention_inputs.sequence_lengths = inputs.attention_inputs.sequence_lengths;
3334
py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device =
3435
inputs.attention_inputs.kv_cache_kernel_block_id_device;
@@ -40,11 +41,13 @@ class CaptureMemoryHold {
4041
inputs.attention_inputs.kv_cache_kernel_block_id_host_by_group;
4142
py_model_inputs_.attention_inputs.kv_cache_layer_to_group = inputs.attention_inputs.kv_cache_layer_to_group;
4243
py_model_inputs_.attention_inputs.prefix_lengths = inputs.attention_inputs.prefix_lengths;
44+
py_model_inputs_.attention_inputs.prefix_lengths_d = inputs.attention_inputs.prefix_lengths_d;
4345
py_model_inputs_.input_ids = inputs.input_ids;
4446

4547
// for spec
4648
py_model_inputs_.input_hiddens = inputs.input_hiddens;
4749
py_model_inputs_.attention_inputs.cu_seqlens = inputs.attention_inputs.cu_seqlens;
50+
py_model_inputs_.attention_inputs.cu_seqlens_host = inputs.attention_inputs.cu_seqlens_host;
4851
py_model_inputs_.attention_inputs.cu_kv_seqlens = inputs.attention_inputs.cu_kv_seqlens;
4952
py_model_inputs_.attention_inputs.padding_offset = inputs.attention_inputs.padding_offset;
5053
py_model_inputs_.attention_inputs.is_prefill = inputs.attention_inputs.is_prefill;

rtp_llm/cpp/cuda/ops/CudaFlashInfer.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ FlashInferAttnParams::allocateManyBuffer(const std::vector<std::vector<int64_t>>
7676
auto buf_options = torch::TensorOptions(torch::kInt32);
7777
if (atype == AllocationType::DEVICE) {
7878
buf_options = buf_options.device(torch::kCUDA);
79+
} else {
80+
buf_options = buf_options.device(torch::kCPU).pinned_memory(true);
7981
}
82+
8083
auto buf = torch::empty({(int64_t)total_size}, buf_options);
8184

8285
size_t offset = 0;
@@ -104,7 +107,8 @@ FlashInferAttnParams* FlashInferAttnParams::create(int batch_size, int input_tok
104107
params->float_workspace_ =
105108
torch::empty({128 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCUDA));
106109
params->int_workspace_ = torch::empty({8 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCUDA));
107-
params->int_host_workspace_ = torch::empty({8 * 1024 * 1024}, torch::kInt8);
110+
params->int_host_workspace_ =
111+
torch::empty({8 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCPU).pinned_memory(true));
108112

109113
params->float_workspace_d = params->float_workspace_;
110114
params->int_workspace_d = params->int_workspace_;

rtp_llm/cpp/kernels/BUILD

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,45 @@ cc_library(
489489
visibility = ["//visibility:public"],
490490
)
491491

492+
cc_library(
493+
name = "fuse_copy_util",
494+
hdrs = [
495+
"fuse_copy/util.h",
496+
],
497+
deps = any_cuda_deps + [
498+
"//rtp_llm/cpp/utils:core_utils",
499+
],
500+
copts = any_cuda_copts(),
501+
include_prefix = "src",
502+
visibility = ["//visibility:public"],
503+
)
504+
505+
cc_library(
506+
name = "fuse_copy_kernel",
507+
srcs = [
508+
"fuse_copy/fuse_copy_kernel.cu",
509+
],
510+
hdrs = [
511+
"fuse_copy/fuse_copy_kernel.h",
512+
],
513+
deps = any_cuda_deps + [
514+
":fuse_copy_util",
515+
"//rtp_llm/cpp/utils:core_utils",
516+
],
517+
copts = any_cuda_copts(),
518+
include_prefix = "src",
519+
visibility = ["//visibility:public"],
520+
)
521+
522+
cc_library(
523+
name = "kernels_fused_copy",
524+
deps = [
525+
":fuse_copy_kernel",
526+
":fuse_copy_util",
527+
],
528+
visibility = ["//visibility:public"],
529+
)
530+
492531
cc_library(
493532
name = "kernels_copy",
494533
srcs = [
@@ -576,6 +615,7 @@ cc_library(
576615
":kernels_embedding",
577616
":kernels_tensor_ops",
578617
":kernels_kv_cache",
618+
":kernels_fused_copy",
579619
":kernels_copy",
580620
":kernels_moe",
581621
":kernels_mla",

rtp_llm/cpp/kernels/copy_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include <assert.h>
44
#include <vector>
55

6-
#if USEING_CUDA
6+
#if USING_CUDA
77
#include <cuda_runtime.h>
88
#endif
99

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include <cstdint>
2+
#include <cstddef>
3+
4+
#include "rtp_llm/cpp/kernels/fuse_copy/fuse_copy_kernel.h"
5+
6+
namespace rtp_llm {
7+
8+
static constexpr int FUSED_COPY_BLOCKS_PER_TASK = 8;
9+
static constexpr int FUSED_COPY_THREADS = 256;
10+
11+
__global__ void fusedCopyKernel(FusedD2DCopyParams params) {
12+
const int copy_idx = blockIdx.y;
13+
if (copy_idx >= params.num_copies)
14+
return;
15+
16+
const size_t total_bytes = params.size[copy_idx];
17+
const size_t global_tid = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
18+
const size_t global_stride = static_cast<size_t>(gridDim.x) * blockDim.x;
19+
20+
const auto src_addr = reinterpret_cast<uintptr_t>(params.src[copy_idx]);
21+
const auto dst_addr = reinterpret_cast<uintptr_t>(params.dst[copy_idx]);
22+
23+
if ((src_addr % sizeof(int4) == 0) && (dst_addr % sizeof(int4) == 0)) {
24+
// Fast path: 16-byte vectorized bulk copy
25+
const int4* src = reinterpret_cast<const int4*>(src_addr);
26+
int4* dst = reinterpret_cast<int4*>(dst_addr);
27+
const size_t n16 = total_bytes / sizeof(int4);
28+
29+
for (size_t i = global_tid; i < n16; i += global_stride) {
30+
dst[i] = src[i];
31+
}
32+
33+
if (blockIdx.x == 0) {
34+
const size_t rem_start = n16 * sizeof(int4);
35+
const char* src_byte = reinterpret_cast<const char*>(src_addr);
36+
char* dst_byte = reinterpret_cast<char*>(dst_addr);
37+
for (size_t i = rem_start + threadIdx.x; i < total_bytes; i += blockDim.x) {
38+
dst_byte[i] = src_byte[i];
39+
}
40+
}
41+
} else {
42+
// Slow path: byte-by-byte copy for unaligned pointers
43+
const char* src_byte = reinterpret_cast<const char*>(src_addr);
44+
char* dst_byte = reinterpret_cast<char*>(dst_addr);
45+
for (size_t i = global_tid; i < total_bytes; i += global_stride) {
46+
dst_byte[i] = src_byte[i];
47+
}
48+
}
49+
}
50+
51+
void invokeFusedCopy(const FusedD2DCopyParams& params, cudaStream_t stream) {
52+
if (params.num_copies <= 0)
53+
return;
54+
dim3 grid(FUSED_COPY_BLOCKS_PER_TASK, params.num_copies);
55+
fusedCopyKernel<<<grid, FUSED_COPY_THREADS, 0, stream>>>(params);
56+
}
57+
58+
__global__ void fusedStridedCopyKernel(FusedStridedCopyParams params) {
59+
const int copy_idx = blockIdx.y;
60+
if (copy_idx >= params.num_copies)
61+
return;
62+
63+
const size_t nrows = params.num_rows[copy_idx];
64+
const size_t rbytes = params.row_bytes[copy_idx];
65+
const size_t src_stride = params.src_row_stride[copy_idx];
66+
const size_t dst_stride = params.dst_row_stride[copy_idx];
67+
const char* src = reinterpret_cast<const char*>(params.src[copy_idx]);
68+
char* dst = reinterpret_cast<char*>(params.dst[copy_idx]);
69+
70+
const size_t total = nrows * rbytes;
71+
const size_t global_tid = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
72+
const size_t stride = static_cast<size_t>(gridDim.x) * blockDim.x;
73+
74+
for (size_t idx = global_tid; idx < total; idx += stride) {
75+
const size_t row = idx / rbytes;
76+
const size_t col = idx % rbytes;
77+
dst[row * dst_stride + col] = src[row * src_stride + col];
78+
}
79+
}
80+
81+
void invokeFusedStridedCopy(const FusedStridedCopyParams& params, cudaStream_t stream) {
82+
if (params.num_copies <= 0)
83+
return;
84+
dim3 grid(FUSED_COPY_BLOCKS_PER_TASK, params.num_copies);
85+
fusedStridedCopyKernel<<<grid, FUSED_COPY_THREADS, 0, stream>>>(params);
86+
}
87+
88+
} // namespace rtp_llm

0 commit comments

Comments
 (0)