Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions custom_ops/iluvatar_ops/mixed_reorder_hidden_states.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template <typename T, int VecSize>
__global__ void MixedReorderHiddenStatesKernel(const T* input,
T* output,
const int* seq_lens_encoder,
const int* seq_lens_decoder,
const int* seq_lens_this_time,
const int prefill_num_tokens,
const int hidden_dim,
const bool reverse) {
using LoadT = AlignedVector<T, VecSize>;

const int bid = blockIdx.x;
const int seq_len = seq_lens_this_time[bid];
if (seq_len <= 0) {
return;
}

const bool is_prefill = seq_lens_encoder[bid] > 0;
const bool is_decode = !is_prefill && seq_lens_decoder[bid] > 0;
if (!is_prefill && !is_decode) {
return;
}

int original_start = 0;
int reordered_start = 0;
int decode_rank = 0;
for (int i = 0; i < bid; ++i) {
original_start += seq_lens_this_time[i];
if (seq_lens_encoder[i] > 0) {
reordered_start += seq_lens_encoder[i];
} else if (seq_lens_decoder[i] > 0) {
++decode_rank;
}
}

int copy_seq_len = seq_len;
if (is_decode) {
reordered_start = prefill_num_tokens + decode_rank;
// Decode mixed attention consumes one query token for each decode request.
copy_seq_len = 1;

This comment was marked as outdated.

}

for (int idx = threadIdx.x * VecSize; idx < copy_seq_len * hidden_dim;
idx += blockDim.x * VecSize) {
const int token_offset = idx / hidden_dim;
const int hidden_offset = idx % hidden_dim;
const int original_offset =
(original_start + token_offset) * hidden_dim + hidden_offset;
const int reordered_offset =
(reordered_start + token_offset) * hidden_dim + hidden_offset;

LoadT src_vec;
if (reverse) {
Load<T, VecSize>(&input[reordered_offset], &src_vec);
Store<T, VecSize>(src_vec, &output[original_offset]);
} else {
Load<T, VecSize>(&input[original_offset], &src_vec);
Store<T, VecSize>(src_vec, &output[reordered_offset]);
}
}
}

template <paddle::DataType D, int VecSize>
void LaunchMixedReorderHiddenStates(const paddle::Tensor& hidden_states,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int prefill_num_tokens,
const bool reverse,
paddle::Tensor* out) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(
hidden_states.place()));
auto stream = dev_ctx->stream();

const auto hidden_shape = hidden_states.shape();
const int hidden_dim = hidden_shape[1];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 新增重排 kernel 把 hidden_shape[1]seq_lens_this_time.shape()[0] 收窄成 int,后续 copy_seq_len * hidden_dimoriginal_offsetreordered_offset 也全部用 32 位索引计算。

这些值来自 tensor shape 和调度 token 数,当前没有任何 <= INT_MAX / 乘法溢出校验;一旦 max_num_batched_tokens * hidden_dim 或 offset 超过 32 位范围,索引会回绕并读写错误位置。FastDeploy checklist 对 custom op 的 shape/offset 承接要求使用 int64_t

建议修复方式:
将 host 侧 hidden_dim/max_num_seqs 和 kernel 内 original_startreordered_startidx、offset 相关变量改为 int64_t 计算;如果 Iluvatar kernel 必须使用 32 位索引,则在 launch 前显式校验 hidden_shape[0] * hidden_shape[1] <= INT_MAXhidden_shape[1] <= INT_MAX 并给出清晰报错。

const int max_num_seqs = seq_lens_this_time.shape()[0];
const int block_size = 128;

MixedReorderHiddenStatesKernel<DataType_, VecSize>
<<<max_num_seqs, block_size, 0, stream>>>(
reinterpret_cast<const DataType_*>(hidden_states.data<data_t>()),
reinterpret_cast<DataType_*>(out->data<data_t>()),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
prefill_num_tokens,
hidden_dim,
reverse);
}

template <paddle::DataType D>
paddle::Tensor MixedReorderHiddenStatesImpl(
const paddle::Tensor& hidden_states,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int prefill_num_tokens,
const bool reverse) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;

const auto hidden_shape = hidden_states.shape();
PADDLE_ENFORCE_EQ(hidden_shape.size(),
2,
common::errors::InvalidArgument(
"hidden_states must be a 2-D tensor, but got %d dims.",
hidden_shape.size()));

auto out = GetEmptyTensor({hidden_shape[0], hidden_shape[1]},
hidden_states.dtype(),
hidden_states.place());

constexpr int PackSize = VEC_16B / sizeof(DataType_);
if (hidden_shape[1] % PackSize == 0) {
LaunchMixedReorderHiddenStates<D, PackSize>(hidden_states,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
prefill_num_tokens,
reverse,
&out);
} else {
LaunchMixedReorderHiddenStates<D, 1>(hidden_states,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
prefill_num_tokens,
reverse,
&out);
}

return out;
}

paddle::Tensor MixedReorderHiddenStatesFunc(
const paddle::Tensor& hidden_states,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
int prefill_num_tokens,
bool reverse) {
switch (hidden_states.type()) {
case paddle::DataType::BFLOAT16: {
return MixedReorderHiddenStatesImpl<paddle::DataType::BFLOAT16>(
hidden_states,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
prefill_num_tokens,
reverse);
}
case paddle::DataType::FLOAT16: {
return MixedReorderHiddenStatesImpl<paddle::DataType::FLOAT16>(
hidden_states,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
prefill_num_tokens,
reverse);
}
case paddle::DataType::FLOAT32: {
return MixedReorderHiddenStatesImpl<paddle::DataType::FLOAT32>(
hidden_states,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
prefill_num_tokens,
reverse);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
}
}
}

std::vector<paddle::Tensor> MixedReorderHiddenStates(
const paddle::Tensor& hidden_states,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
int prefill_num_tokens,
bool reverse) {
return {MixedReorderHiddenStatesFunc(hidden_states,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
prefill_num_tokens,
reverse)};
}

std::vector<std::vector<int64_t>> MixedReorderHiddenStatesInferShape(
const std::vector<int64_t>& hidden_states_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape) {
return {hidden_states_shape};
}

std::vector<paddle::DataType> MixedReorderHiddenStatesInferDtype(
const paddle::DataType& hidden_states_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype) {
return {hidden_states_dtype};
}

PD_BUILD_STATIC_OP(mixed_reorder_hidden_states)
.Inputs({"hidden_states",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time"})
.Outputs({"out"})
.Attrs({"prefill_num_tokens:int", "reverse:bool"})
.SetKernelFn(PD_KERNEL(MixedReorderHiddenStates))
.SetInferShapeFn(PD_INFER_SHAPE(MixedReorderHiddenStatesInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MixedReorderHiddenStatesInferDtype));
1 change: 1 addition & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def find_end_files(directory, end_str):
"iluvatar_ops/paged_attn.cu",
"iluvatar_ops/prefill_fused_attn.cu",
"iluvatar_ops/mixed_fused_attn.cu",
"iluvatar_ops/mixed_reorder_hidden_states.cu",
"iluvatar_ops/w8a16_group_gemm.cu",
"iluvatar_ops/w8a16_group_gemv.cu",
"iluvatar_ops/wi4a16_group_gemm.cu",
Expand Down
12 changes: 6 additions & 6 deletions docs/get_started/installation/iluvatar_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,9 @@ python3 -m fastdeploy.entrypoints.openai.api_server \
--model /data1/fastdeploy/PaddleOCR-VL \
--port 8180 \
--metrics-port 8471 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--max-num-seqs 240 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--max-num-seqs 256 \
--block-size 16 \
--workers 2 \
--gpu-memory-utilization 0.7 \
Expand Down Expand Up @@ -552,9 +552,9 @@ python3 -m fastdeploy.entrypoints.openai.api_server \
--model /data1/fastdeploy/PaddleOCR-VL-1.6 \
--port 8180 \
--metrics-port 8471 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--max-num-seqs 240 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--max-num-seqs 256 \
--block-size 16 \
--workers 2 \
--gpu-memory-utilization 0.7 \
Expand Down
12 changes: 6 additions & 6 deletions docs/zh/get_started/installation/iluvatar_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,9 @@ python3 -m fastdeploy.entrypoints.openai.api_server \
--model /data1/fastdeploy/PaddleOCR-VL \
--port 8180 \
--metrics-port 8471 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--max-num-seqs 240 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--max-num-seqs 256 \
--block-size 16 \
--workers 2 \
--gpu-memory-utilization 0.7 \
Expand Down Expand Up @@ -549,9 +549,9 @@ python3 -m fastdeploy.entrypoints.openai.api_server \
--model /data1/fastdeploy/PaddleOCR-VL-1.6 \
--port 8180 \
--metrics-port 8471 \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--max-num-seqs 240 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--max-num-seqs 256 \
--block-size 16 \
--workers 2 \
--gpu-memory-utilization 0.7 \
Expand Down
Loading
Loading