@@ -44,13 +44,49 @@ __global__ void FillEncoderDecoderResKernel(T *encoder_res_data,
4444 return ;
4545 }
4646
47- const int load_idx =
48- ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4 ;
47+ const int base_idx =
48+ ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim;
4949
50- *reinterpret_cast <float2 *>(encoder_res_data + load_idx) =
51- *reinterpret_cast <float2 *>(decoder_res_data + load_idx);
50+ if (head_dim == 128 ) {
51+ const int load_idx = base_idx + land_id * 4 ;
52+ *reinterpret_cast <float2 *>(encoder_res_data + load_idx) =
53+ *reinterpret_cast <float2 *>(decoder_res_data + load_idx);
54+ } else if (head_dim == 192 ) {
55+ const int load_idx = base_idx + land_id * 4 ;
56+ *reinterpret_cast <float2 *>(encoder_res_data + load_idx) =
57+ *reinterpret_cast <float2 *>(decoder_res_data + load_idx);
58+ if (land_id < 16 ) {
59+ *reinterpret_cast <float2 *>(encoder_res_data + load_idx + 128 ) =
60+ *reinterpret_cast <float2 *>(decoder_res_data + load_idx + 128 );
61+ }
62+ } else if (head_dim == 256 ) {
63+ // float4 = 单条LDG.128,性能最优
64+ const int load_idx = base_idx + land_id * 8 ;
65+ *reinterpret_cast <float4 *>(encoder_res_data + load_idx) =
66+ *reinterpret_cast <float4 *>(decoder_res_data + load_idx);
67+ }
5268}
5369
70+ #define LAUNCH_KERNEL (T, WARPS ) \
71+ FillEncoderDecoderResKernel<WARPS > \
72+ <<<grid_dims, head_dim, 0 , encoder_res.stream()>>> ( \
73+ const_cast <T *>(encoder_res.data<T>()), \
74+ const_cast <T *>(decoder_res.data<T>()), \
75+ seq_lens_encoder.data<int >(), \
76+ seq_lens_decoder.data<int >(), \
77+ seq_lens_this_time.data<int >(), \
78+ cu_seq_q.data<int >(), \
79+ head_num, \
80+ head_dim)
81+
82+ #define LAUNCH_KERNEL_BY_HEAD_DIM (T ) \
83+ if (head_dim == 128 ) \
84+ LAUNCH_KERNEL (T, 4 ); \
85+ else if (head_dim == 192 ) \
86+ LAUNCH_KERNEL (T, 6 ); \
87+ else if (head_dim == 256 ) \
88+ LAUNCH_KERNEL (T, 8 )
89+
5490void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res,
5591 const paddle::Tensor &decoder_res,
5692 const paddle::Tensor &seq_lens_encoder,
@@ -60,41 +96,20 @@ void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res,
6096 const int head_num,
6197 const int head_dim,
6298 const int max_token) {
63- if (head_dim != 128 ) {
64- PD_THROW (" Only supported head_dim = 128" );
99+ if (head_dim != 128 && head_dim != 192 && head_dim != 256 ) {
100+ PD_THROW (" Only supported head_dim = 128, 192 or 256 " );
65101 }
66102 const int batch_size = seq_lens_encoder.shape ()[0 ];
67- constexpr int warps = 4 ;
103+ const int warps = head_dim / 32 ;
68104 const int tokens_block = (max_token + warps - 1 ) / warps;
69- dim3 grid_dims;
70- grid_dims.x = batch_size;
71- grid_dims.y = head_num;
72- grid_dims.z = tokens_block;
105+ dim3 grid_dims (batch_size, head_num, tokens_block);
73106
74107 if (encoder_res.dtype () == paddle::DataType::FLOAT16 ) {
75108 using T = phi::dtype::float16;
76- FillEncoderDecoderResKernel<warps>
77- <<<grid_dims, 128 , 0 , encoder_res.stream()>>> (
78- const_cast <T *>(encoder_res.data <T>()),
79- const_cast <T *>(decoder_res.data <T>()),
80- seq_lens_encoder.data <int >(),
81- seq_lens_decoder.data <int >(),
82- seq_lens_this_time.data <int >(),
83- cu_seq_q.data <int >(),
84- head_num,
85- head_dim);
109+ LAUNCH_KERNEL_BY_HEAD_DIM (T);
86110 } else if (encoder_res.dtype () == paddle::DataType::BFLOAT16 ) {
87111 using T = phi::dtype::bfloat16;
88- FillEncoderDecoderResKernel<warps>
89- <<<grid_dims, 128 , 0 , encoder_res.stream()>>> (
90- const_cast <T *>(encoder_res.data <T>()),
91- const_cast <T *>(decoder_res.data <T>()),
92- seq_lens_encoder.data <int >(),
93- seq_lens_decoder.data <int >(),
94- seq_lens_this_time.data <int >(),
95- cu_seq_q.data <int >(),
96- head_num,
97- head_dim);
112+ LAUNCH_KERNEL_BY_HEAD_DIM (T);
98113 }
99114}
100115
0 commit comments