|
| 1 | +#include <cstdint> |
| 2 | +#include <cstddef> |
| 3 | + |
| 4 | +#include "rtp_llm/cpp/kernels/fuse_copy/fuse_copy_kernel.h" |
| 5 | + |
| 6 | +namespace rtp_llm { |
| 7 | + |
| 8 | +static constexpr int FUSED_COPY_BLOCKS_PER_TASK = 8; |
| 9 | +static constexpr int FUSED_COPY_THREADS = 256; |
| 10 | + |
| 11 | +__global__ void fusedCopyKernel(FusedD2DCopyParams params) { |
| 12 | + const int copy_idx = blockIdx.y; |
| 13 | + if (copy_idx >= params.num_copies) |
| 14 | + return; |
| 15 | + |
| 16 | + const size_t total_bytes = params.size[copy_idx]; |
| 17 | + const size_t global_tid = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x; |
| 18 | + const size_t global_stride = static_cast<size_t>(gridDim.x) * blockDim.x; |
| 19 | + |
| 20 | + const auto src_addr = reinterpret_cast<uintptr_t>(params.src[copy_idx]); |
| 21 | + const auto dst_addr = reinterpret_cast<uintptr_t>(params.dst[copy_idx]); |
| 22 | + |
| 23 | + if ((src_addr % sizeof(int4) == 0) && (dst_addr % sizeof(int4) == 0)) { |
| 24 | + // Fast path: 16-byte vectorized bulk copy |
| 25 | + const int4* src = reinterpret_cast<const int4*>(src_addr); |
| 26 | + int4* dst = reinterpret_cast<int4*>(dst_addr); |
| 27 | + const size_t n16 = total_bytes / sizeof(int4); |
| 28 | + |
| 29 | + for (size_t i = global_tid; i < n16; i += global_stride) { |
| 30 | + dst[i] = src[i]; |
| 31 | + } |
| 32 | + |
| 33 | + if (blockIdx.x == 0) { |
| 34 | + const size_t rem_start = n16 * sizeof(int4); |
| 35 | + const char* src_byte = reinterpret_cast<const char*>(src_addr); |
| 36 | + char* dst_byte = reinterpret_cast<char*>(dst_addr); |
| 37 | + for (size_t i = rem_start + threadIdx.x; i < total_bytes; i += blockDim.x) { |
| 38 | + dst_byte[i] = src_byte[i]; |
| 39 | + } |
| 40 | + } |
| 41 | + } else { |
| 42 | + // Slow path: byte-by-byte copy for unaligned pointers |
| 43 | + const char* src_byte = reinterpret_cast<const char*>(src_addr); |
| 44 | + char* dst_byte = reinterpret_cast<char*>(dst_addr); |
| 45 | + for (size_t i = global_tid; i < total_bytes; i += global_stride) { |
| 46 | + dst_byte[i] = src_byte[i]; |
| 47 | + } |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | +void invokeFusedCopy(const FusedD2DCopyParams& params, cudaStream_t stream) { |
| 52 | + if (params.num_copies <= 0) |
| 53 | + return; |
| 54 | + dim3 grid(FUSED_COPY_BLOCKS_PER_TASK, params.num_copies); |
| 55 | + fusedCopyKernel<<<grid, FUSED_COPY_THREADS, 0, stream>>>(params); |
| 56 | +} |
| 57 | + |
| 58 | +__global__ void fusedStridedCopyKernel(FusedStridedCopyParams params) { |
| 59 | + const int copy_idx = blockIdx.y; |
| 60 | + if (copy_idx >= params.num_copies) |
| 61 | + return; |
| 62 | + |
| 63 | + const size_t nrows = params.num_rows[copy_idx]; |
| 64 | + const size_t rbytes = params.row_bytes[copy_idx]; |
| 65 | + const size_t src_stride = params.src_row_stride[copy_idx]; |
| 66 | + const size_t dst_stride = params.dst_row_stride[copy_idx]; |
| 67 | + const char* src = reinterpret_cast<const char*>(params.src[copy_idx]); |
| 68 | + char* dst = reinterpret_cast<char*>(params.dst[copy_idx]); |
| 69 | + |
| 70 | + const size_t total = nrows * rbytes; |
| 71 | + const size_t global_tid = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x; |
| 72 | + const size_t stride = static_cast<size_t>(gridDim.x) * blockDim.x; |
| 73 | + |
| 74 | + for (size_t idx = global_tid; idx < total; idx += stride) { |
| 75 | + const size_t row = idx / rbytes; |
| 76 | + const size_t col = idx % rbytes; |
| 77 | + dst[row * dst_stride + col] = src[row * src_stride + col]; |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +void invokeFusedStridedCopy(const FusedStridedCopyParams& params, cudaStream_t stream) { |
| 82 | + if (params.num_copies <= 0) |
| 83 | + return; |
| 84 | + dim3 grid(FUSED_COPY_BLOCKS_PER_TASK, params.num_copies); |
| 85 | + fusedStridedCopyKernel<<<grid, FUSED_COPY_THREADS, 0, stream>>>(params); |
| 86 | +} |
| 87 | + |
| 88 | +} // namespace rtp_llm |
0 commit comments