Skip to content

Commit 5719306

Browse files
committed
[Iluvatar] Refactor transpose and reverse_transpose
1 parent 165f827 commit 5719306

6 files changed

Lines changed: 296 additions & 64 deletions

File tree

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
17+
template <typename T, int VecSize>
18+
__global__ void MixedReorderHiddenStatesKernel(const T* input,
19+
T* output,
20+
const int* seq_lens_encoder,
21+
const int* seq_lens_decoder,
22+
const int* seq_lens_this_time,
23+
const int prefill_num_tokens,
24+
const int hidden_dim,
25+
const bool reverse) {
26+
using LoadT = AlignedVector<T, VecSize>;
27+
28+
const int bid = blockIdx.x;
29+
const int seq_len = seq_lens_this_time[bid];
30+
if (seq_len <= 0) {
31+
return;
32+
}
33+
34+
const bool is_prefill = seq_lens_encoder[bid] > 0;
35+
const bool is_decode = !is_prefill && seq_lens_decoder[bid] > 0;
36+
if (!is_prefill && !is_decode) {
37+
return;
38+
}
39+
40+
int original_start = 0;
41+
int reordered_start = 0;
42+
int decode_rank = 0;
43+
for (int i = 0; i < bid; ++i) {
44+
original_start += seq_lens_this_time[i];
45+
if (seq_lens_encoder[i] > 0) {
46+
reordered_start += seq_lens_encoder[i];
47+
} else if (seq_lens_decoder[i] > 0) {
48+
++decode_rank;
49+
}
50+
}
51+
52+
int copy_seq_len = seq_len;
53+
if (is_decode) {
54+
reordered_start = prefill_num_tokens + decode_rank;
55+
// Decode mixed attention consumes one query token for each decode request.
56+
copy_seq_len = 1;
57+
}
58+
59+
for (int idx = threadIdx.x * VecSize; idx < copy_seq_len * hidden_dim;
60+
idx += blockDim.x * VecSize) {
61+
const int token_offset = idx / hidden_dim;
62+
const int hidden_offset = idx % hidden_dim;
63+
const int original_offset =
64+
(original_start + token_offset) * hidden_dim + hidden_offset;
65+
const int reordered_offset =
66+
(reordered_start + token_offset) * hidden_dim + hidden_offset;
67+
68+
LoadT src_vec;
69+
if (reverse) {
70+
Load<T, VecSize>(&input[reordered_offset], &src_vec);
71+
Store<T, VecSize>(src_vec, &output[original_offset]);
72+
} else {
73+
Load<T, VecSize>(&input[original_offset], &src_vec);
74+
Store<T, VecSize>(src_vec, &output[reordered_offset]);
75+
}
76+
}
77+
}
78+
79+
template <paddle::DataType D, int VecSize>
80+
void LaunchMixedReorderHiddenStates(const paddle::Tensor& hidden_states,
81+
const paddle::Tensor& seq_lens_encoder,
82+
const paddle::Tensor& seq_lens_decoder,
83+
const paddle::Tensor& seq_lens_this_time,
84+
const int prefill_num_tokens,
85+
const bool reverse,
86+
paddle::Tensor* out) {
87+
typedef PDTraits<D> traits_;
88+
typedef typename traits_::DataType DataType_;
89+
typedef typename traits_::data_t data_t;
90+
91+
auto dev_ctx = static_cast<const phi::CustomContext*>(
92+
paddle::experimental::DeviceContextPool::Instance().Get(
93+
hidden_states.place()));
94+
auto stream = dev_ctx->stream();
95+
96+
const auto hidden_shape = hidden_states.shape();
97+
const int hidden_dim = hidden_shape[1];
98+
const int max_num_seqs = seq_lens_this_time.shape()[0];
99+
const int block_size = 128;
100+
101+
MixedReorderHiddenStatesKernel<DataType_, VecSize>
102+
<<<max_num_seqs, block_size, 0, stream>>>(
103+
reinterpret_cast<const DataType_*>(hidden_states.data<data_t>()),
104+
reinterpret_cast<DataType_*>(out->data<data_t>()),
105+
seq_lens_encoder.data<int>(),
106+
seq_lens_decoder.data<int>(),
107+
seq_lens_this_time.data<int>(),
108+
prefill_num_tokens,
109+
hidden_dim,
110+
reverse);
111+
}
112+
113+
template <paddle::DataType D>
114+
paddle::Tensor MixedReorderHiddenStatesImpl(
115+
const paddle::Tensor& hidden_states,
116+
const paddle::Tensor& seq_lens_encoder,
117+
const paddle::Tensor& seq_lens_decoder,
118+
const paddle::Tensor& seq_lens_this_time,
119+
const int prefill_num_tokens,
120+
const bool reverse) {
121+
typedef PDTraits<D> traits_;
122+
typedef typename traits_::DataType DataType_;
123+
124+
const auto hidden_shape = hidden_states.shape();
125+
PADDLE_ENFORCE_EQ(hidden_shape.size(),
126+
2,
127+
common::errors::InvalidArgument(
128+
"hidden_states must be a 2-D tensor, but got %d dims.",
129+
hidden_shape.size()));
130+
131+
auto out = GetEmptyTensor({hidden_shape[0], hidden_shape[1]},
132+
hidden_states.dtype(),
133+
hidden_states.place());
134+
135+
constexpr int PackSize = VEC_16B / sizeof(DataType_);
136+
if (hidden_shape[1] % PackSize == 0) {
137+
LaunchMixedReorderHiddenStates<D, PackSize>(hidden_states,
138+
seq_lens_encoder,
139+
seq_lens_decoder,
140+
seq_lens_this_time,
141+
prefill_num_tokens,
142+
reverse,
143+
&out);
144+
} else {
145+
LaunchMixedReorderHiddenStates<D, 1>(hidden_states,
146+
seq_lens_encoder,
147+
seq_lens_decoder,
148+
seq_lens_this_time,
149+
prefill_num_tokens,
150+
reverse,
151+
&out);
152+
}
153+
154+
return out;
155+
}
156+
157+
paddle::Tensor MixedReorderHiddenStatesFunc(
158+
const paddle::Tensor& hidden_states,
159+
const paddle::Tensor& seq_lens_encoder,
160+
const paddle::Tensor& seq_lens_decoder,
161+
const paddle::Tensor& seq_lens_this_time,
162+
int prefill_num_tokens,
163+
bool reverse) {
164+
switch (hidden_states.type()) {
165+
case paddle::DataType::BFLOAT16: {
166+
return MixedReorderHiddenStatesImpl<paddle::DataType::BFLOAT16>(
167+
hidden_states,
168+
seq_lens_encoder,
169+
seq_lens_decoder,
170+
seq_lens_this_time,
171+
prefill_num_tokens,
172+
reverse);
173+
}
174+
case paddle::DataType::FLOAT16: {
175+
return MixedReorderHiddenStatesImpl<paddle::DataType::FLOAT16>(
176+
hidden_states,
177+
seq_lens_encoder,
178+
seq_lens_decoder,
179+
seq_lens_this_time,
180+
prefill_num_tokens,
181+
reverse);
182+
}
183+
case paddle::DataType::FLOAT32: {
184+
return MixedReorderHiddenStatesImpl<paddle::DataType::FLOAT32>(
185+
hidden_states,
186+
seq_lens_encoder,
187+
seq_lens_decoder,
188+
seq_lens_this_time,
189+
prefill_num_tokens,
190+
reverse);
191+
}
192+
default: {
193+
PD_THROW(
194+
"NOT supported data type. "
195+
"Only float16, bfloat16 and float32 are supported. ");
196+
}
197+
}
198+
}
199+
200+
std::vector<paddle::Tensor> MixedReorderHiddenStates(
201+
const paddle::Tensor& hidden_states,
202+
const paddle::Tensor& seq_lens_encoder,
203+
const paddle::Tensor& seq_lens_decoder,
204+
const paddle::Tensor& seq_lens_this_time,
205+
int prefill_num_tokens,
206+
bool reverse) {
207+
return {MixedReorderHiddenStatesFunc(hidden_states,
208+
seq_lens_encoder,
209+
seq_lens_decoder,
210+
seq_lens_this_time,
211+
prefill_num_tokens,
212+
reverse)};
213+
}
214+
215+
std::vector<std::vector<int64_t>> MixedReorderHiddenStatesInferShape(
216+
const std::vector<int64_t>& hidden_states_shape,
217+
const std::vector<int64_t>& seq_lens_encoder_shape,
218+
const std::vector<int64_t>& seq_lens_decoder_shape,
219+
const std::vector<int64_t>& seq_lens_this_time_shape) {
220+
return {hidden_states_shape};
221+
}
222+
223+
std::vector<paddle::DataType> MixedReorderHiddenStatesInferDtype(
224+
const paddle::DataType& hidden_states_dtype,
225+
const paddle::DataType& seq_lens_encoder_dtype,
226+
const paddle::DataType& seq_lens_decoder_dtype,
227+
const paddle::DataType& seq_lens_this_time_dtype) {
228+
return {hidden_states_dtype};
229+
}
230+
231+
PD_BUILD_STATIC_OP(mixed_reorder_hidden_states)
232+
.Inputs({"hidden_states",
233+
"seq_lens_encoder",
234+
"seq_lens_decoder",
235+
"seq_lens_this_time"})
236+
.Outputs({"out"})
237+
.Attrs({"prefill_num_tokens:int", "reverse:bool"})
238+
.SetKernelFn(PD_KERNEL(MixedReorderHiddenStates))
239+
.SetInferShapeFn(PD_INFER_SHAPE(MixedReorderHiddenStatesInferShape))
240+
.SetInferDtypeFn(PD_INFER_DTYPE(MixedReorderHiddenStatesInferDtype));

custom_ops/setup_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ def find_end_files(directory, end_str):
645645
"iluvatar_ops/paged_attn.cu",
646646
"iluvatar_ops/prefill_fused_attn.cu",
647647
"iluvatar_ops/mixed_fused_attn.cu",
648+
"iluvatar_ops/mixed_reorder_hidden_states.cu",
648649
"iluvatar_ops/w8a16_group_gemm.cu",
649650
"iluvatar_ops/w8a16_group_gemv.cu",
650651
"iluvatar_ops/wi4a16_group_gemm.cu",

docs/get_started/installation/iluvatar_gpu.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,9 @@ python3 -m fastdeploy.entrypoints.openai.api_server \
476476
--model /data1/fastdeploy/PaddleOCR-VL \
477477
--port 8180 \
478478
--metrics-port 8471 \
479-
--max-model-len 16384 \
480-
--max-num-batched-tokens 16384 \
481-
--max-num-seqs 240 \
479+
--max-model-len 4096 \
480+
--max-num-batched-tokens 4096 \
481+
--max-num-seqs 256 \
482482
--block-size 16 \
483483
--workers 2 \
484484
--gpu-memory-utilization 0.7 \
@@ -552,9 +552,9 @@ python3 -m fastdeploy.entrypoints.openai.api_server \
552552
--model /data1/fastdeploy/PaddleOCR-VL-1.6 \
553553
--port 8180 \
554554
--metrics-port 8471 \
555-
--max-model-len 16384 \
556-
--max-num-batched-tokens 16384 \
557-
--max-num-seqs 240 \
555+
--max-model-len 4096 \
556+
--max-num-batched-tokens 4096 \
557+
--max-num-seqs 256 \
558558
--block-size 16 \
559559
--workers 2 \
560560
--gpu-memory-utilization 0.7 \

docs/zh/get_started/installation/iluvatar_gpu.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,9 @@ python3 -m fastdeploy.entrypoints.openai.api_server \
476476
--model /data1/fastdeploy/PaddleOCR-VL \
477477
--port 8180 \
478478
--metrics-port 8471 \
479-
--max-model-len 16384 \
480-
--max-num-batched-tokens 16384 \
481-
--max-num-seqs 240 \
479+
--max-model-len 4096 \
480+
--max-num-batched-tokens 4096 \
481+
--max-num-seqs 256 \
482482
--block-size 16 \
483483
--workers 2 \
484484
--gpu-memory-utilization 0.7 \
@@ -549,9 +549,9 @@ python3 -m fastdeploy.entrypoints.openai.api_server \
549549
--model /data1/fastdeploy/PaddleOCR-VL-1.6 \
550550
--port 8180 \
551551
--metrics-port 8471 \
552-
--max-model-len 16384 \
553-
--max-num-batched-tokens 16384 \
554-
--max-num-seqs 240 \
552+
--max-model-len 4096 \
553+
--max-num-batched-tokens 4096 \
554+
--max-num-seqs 256 \
555555
--block-size 16 \
556556
--workers 2 \
557557
--gpu-memory-utilization 0.7 \

0 commit comments

Comments
 (0)