Skip to content

Commit 22735d0

Browse files
committed
fix(fuse_copy): raise cap to 64 for micro-batch accumulation
1 parent f5be2b2 commit 22735d0

4 files changed

Lines changed: 200 additions & 5 deletions

File tree

rtp_llm/cpp/cuda_graph/cuda_graph_runner.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ void CudaGraphRunner::prepareInputs(const PyModelInputs& inputs, CudaGraphState&
6666
auto& py_model_inputs_ = graph_instances_[graph_idx].mem_hold_.py_model_inputs_;
6767
auto attn_pyobj = graph_instances_[graph_idx].mem_hold_.attn_pyobj_;
6868

69+
// Per-launch capacity contract: see fuse_copy_util.h sizing rationale.
70+
// Worst case here is ~8 contiguous + (1 + group_count) strided copies,
71+
// batched into one launch each. If new copies are added below — or if the
72+
// hybrid KV-cache group_count grows materially — re-check MAX_FUSED_*_COPIES.
6973
FusedD2DCopyParams d2d_copies;
7074
FusedStridedCopyParams strided_d2d_copies;
7175

rtp_llm/cpp/models/PyWrappedModel.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,15 @@ std::optional<PyCacheStoreInputs> PyWrappedModel::prepareWriteCacheParams(const
281281
GptModelOutputs PyWrappedModel::forwardMicroBatched(const GptModelInputs& inputs) {
282282
RTP_LLM_PROFILE_SCOPE("py_model.forwardMicroBatched");
283283

284+
// Per-launch capacity contract: see fuse_copy_util.h sizing rationale.
285+
// d2d_copies_ accumulates across ALL micro-batches before the single
286+
// fusedCopy() flush below. Per micro-batch this adds ~6 copies from
287+
// buildPyAttentionInputs + padding_offset, plus group_count from
288+
// setupKVCacheForAttentionInputs. With the planMicroBatches cap of 2
289+
// micro-batches and hybrid group_count of 4 the worst case is ~20.
290+
// If new tensorHoldHostAndToCuda call sites land below — or if
291+
// planMicroBatches starts producing >2 micro-batches — re-check
292+
// MAX_FUSED_D2D_COPIES.
284293
d2d_copies_.clear();
285294
if (pinned_check_remaining_ > 0) {
286295
--pinned_check_remaining_;

rtp_llm/models_py/bindings/common/kernels/fuse_copy_util.h

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,38 @@
11
#pragma once
22
#include <cstddef>
33
#include <stdexcept>
4+
#include <string>
45

56
namespace rtp_llm {
67

7-
// NOTE: Hardcoded limits for fused copies. It is enough for most cases. If you need more, please increase the limits.
8-
static constexpr int MAX_FUSED_D2D_COPIES = 16;
9-
static constexpr int MAX_FUSED_STRIDED_COPIES = 16;
8+
// Hard caps on copies fused into a single kernel launch. The structs below
9+
// are passed by value as kernel parameters, so the arrays must be sized at
10+
// compile time.
11+
//
12+
// Sizing rationale (worst-case callers as of 2026):
13+
// * cuda_graph_runner.cc::prepareInputs accumulates ~8 contiguous copies
14+
// plus 1 + group_count strided copies per launch (one launch per replay).
15+
// * PyWrappedModel.cc::forwardMicroBatched is the tightest path: it
16+
// accumulates across ALL micro-batches before a single flush. Per
17+
// micro-batch it adds ~6 contiguous copies (5 from buildPyAttentionInputs
18+
// plus 1 padding_offset) plus `group_count` per-group block-id copies.
19+
// With the current planMicroBatches cap of 2 micro-batches and a hybrid
20+
// KV-cache group_count of 4 that's (6 + 4) * 2 = 20 contiguous copies.
21+
//
22+
// 64 entries gives ~3x headroom over today's worst case (20 contiguous, 5
23+
// strided) and accommodates ~30 KV-cache groups before hitting the cap. Each
24+
// FusedStridedCopyParams is 6 * 8 * 64 + 4 = 3076 bytes, well under the 32 KB
25+
// kernel parameter buffer available on Volta and newer GPUs (all currently
26+
// supported targets).
27+
//
28+
// If you need to raise these further: bump the constant, re-check the kernel
29+
// parameter buffer size for the lowest supported compute capability, and
30+
// extend the MaxFusedCopies / micro-batch unit tests accordingly. If the
31+
// upper bound ever needs to be unbounded, prefer adding a chunked-launch
32+
// helper (split into multiple param structs and launch each) over making the
33+
// arrays dynamic — the kernel signature must stay POD for grid launch.
34+
static constexpr int MAX_FUSED_D2D_COPIES = 64;
35+
static constexpr int MAX_FUSED_STRIDED_COPIES = 64;
1036

1137
inline void copyParamsAssert(bool value, const std::string& msg) {
1238
if (!value) {
@@ -22,7 +48,9 @@ struct FusedD2DCopyParams {
2248

2349
void add(const void* src_ptr, void* dst_ptr, size_t bytes) {
2450
copyParamsAssert(num_copies < MAX_FUSED_D2D_COPIES,
25-
"FusedD2DCopyParams: num_copies exceeds MAX_FUSED_D2D_COPIES");
51+
"FusedD2DCopyParams: num_copies (" + std::to_string(num_copies + 1)
52+
+ ") exceeds MAX_FUSED_D2D_COPIES (" + std::to_string(MAX_FUSED_D2D_COPIES)
53+
+ "). Bump the cap in fuse_copy_util.h after re-checking the sizing rationale.");
2654
src[num_copies] = src_ptr;
2755
dst[num_copies] = dst_ptr;
2856
size[num_copies] = bytes;
@@ -45,7 +73,9 @@ struct FusedStridedCopyParams {
4573

4674
void add(const void* src_ptr, void* dst_ptr, size_t rows, size_t row_b, size_t src_stride, size_t dst_stride) {
4775
copyParamsAssert(num_copies < MAX_FUSED_STRIDED_COPIES,
48-
"FusedStridedCopyParams: num_copies exceeds MAX_FUSED_STRIDED_COPIES");
76+
"FusedStridedCopyParams: num_copies (" + std::to_string(num_copies + 1)
77+
+ ") exceeds MAX_FUSED_STRIDED_COPIES (" + std::to_string(MAX_FUSED_STRIDED_COPIES)
78+
+ "). Bump the cap in fuse_copy_util.h after re-checking the sizing rationale.");
4979
src[num_copies] = src_ptr;
5080
dst[num_copies] = dst_ptr;
5181
num_rows[num_copies] = rows;

rtp_llm/models_py/bindings/common/kernels/test/fuse_copy_kernel_test.cc

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ std::vector<T> deviceToHost(const T* d_ptr, size_t n) {
4848
return host;
4949
}
5050

51+
// Allocate page-locked (pinned) host memory and fill it with the given data.
52+
// With UVA the returned pointer is directly dereferenceable from a CUDA kernel,
53+
// so it can be passed straight into FusedD2DCopyParams as a source pointer.
54+
template<typename T>
55+
T* pinnedHostAlloc(const std::vector<T>& host_data) {
56+
T* h_pinned = nullptr;
57+
EXPECT_EQ(cudaHostAlloc(&h_pinned, host_data.size() * sizeof(T), cudaHostAllocMapped), cudaSuccess);
58+
std::memcpy(h_pinned, host_data.data(), host_data.size() * sizeof(T));
59+
return h_pinned;
60+
}
61+
5162
} // namespace
5263

5364
// ---------------------------------------------------------------------------
@@ -219,6 +230,117 @@ TEST_F(FusedCopyTest, MaxFusedCopies) {
219230
}
220231
}
221232

233+
// Documented worst-case contract: PyWrappedModel::forwardMicroBatched
234+
// accumulates copies across all micro-batches before a single flush. With
235+
// the planMicroBatches cap of 2 micro-batches and a hybrid KV-cache
236+
// group_count of 4, the total is (6 base + 4 group) * 2 = 20 copies.
237+
// This test pins that scenario down so any regression in the accounting
238+
// (or in MAX_FUSED_D2D_COPIES) fails here rather than at production runtime.
239+
TEST_F(FusedCopyTest, MicroBatchedAccumulationWorstCase) {
240+
constexpr int NUM_MICRO_BATCHES = 2;
241+
constexpr int BASE_COPIES_PER_MB = 6;
242+
constexpr int GROUP_COUNT = 4;
243+
constexpr int COPIES_PER_MB = BASE_COPIES_PER_MB + GROUP_COUNT;
244+
constexpr int TOTAL_COPIES = NUM_MICRO_BATCHES * COPIES_PER_MB; // 20
245+
constexpr size_t N = 256;
246+
247+
static_assert(TOTAL_COPIES <= rtp_llm::MAX_FUSED_D2D_COPIES,
248+
"MAX_FUSED_D2D_COPIES is below the documented forwardMicroBatched worst case; "
249+
"see fuse_copy_util.h sizing rationale.");
250+
251+
std::vector<std::vector<uint8_t>> host_srcs(TOTAL_COPIES);
252+
std::vector<uint8_t*> d_srcs(TOTAL_COPIES);
253+
std::vector<uint8_t*> d_dsts(TOTAL_COPIES);
254+
255+
for (int c = 0; c < TOTAL_COPIES; ++c) {
256+
host_srcs[c].resize(N);
257+
for (size_t i = 0; i < N; ++i)
258+
host_srcs[c][i] = static_cast<uint8_t>((c * 19 + i) & 0xFF);
259+
d_srcs[c] = deviceAlloc(host_srcs[c]);
260+
d_dsts[c] = deviceAllocZero<uint8_t>(N);
261+
}
262+
263+
rtp_llm::FusedD2DCopyParams params;
264+
for (int c = 0; c < TOTAL_COPIES; ++c)
265+
params.add(d_srcs[c], d_dsts[c], N);
266+
267+
rtp_llm::invokeFusedCopy(params, stream_);
268+
CUDA_CHECK(cudaStreamSynchronize(stream_));
269+
270+
for (int c = 0; c < TOTAL_COPIES; ++c) {
271+
auto result = deviceToHost(d_dsts[c], N);
272+
for (size_t i = 0; i < N; ++i)
273+
ASSERT_EQ(result[i], host_srcs[c][i]) << "copy " << c << " mismatch at byte " << i;
274+
}
275+
276+
for (int c = 0; c < TOTAL_COPIES; ++c) {
277+
cudaFree(d_srcs[c]);
278+
cudaFree(d_dsts[c]);
279+
}
280+
}
281+
282+
// Copy from page-locked (pinned) host memory directly into device memory.
283+
// The kernel dereferences the source pointer on the GPU, so this exercises
284+
// the UVA path where pinned host memory is reachable from a CUDA kernel.
285+
TEST_F(FusedCopyTest, PinnedHostToDeviceCopy) {
286+
constexpr size_t N = 1024; // 16-byte aligned, hits the vectorised fast path
287+
std::vector<uint8_t> host_src(N);
288+
for (size_t i = 0; i < N; ++i)
289+
host_src[i] = static_cast<uint8_t>((i * 5 + 1) & 0xFF);
290+
291+
uint8_t* h_src_pinned = pinnedHostAlloc(host_src);
292+
uint8_t* d_dst = deviceAllocZero<uint8_t>(N);
293+
294+
rtp_llm::FusedD2DCopyParams params;
295+
params.add(h_src_pinned, d_dst, N);
296+
297+
rtp_llm::invokeFusedCopy(params, stream_);
298+
CUDA_CHECK(cudaStreamSynchronize(stream_));
299+
300+
auto result = deviceToHost(d_dst, N);
301+
for (size_t i = 0; i < N; ++i)
302+
ASSERT_EQ(result[i], host_src[i]) << "mismatch at byte " << i;
303+
304+
cudaFreeHost(h_src_pinned);
305+
cudaFree(d_dst);
306+
}
307+
308+
// Mixed sources in a single fused launch: some copies read from pinned host
309+
// memory, others from device memory. This is the realistic batched scenario.
310+
TEST_F(FusedCopyTest, MixedPinnedAndDeviceSrc) {
311+
constexpr size_t N = 512;
312+
313+
std::vector<uint8_t> host_a(N), host_b(N);
314+
for (size_t i = 0; i < N; ++i) {
315+
host_a[i] = static_cast<uint8_t>((i + 11) & 0xFF);
316+
host_b[i] = static_cast<uint8_t>((i * 3 + 7) & 0xFF);
317+
}
318+
319+
uint8_t* h_src_pinned = pinnedHostAlloc(host_a); // pinned host source
320+
uint8_t* d_src_dev = deviceAlloc(host_b); // device source
321+
uint8_t* d_dst_a = deviceAllocZero<uint8_t>(N);
322+
uint8_t* d_dst_b = deviceAllocZero<uint8_t>(N);
323+
324+
rtp_llm::FusedD2DCopyParams params;
325+
params.add(h_src_pinned, d_dst_a, N);
326+
params.add(d_src_dev, d_dst_b, N);
327+
328+
rtp_llm::invokeFusedCopy(params, stream_);
329+
CUDA_CHECK(cudaStreamSynchronize(stream_));
330+
331+
auto result_a = deviceToHost(d_dst_a, N);
332+
auto result_b = deviceToHost(d_dst_b, N);
333+
for (size_t i = 0; i < N; ++i) {
334+
ASSERT_EQ(result_a[i], host_a[i]) << "pinned-src mismatch at byte " << i;
335+
ASSERT_EQ(result_b[i], host_b[i]) << "device-src mismatch at byte " << i;
336+
}
337+
338+
cudaFreeHost(h_src_pinned);
339+
cudaFree(d_src_dev);
340+
cudaFree(d_dst_a);
341+
cudaFree(d_dst_b);
342+
}
343+
222344
// ---------------------------------------------------------------------------
223345
// FusedStridedCopy tests (invokeFusedStridedCopy)
224346
// ---------------------------------------------------------------------------
@@ -382,6 +504,36 @@ TEST_F(FusedStridedCopyTest, SingleRowCopy) {
382504
cudaFree(d_dst);
383505
}
384506

507+
// Strided copy from pinned host memory directly into device memory.
508+
TEST_F(FusedStridedCopyTest, PinnedHostToDeviceCopy) {
509+
constexpr size_t NROWS = 8;
510+
constexpr size_t ROW_BYTES = 32;
511+
constexpr size_t SRC_STRIDE = 64; // pinned source has padding per row
512+
constexpr size_t DST_STRIDE = ROW_BYTES; // compact device destination
513+
514+
std::vector<uint8_t> host_src(NROWS * SRC_STRIDE, 0xCD);
515+
for (size_t r = 0; r < NROWS; ++r)
516+
for (size_t b = 0; b < ROW_BYTES; ++b)
517+
host_src[r * SRC_STRIDE + b] = static_cast<uint8_t>((r * ROW_BYTES + b * 2) & 0xFF);
518+
519+
uint8_t* h_src_pinned = pinnedHostAlloc(host_src);
520+
uint8_t* d_dst = deviceAllocZero<uint8_t>(NROWS * DST_STRIDE);
521+
522+
rtp_llm::FusedStridedCopyParams params;
523+
params.add(h_src_pinned, d_dst, NROWS, ROW_BYTES, SRC_STRIDE, DST_STRIDE);
524+
525+
rtp_llm::invokeFusedStridedCopy(params, stream_);
526+
CUDA_CHECK(cudaStreamSynchronize(stream_));
527+
528+
auto result = deviceToHost(d_dst, NROWS * DST_STRIDE);
529+
for (size_t r = 0; r < NROWS; ++r)
530+
for (size_t b = 0; b < ROW_BYTES; ++b)
531+
ASSERT_EQ(result[r * DST_STRIDE + b], host_src[r * SRC_STRIDE + b]) << "row " << r << " col " << b;
532+
533+
cudaFreeHost(h_src_pinned);
534+
cudaFree(d_dst);
535+
}
536+
385537
int main(int argc, char** argv) {
386538
::testing::InitGoogleTest(&argc, argv);
387539
return RUN_ALL_TESTS();

0 commit comments

Comments
 (0)