Skip to content

Commit 8a1e71d

Browse files
authored
[PD] PD send cache via storage & Refine swap_cache_layout op (#7839)
* PD send cache via storage & Refine swap_cache_layout op * skip messager * up * consider write cache error * fix ci * up
1 parent 261041b commit 8a1e71d

19 files changed

Lines changed: 705 additions & 142 deletions

custom_ops/gpu_ops/swap_cache_layout.cu

Lines changed: 247 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,90 +15,279 @@
1515
#include "helper.h"
1616
#include "paddle/extension.h"
1717

18-
// #define SWAP_DEBUG
18+
// D2H: Each thread block handles ALL layers for one swap block.
19+
// This produces perfectly contiguous host writes (1 block × all layers),
20+
// maximizing write-combining efficiency.
21+
template <typename T>
22+
__global__ void swap_d2h_kernel(T** __restrict__ layer_ptrs,
23+
T* __restrict__ cpu_buffer,
24+
const int64_t* __restrict__ gpu_block_ids,
25+
int n_blocks,
26+
int layer_num,
27+
int64_t block_stride) {
28+
int block_idx = blockIdx.x;
29+
if (block_idx >= n_blocks) return;
30+
31+
int64_t gpu_block = gpu_block_ids[block_idx];
32+
int64_t num_vec_per_layer = (block_stride * sizeof(T)) / sizeof(float4);
33+
34+
T* dst_base = cpu_buffer + (int64_t)block_idx * layer_num * block_stride;
35+
36+
for (int layer_idx = 0; layer_idx < layer_num; layer_idx++) {
37+
const T* src = layer_ptrs[layer_idx] + gpu_block * block_stride;
38+
float4* dst4 =
39+
reinterpret_cast<float4*>(dst_base + layer_idx * block_stride);
40+
const float4* src4 = reinterpret_cast<const float4*>(src);
41+
42+
for (int64_t i = threadIdx.x; i < num_vec_per_layer; i += blockDim.x) {
43+
dst4[i] = src4[i];
44+
}
45+
}
46+
}
47+
48+
// H2D: scatter from contiguous staging buffer to scattered GPU layer tensors
49+
template <typename T>
50+
__global__ void scatter_blocks_kernel(T** __restrict__ layer_ptrs,
51+
const T* __restrict__ staging,
52+
const int64_t* __restrict__ gpu_block_ids,
53+
int n_blocks,
54+
int layer_num,
55+
int64_t block_stride) {
56+
int pair_idx = blockIdx.x;
57+
int block_idx = pair_idx / layer_num;
58+
int layer_idx = pair_idx % layer_num;
59+
60+
if (block_idx >= n_blocks) return;
61+
62+
int64_t gpu_block = gpu_block_ids[block_idx];
63+
const T* src = staging + (int64_t)block_idx * layer_num * block_stride +
64+
layer_idx * block_stride;
65+
T* dst = layer_ptrs[layer_idx] + gpu_block * block_stride;
66+
67+
int64_t num_vec = (block_stride * sizeof(T)) / sizeof(float4);
68+
const float4* src4 = reinterpret_cast<const float4*>(src);
69+
float4* dst4 = reinterpret_cast<float4*>(dst);
70+
71+
for (int64_t i = threadIdx.x; i < num_vec; i += blockDim.x) {
72+
dst4[i] = src4[i];
73+
}
74+
}
75+
76+
static void* g_staging_buffer = nullptr;
77+
static size_t g_staging_buffer_size = 0;
78+
static void* g_device_block_ids = nullptr;
79+
static size_t g_device_block_ids_size = 0;
80+
static void* g_device_layer_ptrs = nullptr;
81+
static size_t g_device_layer_ptrs_size = 0;
82+
83+
static void ensure_staging_buffer(size_t required_size) {
84+
if (g_staging_buffer_size < required_size) {
85+
if (g_staging_buffer) cudaFree(g_staging_buffer);
86+
cudaError_t err = cudaMalloc(&g_staging_buffer, required_size);
87+
PADDLE_ENFORCE_EQ(
88+
err,
89+
cudaSuccess,
90+
phi::errors::External("cudaMalloc staging buffer failed: %s",
91+
cudaGetErrorString(err)));
92+
g_staging_buffer_size = required_size;
93+
}
94+
}
95+
96+
static void ensure_device_block_ids(size_t required_size) {
97+
if (g_device_block_ids_size < required_size) {
98+
if (g_device_block_ids) cudaFree(g_device_block_ids);
99+
cudaError_t err = cudaMalloc(&g_device_block_ids, required_size);
100+
PADDLE_ENFORCE_EQ(
101+
err,
102+
cudaSuccess,
103+
phi::errors::External("cudaMalloc device block_ids failed: %s",
104+
cudaGetErrorString(err)));
105+
g_device_block_ids_size = required_size;
106+
}
107+
}
108+
109+
static void ensure_device_layer_ptrs(size_t required_size) {
110+
if (g_device_layer_ptrs_size < required_size) {
111+
if (g_device_layer_ptrs) cudaFree(g_device_layer_ptrs);
112+
cudaError_t err = cudaMalloc(&g_device_layer_ptrs, required_size);
113+
PADDLE_ENFORCE_EQ(
114+
err,
115+
cudaSuccess,
116+
phi::errors::External("cudaMalloc device layer_ptrs failed: %s",
117+
cudaGetErrorString(err)));
118+
g_device_layer_ptrs_size = required_size;
119+
}
120+
}
121+
122+
static bool is_cpu_block_ids_sequential(
123+
const std::vector<int64_t>& cpu_block_ids) {
124+
if (cpu_block_ids.empty()) return true;
125+
int64_t start = cpu_block_ids[0];
126+
for (size_t i = 1; i < cpu_block_ids.size(); i++) {
127+
if (cpu_block_ids[i] != start + static_cast<int64_t>(i)) return false;
128+
}
129+
return true;
130+
}
19131

20132
template <paddle::DataType D>
21-
void SwapCacheImpLayout(
22-
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
23-
const int64_t& cache_cpu_pointer, // cpu
24-
const std::vector<int64_t>& cache_shape,
25-
const std::vector<int64_t>& gpu_block_ids,
26-
const std::vector<int64_t>& cpu_block_ids,
27-
int mode) {
28-
/*
29-
mode is 0: gpu to cpu; 1: cpu to gpu
30-
31-
cache layout: layer_num * [block_num, head_num, block_size, head_dim]
32-
scale layout: layer_num * [block_num, head_num, block_size]
33-
cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim]
34-
scale buffer layout: [block_num, layer_num, head_num, block_size]
35-
*/
133+
void SwapCacheImpLayout(const std::vector<paddle::Tensor>& cache_gpu_tensors,
134+
const int64_t& cache_cpu_pointer,
135+
const std::vector<int64_t>& cache_shape,
136+
const std::vector<int64_t>& gpu_block_ids,
137+
const std::vector<int64_t>& cpu_block_ids,
138+
int mode) {
36139
typedef PDTraits<D> traits_;
37140
typedef typename traits_::DataType DataType_;
38141
typedef typename traits_::data_t data_t;
39142

40143
const int64_t layer_number = cache_gpu_tensors.size();
41144
int64_t cache_block_stride = 1;
42-
for (int i = 1; i < cache_shape.size(); i++) {
145+
for (size_t i = 1; i < cache_shape.size(); i++) {
43146
cache_block_stride *= cache_shape[i];
44147
}
45148

149+
const int n_blocks = gpu_block_ids.size();
150+
if (n_blocks == 0) return;
151+
46152
auto stream = cache_gpu_tensors[0].stream();
47-
const cudaMemcpyKind copy_kind =
48-
(mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice;
49-
50-
for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) {
51-
const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx];
52-
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
53-
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
54-
55-
for (int block_idx = 0; block_idx < gpu_block_ids.size(); block_idx++) {
56-
auto cur_gpu_block_id = gpu_block_ids[block_idx];
57-
auto cur_cpu_block_id = cpu_block_ids[block_idx];
58-
auto* cache_gpu_ptr_now =
59-
cache_gpu_ptr + cur_gpu_block_id * cache_block_stride;
60-
auto* cache_cpu_ptr_now =
61-
cache_cpu_ptr + cur_cpu_block_id * cache_block_stride * layer_number +
62-
layer_idx * cache_block_stride;
63-
64-
cudaError_t status = cudaMemcpyAsync(
65-
(copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now
66-
: cache_gpu_ptr_now,
67-
(copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now
68-
: cache_cpu_ptr_now,
69-
cache_block_stride * sizeof(DataType_),
70-
copy_kind,
71-
stream);
153+
const size_t block_bytes = cache_block_stride * sizeof(DataType_);
154+
const size_t total_bytes = (size_t)n_blocks * layer_number * block_bytes;
155+
156+
bool use_optimized = is_cpu_block_ids_sequential(cpu_block_ids);
157+
158+
// float4 vectorized kernels require block_bytes to be 16-byte aligned
159+
// and cache_cpu_base to be 16-byte aligned for correct float4 access.
160+
if (use_optimized && (block_bytes % sizeof(float4) != 0)) {
161+
use_optimized = false;
162+
}
163+
if (use_optimized) {
164+
int64_t cpu_start_block = cpu_block_ids[0];
165+
uintptr_t cpu_base_addr =
166+
static_cast<uintptr_t>(cache_cpu_pointer) +
167+
cpu_start_block * layer_number * cache_block_stride * sizeof(DataType_);
168+
if (cpu_base_addr % sizeof(float4) != 0) {
169+
use_optimized = false;
170+
}
171+
}
72172

173+
if (use_optimized) {
174+
ensure_device_block_ids(n_blocks * sizeof(int64_t));
175+
ensure_device_layer_ptrs(layer_number * sizeof(DataType_*));
176+
177+
cudaError_t status = cudaMemcpyAsync(g_device_block_ids,
178+
gpu_block_ids.data(),
179+
n_blocks * sizeof(int64_t),
180+
cudaMemcpyHostToDevice,
181+
stream);
182+
PADDLE_ENFORCE_EQ(
183+
status,
184+
cudaSuccess,
185+
phi::errors::External("cudaMemcpyAsync block_ids H2D failed: %s",
186+
cudaGetErrorString(status)));
187+
188+
std::vector<DataType_*> h_layer_ptrs(layer_number);
189+
for (int64_t i = 0; i < layer_number; i++) {
190+
h_layer_ptrs[i] = reinterpret_cast<DataType_*>(
191+
const_cast<data_t*>(cache_gpu_tensors[i].data<data_t>()));
192+
}
193+
status = cudaMemcpyAsync(g_device_layer_ptrs,
194+
h_layer_ptrs.data(),
195+
layer_number * sizeof(DataType_*),
196+
cudaMemcpyHostToDevice,
197+
stream);
198+
PADDLE_ENFORCE_EQ(
199+
status,
200+
cudaSuccess,
201+
phi::errors::External("cudaMemcpyAsync layer_ptrs H2D failed: %s",
202+
cudaGetErrorString(status)));
203+
204+
int64_t cpu_start_block = cpu_block_ids[0];
205+
auto* cache_cpu_base = reinterpret_cast<DataType_*>(cache_cpu_pointer) +
206+
cpu_start_block * layer_number * cache_block_stride;
207+
208+
int grid_size = n_blocks * layer_number;
209+
210+
if (mode == 0) {
211+
// GPU→CPU: direct kernel write to pinned host memory
212+
// Multi-layer kernel: each block handles all layers for one swap block
213+
swap_d2h_kernel<DataType_><<<n_blocks, 512, 0, stream>>>(
214+
reinterpret_cast<DataType_**>(g_device_layer_ptrs),
215+
cache_cpu_base,
216+
reinterpret_cast<int64_t*>(g_device_block_ids),
217+
n_blocks,
218+
layer_number,
219+
cache_block_stride);
220+
} else {
221+
// CPU→GPU: DMA memcpy to staging then scatter kernel
222+
ensure_staging_buffer(total_bytes);
223+
224+
status = cudaMemcpyAsync(g_staging_buffer,
225+
cache_cpu_base,
226+
total_bytes,
227+
cudaMemcpyHostToDevice,
228+
stream);
73229
PADDLE_ENFORCE_EQ(status,
74230
cudaSuccess,
75-
phi::errors::External("cudaMemcpyAsync failed: %s",
231+
phi::errors::External("cudaMemcpyAsync H2D failed: %s",
76232
cudaGetErrorString(status)));
77233

78-
#ifdef SWAP_DEBUG
79-
cudaStreamSynchronize(stream);
80-
std::cout << "mode:" << mode << ", layer_idx:" << layer_idx
81-
<< ", block_idx:" << block_idx << ", cache_cpu_ptr_now data:"
82-
<< static_cast<float>(*cache_cpu_ptr_now) << std::endl;
83-
#endif
234+
scatter_blocks_kernel<DataType_><<<grid_size, 256, 0, stream>>>(
235+
reinterpret_cast<DataType_**>(g_device_layer_ptrs),
236+
reinterpret_cast<const DataType_*>(g_staging_buffer),
237+
reinterpret_cast<int64_t*>(g_device_block_ids),
238+
n_blocks,
239+
layer_number,
240+
cache_block_stride);
241+
}
242+
} else {
243+
const cudaMemcpyKind copy_kind =
244+
(mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice;
245+
for (int64_t layer_idx = 0; layer_idx < layer_number; layer_idx++) {
246+
const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx];
247+
data_t* cache_gpu_ptr = const_cast<data_t*>(cache_gpu.data<data_t>());
248+
auto* cache_cpu_ptr = reinterpret_cast<data_t*>(cache_cpu_pointer);
249+
250+
for (int block_idx = 0; block_idx < n_blocks; block_idx++) {
251+
auto cur_gpu_block_id = gpu_block_ids[block_idx];
252+
auto cur_cpu_block_id = cpu_block_ids[block_idx];
253+
auto* cache_gpu_ptr_now =
254+
cache_gpu_ptr + cur_gpu_block_id * cache_block_stride;
255+
auto* cache_cpu_ptr_now =
256+
cache_cpu_ptr +
257+
cur_cpu_block_id * cache_block_stride * layer_number +
258+
layer_idx * cache_block_stride;
259+
260+
cudaError_t status = cudaMemcpyAsync(
261+
(copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now
262+
: cache_gpu_ptr_now,
263+
(copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now
264+
: cache_cpu_ptr_now,
265+
block_bytes,
266+
copy_kind,
267+
stream);
268+
PADDLE_ENFORCE_EQ(status,
269+
cudaSuccess,
270+
phi::errors::External("cudaMemcpyAsync failed: %s",
271+
cudaGetErrorString(status)));
272+
}
84273
}
85274
}
275+
86276
cudaError_t sync_status = cudaStreamSynchronize(stream);
87277
PADDLE_ENFORCE_EQ(sync_status,
88278
cudaSuccess,
89279
phi::errors::External("cudaStreamSynchronize failed: %s",
90280
cudaGetErrorString(sync_status)));
91281
}
92282

93-
void SwapCacheLayout(
94-
const std::vector<paddle::Tensor>& cache_gpu_tensors, // gpu
95-
const int64_t& cache_cpu_ptrs, // cpu memory pointer
96-
const std::vector<int64_t>& cache_shape,
97-
const std::vector<int64_t>& gpu_block_ids,
98-
const std::vector<int64_t>& cpu_block_ids,
99-
int rank,
100-
int mode) {
101-
cudaSetDevice(rank); // used for distributed launch
283+
void SwapCacheLayout(const std::vector<paddle::Tensor>& cache_gpu_tensors,
284+
const int64_t& cache_cpu_ptrs,
285+
const std::vector<int64_t>& cache_shape,
286+
const std::vector<int64_t>& gpu_block_ids,
287+
const std::vector<int64_t>& cpu_block_ids,
288+
int rank,
289+
int mode) {
290+
cudaSetDevice(rank);
102291
assert(cache_gpu_tensors.size() > 0);
103292
switch (cache_gpu_tensors[0].dtype()) {
104293
case paddle::DataType::BFLOAT16:

examples/cache_storage/run_03b_pd_storage.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ metadata_port=15002
1818

1919
export MOONCAKE_MASTER_SERVER_ADDR="${master_ip}:${master_port}"
2020
export MOONCAKE_METADATA_SERVER="http://${master_ip}:${metadata_port}/metadata"
21-
export MOONCAKE_GLOBAL_SEGMENT_SIZE="50000000000"
21+
export MOONCAKE_GLOBAL_SEGMENT_SIZE="50000000000" # 50GB
2222
# export MOONCAKE_PROTOCOL="tcp"
2323
export MOONCAKE_PROTOCOL="rdma"
2424
# export MOONCAKE_RDMA_DEVICES="mlx5_0"

fastdeploy/cache_manager/cache_messager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ def prefill_layerwise_send_cache_thread(self):
705705
try:
706706
batch_engine_signals = self.cache_prefilled_engine_ids_queue.get()
707707
self.engine_worker_queue.begin_send_cache_barrier.wait()
708+
708709
block_start_end_list = []
709710
current_prefilled_token_num_list = []
710711
for engine_index, current_step_prefilled_token_num in batch_engine_signals:

0 commit comments

Comments
 (0)