@@ -48,6 +48,17 @@ std::vector<T> deviceToHost(const T* d_ptr, size_t n) {
4848 return host;
4949}
5050
51+ // Allocate page-locked (pinned) host memory and fill it with the given data.
52+ // With UVA the returned pointer is directly dereferenceable from a CUDA kernel,
53+ // so it can be passed straight into FusedD2DCopyParams as a source pointer.
54+ template <typename T>
55+ T* pinnedHostAlloc (const std::vector<T>& host_data) {
56+ T* h_pinned = nullptr ;
57+ EXPECT_EQ (cudaHostAlloc (&h_pinned, host_data.size () * sizeof (T), cudaHostAllocMapped), cudaSuccess);
58+ std::memcpy (h_pinned, host_data.data (), host_data.size () * sizeof (T));
59+ return h_pinned;
60+ }
61+
5162} // namespace
5263
5364// ---------------------------------------------------------------------------
@@ -219,6 +230,117 @@ TEST_F(FusedCopyTest, MaxFusedCopies) {
219230 }
220231}
221232
233+ // Documented worst-case contract: PyWrappedModel::forwardMicroBatched
234+ // accumulates copies across all micro-batches before a single flush. With
235+ // the planMicroBatches cap of 2 micro-batches and a hybrid KV-cache
236+ // group_count of 4, the total is (6 base + 4 group) * 2 = 20 copies.
237+ // This test pins that scenario down so any regression in the accounting
238+ // (or in MAX_FUSED_D2D_COPIES) fails here rather than at production runtime.
239+ TEST_F (FusedCopyTest, MicroBatchedAccumulationWorstCase) {
240+ constexpr int NUM_MICRO_BATCHES = 2 ;
241+ constexpr int BASE_COPIES_PER_MB = 6 ;
242+ constexpr int GROUP_COUNT = 4 ;
243+ constexpr int COPIES_PER_MB = BASE_COPIES_PER_MB + GROUP_COUNT;
244+ constexpr int TOTAL_COPIES = NUM_MICRO_BATCHES * COPIES_PER_MB; // 20
245+ constexpr size_t N = 256 ;
246+
247+ static_assert (TOTAL_COPIES <= rtp_llm::MAX_FUSED_D2D_COPIES,
248+ " MAX_FUSED_D2D_COPIES is below the documented forwardMicroBatched worst case; "
249+ " see fuse_copy_util.h sizing rationale." );
250+
251+ std::vector<std::vector<uint8_t >> host_srcs (TOTAL_COPIES);
252+ std::vector<uint8_t *> d_srcs (TOTAL_COPIES);
253+ std::vector<uint8_t *> d_dsts (TOTAL_COPIES);
254+
255+ for (int c = 0 ; c < TOTAL_COPIES; ++c) {
256+ host_srcs[c].resize (N);
257+ for (size_t i = 0 ; i < N; ++i)
258+ host_srcs[c][i] = static_cast <uint8_t >((c * 19 + i) & 0xFF );
259+ d_srcs[c] = deviceAlloc (host_srcs[c]);
260+ d_dsts[c] = deviceAllocZero<uint8_t >(N);
261+ }
262+
263+ rtp_llm::FusedD2DCopyParams params;
264+ for (int c = 0 ; c < TOTAL_COPIES; ++c)
265+ params.add (d_srcs[c], d_dsts[c], N);
266+
267+ rtp_llm::invokeFusedCopy (params, stream_);
268+ CUDA_CHECK (cudaStreamSynchronize (stream_));
269+
270+ for (int c = 0 ; c < TOTAL_COPIES; ++c) {
271+ auto result = deviceToHost (d_dsts[c], N);
272+ for (size_t i = 0 ; i < N; ++i)
273+ ASSERT_EQ (result[i], host_srcs[c][i]) << " copy " << c << " mismatch at byte " << i;
274+ }
275+
276+ for (int c = 0 ; c < TOTAL_COPIES; ++c) {
277+ cudaFree (d_srcs[c]);
278+ cudaFree (d_dsts[c]);
279+ }
280+ }
281+
282+ // Copy from page-locked (pinned) host memory directly into device memory.
283+ // The kernel dereferences the source pointer on the GPU, so this exercises
284+ // the UVA path where pinned host memory is reachable from a CUDA kernel.
285+ TEST_F (FusedCopyTest, PinnedHostToDeviceCopy) {
286+ constexpr size_t N = 1024 ; // 16-byte aligned, hits the vectorised fast path
287+ std::vector<uint8_t > host_src (N);
288+ for (size_t i = 0 ; i < N; ++i)
289+ host_src[i] = static_cast <uint8_t >((i * 5 + 1 ) & 0xFF );
290+
291+ uint8_t * h_src_pinned = pinnedHostAlloc (host_src);
292+ uint8_t * d_dst = deviceAllocZero<uint8_t >(N);
293+
294+ rtp_llm::FusedD2DCopyParams params;
295+ params.add (h_src_pinned, d_dst, N);
296+
297+ rtp_llm::invokeFusedCopy (params, stream_);
298+ CUDA_CHECK (cudaStreamSynchronize (stream_));
299+
300+ auto result = deviceToHost (d_dst, N);
301+ for (size_t i = 0 ; i < N; ++i)
302+ ASSERT_EQ (result[i], host_src[i]) << " mismatch at byte " << i;
303+
304+ cudaFreeHost (h_src_pinned);
305+ cudaFree (d_dst);
306+ }
307+
308+ // Mixed sources in a single fused launch: some copies read from pinned host
309+ // memory, others from device memory. This is the realistic batched scenario.
310+ TEST_F (FusedCopyTest, MixedPinnedAndDeviceSrc) {
311+ constexpr size_t N = 512 ;
312+
313+ std::vector<uint8_t > host_a (N), host_b (N);
314+ for (size_t i = 0 ; i < N; ++i) {
315+ host_a[i] = static_cast <uint8_t >((i + 11 ) & 0xFF );
316+ host_b[i] = static_cast <uint8_t >((i * 3 + 7 ) & 0xFF );
317+ }
318+
319+ uint8_t * h_src_pinned = pinnedHostAlloc (host_a); // pinned host source
320+ uint8_t * d_src_dev = deviceAlloc (host_b); // device source
321+ uint8_t * d_dst_a = deviceAllocZero<uint8_t >(N);
322+ uint8_t * d_dst_b = deviceAllocZero<uint8_t >(N);
323+
324+ rtp_llm::FusedD2DCopyParams params;
325+ params.add (h_src_pinned, d_dst_a, N);
326+ params.add (d_src_dev, d_dst_b, N);
327+
328+ rtp_llm::invokeFusedCopy (params, stream_);
329+ CUDA_CHECK (cudaStreamSynchronize (stream_));
330+
331+ auto result_a = deviceToHost (d_dst_a, N);
332+ auto result_b = deviceToHost (d_dst_b, N);
333+ for (size_t i = 0 ; i < N; ++i) {
334+ ASSERT_EQ (result_a[i], host_a[i]) << " pinned-src mismatch at byte " << i;
335+ ASSERT_EQ (result_b[i], host_b[i]) << " device-src mismatch at byte " << i;
336+ }
337+
338+ cudaFreeHost (h_src_pinned);
339+ cudaFree (d_src_dev);
340+ cudaFree (d_dst_a);
341+ cudaFree (d_dst_b);
342+ }
343+
222344// ---------------------------------------------------------------------------
223345// FusedStridedCopy tests (invokeFusedStridedCopy)
224346// ---------------------------------------------------------------------------
@@ -382,6 +504,36 @@ TEST_F(FusedStridedCopyTest, SingleRowCopy) {
382504 cudaFree (d_dst);
383505}
384506
507+ // Strided copy from pinned host memory directly into device memory.
508+ TEST_F (FusedStridedCopyTest, PinnedHostToDeviceCopy) {
509+ constexpr size_t NROWS = 8 ;
510+ constexpr size_t ROW_BYTES = 32 ;
511+ constexpr size_t SRC_STRIDE = 64 ; // pinned source has padding per row
512+ constexpr size_t DST_STRIDE = ROW_BYTES; // compact device destination
513+
514+ std::vector<uint8_t > host_src (NROWS * SRC_STRIDE, 0xCD );
515+ for (size_t r = 0 ; r < NROWS; ++r)
516+ for (size_t b = 0 ; b < ROW_BYTES; ++b)
517+ host_src[r * SRC_STRIDE + b] = static_cast <uint8_t >((r * ROW_BYTES + b * 2 ) & 0xFF );
518+
519+ uint8_t * h_src_pinned = pinnedHostAlloc (host_src);
520+ uint8_t * d_dst = deviceAllocZero<uint8_t >(NROWS * DST_STRIDE);
521+
522+ rtp_llm::FusedStridedCopyParams params;
523+ params.add (h_src_pinned, d_dst, NROWS, ROW_BYTES, SRC_STRIDE, DST_STRIDE);
524+
525+ rtp_llm::invokeFusedStridedCopy (params, stream_);
526+ CUDA_CHECK (cudaStreamSynchronize (stream_));
527+
528+ auto result = deviceToHost (d_dst, NROWS * DST_STRIDE);
529+ for (size_t r = 0 ; r < NROWS; ++r)
530+ for (size_t b = 0 ; b < ROW_BYTES; ++b)
531+ ASSERT_EQ (result[r * DST_STRIDE + b], host_src[r * SRC_STRIDE + b]) << " row " << r << " col " << b;
532+
533+ cudaFreeHost (h_src_pinned);
534+ cudaFree (d_dst);
535+ }
536+
385537int main (int argc, char ** argv) {
386538 ::testing::InitGoogleTest (&argc, argv);
387539 return RUN_ALL_TESTS ();
0 commit comments