Skip to content

Commit 554b37d

Browse files
committed
[Iluvatar] Refactor transpose and reverse_transpose
1 parent 165f827 commit 554b37d

3 files changed

Lines changed: 269 additions & 51 deletions

File tree

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
17+
template <typename T, int VecSize>
18+
__global__ void MixedReorderHiddenStatesKernel(
19+
const T* input,
20+
T* output,
21+
const int* seq_lens_encoder,
22+
const int* seq_lens_decoder,
23+
const int* seq_lens_this_time,
24+
const int prefill_num_tokens,
25+
const int hidden_dim,
26+
const bool reverse) {
27+
using LoadT = AlignedVector<T, VecSize>;
28+
29+
const int bid = blockIdx.x;
30+
const int seq_len = seq_lens_this_time[bid];
31+
if (seq_len <= 0) {
32+
return;
33+
}
34+
35+
const bool is_prefill = seq_lens_encoder[bid] > 0;
36+
const bool is_decode = !is_prefill && seq_lens_decoder[bid] > 0;
37+
if (!is_prefill && !is_decode) {
38+
return;
39+
}
40+
41+
int original_start = 0;
42+
int reordered_start = 0;
43+
int decode_rank = 0;
44+
for (int i = 0; i < bid; ++i) {
45+
original_start += seq_lens_this_time[i];
46+
if (seq_lens_encoder[i] > 0) {
47+
reordered_start += seq_lens_encoder[i];
48+
} else if (seq_lens_decoder[i] > 0) {
49+
++decode_rank;
50+
}
51+
}
52+
53+
int copy_seq_len = seq_len;
54+
if (is_decode) {
55+
reordered_start = prefill_num_tokens + decode_rank;
56+
// Decode mixed attention consumes one query token for each decode request.
57+
copy_seq_len = 1;
58+
}
59+
60+
for (int idx = threadIdx.x * VecSize; idx < copy_seq_len * hidden_dim;
61+
idx += blockDim.x * VecSize) {
62+
const int token_offset = idx / hidden_dim;
63+
const int hidden_offset = idx % hidden_dim;
64+
const int original_offset =
65+
(original_start + token_offset) * hidden_dim + hidden_offset;
66+
const int reordered_offset =
67+
(reordered_start + token_offset) * hidden_dim + hidden_offset;
68+
69+
LoadT src_vec;
70+
if (reverse) {
71+
Load<T, VecSize>(&input[reordered_offset], &src_vec);
72+
Store<T, VecSize>(src_vec, &output[original_offset]);
73+
} else {
74+
Load<T, VecSize>(&input[original_offset], &src_vec);
75+
Store<T, VecSize>(src_vec, &output[reordered_offset]);
76+
}
77+
}
78+
}
79+
80+
template <paddle::DataType D, int VecSize>
81+
void LaunchMixedReorderHiddenStates(
82+
const paddle::Tensor& hidden_states,
83+
const paddle::Tensor& seq_lens_encoder,
84+
const paddle::Tensor& seq_lens_decoder,
85+
const paddle::Tensor& seq_lens_this_time,
86+
const int prefill_num_tokens,
87+
const bool reverse,
88+
paddle::Tensor* out) {
89+
typedef PDTraits<D> traits_;
90+
typedef typename traits_::DataType DataType_;
91+
typedef typename traits_::data_t data_t;
92+
93+
auto dev_ctx = static_cast<const phi::CustomContext*>(
94+
paddle::experimental::DeviceContextPool::Instance().Get(
95+
hidden_states.place()));
96+
auto stream = dev_ctx->stream();
97+
98+
const auto hidden_shape = hidden_states.shape();
99+
const int hidden_dim = hidden_shape[1];
100+
const int max_num_seqs = seq_lens_this_time.shape()[0];
101+
const int block_size = 128;
102+
103+
MixedReorderHiddenStatesKernel<DataType_, VecSize>
104+
<<<max_num_seqs, block_size, 0, stream>>>(
105+
reinterpret_cast<const DataType_*>(hidden_states.data<data_t>()),
106+
reinterpret_cast<DataType_*>(out->data<data_t>()),
107+
seq_lens_encoder.data<int>(),
108+
seq_lens_decoder.data<int>(),
109+
seq_lens_this_time.data<int>(),
110+
prefill_num_tokens,
111+
hidden_dim,
112+
reverse);
113+
}
114+
115+
template <paddle::DataType D>
116+
paddle::Tensor MixedReorderHiddenStatesImpl(
117+
const paddle::Tensor& hidden_states,
118+
const paddle::Tensor& seq_lens_encoder,
119+
const paddle::Tensor& seq_lens_decoder,
120+
const paddle::Tensor& seq_lens_this_time,
121+
const int prefill_num_tokens,
122+
const bool reverse) {
123+
typedef PDTraits<D> traits_;
124+
typedef typename traits_::DataType DataType_;
125+
126+
const auto hidden_shape = hidden_states.shape();
127+
PADDLE_ENFORCE_EQ(
128+
hidden_shape.size(),
129+
2,
130+
common::errors::InvalidArgument(
131+
"hidden_states must be a 2-D tensor, but got %d dims.",
132+
hidden_shape.size()));
133+
134+
auto out = GetEmptyTensor(
135+
{hidden_shape[0], hidden_shape[1]}, hidden_states.dtype(), hidden_states.place());
136+
137+
constexpr int PackSize = VEC_16B / sizeof(DataType_);
138+
if (hidden_shape[1] % PackSize == 0) {
139+
LaunchMixedReorderHiddenStates<D, PackSize>(hidden_states,
140+
seq_lens_encoder,
141+
seq_lens_decoder,
142+
seq_lens_this_time,
143+
prefill_num_tokens,
144+
reverse,
145+
&out);
146+
} else {
147+
LaunchMixedReorderHiddenStates<D, 1>(hidden_states,
148+
seq_lens_encoder,
149+
seq_lens_decoder,
150+
seq_lens_this_time,
151+
prefill_num_tokens,
152+
reverse,
153+
&out);
154+
}
155+
156+
return out;
157+
}
158+
159+
paddle::Tensor MixedReorderHiddenStatesFunc(
160+
const paddle::Tensor& hidden_states,
161+
const paddle::Tensor& seq_lens_encoder,
162+
const paddle::Tensor& seq_lens_decoder,
163+
const paddle::Tensor& seq_lens_this_time,
164+
int prefill_num_tokens,
165+
bool reverse) {
166+
switch (hidden_states.type()) {
167+
case paddle::DataType::BFLOAT16: {
168+
return MixedReorderHiddenStatesImpl<paddle::DataType::BFLOAT16>(
169+
hidden_states,
170+
seq_lens_encoder,
171+
seq_lens_decoder,
172+
seq_lens_this_time,
173+
prefill_num_tokens,
174+
reverse);
175+
}
176+
case paddle::DataType::FLOAT16: {
177+
return MixedReorderHiddenStatesImpl<paddle::DataType::FLOAT16>(
178+
hidden_states,
179+
seq_lens_encoder,
180+
seq_lens_decoder,
181+
seq_lens_this_time,
182+
prefill_num_tokens,
183+
reverse);
184+
}
185+
case paddle::DataType::FLOAT32: {
186+
return MixedReorderHiddenStatesImpl<paddle::DataType::FLOAT32>(
187+
hidden_states,
188+
seq_lens_encoder,
189+
seq_lens_decoder,
190+
seq_lens_this_time,
191+
prefill_num_tokens,
192+
reverse);
193+
}
194+
default: {
195+
PD_THROW(
196+
"NOT supported data type. "
197+
"Only float16, bfloat16 and float32 are supported. ");
198+
}
199+
}
200+
}
201+
202+
std::vector<paddle::Tensor> MixedReorderHiddenStates(
203+
const paddle::Tensor& hidden_states,
204+
const paddle::Tensor& seq_lens_encoder,
205+
const paddle::Tensor& seq_lens_decoder,
206+
const paddle::Tensor& seq_lens_this_time,
207+
int prefill_num_tokens,
208+
bool reverse) {
209+
return {MixedReorderHiddenStatesFunc(hidden_states,
210+
seq_lens_encoder,
211+
seq_lens_decoder,
212+
seq_lens_this_time,
213+
prefill_num_tokens,
214+
reverse)};
215+
}
216+
217+
std::vector<std::vector<int64_t>> MixedReorderHiddenStatesInferShape(
218+
const std::vector<int64_t>& hidden_states_shape,
219+
const std::vector<int64_t>& seq_lens_encoder_shape,
220+
const std::vector<int64_t>& seq_lens_decoder_shape,
221+
const std::vector<int64_t>& seq_lens_this_time_shape) {
222+
return {hidden_states_shape};
223+
}
224+
225+
std::vector<paddle::DataType> MixedReorderHiddenStatesInferDtype(
226+
const paddle::DataType& hidden_states_dtype,
227+
const paddle::DataType& seq_lens_encoder_dtype,
228+
const paddle::DataType& seq_lens_decoder_dtype,
229+
const paddle::DataType& seq_lens_this_time_dtype) {
230+
return {hidden_states_dtype};
231+
}
232+
233+
PD_BUILD_STATIC_OP(mixed_reorder_hidden_states)
234+
.Inputs({"hidden_states",
235+
"seq_lens_encoder",
236+
"seq_lens_decoder",
237+
"seq_lens_this_time"})
238+
.Outputs({"out"})
239+
.Attrs({"prefill_num_tokens:int", "reverse:bool"})
240+
.SetKernelFn(PD_KERNEL(MixedReorderHiddenStates))
241+
.SetInferShapeFn(PD_INFER_SHAPE(MixedReorderHiddenStatesInferShape))
242+
.SetInferDtypeFn(PD_INFER_DTYPE(MixedReorderHiddenStatesInferDtype));

custom_ops/setup_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ def find_end_files(directory, end_str):
645645
"iluvatar_ops/paged_attn.cu",
646646
"iluvatar_ops/prefill_fused_attn.cu",
647647
"iluvatar_ops/mixed_fused_attn.cu",
648+
"iluvatar_ops/mixed_reorder_hidden_states.cu",
648649
"iluvatar_ops/w8a16_group_gemm.cu",
649650
"iluvatar_ops/w8a16_group_gemv.cu",
650651
"iluvatar_ops/wi4a16_group_gemm.cu",

fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from fastdeploy.model_executor.ops.iluvatar import (
3232
mixed_fused_paged_attention,
33+
mixed_reorder_hidden_states,
3334
paged_attention,
3435
prefill_fused_paged_attention,
3536
)
@@ -134,7 +135,12 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
134135
self.prefill_info_dict = {}
135136
self.decode_info_dict = {}
136137
prefill_batch_ids = paddle.where(forward_meta.seq_lens_encoder)[0]
137-
decode_batch_ids = paddle.where(forward_meta.seq_lens_decoder)[0]
138+
decode_batch_ids = paddle.where(
139+
paddle.logical_and(
140+
forward_meta.seq_lens_encoder == 0,
141+
forward_meta.seq_lens_decoder > 0,
142+
)
143+
)[0]
138144
self.prefill_len = len(prefill_batch_ids)
139145
self.decode_len = len(decode_batch_ids)
140146
# only prefill
@@ -169,50 +175,9 @@ def init_attention_metadata(self, forward_meta: ForwardMeta):
169175
self.attention_metadata.decode_seq_lens.copy_(forward_meta.seq_lens_decoder[decode_batch_ids] + 1)
170176
self.attention_metadata.decode_block_tables.copy_(forward_meta.block_tables[decode_batch_ids, :])
171177

172-
self.tmp_buffer = paddle.zeros(
173-
[self.prefill_num_tokens + self.decode_len, self.hidden_dim], dtype=self.dtype
174-
)
175-
prefill_start, decode_start, start = 0, self.prefill_num_tokens, 0
176-
non_zeros_ids = paddle.where(forward_meta.seq_lens_this_time)[0]
177-
non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids]
178-
end = non_zeros_seq_lens[0]
179-
if end > 1:
180-
last_stage = "prefill"
181-
prefill_end = end
182-
decode_end = decode_start
183-
else:
184-
last_stage = "decode"
185-
prefill_end = 0
186-
decode_end = decode_start + end
187-
188-
self.id_group = []
189-
self.reverse_id_group = []
190-
for seq_len in non_zeros_seq_lens[1:]:
191-
if seq_len > 1:
192-
if last_stage == "decode":
193-
self.id_group.append((decode_start, decode_end))
194-
self.reverse_id_group.append((start, end))
195-
decode_start = decode_end
196-
start = end
197-
last_stage = "prefill"
198-
prefill_end += seq_len
199-
end += seq_len
200-
else:
201-
if last_stage == "prefill":
202-
self.id_group.append((prefill_start, prefill_end))
203-
self.reverse_id_group.append((start, end))
204-
prefill_start = prefill_end
205-
start = end
206-
last_stage = "decode"
207-
decode_end += seq_len
208-
end += seq_len
209-
210-
if prefill_start < prefill_end:
211-
self.id_group.append((prefill_start, prefill_end))
212-
self.reverse_id_group.append((start, end))
213-
if decode_start < decode_end:
214-
self.id_group.append((decode_start, decode_end))
215-
self.reverse_id_group.append((start, end))
178+
self.seq_lens_encoder = forward_meta.seq_lens_encoder
179+
self.seq_lens_decoder = forward_meta.seq_lens_decoder
180+
self.seq_lens_this_time = forward_meta.seq_lens_this_time
216181

217182
def get_attention_meta(self):
218183
"""get_attention_meta"""
@@ -231,14 +196,24 @@ def get_kv_cache_shape(
231196
return key_cache_shape, value_cache_shape
232197

233198
def transpose(self, hidden_states):
234-
for ids, reverse_ids in zip(self.id_group, self.reverse_id_group):
235-
self.tmp_buffer[ids[0] : ids[1], :] = hidden_states[reverse_ids[0] : reverse_ids[1], :]
236-
return self.tmp_buffer
199+
return mixed_reorder_hidden_states(
200+
hidden_states,
201+
self.seq_lens_encoder,
202+
self.seq_lens_decoder,
203+
self.seq_lens_this_time,
204+
self.prefill_num_tokens,
205+
False,
206+
)
237207

238208
def reverse_transpose(self, hidden_states):
239-
for ids, reverse_ids in zip(self.id_group, self.reverse_id_group):
240-
self.tmp_buffer[reverse_ids[0] : reverse_ids[1], :] = hidden_states[ids[0] : ids[1], :]
241-
return self.tmp_buffer
209+
return mixed_reorder_hidden_states(
210+
hidden_states,
211+
self.seq_lens_encoder,
212+
self.seq_lens_decoder,
213+
self.seq_lens_this_time,
214+
self.prefill_num_tokens,
215+
True,
216+
)
242217

243218
def forward_mixed(
244219
self,

0 commit comments

Comments
 (0)