From 5707312b8ea466d4443aadb42c887c20e8f98ed7 Mon Sep 17 00:00:00 2001 From: "huzetao.hzt" Date: Thu, 16 Apr 2026 14:47:41 +0800 Subject: [PATCH 1/4] refactor: implement fused copy operations and enhance tensor handling in CUDA graph --- rtp_llm/cpp/core/BUILD | 1 + rtp_llm/cpp/core/ExecOps.h | 4 + rtp_llm/cpp/cuda/ops/CudaFlashInfer.cc | 6 +- rtp_llm/cpp/cuda_graph/BUILD | 3 + rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc | 29 +- rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc | 320 +++++++++++------- rtp_llm/cpp/cuda_graph/cuda_graph_runner.h | 1 - rtp_llm/cpp/cuda_graph/cuda_graph_utils.cc | 1 + rtp_llm/cpp/cuda_graph/cuda_graph_utils.h | 3 + .../tests/cuda_graph_decode_padding.py | 2 +- .../cpp/embedding_engine/EmbeddingExecutor.cc | 14 +- rtp_llm/cpp/models/BUILD | 3 + rtp_llm/cpp/models/PyWrappedModel.cc | 59 +++- rtp_llm/cpp/models/PyWrappedModel.h | 8 +- .../speculative/MtpBatchStreamProcessor.cc | 2 +- .../normal_engine/speculative/MtpExecutor.cc | 20 +- rtp_llm/models_py/bindings/OpDefs.cc | 1 + rtp_llm/models_py/bindings/OpDefs.h | 7 +- rtp_llm/models_py/bindings/OpDefsUtils.h | 5 +- rtp_llm/models_py/bindings/common/BUILD | 28 +- .../models_py/bindings/common/FusedCopyOp.cc | 41 +++ .../models_py/bindings/common/kernels/BUILD | 27 ++ .../bindings/common/kernels/copy_utils.h | 2 +- .../common/kernels/cuda_graph_copy_kernel.cu | 24 +- .../common/kernels/fuse_copy_kernel.cu | 88 +++++ .../common/kernels/fuse_copy_kernel.h | 19 ++ .../bindings/common/kernels/fuse_copy_util.h | 62 ++++ rtp_llm/models_py/bindings/cuda/BUILD | 1 + 28 files changed, 589 insertions(+), 192 deletions(-) create mode 100644 rtp_llm/models_py/bindings/common/FusedCopyOp.cc create mode 100644 rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.cu create mode 100644 rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.h create mode 100644 rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h diff --git a/rtp_llm/cpp/core/BUILD b/rtp_llm/cpp/core/BUILD index 62156381da..6f6b6a8728 100644 --- a/rtp_llm/cpp/core/BUILD +++ b/rtp_llm/cpp/core/BUILD @@ -184,6 +184,7 @@ cc_library( ":event", "//rtp_llm/cpp/config:config_modules", "//rtp_llm/cpp/models:stats", + "//rtp_llm/models_py/bindings/common/kernels:fuse_copy_util", ] + torch_deps() + select({ "@//:using_rocm": ["@local_config_rocm//rocm:rocm_headers"], "//conditions:default": [], diff --git a/rtp_llm/cpp/core/ExecOps.h b/rtp_llm/cpp/core/ExecOps.h index 7cede3790e..03b39f6b62 100644 --- a/rtp_llm/cpp/core/ExecOps.h +++ b/rtp_llm/cpp/core/ExecOps.h @@ -5,6 +5,7 @@ #include "rtp_llm/cpp/core/Event.h" #include "rtp_llm/cpp/config/ConfigModules.h" #include "rtp_llm/cpp/models/eplb/stats/ExpertStats.h" +#include "rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h" #include #include @@ -92,6 +93,9 @@ void execNoBlockCopy(const CopyParams& params); void execBatchCopy(const BatchCopyParams& params); void execMultiMergeCopy(const MultiMergeCopyParams& params); +void fusedCopy(const FusedD2DCopyParams& params); +void fusedStridedCopy(const FusedStridedCopyParams& params); + // =================================================================== // Sample ops // =================================================================== diff --git a/rtp_llm/cpp/cuda/ops/CudaFlashInfer.cc b/rtp_llm/cpp/cuda/ops/CudaFlashInfer.cc index 3d83dd54b6..c63b1a1699 100644 --- a/rtp_llm/cpp/cuda/ops/CudaFlashInfer.cc +++ b/rtp_llm/cpp/cuda/ops/CudaFlashInfer.cc @@ -76,7 +76,10 @@ FlashInferAttnParams::allocateManyBuffer(const std::vector> auto buf_options = torch::TensorOptions(torch::kInt32); if (atype == AllocationType::DEVICE) { buf_options = buf_options.device(torch::kCUDA); + } else { + buf_options = buf_options.device(torch::kCPU).pinned_memory(true); } + auto buf = torch::empty({(int64_t)total_size}, buf_options); size_t offset = 0; @@ -104,7 +107,8 @@ FlashInferAttnParams* FlashInferAttnParams::create(int batch_size, int input_tok params->float_workspace_ = torch::empty({128 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCUDA)); params->int_workspace_ = torch::empty({8 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCUDA)); - params->int_host_workspace_ = torch::empty({8 * 1024 * 1024}, torch::kInt8); + params->int_host_workspace_ = + torch::empty({8 * 1024 * 1024}, torch::TensorOptions(torch::kInt8).device(torch::kCPU).pinned_memory(true)); params->float_workspace_d = params->float_workspace_; params->int_workspace_d = params->int_workspace_; diff --git a/rtp_llm/cpp/cuda_graph/BUILD b/rtp_llm/cpp/cuda_graph/BUILD index 9479137799..92e31d4262 100644 --- a/rtp_llm/cpp/cuda_graph/BUILD +++ b/rtp_llm/cpp/cuda_graph/BUILD @@ -66,9 +66,12 @@ cc_library( deps = torch_deps() + [ ":cuda_graph_base", ":cuda_graph_hdrs_lib", + "//rtp_llm/cpp/core:exec_ops_hdr", "//rtp_llm/cpp/utils:core_utils", "//rtp_llm/cpp/utils:profiling_scope", "//rtp_llm/models_py/bindings:op_defs", + "//rtp_llm/models_py/bindings/common/kernels:fuse_copy_util", + "//rtp_llm/models_py/bindings/common:fuse_copy_op", ] + select({ "//:using_cuda": [ "//rtp_llm/cpp/cuda:cuda", diff --git a/rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc b/rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc index 3b97fe74be..0c4e8b8f11 100644 --- a/rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc +++ b/rtp_llm/cpp/cuda_graph/cuda_graph_prefill.cc @@ -23,9 +23,10 @@ void CudaGraphRunner::capturePrefill() { inputs.attention_inputs.prefix_lengths.fill_(0); // Must set cu_seqlens/cu_kv_seqlens/input_lengths to match actual seq_len, // otherwise FlashInfer plans for max_seq_len tokens but q/k/v only have seq_len tokens - inputs.attention_inputs.cu_seqlens.data_ptr()[0] = 0; - inputs.attention_inputs.cu_seqlens.data_ptr()[1] = seq_len; - inputs.attention_inputs.input_lengths.data_ptr()[0] = seq_len; + inputs.attention_inputs.cu_seqlens_host[0] = 0; + inputs.attention_inputs.cu_seqlens_host[1] = seq_len; + inputs.attention_inputs.cu_seqlens.copy_(inputs.attention_inputs.cu_seqlens_host, false); + inputs.attention_inputs.input_lengths[0] = seq_len; } else { // Draft model prefill: distribute seq_len tokens across batches (max num_tokens_per_bs_ each). // All max_bs_ batches get prefix to ensure buffer allocation covers worst-case replay. @@ -36,22 +37,26 @@ void CudaGraphRunner::capturePrefill() { // Active batches get real input tokens, inactive batches get 0 input tokens. inputs.attention_inputs.input_lengths.fill_(0); inputs.attention_inputs.prefix_lengths.fill_(prefix_len); - auto* input_lengths_ptr = inputs.attention_inputs.input_lengths.data_ptr(); + auto& input_lengths = inputs.attention_inputs.input_lengths; for (int b = 0; b < active_bs; b++) { int tokens = (b < active_bs - 1) ? num_tokens_per_bs_ : (seq_len - b * num_tokens_per_bs_); - input_lengths_ptr[b] = tokens; + input_lengths[b] = tokens; } // Build cu_seqlens and cu_kv_seqlens as cumulative sums - auto* prefix_lengths_ptr = inputs.attention_inputs.prefix_lengths.data_ptr(); - auto* cu_seqlens_ptr = inputs.attention_inputs.cu_seqlens.data_ptr(); - auto* cu_kv_seqlens_ptr = inputs.attention_inputs.cu_kv_seqlens.data_ptr(); - cu_seqlens_ptr[0] = 0; - cu_kv_seqlens_ptr[0] = 0; + auto cu_seqlens_host = inputs.attention_inputs.cu_seqlens_host; + auto cu_kv_seqlens_host = inputs.attention_inputs.cu_kv_seqlens.cpu(); + auto prefix_lengths = inputs.attention_inputs.prefix_lengths; + + cu_seqlens_host[0] = 0; + cu_kv_seqlens_host[0] = 0; for (int b = 0; b < max_bs_; b++) { - cu_seqlens_ptr[b + 1] = cu_seqlens_ptr[b] + input_lengths_ptr[b]; - cu_kv_seqlens_ptr[b + 1] = cu_kv_seqlens_ptr[b] + input_lengths_ptr[b] + prefix_lengths_ptr[b]; + cu_seqlens_host[b + 1] = cu_seqlens_host[b].item() + input_lengths[b].item(); + cu_kv_seqlens_host[b + 1] = cu_kv_seqlens_host[b].item() + input_lengths[b].item() + prefix_lengths[b].item(); } + + inputs.attention_inputs.cu_seqlens.copy_(cu_seqlens_host); + inputs.attention_inputs.cu_kv_seqlens.copy_(cu_kv_seqlens_host); } inputs.attention_inputs.context_total_kv_length = seq_len; diff --git a/rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc b/rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc index 807dc7705d..f80880846c 100644 --- a/rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc +++ b/rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc @@ -4,6 +4,8 @@ #include #include "rtp_llm/cpp/cuda_graph/cuda_graph_device_shims.h" #include "rtp_llm/cpp/utils/ProfilingScope.h" +#include "torch/csrc/autograd/generated/variable_factories.h" +#include "rtp_llm/cpp/core/ExecOps.h" using namespace torch_ext; namespace rtp_llm { @@ -25,10 +27,12 @@ namespace rtp_llm { // Helper function for optimized tensor copy using async operations with current CUDA stream void optimizedCopyAsync(const torch::Tensor& src, torch::Tensor& dst, size_t size) { - if (!src.defined() || src.numel() <= 0) { + if (!src.defined() || !dst.defined() || src.numel() <= 0) { return; } + RTP_LLM_PROFILE_SCOPE("optimizedCopyAsync"); + void* stream = reinterpret_cast(cuda_graph::graphGetCurrentStream().stream()); if (src.is_cuda() && dst.is_cuda()) { cuda_graph::graphMemcpyAsync(dst.data_ptr(), src.data_ptr(), size, cuda_graph::GraphMemcpyKind::D2D, stream); @@ -41,37 +45,6 @@ void optimizedCopyAsync(const torch::Tensor& src, torch::Tensor& dst, size_t siz } } -// column dimension -void CudaGraphRunner::copySmallerIntoLarger(const torch::Tensor& source_tensor, torch::Tensor& target_tensor) { - if (!source_tensor.defined() || source_tensor.numel() <= 0) { - return; - } - if (!target_tensor.defined() || target_tensor.numel() <= 0) { - return; - } - if (source_tensor.dim() != target_tensor.dim()) { - throw std::runtime_error( - "Error: Source and target tensors must have the same number of dimensions. source.dim()=" - + std::to_string(source_tensor.dim()) + ", target.dim()=" + std::to_string(target_tensor.dim()) + "."); - } - for (int i = 0; i < source_tensor.dim(); ++i) { - if (source_tensor.size(i) > target_tensor.size(i)) { - throw std::runtime_error( - "Error: Target tensor dimension " + std::to_string(i) + " (" + std::to_string(target_tensor.size(i)) - + ") is smaller than source tensor dimension " + std::to_string(i) + " (" - + std::to_string(source_tensor.size(i)) + "). This violates the function's guarantee."); - } - } - - torch::Tensor target_slice = target_tensor; - - for (int i = 0; i < source_tensor.dim(); ++i) { - target_slice = target_slice.slice(i, 0, source_tensor.size(i)); - } - - target_slice.copy_(source_tensor); -} - void CudaGraphRunner::prepareInputs(const PyModelInputs& inputs, CudaGraphState& state) { RTP_LLM_PROFILE_SCOPE("cuda_graph.prepareInputs"); // 1. non spec cuda graph: @@ -88,131 +61,209 @@ void CudaGraphRunner::prepareInputs(const PyModelInputs& inputs, CudaGraphState& // should wait last forward done before prepare inputs forward_event_.synchronize(); - // Get the appropriate graph instance based on mode const size_t graph_idx = is_prefill_cuda_graph_mode_ ? state.current_real_graph_seq_len : state.current_real_graph_bs; auto& py_model_inputs_ = graph_instances_[graph_idx].mem_hold_.py_model_inputs_; auto attn_pyobj = graph_instances_[graph_idx].mem_hold_.attn_pyobj_; - // Clear kv_cache block ids to prevent cache block pollution + FusedD2DCopyParams d2d_copies; + FusedStridedCopyParams strided_d2d_copies; + + auto tryAddD2DCopy = [&d2d_copies](const torch::Tensor& src, torch::Tensor& dst, size_t bytes) { + if (src.defined() && src.numel() > 0) { + d2d_copies.add(src.data_ptr(), dst.data_ptr(), bytes); + } + }; + + // Collect a strided 2D D2D copy: copies src[0..rows, 0..cols] into dst[0..rows, 0..cols] + // where src and dst may have different column strides (copySmallerIntoLarger semantics). + // For 1D tensors, falls back to a contiguous D2D copy to avoid silent data loss. + auto tryAddStridedD2DCopy = [&strided_d2d_copies, &d2d_copies](const torch::Tensor& src, torch::Tensor& dst) { + if (!src.defined() || src.numel() <= 0) + return; + if (src.dim() < 2) { + d2d_copies.add(src.data_ptr(), dst.data_ptr(), src.numel() * src.element_size()); + return; + } + strided_d2d_copies.add(src.data_ptr(), + dst.data_ptr(), + src.size(0), + src.size(1) * src.element_size(), + src.stride(0) * src.element_size(), + dst.stride(0) * dst.element_size()); + }; + + // H2H strided 2D copy via row-by-row memcpy (cannot use GPU kernel for host memory). + // For 1D tensors, falls back to a contiguous memcpy. + auto stridedCopyHost = [](const torch::Tensor& src, torch::Tensor& dst) { + if (!src.defined() || src.numel() <= 0) + return; + RTP_LLM_PROFILE_SCOPE("stridedCopyHost"); + if (src.dim() < 2) { + memcpy(dst.data_ptr(), src.data_ptr(), src.numel() * src.element_size()); + return; + } + const size_t nrows = src.size(0); + const size_t row_bytes = src.size(1) * src.element_size(); + const size_t src_stride = src.stride(0) * src.element_size(); + const size_t dst_stride = dst.stride(0) * dst.element_size(); + const char* src_ptr = reinterpret_cast(src.data_ptr()); + char* dst_ptr = reinterpret_cast(dst.data_ptr()); + for (size_t r = 0; r < nrows; ++r) { + memcpy(dst_ptr + r * dst_stride, src_ptr + r * src_stride, row_bytes); + } + }; + + // clear kv_cache_kernel_block_id_device, otherwise it will cause the cache block pollution py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device.fill_(0); - if (py_model_inputs_.attention_inputs.kv_cache_block_id_device.defined()) { - py_model_inputs_.attention_inputs.kv_cache_block_id_device.fill_(0); + + // NOTE: kv_cache_block_id_{host,device} are physical block IDs dedicated for cache store + // (see OpDefs.h). They are NOT consumed by any GPU attention kernel during CUDA graph replay; + // attention kernels only use kv_cache_kernel_block_id_{host,device}. Cache store operations + // run outside the CUDA graph and read from the original (non-graph) inputs directly. + + // Common device copy + int token_num = is_prefill_cuda_graph_mode_ ? state.current_seq_len : inputs.input_ids.size(0); + + tryAddD2DCopy(inputs.input_ids, py_model_inputs_.input_ids, token_num * sizeof(int)); + tryAddD2DCopy(inputs.input_hiddens, + py_model_inputs_.input_hiddens, + inputs.input_hiddens.numel() * inputs.input_hiddens.element_size()); + tryAddD2DCopy(inputs.attention_inputs.cu_seqlens, + py_model_inputs_.attention_inputs.cu_seqlens, + (state.current_batch_size + 1) * sizeof(int)); + tryAddD2DCopy(inputs.attention_inputs.cu_kv_seqlens, + py_model_inputs_.attention_inputs.cu_kv_seqlens, + (state.current_batch_size + 1) * sizeof(int)); + tryAddD2DCopy(inputs.attention_inputs.input_lengths_d, + py_model_inputs_.attention_inputs.input_lengths_d, + state.current_batch_size * sizeof(int)); + // Strided 2D D2D copy for flat kv_cache_block_id + tryAddStridedD2DCopy(inputs.attention_inputs.kv_cache_kernel_block_id_device, + py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device); + + if (!is_prefill_cuda_graph_mode_) { + // D2D copies — collected for single batched kernel launch + tryAddD2DCopy(inputs.attention_inputs.prefix_lengths_d, + py_model_inputs_.attention_inputs.prefix_lengths_d, + state.current_batch_size * sizeof(int)); + tryAddD2DCopy(inputs.attention_inputs.sequence_lengths_plus_1_d, + py_model_inputs_.attention_inputs.sequence_lengths_plus_1_d, + state.current_batch_size * sizeof(int)); + tryAddD2DCopy(inputs.attention_inputs.decode_cu_seqlens_d, + py_model_inputs_.attention_inputs.decode_cu_seqlens_d, + (state.current_batch_size + 1) * sizeof(int)); + } else { + // D2D copy + if (inputs.bert_embedding_inputs.position_encoding.numel() > 0) { + tryAddD2DCopy(inputs.bert_embedding_inputs.combo_position_ids, + py_model_inputs_.bert_embedding_inputs.combo_position_ids, + state.current_seq_len * sizeof(int)); + tryAddD2DCopy(inputs.bert_embedding_inputs.combo_tokens_type_ids, + py_model_inputs_.bert_embedding_inputs.combo_tokens_type_ids, + state.current_seq_len * sizeof(int)); + } } - if (py_model_inputs_.attention_inputs.kv_cache_block_id_host.defined()) { - py_model_inputs_.attention_inputs.kv_cache_block_id_host.fill_(0); + + // Hybrid cache: collect per-group D2D strided copies + const bool has_hybrid_cache = !inputs.attention_inputs.kv_cache_kernel_block_id_device_by_group.empty() + && !inputs.attention_inputs.kv_cache_kernel_block_id_host_by_group.empty() + && !py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device_by_group.empty() + && !py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host_by_group.empty(); + size_t hybrid_cache_group = 0; + + if (has_hybrid_cache) { + RTP_LLM_CHECK_WITH_INFO( + inputs.attention_inputs.kv_cache_kernel_block_id_device_by_group.size() + == py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device_by_group.size(), + "kv_cache_kernel_block_id_device_by_group size mismatch"); + hybrid_cache_group = inputs.attention_inputs.kv_cache_kernel_block_id_device_by_group.size(); + RTP_LLM_CHECK_WITH_INFO(inputs.attention_inputs.kv_cache_kernel_block_id_host_by_group.size() + == hybrid_cache_group + && py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host_by_group.size() + == hybrid_cache_group, + "kv_cache_kernel_block_id_host_by_group size mismatch"); + for (size_t g = 0; g < hybrid_cache_group; ++g) { + py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device_by_group[g].fill_(0); + py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host_by_group[g].fill_(0); + tryAddStridedD2DCopy(inputs.attention_inputs.kv_cache_kernel_block_id_device_by_group[g], + py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device_by_group[g]); + } } - // Common copies: input_ids, input_hiddens, attention lengths, kv_cache blocks - int token_num = is_prefill_cuda_graph_mode_ ? state.current_seq_len : inputs.input_ids.size(0); - optimizedCopyAsync(inputs.input_ids, py_model_inputs_.input_ids, token_num * sizeof(int)); - optimizedCopyAsync(inputs.input_hiddens, - py_model_inputs_.input_hiddens, - inputs.input_hiddens.numel() * inputs.input_hiddens.element_size()); - optimizedCopyAsync(inputs.attention_inputs.prefix_lengths, - py_model_inputs_.attention_inputs.prefix_lengths, - state.current_batch_size * sizeof(int)); + // Launch ALL D2D copies (contiguous + strided) in two fused kernels + fusedCopy(d2d_copies); + fusedStridedCopy(strided_d2d_copies); + + // NOTE: we do H2H after D2D copies to let GPU finish the D2D copies as soon as possible, + // so that the GPU can start the kernel launch as soon as possible. + + // H2H copies (common to both modes) + optimizedCopyAsync(inputs.attention_inputs.cu_seqlens_host, + py_model_inputs_.attention_inputs.cu_seqlens_host, + (state.current_batch_size + 1) * sizeof(int)); + optimizedCopyAsync(inputs.attention_inputs.input_lengths, py_model_inputs_.attention_inputs.input_lengths, state.current_batch_size * sizeof(int)); - optimizedCopyAsync(inputs.attention_inputs.cu_seqlens, - py_model_inputs_.attention_inputs.cu_seqlens, - (state.current_batch_size + 1) * sizeof(int)); - optimizedCopyAsync(inputs.attention_inputs.cu_kv_seqlens, - py_model_inputs_.attention_inputs.cu_kv_seqlens, - (state.current_batch_size + 1) * sizeof(int)); - if (inputs.attention_inputs.kv_cache_block_id_device.defined()) { - copySmallerIntoLarger(inputs.attention_inputs.kv_cache_block_id_device, - py_model_inputs_.attention_inputs.kv_cache_block_id_device); - } - if (inputs.attention_inputs.kv_cache_block_id_host.defined()) { - copySmallerIntoLarger(inputs.attention_inputs.kv_cache_block_id_host, - py_model_inputs_.attention_inputs.kv_cache_block_id_host); - } - copySmallerIntoLarger(inputs.attention_inputs.kv_cache_kernel_block_id_device, - py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device); - copySmallerIntoLarger(inputs.attention_inputs.kv_cache_kernel_block_id_host, - py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host); + optimizedCopyAsync(inputs.attention_inputs.prefix_lengths, + py_model_inputs_.attention_inputs.prefix_lengths, + state.current_batch_size * sizeof(int)); + + // Common H2H strided copies for kv_cache block tables (both decode & prefill) + stridedCopyHost(inputs.attention_inputs.kv_cache_kernel_block_id_host, + py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host); + + optimizedCopyAsync(inputs.attention_inputs.kv_cache_layer_to_group, + py_model_inputs_.attention_inputs.kv_cache_layer_to_group, + inputs.attention_inputs.kv_cache_layer_to_group.numel() * sizeof(int32_t)); - // Mode-specific copies if (!is_prefill_cuda_graph_mode_) { optimizedCopyAsync(inputs.attention_inputs.sequence_lengths, py_model_inputs_.attention_inputs.sequence_lengths, state.current_batch_size * sizeof(int)); - optimizedCopyAsync(inputs.attention_inputs.sequence_lengths_plus_1_d, - py_model_inputs_.attention_inputs.sequence_lengths_plus_1_d, - state.current_batch_size * sizeof(int)); - optimizedCopyAsync(inputs.attention_inputs.decode_cu_seqlens_d, - py_model_inputs_.attention_inputs.decode_cu_seqlens_d, - (state.current_batch_size + 1) * sizeof(int)); - auto attn_pyobj = graph_instances_[state.current_real_graph_bs].mem_hold_.attn_pyobj_; - // decode padding - attn_pyobj.attr("prepare_cuda_graph")(py_model_inputs_.attention_inputs); } else { optimizedCopyAsync(inputs.attention_inputs.padding_offset, py_model_inputs_.attention_inputs.padding_offset, state.current_seq_len * sizeof(int)); if (py_model_inputs_.attention_inputs.prefill_cuda_graph_copy_params) { - (*(py_model_inputs_.attention_inputs.prefill_cuda_graph_copy_params->cuda_graph_prefill_batch_size - .data_ptr())) = state.current_batch_size; + auto* batch_size_ptr = py_model_inputs_.attention_inputs.prefill_cuda_graph_copy_params + ->cuda_graph_prefill_batch_size.data_ptr(); + *batch_size_ptr = state.current_batch_size; } + } - if (inputs.bert_embedding_inputs.position_encoding.numel() > 0) { - optimizedCopyAsync(inputs.bert_embedding_inputs.combo_position_ids, - py_model_inputs_.bert_embedding_inputs.combo_position_ids, - state.current_seq_len * sizeof(int)); - optimizedCopyAsync(inputs.bert_embedding_inputs.combo_tokens_type_ids, - py_model_inputs_.bert_embedding_inputs.combo_tokens_type_ids, - state.current_seq_len * sizeof(int)); + // Hybrid cache: H2H strided copies for per-group block tables + if (has_hybrid_cache) { + for (size_t g = 0; g < hybrid_cache_group; ++g) { + stridedCopyHost(inputs.attention_inputs.kv_cache_kernel_block_id_host_by_group[g], + py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host_by_group[g]); } + } - // Reset unused batch portions to prevent stale data + // Reset unused batch portions to prevent stale data (prefill only) + if (is_prefill_cuda_graph_mode_) { if (state.current_batch_size < max_bs_) { py_model_inputs_.attention_inputs.prefix_lengths.slice(0, state.current_batch_size, max_bs_).fill_(0); py_model_inputs_.attention_inputs.input_lengths.slice(0, state.current_batch_size, max_bs_).fill_(0); } int last_valid = state.current_seq_len; + py_model_inputs_.attention_inputs.cu_seqlens_host.slice(0, state.current_batch_size + 1, max_bs_ + 1) + .fill_(last_valid); py_model_inputs_.attention_inputs.cu_seqlens.slice(0, state.current_batch_size + 1, max_bs_ + 1) .fill_(last_valid); py_model_inputs_.attention_inputs.cu_kv_seqlens.slice(0, state.current_batch_size + 1, max_bs_ + 1) .fill_(last_valid); } - attn_pyobj.attr("prepare_cuda_graph")(py_model_inputs_.attention_inputs); - - // Hybrid cache: update per-group block tables (including group 0). - if (!inputs.attention_inputs.kv_cache_kernel_block_id_device_by_group.empty() - && !inputs.attention_inputs.kv_cache_kernel_block_id_host_by_group.empty() - && !py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device_by_group.empty() - && !py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host_by_group.empty()) { - RTP_LLM_CHECK_WITH_INFO( - inputs.attention_inputs.kv_cache_kernel_block_id_device_by_group.size() - == py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device_by_group.size(), - "kv_cache_kernel_block_id_device_by_group size mismatch"); - const size_t group = inputs.attention_inputs.kv_cache_kernel_block_id_device_by_group.size(); - RTP_LLM_CHECK_WITH_INFO(inputs.attention_inputs.kv_cache_kernel_block_id_host_by_group.size() == group - && py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host_by_group.size() - == group, - "kv_cache_kernel_block_id_host_by_group size mismatch"); - for (size_t g = 0; g < group; ++g) { - // Clear per-group block tables before copying real entries. - // Without this, padding entries retain stale block IDs from previous calls, - // causing linear attention (GatedDeltaNet) to corrupt SSM/conv states of stale blocks. - py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device_by_group[g].fill_(0); - py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host_by_group[g].fill_(0); - copySmallerIntoLarger(inputs.attention_inputs.kv_cache_kernel_block_id_device_by_group[g], - py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device_by_group[g]); - copySmallerIntoLarger(inputs.attention_inputs.kv_cache_kernel_block_id_host_by_group[g], - py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_host_by_group[g]); - } + // launch prepare_cuda_graph when attention inputs are ready + { + RTP_LLM_PROFILE_SCOPE("cuda_graph.prepareInputs(prepare_cuda_graph)"); + attn_pyobj.attr("prepare_cuda_graph")(py_model_inputs_.attention_inputs); } - - optimizedCopyAsync(inputs.attention_inputs.kv_cache_layer_to_group, - py_model_inputs_.attention_inputs.kv_cache_layer_to_group, - inputs.attention_inputs.kv_cache_layer_to_group.numel() * sizeof(int32_t)); } PyModelOutputs CudaGraphRunner::forward(const PyModelInputs& inputs, CudaGraphState& state) { @@ -341,9 +392,9 @@ bool CudaGraphRunner::canRun(const PyModelInputs& inputs, CudaGraphState& state) void CudaGraphRunner::initKernelInternalMemory() { torch::Tensor cu_seqlens = - torch::zeros({int(max_bs_ + 1)}, torch::TensorOptions(torch::kInt32).device(torch::kCPU)); + torch::zeros({int(max_bs_ + 1)}, torch::TensorOptions(torch::kInt32).device(torch::kCPU)).pin_memory(); torch::Tensor cu_kv_seqlens = - torch::zeros({int(max_bs_ + 1)}, torch::TensorOptions(torch::kInt32).device(torch::kCPU)); + torch::zeros({int(max_bs_ + 1)}, torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); auto input_lengths = capture_mem_hold_.py_model_inputs_.attention_inputs.input_lengths; auto prefix_lengths = capture_mem_hold_.py_model_inputs_.attention_inputs.prefix_lengths; @@ -351,8 +402,9 @@ void CudaGraphRunner::initKernelInternalMemory() { if (prefix_lengths.defined()) { cu_kv_seqlens.slice(0, 1, max_bs_ + 1) = input_lengths.add(prefix_lengths).cumsum(0); } - capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens = cu_seqlens.pin_memory(); - capture_mem_hold_.py_model_inputs_.attention_inputs.cu_kv_seqlens = cu_kv_seqlens.pin_memory(); + capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens_host = cu_seqlens; + capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens = cu_seqlens.cuda(); + capture_mem_hold_.py_model_inputs_.attention_inputs.cu_kv_seqlens = cu_kv_seqlens.cuda(); } int CudaGraphRunner::getCurrentRealGraphBs(const CudaGraphState& state) const { @@ -366,8 +418,9 @@ void CudaGraphRunner::initCaptureAttentionInputs(PyModelInputs& inputs, int max_ // input_ids [tokens_nums] = [batch_size * num_tokens_per_bs] inputs.input_ids = torch::zeros({max_num_token_}, options_cuda_int32_); // input_lengths [batch_size, int32] (decode only) - inputs.attention_inputs.input_lengths = torch::full({int(max_bs_)}, num_tokens_per_bs_, options_cpu_int32_); - inputs.attention_inputs.input_lengths = inputs.attention_inputs.input_lengths.pin_memory(); + inputs.attention_inputs.input_lengths = torch::full({int(max_bs_)}, num_tokens_per_bs_, options_cpu_int32_); + inputs.attention_inputs.input_lengths = inputs.attention_inputs.input_lengths.pin_memory(); + inputs.attention_inputs.input_lengths_d = inputs.attention_inputs.input_lengths.cuda(); // sequence_lengths [batch_size, int32] (decode only) // sequence_length should in pinned memory inputs.attention_inputs.sequence_lengths = torch::ones({int(max_bs_)}, options_cpu_int32_); @@ -415,8 +468,10 @@ void CudaGraphRunner::initCaptureAttentionInputs(PyModelInputs& inputs, int max_ if (num_tokens_per_bs_ > 1 && !is_prefill_cuda_graph_mode_) { inputs.attention_inputs.prefix_lengths = torch::full({int(max_bs_)}, max_seq_len_ - num_tokens_per_bs_, options_cpu_int32_).pin_memory(); + inputs.attention_inputs.prefix_lengths_d = inputs.attention_inputs.prefix_lengths.cuda(); } else if (is_prefill_cuda_graph_mode_) { - inputs.attention_inputs.prefix_lengths = torch::zeros({int(max_bs_)}, options_cpu_int32_).pin_memory(); + inputs.attention_inputs.prefix_lengths = torch::zeros({int(max_bs_)}, options_cpu_int32_).pin_memory(); + inputs.attention_inputs.prefix_lengths_d = inputs.attention_inputs.prefix_lengths.cuda(); } // padding_offset [max_num_token_, int32] (for attention padding) inputs.attention_inputs.padding_offset = torch::zeros({int(max_seq_len_ * max_bs_)}, options_cpu_int32_); @@ -537,16 +592,23 @@ void CudaGraphRunner::initCapture() { if (is_prefill_cuda_graph_mode_) { RTP_LLM_LOG_INFO("initCapture forward post check start for prefill"); - capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens.data_ptr()[1] = max_num_token_; - capture_mem_hold_.py_model_inputs_.attention_inputs.cu_kv_seqlens.data_ptr()[1] = max_num_token_; - capture_mem_hold_.py_model_inputs_.attention_inputs.input_lengths.data_ptr()[0] = max_num_token_; + capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens_host[1] = max_num_token_; + capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens[1] = max_num_token_; + capture_mem_hold_.py_model_inputs_.attention_inputs.cu_kv_seqlens[1] = max_num_token_; + capture_mem_hold_.py_model_inputs_.attention_inputs.input_lengths[0] = max_num_token_; + capture_mem_hold_.py_model_inputs_.attention_inputs.input_lengths_d[0] = max_num_token_; + PyModelInputs inputs = capture_mem_hold_.py_model_inputs_; + inputs.attention_inputs.cu_seqlens_host = + capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens_host.slice(0, 0, 2); inputs.attention_inputs.cu_seqlens = capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens.slice(0, 0, 2); inputs.attention_inputs.cu_kv_seqlens = capture_mem_hold_.py_model_inputs_.attention_inputs.cu_kv_seqlens.slice(0, 0, 2); inputs.attention_inputs.input_lengths = capture_mem_hold_.py_model_inputs_.attention_inputs.input_lengths.slice(0, 0, 1); + inputs.attention_inputs.input_lengths_d = + capture_mem_hold_.py_model_inputs_.attention_inputs.input_lengths_d.slice(0, 0, 1); inputs.attention_inputs.kv_cache_kernel_block_id_device = capture_mem_hold_.py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device.slice(0, 0, 1); inputs.attention_inputs.kv_cache_kernel_block_id_host = @@ -648,6 +710,8 @@ void CudaGraphRunner::prepareCaptureInputs(PyModelInputs& inputs, int batch_size inputs.input_hiddens = capture_mem_hold_.py_model_inputs_.input_hiddens.slice(0, 0, seq_len_or_tokens); inputs.attention_inputs.input_lengths = capture_mem_hold_.py_model_inputs_.attention_inputs.input_lengths.slice(0, 0, batch_size); + inputs.attention_inputs.input_lengths_d = + capture_mem_hold_.py_model_inputs_.attention_inputs.input_lengths_d.slice(0, 0, batch_size); inputs.attention_inputs.padding_offset = capture_mem_hold_.py_model_inputs_.attention_inputs.padding_offset.slice(0, 0, seq_len_or_tokens); @@ -655,6 +719,8 @@ void CudaGraphRunner::prepareCaptureInputs(PyModelInputs& inputs, int batch_size if (capture_mem_hold_.py_model_inputs_.attention_inputs.prefix_lengths.defined()) { inputs.attention_inputs.prefix_lengths = capture_mem_hold_.py_model_inputs_.attention_inputs.prefix_lengths.slice(0, 0, batch_size); + inputs.attention_inputs.prefix_lengths_d = + capture_mem_hold_.py_model_inputs_.attention_inputs.prefix_lengths_d.slice(0, 0, batch_size); } inputs.attention_inputs.sequence_lengths = capture_mem_hold_.py_model_inputs_.attention_inputs.sequence_lengths.slice(0, 0, batch_size); @@ -671,6 +737,8 @@ void CudaGraphRunner::prepareCaptureInputs(PyModelInputs& inputs, int batch_size capture_mem_hold_.py_model_inputs_.attention_inputs.kv_cache_block_id_host.defined() ? capture_mem_hold_.py_model_inputs_.attention_inputs.kv_cache_block_id_host.slice(0, 0, batch_size) : torch::Tensor(); + inputs.attention_inputs.cu_seqlens_host = + capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens_host.slice(0, 0, batch_size + 1); inputs.attention_inputs.cu_seqlens = capture_mem_hold_.py_model_inputs_.attention_inputs.cu_seqlens.slice(0, 0, batch_size + 1); inputs.attention_inputs.cu_kv_seqlens = diff --git a/rtp_llm/cpp/cuda_graph/cuda_graph_runner.h b/rtp_llm/cpp/cuda_graph/cuda_graph_runner.h index 50bd1db844..90fb0c4e57 100644 --- a/rtp_llm/cpp/cuda_graph/cuda_graph_runner.h +++ b/rtp_llm/cpp/cuda_graph/cuda_graph_runner.h @@ -102,7 +102,6 @@ class CudaGraphRunner: public GraphBase { void setInputEmbeddingScalar(float input_embedding_scalar) override; private: - void copySmallerIntoLarger(const torch::Tensor& source_tensor, torch::Tensor& target_tensor); std::vector getDecodeBatchSizesToCapture(); std::vector getPrefillSequenceLengthsToCapture(); /// Select graph key for decode; false if no captured graph can serve current_batch_size (e.g. lower_bound hit end). diff --git a/rtp_llm/cpp/cuda_graph/cuda_graph_utils.cc b/rtp_llm/cpp/cuda_graph/cuda_graph_utils.cc index efc0989b17..6483d48508 100644 --- a/rtp_llm/cpp/cuda_graph/cuda_graph_utils.cc +++ b/rtp_llm/cpp/cuda_graph/cuda_graph_utils.cc @@ -103,6 +103,7 @@ void debugPrintPyModelInputs(const torch_ext::PyModelInputs& inputs) { printTensorInfo("kv_cache_block_id_host", inputs.attention_inputs.kv_cache_block_id_host, 40); printTensorInfo("kv_cache_block_id_device", inputs.attention_inputs.kv_cache_block_id_device, 40); printTensorInfo("cu_seqlens", inputs.attention_inputs.cu_seqlens); + printTensorInfo("cu_seqlens_host", inputs.attention_inputs.cu_seqlens_host); printTensorInfo("cu_kv_seqlens", inputs.attention_inputs.cu_kv_seqlens); printTensorInfo("sequence_lengths_plus_1_d", inputs.attention_inputs.sequence_lengths_plus_1_d); printTensorInfo("decode_cu_seqlens_d", inputs.attention_inputs.decode_cu_seqlens_d); diff --git a/rtp_llm/cpp/cuda_graph/cuda_graph_utils.h b/rtp_llm/cpp/cuda_graph/cuda_graph_utils.h index ba1603fcef..ed2ac63ba2 100644 --- a/rtp_llm/cpp/cuda_graph/cuda_graph_utils.h +++ b/rtp_llm/cpp/cuda_graph/cuda_graph_utils.h @@ -27,6 +27,7 @@ class CaptureMemoryHold { CaptureMemoryHold(at::Tensor hidden_states, torch_ext::PyModelInputs& inputs, bool is_embedding): decoder_layer_hidden_states_(hidden_states) { py_model_inputs_.attention_inputs.input_lengths = inputs.attention_inputs.input_lengths; + py_model_inputs_.attention_inputs.input_lengths_d = inputs.attention_inputs.input_lengths_d; py_model_inputs_.attention_inputs.sequence_lengths = inputs.attention_inputs.sequence_lengths; py_model_inputs_.attention_inputs.kv_cache_kernel_block_id_device = inputs.attention_inputs.kv_cache_kernel_block_id_device; @@ -40,11 +41,13 @@ class CaptureMemoryHold { inputs.attention_inputs.kv_cache_kernel_block_id_host_by_group; py_model_inputs_.attention_inputs.kv_cache_layer_to_group = inputs.attention_inputs.kv_cache_layer_to_group; py_model_inputs_.attention_inputs.prefix_lengths = inputs.attention_inputs.prefix_lengths; + py_model_inputs_.attention_inputs.prefix_lengths_d = inputs.attention_inputs.prefix_lengths_d; py_model_inputs_.input_ids = inputs.input_ids; // for spec py_model_inputs_.input_hiddens = inputs.input_hiddens; py_model_inputs_.attention_inputs.cu_seqlens = inputs.attention_inputs.cu_seqlens; + py_model_inputs_.attention_inputs.cu_seqlens_host = inputs.attention_inputs.cu_seqlens_host; py_model_inputs_.attention_inputs.cu_kv_seqlens = inputs.attention_inputs.cu_kv_seqlens; py_model_inputs_.attention_inputs.padding_offset = inputs.attention_inputs.padding_offset; py_model_inputs_.attention_inputs.is_prefill = inputs.attention_inputs.is_prefill; diff --git a/rtp_llm/cpp/cuda_graph/tests/cuda_graph_decode_padding.py b/rtp_llm/cpp/cuda_graph/tests/cuda_graph_decode_padding.py index 56b50cf642..6dca72de5a 100644 --- a/rtp_llm/cpp/cuda_graph/tests/cuda_graph_decode_padding.py +++ b/rtp_llm/cpp/cuda_graph/tests/cuda_graph_decode_padding.py @@ -141,7 +141,7 @@ def build_inputs(self, batch_size: int, max_seq_len: int, seq_size_per_block: in # cu_seqlens cu_len = batch_size + 1 - cu_seqlens = torch.zeros(cu_len, dtype=torch.int32, device="cpu").pin_memory() + cu_seqlens = torch.zeros(cu_len, dtype=torch.int32, device="cuda") attention_inputs.cu_seqlens = cu_seqlens attention_inputs.cu_kv_seqlens = cu_seqlens.clone() diff --git a/rtp_llm/cpp/embedding_engine/EmbeddingExecutor.cc b/rtp_llm/cpp/embedding_engine/EmbeddingExecutor.cc index b6cabd4d9f..7b71e96e3d 100644 --- a/rtp_llm/cpp/embedding_engine/EmbeddingExecutor.cc +++ b/rtp_llm/cpp/embedding_engine/EmbeddingExecutor.cc @@ -94,12 +94,14 @@ absl::StatusOr EmbeddingExecutor::gatherModelInput(const std::li int64_t batch_size = 0; calcTokenNum(streams, token_num, batch_size); GptModelInputs model_input; - model_input.combo_tokens = torch::empty({token_num}, torch::kInt32); - model_input.combo_tokens_type_ids = torch::empty({token_num}, torch::kInt32); - model_input.combo_position_ids = torch::empty({token_num}, torch::kInt32); - model_input.input_lengths = torch::empty({batch_size}, torch::kInt32); - model_input.sequence_lengths = torch::empty({0}, torch::kInt32); - model_input.prefix_lengths = torch::zeros({batch_size}, torch::kInt32); + auto i32_options = torch::TensorOptions(torch::kInt32).pinned_memory(true); + + model_input.combo_tokens = torch::empty({token_num}, i32_options); + model_input.combo_tokens_type_ids = torch::empty({token_num}, i32_options); + model_input.combo_position_ids = torch::empty({token_num}, i32_options); + model_input.input_lengths = torch::empty({batch_size}, i32_options); + model_input.sequence_lengths = torch::empty({0}, i32_options); + model_input.prefix_lengths = torch::zeros({batch_size}, i32_options); int* merged_tokens = model_input.combo_tokens.data_ptr(); int* input_lengths = model_input.input_lengths.data_ptr(); int* merged_positon_ids = model_input.combo_position_ids.data_ptr(); diff --git a/rtp_llm/cpp/models/BUILD b/rtp_llm/cpp/models/BUILD index 643389fa22..e272fb325a 100644 --- a/rtp_llm/cpp/models/BUILD +++ b/rtp_llm/cpp/models/BUILD @@ -129,6 +129,9 @@ cc_library( "//rtp_llm/cpp/utils:debug_utils", "//rtp_llm/cpp/utils:profiling_scope", "//rtp_llm/cpp/cuda_graph:cuda_graph_impl", + "//rtp_llm/cpp/cuda_graph:cuda_graph_base", + "//rtp_llm/cpp/cuda_graph:cuda_graph_hdrs_lib", + "//rtp_llm/models_py/bindings/common:fuse_copy_op", ], visibility = ["//visibility:public"], ) diff --git a/rtp_llm/cpp/models/PyWrappedModel.cc b/rtp_llm/cpp/models/PyWrappedModel.cc index 9f6c46c838..dca975a9d3 100644 --- a/rtp_llm/cpp/models/PyWrappedModel.cc +++ b/rtp_llm/cpp/models/PyWrappedModel.cc @@ -30,7 +30,22 @@ torch::Tensor PyWrappedModel::tensorHoldHostAndToCuda(const torch::Tensor& tenso } buffer_holder_.hold_host(tensor); - return tensor.to(torch::kCUDA, /*non_blocking=*/true, /*copy=*/false); + + if (tensor.numel() == 0) { + return torch::empty(tensor.sizes(), torch::TensorOptions(tensor.dtype()).device(torch::kCUDA)); + } + + // NOTE: since is_pinned() operation costs a lot cpu time, we only check it when pinned_check_remaining_ > 0. + if (pinned_check_remaining_ > 0) { + RTP_LLM_CHECK_WITH_INFO(tensor.is_pinned(), "tensor is not pinned, fused copy requires pinned memory"); + } + + // create tensor on cuda + auto cuda_tensor = torch::empty(tensor.sizes(), torch::TensorOptions(tensor.dtype()).device(torch::kCUDA)); + + d2d_copies_.add(tensor.data_ptr(), cuda_tensor.data_ptr(), tensor.nbytes()); + + return cuda_tensor; } void PyWrappedModel::releaseBuffers() { @@ -69,10 +84,10 @@ torch_ext::PyAttentionInputs PyWrappedModel::buildPyAttentionInputs(const GptMod py_attn_inputs.input_lengths = inputs.input_lengths; if (inputs.kv_cache_kernel_block_id.defined()) { - py_attn_inputs.kv_cache_kernel_block_id_host = inputs.kv_cache_kernel_block_id.clone(); + py_attn_inputs.kv_cache_kernel_block_id_host = inputs.kv_cache_kernel_block_id.clone().pin_memory(); } if (inputs.kv_cache_block_id.defined()) { - py_attn_inputs.kv_cache_block_id_host = inputs.kv_cache_block_id.clone(); + py_attn_inputs.kv_cache_block_id_host = inputs.kv_cache_block_id.clone().pin_memory(); } if (inputs.kv_cache_layer_to_group.defined()) { py_attn_inputs.kv_cache_layer_to_group = inputs.kv_cache_layer_to_group; @@ -94,9 +109,9 @@ torch_ext::PyAttentionInputs PyWrappedModel::buildPyAttentionInputs(const GptMod if (context_batch_size > 0) { torch::Tensor cu_seqlens = - torch::zeros({batch_size + 1}, torch::TensorOptions(torch::kInt32).device(torch::kCPU)); + torch::zeros({batch_size + 1}, torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); torch::Tensor cu_kv_seqlens = - torch::zeros({batch_size + 1}, torch::TensorOptions(torch::kInt32).device(torch::kCPU)); + torch::zeros({batch_size + 1}, torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); cu_seqlens.slice(0, 1, context_batch_size + 1) = py_attn_inputs.input_lengths.cumsum(0); cu_kv_seqlens.slice(0, 1, context_batch_size + 1) = @@ -104,16 +119,22 @@ torch_ext::PyAttentionInputs PyWrappedModel::buildPyAttentionInputs(const GptMod py_attn_inputs.context_total_kv_length = cu_kv_seqlens[context_batch_size].item(); py_attn_inputs.total_tokens = cu_seqlens[batch_size].item(); + py_attn_inputs.cu_seqlens_host = cu_seqlens; py_attn_inputs.cu_seqlens = tensorHoldHostAndToCuda(cu_seqlens); py_attn_inputs.cu_kv_seqlens = tensorHoldHostAndToCuda(cu_kv_seqlens); } else { py_attn_inputs.total_tokens = 0; + py_attn_inputs.cu_seqlens_host = + torch::zeros({batch_size + 1}, torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); py_attn_inputs.cu_seqlens = torch::zeros({batch_size + 1}, torch::TensorOptions(torch::kInt32).device(torch::kCUDA)); py_attn_inputs.cu_kv_seqlens = torch::zeros({batch_size + 1}, torch::TensorOptions(torch::kInt32).device(torch::kCUDA)); - torch::Tensor decode_cu_seqlens = torch::arange( - 0, py_attn_inputs.sequence_lengths.size(0) + 1, 1, torch::TensorOptions(torch::kInt32).device(torch::kCPU)); + torch::Tensor decode_cu_seqlens = + torch::arange(0, + py_attn_inputs.sequence_lengths.size(0) + 1, + 1, + torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); py_attn_inputs.decode_cu_seqlens_host = decode_cu_seqlens; py_attn_inputs.decode_cu_seqlens_d = tensorHoldHostAndToCuda(decode_cu_seqlens); } @@ -124,9 +145,11 @@ torch_ext::PyAttentionInputs PyWrappedModel::buildPyAttentionInputs(const GptMod // In qwen3-next target verify mode, sequence_lengths_plus_1_d uses prefix_lengths if (py_attn_inputs.is_target_verify) { - py_attn_inputs.sequence_lengths_plus_1_d = tensorHoldHostAndToCuda(py_attn_inputs.prefix_lengths + 1); + auto sequence_lengths_plus_1 = (py_attn_inputs.prefix_lengths + 1).pin_memory(); + py_attn_inputs.sequence_lengths_plus_1_d = tensorHoldHostAndToCuda(sequence_lengths_plus_1); } else { - py_attn_inputs.sequence_lengths_plus_1_d = tensorHoldHostAndToCuda(py_attn_inputs.sequence_lengths + 1); + auto sequence_lengths_plus_1 = (py_attn_inputs.sequence_lengths + 1).pin_memory(); + py_attn_inputs.sequence_lengths_plus_1_d = tensorHoldHostAndToCuda(sequence_lengths_plus_1); } return py_attn_inputs; @@ -154,8 +177,7 @@ void PyWrappedModel::setupKVCacheForAttentionInputs(torch_ext::PyAttentionInputs // group view: [batch, kernel_blocks] on HOST auto group_view = inputs.kv_cache_kernel_block_id[g]; py_attn_inputs.kv_cache_kernel_block_id_host_by_group.push_back(group_view); - py_attn_inputs.kv_cache_kernel_block_id_device_by_group.push_back( - group_view.to(torch::kCUDA, /*non_blocking=*/true)); + py_attn_inputs.kv_cache_kernel_block_id_device_by_group.push_back(tensorHoldHostAndToCuda(group_view)); } // Legacy 2-D fields default to group 0. @@ -259,6 +281,11 @@ std::optional PyWrappedModel::prepareWriteCacheParams(const GptModelOutputs PyWrappedModel::forwardMicroBatched(const GptModelInputs& inputs) { RTP_LLM_PROFILE_SCOPE("py_model.forwardMicroBatched"); + d2d_copies_.clear(); + if (pinned_check_remaining_ > 0) { + --pinned_check_remaining_; + } + { py::gil_scoped_acquire gil; if (device_props_.ffn_as_service) { @@ -296,6 +323,8 @@ GptModelOutputs PyWrappedModel::forwardMicroBatched(const GptModelInputs& inputs cache_store_async_writer_->init(); } + fusedCopy(d2d_copies_); + std::vector py_model_outputs; { py::gil_scoped_acquire gil; @@ -351,9 +380,12 @@ GptModelOutputs PyWrappedModel::forwardMicroBatched(const GptModelInputs& inputs GptModelOutputs PyWrappedModel::forward(const GptModelInputs& inputs) { RTP_LLM_PROFILE_SCOPE("py_model.forward"); + d2d_copies_.clear(); DevicePerfWrapper wrapper(enable_device_perf_, "py model forward"); holdInputsHostBuffers(inputs); - + if (pinned_check_remaining_ > 0) { + --pinned_check_remaining_; + } try { RTP_LLM_LOG_DEBUG("Calling forward method on Python object instance."); @@ -387,6 +419,9 @@ GptModelOutputs PyWrappedModel::forward(const GptModelInputs& inputs) { calculatePaddingOffset(attention_inputs); attention_inputs.padding_offset = tensorHoldHostAndToCuda(attention_inputs.padding_offset); + // launch fused copy + fusedCopy(d2d_copies_); + auto py_model_inputs = PyModelInputs({token_ids, input_hiddens, attention_inputs, bert_embedding_inputs}); PyModelOutputs py_model_outputs; torch::Tensor hidden_states; diff --git a/rtp_llm/cpp/models/PyWrappedModel.h b/rtp_llm/cpp/models/PyWrappedModel.h index 8bae97cfeb..092efb91eb 100644 --- a/rtp_llm/cpp/models/PyWrappedModel.h +++ b/rtp_llm/cpp/models/PyWrappedModel.h @@ -16,7 +16,6 @@ #if USING_CUDA || USING_ROCM #include "rtp_llm/cpp/cuda_graph/cuda_graph_runner.h" #endif - #include "rtp_llm/cpp/models/context_parallel/ContextParallelProcessorBase.h" #include "rtp_llm/cpp/core/DeviceData.h" #include "rtp_llm/cpp/core/ExecOps.h" @@ -94,6 +93,13 @@ class PyWrappedModel: public ModelBase { std::unique_ptr context_parallel_processor_{nullptr}; std::unique_ptr cache_store_async_writer_; + + // Accumulated H2D copies from tensorHoldHostAndToCuda(); flushed as one kernel per forward. + FusedD2DCopyParams d2d_copies_; + + // is_pinned() is expensive on CPU; only assert during first N forwards as a sanity check. + static constexpr int kPinnedCheckForwardCount = 3; + int pinned_check_remaining_{kPinnedCheckForwardCount}; }; // NOTE(wangyin): constructor can not be compiled correctly when placed in cc file. diff --git a/rtp_llm/cpp/normal_engine/speculative/MtpBatchStreamProcessor.cc b/rtp_llm/cpp/normal_engine/speculative/MtpBatchStreamProcessor.cc index 14b3af6f4d..3da39adb55 100644 --- a/rtp_llm/cpp/normal_engine/speculative/MtpBatchStreamProcessor.cc +++ b/rtp_llm/cpp/normal_engine/speculative/MtpBatchStreamProcessor.cc @@ -166,7 +166,7 @@ void MtpBatchStreamProcessor::prepareOneStepSpecDecodeModelInput(const StreamGro size_t batch_size = stream_groups.size(); // prepare target model input buffer - auto target_prefix_lengths = model_input.sequence_lengths.cpu().clone(); + auto target_prefix_lengths = model_input.sequence_lengths.cpu().clone().pin_memory(); // allocate target_combo_tokens shape [batch_size, propose_step_ + 1] auto target_combo_tokens = diff --git a/rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc b/rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc index fed7505b0a..d24c43603f 100644 --- a/rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc +++ b/rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc @@ -810,7 +810,7 @@ void MtpExecutor::draftModelDecode(GptModelInputs& model_input, // update TP > 0 batch_size size_t batch_size = model_input.combo_tokens.size(0); - spec_prefix_lengths = model_input.sequence_lengths.cpu().clone(); + spec_prefix_lengths = model_input.sequence_lengths.cpu().clone().pin_memory(); auto pre_propose_token_t_raw = model_input.combo_tokens.to(torch::kCUDA).clone(); @@ -864,8 +864,11 @@ void MtpExecutor::draftModelDecode(GptModelInputs& model_input, draft_token_ids_t = torch::cat(draft_token_ids_list, 1).reshape({(int)batch_size, (int)(propose_step_ + 1)}).contiguous(); - auto lm_output_indexes = torch::empty({(int64_t)(batch_size * (propose_step_ + 1))}, torch::kInt32); - auto input_lengths = torch::empty({(int64_t)batch_size}, torch::kInt32); + auto lm_output_indexes = + torch::empty({(int64_t)(batch_size * (propose_step_ + 1))}, + torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); + auto input_lengths = torch::empty({(int64_t)batch_size}, + torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); for (int i = 0; i < batch_size; i++) { input_lengths.data_ptr()[i] = propose_step_ + 1; @@ -874,11 +877,12 @@ void MtpExecutor::draftModelDecode(GptModelInputs& model_input, lm_output_indexes.data_ptr()[i] = i; } - model_input.input_lengths = std::move(input_lengths); - model_input.lm_output_indexes = std::move(lm_output_indexes); - model_input.prefix_lengths = spec_prefix_lengths; - model_input.combo_tokens = draft_token_ids_t.reshape({(int64_t)(batch_size * (propose_step_ + 1))}); - model_input.sequence_lengths = torch::empty({0}, torch::kInt32); + model_input.input_lengths = std::move(input_lengths); + model_input.lm_output_indexes = std::move(lm_output_indexes); + model_input.prefix_lengths = spec_prefix_lengths; + model_input.combo_tokens = draft_token_ids_t.reshape({(int64_t)(batch_size * (propose_step_ + 1))}); + model_input.sequence_lengths = + torch::empty({0}, torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); model_input.last_hidden_states = torch::Tensor(); // Since other tp ranks don't have streams, its combo_tokens' first token is not correct. diff --git a/rtp_llm/models_py/bindings/OpDefs.cc b/rtp_llm/models_py/bindings/OpDefs.cc index e6dd576dd4..a053fb4415 100644 --- a/rtp_llm/models_py/bindings/OpDefs.cc +++ b/rtp_llm/models_py/bindings/OpDefs.cc @@ -109,6 +109,7 @@ void registerPyOpDefs(pybind11::module& m) { .def_readwrite("kv_cache_layer_to_group", &PyAttentionInputs::kv_cache_layer_to_group) .def_readwrite("dtype", &PyAttentionInputs::dtype) .def_readwrite("cu_seqlens", &PyAttentionInputs::cu_seqlens) + .def_readwrite("cu_seqlens_host", &PyAttentionInputs::cu_seqlens_host) .def_readwrite("cu_kv_seqlens", &PyAttentionInputs::cu_kv_seqlens) .def_readwrite("context_total_kv_length", &PyAttentionInputs::context_total_kv_length) .def_readwrite("total_tokens", &PyAttentionInputs::total_tokens) diff --git a/rtp_llm/models_py/bindings/OpDefs.h b/rtp_llm/models_py/bindings/OpDefs.h index 80e78b2759..64ec924923 100644 --- a/rtp_llm/models_py/bindings/OpDefs.h +++ b/rtp_llm/models_py/bindings/OpDefs.h @@ -182,8 +182,11 @@ struct PyAttentionInputs { std::vector kv_cache_kernel_block_id_device_by_group; torch::Tensor kv_cache_layer_to_group; caffe2::TypeMeta dtype; - // for `FusedRopeKVCacheDecodeOp`. + // Cumulative sequence lengths for attention kernels (e.g. FusedRopeKVCacheDecodeOp). + // cu_seqlens lives on CUDA device; cu_seqlens_host is its pinned-memory CPU mirror + // used for CUDA graph replay (write host → async copy to device, avoiding GPU-side fills). torch::Tensor cu_seqlens; + torch::Tensor cu_seqlens_host; torch::Tensor cu_kv_seqlens; torch::Tensor decode_cu_seqlens_host; @@ -197,7 +200,7 @@ struct PyAttentionInputs { std::optional prefill_cuda_graph_copy_params; bool is_s_padded = false; - // deivce tensor + // Device-side mirrors of host tensors, managed by C++ for fused D2D copy in CUDA graph. torch::Tensor prefix_lengths_d; torch::Tensor sequence_lengths_plus_1_d; torch::Tensor input_lengths_d; diff --git a/rtp_llm/models_py/bindings/OpDefsUtils.h b/rtp_llm/models_py/bindings/OpDefsUtils.h index ac3f0c6a3a..cf515d1067 100644 --- a/rtp_llm/models_py/bindings/OpDefsUtils.h +++ b/rtp_llm/models_py/bindings/OpDefsUtils.h @@ -35,8 +35,9 @@ inline void calculatePaddingOffset(torch_ext::PyAttentionInputs& py_attn_inputs) // inputs_length: [1,2,1,1] ,total_tokens = 5 // padding_offsets: [0,1,1,1,2] - int max_seq_len = py_attn_inputs.input_lengths.max().item(); - auto padding_offset_host = torch::zeros({total_tokens}, torch::TensorOptions(torch::kInt32).device(torch::kCPU)); + int max_seq_len = py_attn_inputs.input_lengths.max().item(); + auto padding_offset_host = + torch::zeros({total_tokens}, torch::TensorOptions(torch::kInt32).device(torch::kCPU).pinned_memory(true)); if (total_tokens > 0) { getPaddingOffset(padding_offset_host.data_ptr(), diff --git a/rtp_llm/models_py/bindings/common/BUILD b/rtp_llm/models_py/bindings/common/BUILD index 08228611e5..740d4818e1 100644 --- a/rtp_llm/models_py/bindings/common/BUILD +++ b/rtp_llm/models_py/bindings/common/BUILD @@ -2,15 +2,41 @@ load("//bazel:arch_select.bzl", "torch_deps") load("//:def.bzl", "copts") +cc_library( + name = "fuse_copy_op", + srcs = ["FusedCopyOp.cc"], + deps = [ + "//rtp_llm/cpp/core:exec_ops_hdr", + "//rtp_llm/cpp/core:types_hdr", + "//rtp_llm/models_py/bindings/common/kernels:fuse_copy_kernel", + ] + torch_deps() + select({ + "//:using_cuda": [ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudart", + ], + "//:using_rocm": [ + "@local_config_rocm//rocm:rocm_headers", + "@local_config_rocm//rocm:rocm", + "//rtp_llm/cpp/rocm:rocm_types_hdr", + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + copts = copts(), + alwayslink = 1, +) + cc_library( name = "common", - srcs = glob(["*.cc"]), + srcs = glob(["*.cc"], exclude = ["FusedCopyOp.cc"]), hdrs = glob(["*.h"]), deps = [ + ":fuse_copy_op", "//rtp_llm/models_py/bindings:op_defs", "//rtp_llm/cpp/core:op_data", "//rtp_llm/cpp/core:exec_ops_hdr", "//rtp_llm/cpp/core:cache_store_async_writer", + "//rtp_llm/models_py/bindings/common/kernels:fuse_copy_kernel", ] + torch_deps() + select({ "//:using_cuda": [ "//rtp_llm/models_py/bindings/common/kernels:kernels_cu", diff --git a/rtp_llm/models_py/bindings/common/FusedCopyOp.cc b/rtp_llm/models_py/bindings/common/FusedCopyOp.cc new file mode 100644 index 0000000000..7e4822fa44 --- /dev/null +++ b/rtp_llm/models_py/bindings/common/FusedCopyOp.cc @@ -0,0 +1,41 @@ +#include "rtp_llm/cpp/core/ExecOps.h" +#include "rtp_llm/cpp/core/Types.h" +#include "rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.h" + +#if USING_CUDA +#include +#include +#endif +#if USING_ROCM +#include +#include "rtp_llm/cpp/rocm/cuda_shims.h" +#include +#endif + +namespace rtp_llm { + +void fusedCopy(const FusedD2DCopyParams& params) { +#if USING_CUDA + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +#elif USING_ROCM + hipStream_t stream = at::hip::getCurrentHIPStream(); +#else + throw std::runtime_error("No supported GPU backend found for fusedCopy"); + return; +#endif + invokeFusedCopy(params, stream); +} + +void fusedStridedCopy(const FusedStridedCopyParams& params) { +#if USING_CUDA + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +#elif USING_ROCM + hipStream_t stream = at::hip::getCurrentHIPStream(); +#else + throw std::runtime_error("No supported GPU backend found for fusedStridedCopy"); + return; +#endif + invokeFusedStridedCopy(params, stream); +} + +} // namespace rtp_llm diff --git a/rtp_llm/models_py/bindings/common/kernels/BUILD b/rtp_llm/models_py/bindings/common/kernels/BUILD index 9b676ae8c5..82cfb17f9b 100644 --- a/rtp_llm/models_py/bindings/common/kernels/BUILD +++ b/rtp_llm/models_py/bindings/common/kernels/BUILD @@ -210,6 +210,33 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "fuse_copy_util", + hdrs = [ + "fuse_copy_util.h", + ], + deps = any_cuda_deps, + copts = any_cuda_copts(), + include_prefix = "src", + visibility = ["//visibility:public"], +) + +cc_library( + name = "fuse_copy_kernel", + srcs = [ + "fuse_copy_kernel.cu", + ], + hdrs = [ + "fuse_copy_kernel.h", + ], + deps = any_cuda_deps + [ + ":fuse_copy_util", + ], + copts = any_cuda_copts(), + include_prefix = "src", + visibility = ["//visibility:public"], +) + cc_library( name = "kernels_cu", deps = [ diff --git a/rtp_llm/models_py/bindings/common/kernels/copy_utils.h b/rtp_llm/models_py/bindings/common/kernels/copy_utils.h index e9093164f4..9b6ea8bcaa 100644 --- a/rtp_llm/models_py/bindings/common/kernels/copy_utils.h +++ b/rtp_llm/models_py/bindings/common/kernels/copy_utils.h @@ -3,7 +3,7 @@ #include #include -#if USEING_CUDA +#if USING_CUDA #include #endif diff --git a/rtp_llm/models_py/bindings/common/kernels/cuda_graph_copy_kernel.cu b/rtp_llm/models_py/bindings/common/kernels/cuda_graph_copy_kernel.cu index 589f5e404e..e066905fe8 100644 --- a/rtp_llm/models_py/bindings/common/kernels/cuda_graph_copy_kernel.cu +++ b/rtp_llm/models_py/bindings/common/kernels/cuda_graph_copy_kernel.cu @@ -109,14 +109,10 @@ void invokeCudaGraphCopySmall2Large(T* input_tensor, return; } - // Calculate grid and block dimensions - // Use cu_seq_len[batch_size] which contains total token count - const int total_elements = cu_seq_len[*batch_size] * hidden_size; - dim3 block(256); - const int grid_size = min((total_elements + block.x - 1) / block.x, 65536); - dim3 grid(grid_size); - - // Launch kernel + // use fixed block and grid size for cuda graph + dim3 block(256); + dim3 grid(1024); + cudaGraphCopySmall2LargeKernel<<>>( input_tensor, output_tensor, input_lengths, batch_size, max_seq_len, hidden_size, cu_seq_len); } @@ -177,21 +173,15 @@ void invokeCudaGraphCopyLarge2Small(T* input_tensor, #elif USING_ROCM hipStream_t stream) { #endif - // Validate input parameters if (input_tensor == nullptr || output_tensor == nullptr || input_lengths == nullptr || *batch_size <= 0 || max_seq_len <= 0 || hidden_size <= 0 || cu_seq_len == nullptr) { return; } - // Calculate grid and block dimensions - // Use cu_seq_len[batch_size] which contains total token count - const int total_elements = cu_seq_len[*batch_size] * hidden_size; - - dim3 block(256); - const int grid_size = min((total_elements + block.x - 1) / block.x, 65536); - dim3 grid(grid_size); + // use fixed block and grid size for cuda graph + dim3 block(256); + dim3 grid(1024); - // Launch kernel cudaGraphCopyLarge2SmallKernel<<>>( input_tensor, output_tensor, input_lengths, batch_size, max_seq_len, hidden_size, cu_seq_len); } diff --git a/rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.cu b/rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.cu new file mode 100644 index 0000000000..1d6b34dc18 --- /dev/null +++ b/rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.cu @@ -0,0 +1,88 @@ +#include +#include + +#include "rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.h" + +namespace rtp_llm { + +static constexpr int FUSED_COPY_BLOCKS_PER_TASK = 8; +static constexpr int FUSED_COPY_THREADS = 256; + +__global__ void fusedCopyKernel(FusedD2DCopyParams params) { + const int copy_idx = blockIdx.y; + if (copy_idx >= params.num_copies) + return; + + const size_t total_bytes = params.size[copy_idx]; + const size_t global_tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t global_stride = static_cast(gridDim.x) * blockDim.x; + + const auto src_addr = reinterpret_cast(params.src[copy_idx]); + const auto dst_addr = reinterpret_cast(params.dst[copy_idx]); + + if ((src_addr % sizeof(int4) == 0) && (dst_addr % sizeof(int4) == 0)) { + // Fast path: 16-byte vectorized bulk copy + const int4* src = reinterpret_cast(src_addr); + int4* dst = reinterpret_cast(dst_addr); + const size_t n16 = total_bytes / sizeof(int4); + + for (size_t i = global_tid; i < n16; i += global_stride) { + dst[i] = src[i]; + } + + if (blockIdx.x == 0) { + const size_t rem_start = n16 * sizeof(int4); + const char* src_byte = reinterpret_cast(src_addr); + char* dst_byte = reinterpret_cast(dst_addr); + for (size_t i = rem_start + threadIdx.x; i < total_bytes; i += blockDim.x) { + dst_byte[i] = src_byte[i]; + } + } + } else { + // Slow path: byte-by-byte copy for unaligned pointers + const char* src_byte = reinterpret_cast(src_addr); + char* dst_byte = reinterpret_cast(dst_addr); + for (size_t i = global_tid; i < total_bytes; i += global_stride) { + dst_byte[i] = src_byte[i]; + } + } +} + +void invokeFusedCopy(const FusedD2DCopyParams& params, cudaStream_t stream) { + if (params.num_copies <= 0) + return; + dim3 grid(FUSED_COPY_BLOCKS_PER_TASK, params.num_copies); + fusedCopyKernel<<>>(params); +} + +__global__ void fusedStridedCopyKernel(FusedStridedCopyParams params) { + const int copy_idx = blockIdx.y; + if (copy_idx >= params.num_copies) + return; + + const size_t nrows = params.num_rows[copy_idx]; + const size_t rbytes = params.row_bytes[copy_idx]; + const size_t src_stride = params.src_row_stride[copy_idx]; + const size_t dst_stride = params.dst_row_stride[copy_idx]; + const char* src = reinterpret_cast(params.src[copy_idx]); + char* dst = reinterpret_cast(params.dst[copy_idx]); + + const size_t total = nrows * rbytes; + const size_t global_tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t stride = static_cast(gridDim.x) * blockDim.x; + + for (size_t idx = global_tid; idx < total; idx += stride) { + const size_t row = idx / rbytes; + const size_t col = idx % rbytes; + dst[row * dst_stride + col] = src[row * src_stride + col]; + } +} + +void invokeFusedStridedCopy(const FusedStridedCopyParams& params, cudaStream_t stream) { + if (params.num_copies <= 0) + return; + dim3 grid(FUSED_COPY_BLOCKS_PER_TASK, params.num_copies); + fusedStridedCopyKernel<<>>(params); +} + +} // namespace rtp_llm diff --git a/rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.h b/rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.h new file mode 100644 index 0000000000..07f69ce1cc --- /dev/null +++ b/rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.h @@ -0,0 +1,19 @@ +#pragma once +#include "rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h" + +#if USING_CUDA +#include +#endif + +#if USING_ROCM +#include +#include "rtp_llm/cpp/rocm/cuda_shims.h" +#endif + +namespace rtp_llm { + +void invokeFusedCopy(const FusedD2DCopyParams& params, cudaStream_t stream); + +void invokeFusedStridedCopy(const FusedStridedCopyParams& params, cudaStream_t stream); + +} // namespace rtp_llm diff --git a/rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h b/rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h new file mode 100644 index 0000000000..ec62ca9bfb --- /dev/null +++ b/rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h @@ -0,0 +1,62 @@ +#pragma once +#include +#include + +namespace rtp_llm { + +// NOTE: Hardcoded limits for fused copies. It is enough for most cases. If you need more, please increase the limits. +static constexpr int MAX_FUSED_D2D_COPIES = 16; +static constexpr int MAX_FUSED_STRIDED_COPIES = 16; + +inline void copyParamsAssert(bool value, const std::string& msg) { + if (!value) { + throw std::runtime_error(msg); + } +} + +struct FusedD2DCopyParams { + const void* src[MAX_FUSED_D2D_COPIES]; + void* dst[MAX_FUSED_D2D_COPIES]; + size_t size[MAX_FUSED_D2D_COPIES]; + int num_copies = 0; + + void add(const void* src_ptr, void* dst_ptr, size_t bytes) { + copyParamsAssert(num_copies < MAX_FUSED_D2D_COPIES, + "FusedD2DCopyParams: num_copies exceeds MAX_FUSED_D2D_COPIES"); + src[num_copies] = src_ptr; + dst[num_copies] = dst_ptr; + size[num_copies] = bytes; + ++num_copies; + } + + void clear() { + num_copies = 0; + } +}; + +struct FusedStridedCopyParams { + const void* src[MAX_FUSED_STRIDED_COPIES]; + void* dst[MAX_FUSED_STRIDED_COPIES]; + size_t num_rows[MAX_FUSED_STRIDED_COPIES]; + size_t row_bytes[MAX_FUSED_STRIDED_COPIES]; + size_t src_row_stride[MAX_FUSED_STRIDED_COPIES]; + size_t dst_row_stride[MAX_FUSED_STRIDED_COPIES]; + int num_copies = 0; + + void add(const void* src_ptr, void* dst_ptr, size_t rows, size_t row_b, size_t src_stride, size_t dst_stride) { + copyParamsAssert(num_copies < MAX_FUSED_STRIDED_COPIES, + "FusedStridedCopyParams: num_copies exceeds MAX_FUSED_STRIDED_COPIES"); + src[num_copies] = src_ptr; + dst[num_copies] = dst_ptr; + num_rows[num_copies] = rows; + row_bytes[num_copies] = row_b; + src_row_stride[num_copies] = src_stride; + dst_row_stride[num_copies] = dst_stride; + ++num_copies; + } + + void clear() { + num_copies = 0; + } +}; +} // namespace rtp_llm diff --git a/rtp_llm/models_py/bindings/cuda/BUILD b/rtp_llm/models_py/bindings/cuda/BUILD index f4104b0642..5fdbece723 100644 --- a/rtp_llm/models_py/bindings/cuda/BUILD +++ b/rtp_llm/models_py/bindings/cuda/BUILD @@ -54,6 +54,7 @@ cc_library( "//rtp_llm/models_py/bindings/common:common", "//rtp_llm/cpp/config:model_config", "//rtp_llm/models_py/bindings/cuda/kernels:user_buffers", + "//rtp_llm/cpp/utils:profiling_scope", ] + select({ "@//:using_cuda12_9": [ "//rtp_llm/models_py/bindings/cuda/kernels:scaled_fp4_quant", From 64e797c985845d6c6c5c6da2cb9577a2a406e012 Mon Sep 17 00:00:00 2001 From: "huzetao.hzt" Date: Thu, 16 Apr 2026 14:56:07 +0800 Subject: [PATCH 2/4] test: add fused copy kernel tests --- .../bindings/common/kernels/test/BUILD | 27 ++ .../kernels/test/fuse_copy_kernel_test.cc | 388 ++++++++++++++++++ 2 files changed, 415 insertions(+) create mode 100644 rtp_llm/models_py/bindings/common/kernels/test/BUILD create mode 100644 rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc diff --git a/rtp_llm/models_py/bindings/common/kernels/test/BUILD b/rtp_llm/models_py/bindings/common/kernels/test/BUILD new file mode 100644 index 0000000000..d83a0d48cd --- /dev/null +++ b/rtp_llm/models_py/bindings/common/kernels/test/BUILD @@ -0,0 +1,27 @@ +load("//:def.bzl", "cc_test_wrapper", "any_cuda_copts") + +cc_test = cc_test_wrapper + +any_cuda_deps = select({ + "@//:using_cuda": [ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudart", + ], + "@//:using_rocm": [ + "@local_config_rocm//rocm:rocm_headers", + "@local_config_rocm//rocm:hip", + "//rtp_llm/cpp/rocm:rocm_types_hdr", + ], + "//conditions:default": [], +}) + +cc_test( + name = "fuse_copy_kernel_test", + srcs = ["fuse_copy_kernel_test.cc"], + copts = any_cuda_copts(), + deps = any_cuda_deps + [ + "//rtp_llm/models_py/bindings/common/kernels:fuse_copy_kernel", + "@com_google_googletest//:gtest", + ], + exec_properties = {"gpu": "H20"}, +) diff --git a/rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc b/rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc new file mode 100644 index 0000000000..bdb0209871 --- /dev/null +++ b/rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc @@ -0,0 +1,388 @@ +#include +#include +#include + +#if USING_CUDA +#include +#elif USING_ROCM +#include +#include "rtp_llm/cpp/rocm/cuda_shims.h" +#endif + +#include "rtp_llm/models_py/bindings/common/kernels/fuse_copy_kernel.h" + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- +namespace { + +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t _e = (expr); \ + ASSERT_EQ(_e, cudaSuccess) << "CUDA error: " << cudaGetErrorString(_e); \ + } while (0) + +// Allocate device memory and copy host data to it. +template +T* deviceAlloc(const std::vector& host_data) { + T* d_ptr = nullptr; + EXPECT_EQ(cudaMalloc(&d_ptr, host_data.size() * sizeof(T)), cudaSuccess); + EXPECT_EQ(cudaMemcpy(d_ptr, host_data.data(), host_data.size() * sizeof(T), cudaMemcpyHostToDevice), cudaSuccess); + return d_ptr; +} + +// Allocate zero-initialised device memory. +template +T* deviceAllocZero(size_t n) { + T* d_ptr = nullptr; + EXPECT_EQ(cudaMalloc(&d_ptr, n * sizeof(T)), cudaSuccess); + EXPECT_EQ(cudaMemset(d_ptr, 0, n * sizeof(T)), cudaSuccess); + return d_ptr; +} + +// Copy device data back to a host vector. +template +std::vector deviceToHost(const T* d_ptr, size_t n) { + std::vector host(n); + EXPECT_EQ(cudaMemcpy(host.data(), d_ptr, n * sizeof(T), cudaMemcpyDeviceToHost), cudaSuccess); + return host; +} + +} // namespace + +// --------------------------------------------------------------------------- +// FusedCopy tests (invokeFusedCopy) +// --------------------------------------------------------------------------- + +class FusedCopyTest: public ::testing::Test { +protected: + cudaStream_t stream_{}; + + void SetUp() override { + CUDA_CHECK(cudaStreamCreate(&stream_)); + } + void TearDown() override { + cudaStreamDestroy(stream_); + } +}; + +// num_copies == 0 should be a no-op; ensure no crash. +TEST_F(FusedCopyTest, ZeroCopies) { + rtp_llm::FusedD2DCopyParams params; + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); +} + +// Single aligned copy (16-byte vectorised fast path). +TEST_F(FusedCopyTest, SingleAlignedCopy) { + constexpr size_t N = 1024; // 1024 bytes, 16-byte aligned + std::vector host_src(N); + for (size_t i = 0; i < N; ++i) + host_src[i] = static_cast(i & 0xFF); + + uint8_t* d_src = deviceAlloc(host_src); + uint8_t* d_dst = deviceAllocZero(N); + + rtp_llm::FusedD2DCopyParams params; + params.add(d_src, d_dst, N); + + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result = deviceToHost(d_dst, N); + for (size_t i = 0; i < N; ++i) + ASSERT_EQ(result[i], host_src[i]) << "mismatch at byte " << i; + + cudaFree(d_src); + cudaFree(d_dst); +} + +// Unaligned copy: shift dst by 1 byte so the slow (byte-by-byte) path triggers. +TEST_F(FusedCopyTest, UnalignedDstCopy) { + constexpr size_t N = 128; + + std::vector host_src(N); + for (size_t i = 0; i < N; ++i) + host_src[i] = static_cast((i * 7 + 3) & 0xFF); + + // Allocate a buffer that is 1 byte larger, then offset dst by 1 to break alignment. + uint8_t* d_src = deviceAlloc(host_src); + uint8_t* d_dst_base = deviceAllocZero(N + 1); + uint8_t* d_dst = d_dst_base + 1; // unaligned + + rtp_llm::FusedD2DCopyParams params; + params.add(d_src, d_dst, N); + + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result = deviceToHost(d_dst, N); + for (size_t i = 0; i < N; ++i) + ASSERT_EQ(result[i], host_src[i]) << "mismatch at byte " << i; + + cudaFree(d_src); + cudaFree(d_dst_base); +} + +// Copy where size is not a multiple of 16 — exercises the remainder loop. +TEST_F(FusedCopyTest, NonMultipleOf16Size) { + constexpr size_t N = 37; // deliberately not a multiple of 16 + + std::vector host_src(N); + for (size_t i = 0; i < N; ++i) + host_src[i] = static_cast(i); + + uint8_t* d_src = deviceAlloc(host_src); + uint8_t* d_dst = deviceAllocZero(N); + + rtp_llm::FusedD2DCopyParams params; + params.add(d_src, d_dst, N); + + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result = deviceToHost(d_dst, N); + for (size_t i = 0; i < N; ++i) + ASSERT_EQ(result[i], host_src[i]) << "mismatch at byte " << i; + + cudaFree(d_src); + cudaFree(d_dst); +} + +// Multiple copies batched into one kernel launch. +TEST_F(FusedCopyTest, MultipleCopies) { + const std::vector sizes = {64, 128, 256, 512}; + + std::vector> host_srcs(sizes.size()); + std::vector d_srcs(sizes.size()); + std::vector d_dsts(sizes.size()); + + for (size_t c = 0; c < sizes.size(); ++c) { + host_srcs[c].resize(sizes[c]); + for (size_t i = 0; i < sizes[c]; ++i) + host_srcs[c][i] = static_cast((c * 13 + i) & 0xFF); + d_srcs[c] = deviceAlloc(host_srcs[c]); + d_dsts[c] = deviceAllocZero(sizes[c]); + } + + rtp_llm::FusedD2DCopyParams params; + for (size_t c = 0; c < sizes.size(); ++c) + params.add(d_srcs[c], d_dsts[c], sizes[c]); + + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + for (size_t c = 0; c < sizes.size(); ++c) { + auto result = deviceToHost(d_dsts[c], sizes[c]); + for (size_t i = 0; i < sizes[c]; ++i) + ASSERT_EQ(result[i], host_srcs[c][i]) << "copy " << c << " mismatch at byte " << i; + } + + for (size_t c = 0; c < sizes.size(); ++c) { + cudaFree(d_srcs[c]); + cudaFree(d_dsts[c]); + } +} + +// Fill MAX_FUSED_D2D_COPIES copies to stress the capacity limit. +TEST_F(FusedCopyTest, MaxFusedCopies) { + constexpr size_t N = 256; + + std::vector> host_srcs(rtp_llm::MAX_FUSED_D2D_COPIES); + std::vector d_srcs(rtp_llm::MAX_FUSED_D2D_COPIES); + std::vector d_dsts(rtp_llm::MAX_FUSED_D2D_COPIES); + + for (int c = 0; c < rtp_llm::MAX_FUSED_D2D_COPIES; ++c) { + host_srcs[c].resize(N); + for (size_t i = 0; i < N; ++i) + host_srcs[c][i] = static_cast((c * 17 + i) & 0xFF); + d_srcs[c] = deviceAlloc(host_srcs[c]); + d_dsts[c] = deviceAllocZero(N); + } + + rtp_llm::FusedD2DCopyParams params; + for (int c = 0; c < rtp_llm::MAX_FUSED_D2D_COPIES; ++c) + params.add(d_srcs[c], d_dsts[c], N); + + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + for (int c = 0; c < rtp_llm::MAX_FUSED_D2D_COPIES; ++c) { + auto result = deviceToHost(d_dsts[c], N); + for (size_t i = 0; i < N; ++i) + ASSERT_EQ(result[i], host_srcs[c][i]) << "copy " << c << " mismatch at byte " << i; + } + + for (int c = 0; c < rtp_llm::MAX_FUSED_D2D_COPIES; ++c) { + cudaFree(d_srcs[c]); + cudaFree(d_dsts[c]); + } +} + +// --------------------------------------------------------------------------- +// FusedStridedCopy tests (invokeFusedStridedCopy) +// --------------------------------------------------------------------------- + +class FusedStridedCopyTest: public ::testing::Test { +protected: + cudaStream_t stream_{}; + + void SetUp() override { + CUDA_CHECK(cudaStreamCreate(&stream_)); + } + void TearDown() override { + cudaStreamDestroy(stream_); + } +}; + +// num_copies == 0 should be a no-op. +TEST_F(FusedStridedCopyTest, ZeroCopies) { + rtp_llm::FusedStridedCopyParams params; + rtp_llm::invokeFusedStridedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); +} + +// Basic strided copy: src_stride > row_bytes (skip padding bytes in source). +// Layout: src has rows of src_stride bytes, only row_bytes are valid data. +// dst is compact (dst_stride == row_bytes). +TEST_F(FusedStridedCopyTest, SingleStridedCopy) { + constexpr size_t NROWS = 8; + constexpr size_t ROW_BYTES = 32; + constexpr size_t SRC_STRIDE = 64; // each source row is 64 bytes wide + constexpr size_t DST_STRIDE = ROW_BYTES; // compact destination + + std::vector host_src(NROWS * SRC_STRIDE, 0xAB); + // Fill only the valid data region. + for (size_t r = 0; r < NROWS; ++r) + for (size_t b = 0; b < ROW_BYTES; ++b) + host_src[r * SRC_STRIDE + b] = static_cast((r * ROW_BYTES + b) & 0xFF); + + uint8_t* d_src = deviceAlloc(host_src); + uint8_t* d_dst = deviceAllocZero(NROWS * DST_STRIDE); + + rtp_llm::FusedStridedCopyParams params; + params.add(d_src, d_dst, NROWS, ROW_BYTES, SRC_STRIDE, DST_STRIDE); + + rtp_llm::invokeFusedStridedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result = deviceToHost(d_dst, NROWS * DST_STRIDE); + for (size_t r = 0; r < NROWS; ++r) + for (size_t b = 0; b < ROW_BYTES; ++b) + ASSERT_EQ(result[r * DST_STRIDE + b], host_src[r * SRC_STRIDE + b]) << "row " << r << " col " << b; + + cudaFree(d_src); + cudaFree(d_dst); +} + +// Compact-to-strided copy: dst has a larger stride than row_bytes. +TEST_F(FusedStridedCopyTest, CompactToStrided) { + constexpr size_t NROWS = 4; + constexpr size_t ROW_BYTES = 16; + constexpr size_t SRC_STRIDE = ROW_BYTES; // compact source + constexpr size_t DST_STRIDE = 48; // padded destination + + std::vector host_src(NROWS * SRC_STRIDE); + for (size_t i = 0; i < host_src.size(); ++i) + host_src[i] = static_cast(i & 0xFF); + + uint8_t* d_src = deviceAlloc(host_src); + uint8_t* d_dst = deviceAllocZero(NROWS * DST_STRIDE); + + rtp_llm::FusedStridedCopyParams params; + params.add(d_src, d_dst, NROWS, ROW_BYTES, SRC_STRIDE, DST_STRIDE); + + rtp_llm::invokeFusedStridedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result = deviceToHost(d_dst, NROWS * DST_STRIDE); + for (size_t r = 0; r < NROWS; ++r) + for (size_t b = 0; b < ROW_BYTES; ++b) + ASSERT_EQ(result[r * DST_STRIDE + b], host_src[r * SRC_STRIDE + b]) << "row " << r << " col " << b; + + cudaFree(d_src); + cudaFree(d_dst); +} + +// Multiple strided copies in one launch. +TEST_F(FusedStridedCopyTest, MultipleStridedCopies) { + struct CopySpec { + size_t nrows, row_bytes, src_stride, dst_stride; + }; + const std::vector specs = { + {4, 16, 32, 16}, + {8, 32, 64, 32}, + {2, 64, 128, 64}, + }; + + std::vector> host_srcs(specs.size()); + std::vector d_srcs(specs.size()); + std::vector d_dsts(specs.size()); + + for (size_t c = 0; c < specs.size(); ++c) { + const auto& s = specs[c]; + host_srcs[c].resize(s.nrows * s.src_stride, 0); + for (size_t r = 0; r < s.nrows; ++r) + for (size_t b = 0; b < s.row_bytes; ++b) + host_srcs[c][r * s.src_stride + b] = static_cast((c * 31 + r * s.row_bytes + b) & 0xFF); + + d_srcs[c] = deviceAlloc(host_srcs[c]); + d_dsts[c] = deviceAllocZero(s.nrows * s.dst_stride); + } + + rtp_llm::FusedStridedCopyParams params; + for (size_t c = 0; c < specs.size(); ++c) { + const auto& s = specs[c]; + params.add(d_srcs[c], d_dsts[c], s.nrows, s.row_bytes, s.src_stride, s.dst_stride); + } + + rtp_llm::invokeFusedStridedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + for (size_t c = 0; c < specs.size(); ++c) { + const auto& s = specs[c]; + auto result = deviceToHost(d_dsts[c], s.nrows * s.dst_stride); + for (size_t r = 0; r < s.nrows; ++r) + for (size_t b = 0; b < s.row_bytes; ++b) + ASSERT_EQ(result[r * s.dst_stride + b], host_srcs[c][r * s.src_stride + b]) + << "copy " << c << " row " << r << " col " << b; + } + + for (size_t c = 0; c < specs.size(); ++c) { + cudaFree(d_srcs[c]); + cudaFree(d_dsts[c]); + } +} + +// Single-row strided copy (edge case: nrows == 1). +TEST_F(FusedStridedCopyTest, SingleRowCopy) { + constexpr size_t NROWS = 1; + constexpr size_t ROW_BYTES = 100; + constexpr size_t SRC_STRIDE = 256; + constexpr size_t DST_STRIDE = ROW_BYTES; + + std::vector host_src(NROWS * SRC_STRIDE, 0); + for (size_t b = 0; b < ROW_BYTES; ++b) + host_src[b] = static_cast(b & 0xFF); + + uint8_t* d_src = deviceAlloc(host_src); + uint8_t* d_dst = deviceAllocZero(NROWS * DST_STRIDE); + + rtp_llm::FusedStridedCopyParams params; + params.add(d_src, d_dst, NROWS, ROW_BYTES, SRC_STRIDE, DST_STRIDE); + + rtp_llm::invokeFusedStridedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result = deviceToHost(d_dst, NROWS * DST_STRIDE); + for (size_t b = 0; b < ROW_BYTES; ++b) + ASSERT_EQ(result[b], host_src[b]) << "mismatch at byte " << b; + + cudaFree(d_src); + cudaFree(d_dst); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From edfaea1c7f176fe0c2c37ec3a255e47b1482af8d Mon Sep 17 00:00:00 2001 From: "huzetao.hzt" Date: Fri, 17 Apr 2026 11:44:13 +0800 Subject: [PATCH 3/4] chore: use pin mem for cp input --- .../context_parallel/ContextParallelProcessorBase.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/rtp_llm/cpp/models/context_parallel/ContextParallelProcessorBase.cc b/rtp_llm/cpp/models/context_parallel/ContextParallelProcessorBase.cc index feb1df8c8b..e5e1713dc1 100644 --- a/rtp_llm/cpp/models/context_parallel/ContextParallelProcessorBase.cc +++ b/rtp_llm/cpp/models/context_parallel/ContextParallelProcessorBase.cc @@ -15,16 +15,18 @@ void IContextParallelProcessor::handleInputs(GptModelInputs& int prefill_cp_size = parallelism_config_.tp_size; int cp_align_size = prefill_cp_size * 2; + static const auto pinned_i32 = torch::TensorOptions(torch::kInt32).pinned_memory(true); + auto& total_input_tokens = model_input.combo_tokens; auto& input_lengths = model_input.input_lengths; auto& sequence_lengths = model_input.sequence_lengths; - auto input_lengths_cpu_tensor = input_lengths.clone(); + auto input_lengths_cpu_tensor = input_lengths.clone().pin_memory(); size_t num_decode_stream = sequence_lengths.size(0); size_t num_prefill_stream = input_lengths.size(0) - num_decode_stream; - auto prefill_cp_padding_lengths = torch::empty({(int64_t)num_prefill_stream}, torch::kInt32); - auto prefill_cp_chunk_lengths = torch::empty({(int64_t)num_prefill_stream}, torch::kInt32); + auto prefill_cp_padding_lengths = torch::empty({(int64_t)num_prefill_stream}, pinned_i32); + auto prefill_cp_chunk_lengths = torch::empty({(int64_t)num_prefill_stream}, pinned_i32); int* padding_lengths = prefill_cp_padding_lengths.data_ptr(); int* chunk_lengths = prefill_cp_chunk_lengths.data_ptr(); @@ -42,8 +44,8 @@ void IContextParallelProcessor::handleInputs(GptModelInputs& } auto cp_split_input_tokens = - torch::empty({(int64_t)(num_decode_stream + prefill_cp_split_tokens_size)}, torch::kInt32); - auto prefill_shuffle_indices = torch::empty({(int64_t)prefill_cp_split_tokens_size}, torch::kInt32); + torch::empty({(int64_t)(num_decode_stream + prefill_cp_split_tokens_size)}, pinned_i32); + auto prefill_shuffle_indices = torch::empty({(int64_t)prefill_cp_split_tokens_size}, pinned_i32); int* input_token_ptr = cp_split_input_tokens.data_ptr(); int* input_length_ptr = input_lengths.data_ptr(); From f5e96ba9fb9891d6df9ac95d6d8c6372ea5a0d82 Mon Sep 17 00:00:00 2001 From: "huzetao.hzt" Date: Thu, 23 Apr 2026 14:41:08 +0800 Subject: [PATCH 4/4] fix(fuse_copy): raise cap to 64 for micro-batch accumulation --- rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc | 4 + rtp_llm/cpp/models/PyWrappedModel.cc | 9 ++ .../bindings/common/kernels/fuse_copy_util.h | 40 ++++- .../kernels/test/fuse_copy_kernel_test.cc | 152 ++++++++++++++++++ 4 files changed, 200 insertions(+), 5 deletions(-) diff --git a/rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc b/rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc index f80880846c..92d2f6b5bc 100644 --- a/rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc +++ b/rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc @@ -66,6 +66,10 @@ void CudaGraphRunner::prepareInputs(const PyModelInputs& inputs, CudaGraphState& auto& py_model_inputs_ = graph_instances_[graph_idx].mem_hold_.py_model_inputs_; auto attn_pyobj = graph_instances_[graph_idx].mem_hold_.attn_pyobj_; + // Per-launch capacity contract: see fuse_copy_util.h sizing rationale. + // Worst case here is ~8 contiguous + (1 + group_count) strided copies, + // batched into one launch each. If new copies are added below — or if the + // hybrid KV-cache group_count grows materially — re-check MAX_FUSED_*_COPIES. FusedD2DCopyParams d2d_copies; FusedStridedCopyParams strided_d2d_copies; diff --git a/rtp_llm/cpp/models/PyWrappedModel.cc b/rtp_llm/cpp/models/PyWrappedModel.cc index dca975a9d3..985f4d2a30 100644 --- a/rtp_llm/cpp/models/PyWrappedModel.cc +++ b/rtp_llm/cpp/models/PyWrappedModel.cc @@ -281,6 +281,15 @@ std::optional PyWrappedModel::prepareWriteCacheParams(const GptModelOutputs PyWrappedModel::forwardMicroBatched(const GptModelInputs& inputs) { RTP_LLM_PROFILE_SCOPE("py_model.forwardMicroBatched"); + // Per-launch capacity contract: see fuse_copy_util.h sizing rationale. + // d2d_copies_ accumulates across ALL micro-batches before the single + // fusedCopy() flush below. Per micro-batch this adds ~6 copies from + // buildPyAttentionInputs + padding_offset, plus group_count from + // setupKVCacheForAttentionInputs. With the planMicroBatches cap of 2 + // micro-batches and hybrid group_count of 4 the worst case is ~20. + // If new tensorHoldHostAndToCuda call sites land below — or if + // planMicroBatches starts producing >2 micro-batches — re-check + // MAX_FUSED_D2D_COPIES. d2d_copies_.clear(); if (pinned_check_remaining_ > 0) { --pinned_check_remaining_; diff --git a/rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h b/rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h index ec62ca9bfb..9111eb87b3 100644 --- a/rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h +++ b/rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h @@ -1,12 +1,38 @@ #pragma once #include #include +#include namespace rtp_llm { -// NOTE: Hardcoded limits for fused copies. It is enough for most cases. If you need more, please increase the limits. -static constexpr int MAX_FUSED_D2D_COPIES = 16; -static constexpr int MAX_FUSED_STRIDED_COPIES = 16; +// Hard caps on copies fused into a single kernel launch. The structs below +// are passed by value as kernel parameters, so the arrays must be sized at +// compile time. +// +// Sizing rationale (worst-case callers as of 2026): +// * cuda_graph_runner.cc::prepareInputs accumulates ~8 contiguous copies +// plus 1 + group_count strided copies per launch (one launch per replay). +// * PyWrappedModel.cc::forwardMicroBatched is the tightest path: it +// accumulates across ALL micro-batches before a single flush. Per +// micro-batch it adds ~6 contiguous copies (5 from buildPyAttentionInputs +// plus 1 padding_offset) plus `group_count` per-group block-id copies. +// With the current planMicroBatches cap of 2 micro-batches and a hybrid +// KV-cache group_count of 4 that's (6 + 4) * 2 = 20 contiguous copies. +// +// 64 entries gives ~3x headroom over today's worst case (20 contiguous, 5 +// strided) and accommodates ~30 KV-cache groups before hitting the cap. Each +// FusedStridedCopyParams is 6 * 8 * 64 + 4 = 3076 bytes, well under the 32 KB +// kernel parameter buffer available on Volta and newer GPUs (all currently +// supported targets). +// +// If you need to raise these further: bump the constant, re-check the kernel +// parameter buffer size for the lowest supported compute capability, and +// extend the MaxFusedCopies / micro-batch unit tests accordingly. If the +// upper bound ever needs to be unbounded, prefer adding a chunked-launch +// helper (split into multiple param structs and launch each) over making the +// arrays dynamic — the kernel signature must stay POD for grid launch. +static constexpr int MAX_FUSED_D2D_COPIES = 64; +static constexpr int MAX_FUSED_STRIDED_COPIES = 64; inline void copyParamsAssert(bool value, const std::string& msg) { if (!value) { @@ -22,7 +48,9 @@ struct FusedD2DCopyParams { void add(const void* src_ptr, void* dst_ptr, size_t bytes) { copyParamsAssert(num_copies < MAX_FUSED_D2D_COPIES, - "FusedD2DCopyParams: num_copies exceeds MAX_FUSED_D2D_COPIES"); + "FusedD2DCopyParams: num_copies (" + std::to_string(num_copies + 1) + + ") exceeds MAX_FUSED_D2D_COPIES (" + std::to_string(MAX_FUSED_D2D_COPIES) + + "). Bump the cap in fuse_copy_util.h after re-checking the sizing rationale."); src[num_copies] = src_ptr; dst[num_copies] = dst_ptr; size[num_copies] = bytes; @@ -45,7 +73,9 @@ struct FusedStridedCopyParams { void add(const void* src_ptr, void* dst_ptr, size_t rows, size_t row_b, size_t src_stride, size_t dst_stride) { copyParamsAssert(num_copies < MAX_FUSED_STRIDED_COPIES, - "FusedStridedCopyParams: num_copies exceeds MAX_FUSED_STRIDED_COPIES"); + "FusedStridedCopyParams: num_copies (" + std::to_string(num_copies + 1) + + ") exceeds MAX_FUSED_STRIDED_COPIES (" + std::to_string(MAX_FUSED_STRIDED_COPIES) + + "). Bump the cap in fuse_copy_util.h after re-checking the sizing rationale."); src[num_copies] = src_ptr; dst[num_copies] = dst_ptr; num_rows[num_copies] = rows; diff --git a/rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc b/rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc index bdb0209871..e3c3d2f4e0 100644 --- a/rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc +++ b/rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc @@ -48,6 +48,17 @@ std::vector deviceToHost(const T* d_ptr, size_t n) { return host; } +// Allocate page-locked (pinned) host memory and fill it with the given data. +// With UVA the returned pointer is directly dereferenceable from a CUDA kernel, +// so it can be passed straight into FusedD2DCopyParams as a source pointer. +template +T* pinnedHostAlloc(const std::vector& host_data) { + T* h_pinned = nullptr; + EXPECT_EQ(cudaHostAlloc(&h_pinned, host_data.size() * sizeof(T), cudaHostAllocMapped), cudaSuccess); + std::memcpy(h_pinned, host_data.data(), host_data.size() * sizeof(T)); + return h_pinned; +} + } // namespace // --------------------------------------------------------------------------- @@ -219,6 +230,117 @@ TEST_F(FusedCopyTest, MaxFusedCopies) { } } +// Documented worst-case contract: PyWrappedModel::forwardMicroBatched +// accumulates copies across all micro-batches before a single flush. With +// the planMicroBatches cap of 2 micro-batches and a hybrid KV-cache +// group_count of 4, the total is (6 base + 4 group) * 2 = 20 copies. +// This test pins that scenario down so any regression in the accounting +// (or in MAX_FUSED_D2D_COPIES) fails here rather than at production runtime. +TEST_F(FusedCopyTest, MicroBatchedAccumulationWorstCase) { + constexpr int NUM_MICRO_BATCHES = 2; + constexpr int BASE_COPIES_PER_MB = 6; + constexpr int GROUP_COUNT = 4; + constexpr int COPIES_PER_MB = BASE_COPIES_PER_MB + GROUP_COUNT; + constexpr int TOTAL_COPIES = NUM_MICRO_BATCHES * COPIES_PER_MB; // 20 + constexpr size_t N = 256; + + static_assert(TOTAL_COPIES <= rtp_llm::MAX_FUSED_D2D_COPIES, + "MAX_FUSED_D2D_COPIES is below the documented forwardMicroBatched worst case; " + "see fuse_copy_util.h sizing rationale."); + + std::vector> host_srcs(TOTAL_COPIES); + std::vector d_srcs(TOTAL_COPIES); + std::vector d_dsts(TOTAL_COPIES); + + for (int c = 0; c < TOTAL_COPIES; ++c) { + host_srcs[c].resize(N); + for (size_t i = 0; i < N; ++i) + host_srcs[c][i] = static_cast((c * 19 + i) & 0xFF); + d_srcs[c] = deviceAlloc(host_srcs[c]); + d_dsts[c] = deviceAllocZero(N); + } + + rtp_llm::FusedD2DCopyParams params; + for (int c = 0; c < TOTAL_COPIES; ++c) + params.add(d_srcs[c], d_dsts[c], N); + + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + for (int c = 0; c < TOTAL_COPIES; ++c) { + auto result = deviceToHost(d_dsts[c], N); + for (size_t i = 0; i < N; ++i) + ASSERT_EQ(result[i], host_srcs[c][i]) << "copy " << c << " mismatch at byte " << i; + } + + for (int c = 0; c < TOTAL_COPIES; ++c) { + cudaFree(d_srcs[c]); + cudaFree(d_dsts[c]); + } +} + +// Copy from page-locked (pinned) host memory directly into device memory. +// The kernel dereferences the source pointer on the GPU, so this exercises +// the UVA path where pinned host memory is reachable from a CUDA kernel. +TEST_F(FusedCopyTest, PinnedHostToDeviceCopy) { + constexpr size_t N = 1024; // 16-byte aligned, hits the vectorised fast path + std::vector host_src(N); + for (size_t i = 0; i < N; ++i) + host_src[i] = static_cast((i * 5 + 1) & 0xFF); + + uint8_t* h_src_pinned = pinnedHostAlloc(host_src); + uint8_t* d_dst = deviceAllocZero(N); + + rtp_llm::FusedD2DCopyParams params; + params.add(h_src_pinned, d_dst, N); + + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result = deviceToHost(d_dst, N); + for (size_t i = 0; i < N; ++i) + ASSERT_EQ(result[i], host_src[i]) << "mismatch at byte " << i; + + cudaFreeHost(h_src_pinned); + cudaFree(d_dst); +} + +// Mixed sources in a single fused launch: some copies read from pinned host +// memory, others from device memory. This is the realistic batched scenario. +TEST_F(FusedCopyTest, MixedPinnedAndDeviceSrc) { + constexpr size_t N = 512; + + std::vector host_a(N), host_b(N); + for (size_t i = 0; i < N; ++i) { + host_a[i] = static_cast((i + 11) & 0xFF); + host_b[i] = static_cast((i * 3 + 7) & 0xFF); + } + + uint8_t* h_src_pinned = pinnedHostAlloc(host_a); // pinned host source + uint8_t* d_src_dev = deviceAlloc(host_b); // device source + uint8_t* d_dst_a = deviceAllocZero(N); + uint8_t* d_dst_b = deviceAllocZero(N); + + rtp_llm::FusedD2DCopyParams params; + params.add(h_src_pinned, d_dst_a, N); + params.add(d_src_dev, d_dst_b, N); + + rtp_llm::invokeFusedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result_a = deviceToHost(d_dst_a, N); + auto result_b = deviceToHost(d_dst_b, N); + for (size_t i = 0; i < N; ++i) { + ASSERT_EQ(result_a[i], host_a[i]) << "pinned-src mismatch at byte " << i; + ASSERT_EQ(result_b[i], host_b[i]) << "device-src mismatch at byte " << i; + } + + cudaFreeHost(h_src_pinned); + cudaFree(d_src_dev); + cudaFree(d_dst_a); + cudaFree(d_dst_b); +} + // --------------------------------------------------------------------------- // FusedStridedCopy tests (invokeFusedStridedCopy) // --------------------------------------------------------------------------- @@ -382,6 +504,36 @@ TEST_F(FusedStridedCopyTest, SingleRowCopy) { cudaFree(d_dst); } +// Strided copy from pinned host memory directly into device memory. +TEST_F(FusedStridedCopyTest, PinnedHostToDeviceCopy) { + constexpr size_t NROWS = 8; + constexpr size_t ROW_BYTES = 32; + constexpr size_t SRC_STRIDE = 64; // pinned source has padding per row + constexpr size_t DST_STRIDE = ROW_BYTES; // compact device destination + + std::vector host_src(NROWS * SRC_STRIDE, 0xCD); + for (size_t r = 0; r < NROWS; ++r) + for (size_t b = 0; b < ROW_BYTES; ++b) + host_src[r * SRC_STRIDE + b] = static_cast((r * ROW_BYTES + b * 2) & 0xFF); + + uint8_t* h_src_pinned = pinnedHostAlloc(host_src); + uint8_t* d_dst = deviceAllocZero(NROWS * DST_STRIDE); + + rtp_llm::FusedStridedCopyParams params; + params.add(h_src_pinned, d_dst, NROWS, ROW_BYTES, SRC_STRIDE, DST_STRIDE); + + rtp_llm::invokeFusedStridedCopy(params, stream_); + CUDA_CHECK(cudaStreamSynchronize(stream_)); + + auto result = deviceToHost(d_dst, NROWS * DST_STRIDE); + for (size_t r = 0; r < NROWS; ++r) + for (size_t b = 0; b < ROW_BYTES; ++b) + ASSERT_EQ(result[r * DST_STRIDE + b], host_src[r * SRC_STRIDE + b]) << "row " << r << " col " << b; + + cudaFreeHost(h_src_pinned); + cudaFree(d_dst); +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS();